OSDN Git Service

Switch LocalSocket to android::base::{Send,Receive}FileDescriptorVector.
authorJosh Gao <jmgao@google.com>
Mon, 11 Feb 2019 22:37:21 +0000 (14:37 -0800)
committerJosh Gao <jmgao@google.com>
Tue, 26 Feb 2019 07:21:23 +0000 (23:21 -0800)
The previous implementation allocated an array of size
CMSG_SPACE(count) to store CMSG_LEN(count * sizeof(int)) elements, which
leads to bad things happening for values of count greater than 1 on
32-bit, and 2 on 64-bit.

Test: atest android.net.LocalSocketTest
Test: atest android.net.cts.LocalSocketTest
Change-Id: I0a9502c3358d8fa92d2d20e344c6270d6baedc07

core/jni/android_net_LocalSocketImpl.cpp

index a1f2377..1163b86 100644 (file)
 #include <unistd.h>
 #include <sys/ioctl.h>
 
+#include <android-base/cmsg.h>
+#include <android-base/macros.h>
 #include <cutils/sockets.h>
 #include <netinet/tcp.h>
 #include <nativehelper/ScopedUtfChars.h>
 
-namespace android {
+using android::base::ReceiveFileDescriptorVector;
+using android::base::SendFileDescriptorVector;
 
-template <typename T>
-void UNUSED(T t) {}
+namespace android {
 
 static jfieldID field_inboundFileDescriptors;
 static jfieldID field_outboundFileDescriptors;
@@ -118,67 +120,6 @@ socket_bind_local (JNIEnv *env, jobject object, jobject fileDescriptor,
 }
 
 /**
- * Processes ancillary data, handling only
- * SCM_RIGHTS. Creates appropriate objects and sets appropriate
- * fields in the LocalSocketImpl object. Returns 0 on success
- * or -1 if an exception was thrown.
- */
-static int socket_process_cmsg(JNIEnv *env, jobject thisJ, struct msghdr * pMsg)
-{
-    struct cmsghdr *cmsgptr;
-
-    for (cmsgptr = CMSG_FIRSTHDR(pMsg);
-            cmsgptr != NULL; cmsgptr = CMSG_NXTHDR(pMsg, cmsgptr)) {
-
-        if (cmsgptr->cmsg_level != SOL_SOCKET) {
-            continue;
-        }
-
-        if (cmsgptr->cmsg_type == SCM_RIGHTS) {
-            int *pDescriptors = (int *)CMSG_DATA(cmsgptr);
-            jobjectArray fdArray;
-            int count
-                = ((cmsgptr->cmsg_len - CMSG_LEN(0)) / sizeof(int));
-
-            if (count < 0) {
-                jniThrowException(env, "java/io/IOException",
-                    "invalid cmsg length");
-                return -1;
-            }
-
-            fdArray = env->NewObjectArray(count, class_FileDescriptor, NULL);
-
-            if (fdArray == NULL) {
-                return -1;
-            }
-
-            for (int i = 0; i < count; i++) {
-                jobject fdObject
-                        = jniCreateFileDescriptor(env, pDescriptors[i]);
-
-                if (env->ExceptionCheck()) {
-                    return -1;
-                }
-
-                env->SetObjectArrayElement(fdArray, i, fdObject);
-
-                if (env->ExceptionCheck()) {
-                    return -1;
-                }
-            }
-
-            env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
-
-            if (env->ExceptionCheck()) {
-                return -1;
-            }
-        }
-    }
-
-    return 0;
-}
-
-/**
  * Reads data from a socket into buf, processing any ancillary data
  * and adding it to thisJ.
  *
@@ -189,47 +130,48 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
         void *buffer, size_t len)
 {
     ssize_t ret;
-    struct msghdr msg;
-    struct iovec iv;
-    unsigned char *buf = (unsigned char *)buffer;
-    // Enough buffer for a pile of fd's. We throw an exception if
-    // this buffer is too small.
-    struct cmsghdr cmsgbuf[2*sizeof(cmsghdr) + 0x100];
-
-    memset(&msg, 0, sizeof(msg));
-    memset(&iv, 0, sizeof(iv));
-
-    iv.iov_base = buf;
-    iv.iov_len = len;
+    std::vector<android::base::unique_fd> received_fds;
 
-    msg.msg_iov = &iv;
-    msg.msg_iovlen = 1;
-    msg.msg_control = cmsgbuf;
-    msg.msg_controllen = sizeof(cmsgbuf);
-
-    ret = TEMP_FAILURE_RETRY(recvmsg(fd, &msg, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC));
-
-    if (ret < 0 && errno == EPIPE) {
-        // Treat this as an end of stream
-        return 0;
-    }
+    ret = ReceiveFileDescriptorVector(fd, buffer, len, 64, &received_fds);
 
     if (ret < 0) {
+        if (errno == EPIPE) {
+            // Treat this as an end of stream
+            return 0;
+        }
+
         jniThrowIOException(env, errno);
         return -1;
     }
 
-    if ((msg.msg_flags & (MSG_CTRUNC | MSG_OOB | MSG_ERRQUEUE)) != 0) {
-        // To us, any of the above flags are a fatal error
+    if (received_fds.size() > 0) {
+        jobjectArray fdArray = env->NewObjectArray(received_fds.size(), class_FileDescriptor, NULL);
+
+        if (fdArray == NULL) {
+            // NewObjectArray has thrown.
+            return -1;
+        }
 
-        jniThrowException(env, "java/io/IOException",
-                "Unexpected error or truncation during recvmsg()");
+        for (size_t i = 0; i < received_fds.size(); i++) {
+            jobject fdObject = jniCreateFileDescriptor(env, received_fds[i].get());
 
-        return -1;
-    }
+            if (env->ExceptionCheck()) {
+                return -1;
+            }
+
+            env->SetObjectArrayElement(fdArray, i, fdObject);
+
+            if (env->ExceptionCheck()) {
+                return -1;
+            }
+        }
 
-    if (ret >= 0) {
-        socket_process_cmsg(env, thisJ, &msg);
+        for (auto &fd : received_fds) {
+            // The fds are stored in java.io.FileDescriptors now.
+            static_cast<void>(fd.release());
+        }
+
+        env->SetObjectField(thisJ, field_inboundFileDescriptors, fdArray);
     }
 
     return ret;
@@ -243,7 +185,6 @@ static ssize_t socket_read_all(JNIEnv *env, jobject thisJ, int fd,
 static int socket_write_all(JNIEnv *env, jobject object, int fd,
         void *buf, size_t len)
 {
-    ssize_t ret;
     struct msghdr msg;
     unsigned char *buffer = (unsigned char *)buf;
     memset(&msg, 0, sizeof(msg));
@@ -256,14 +197,11 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
         return -1;
     }
 
-    struct cmsghdr *cmsg;
     int countFds = outboundFds == NULL ? 0 : env->GetArrayLength(outboundFds);
-    int fds[countFds];
-    char msgbuf[CMSG_SPACE(countFds)];
+    std::vector<int> fds;
 
     // Add any pending outbound file descriptors to the message
     if (outboundFds != NULL) {
-
         if (env->ExceptionCheck()) {
             return -1;
         }
@@ -274,47 +212,25 @@ static int socket_write_all(JNIEnv *env, jobject object, int fd,
                 return -1;
             }
 
-            fds[i] = jniGetFDFromFileDescriptor(env, fdObject);
+            fds.push_back(jniGetFDFromFileDescriptor(env, fdObject));
             if (env->ExceptionCheck()) {
                 return -1;
             }
         }
-
-        // See "man cmsg" really
-        msg.msg_control = msgbuf;
-        msg.msg_controllen = sizeof msgbuf;
-        cmsg = CMSG_FIRSTHDR(&msg);
-        cmsg->cmsg_level = SOL_SOCKET;
-        cmsg->cmsg_type = SCM_RIGHTS;
-        cmsg->cmsg_len = CMSG_LEN(sizeof fds);
-        memcpy(CMSG_DATA(cmsg), fds, sizeof fds);
     }
 
-    // We only write our msg_control during the first write
-    while (len > 0) {
-        struct iovec iv;
-        memset(&iv, 0, sizeof(iv));
-
-        iv.iov_base = buffer;
-        iv.iov_len = len;
-
-        msg.msg_iov = &iv;
-        msg.msg_iovlen = 1;
-
-        do {
-            ret = sendmsg(fd, &msg, MSG_NOSIGNAL);
-        } while (ret < 0 && errno == EINTR);
+    ssize_t rc = SendFileDescriptorVector(fd, buffer, len, fds);
 
-        if (ret < 0) {
+    while (rc != len) {
+        if (rc == -1) {
             jniThrowIOException(env, errno);
             return -1;
         }
 
-        buffer += ret;
-        len -= ret;
+        buffer += rc;
+        len -= rc;
 
-        // Wipes out any msg_control too
-        memset(&msg, 0, sizeof(msg));
+        rc = send(fd, buffer, len, MSG_NOSIGNAL);
     }
 
     return 0;