OSDN Git Service

Fix heap corruption in nio select(2) code.
authorElliott Hughes <enh@google.com>
Fri, 18 Sep 2009 01:32:07 +0000 (18:32 -0700)
committerElliott Hughes <enh@google.com>
Fri, 18 Sep 2009 18:52:23 +0000 (11:52 -0700)
The active ingredient in this change is that we now test that the fd isn't -1,
used to represent an invalid fd. There's a race condition where a socket can be
closed between SelectorImpl.prepareChannels and the native code. This caused us
to write to the -1th element of a heap-allocated structure, leading to SIGSEGV.

I've also removed the check for an empty fd_set. It was broken before and will
never have fired, but I don't think it makes sense to fix it, given this race
condition.

The race can't be fixed because the implementation is documented to close the
socket channel and *then* cancel the selection key.

This patch also removes various dead functions and tidies up timeval usage.

Bug: 2093094

libcore/luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp

index 1b42a45..afa0dfc 100644 (file)
@@ -170,8 +170,6 @@ struct CachedFields {
     jclass byte_class;
     jmethodID byte_class_init;
     jfieldID byte_class_value;
-    jclass string_class;
-    jmethodID string_class_init;
     jclass socketimpl_class;
     jfieldID socketimpl_address;
     jfieldID socketimpl_port;
@@ -424,22 +422,17 @@ jobject newJavaLangInteger(JNIEnv * env, jint anInt) {
     return env->NewObject(tempClass, tempMethod, anInt);
 }
 
-/**
- * Answer a new java.lang.String object.
- *
- * @param env   pointer to the JNI library
- * @param anInt the byte[] constructor argument
- *
- * @return  the new String
- */
-
-jobject newJavaLangString(JNIEnv * env, jbyteArray bytes) {
-    jclass tempClass;
-    jmethodID tempMethod;
+// Converts a number of milliseconds to a timeval.
+static timeval toTimeval(long ms) {
+    timeval tv;
+    tv.tv_sec = ms / 1000;
+    tv.tv_usec = (ms - tv.tv_sec*1000) * 1000;
+    return tv;
+}
 
-    tempClass = gCachedFields.string_class;
-    tempMethod = gCachedFields.string_class_init;
-    return env->NewObject(tempClass, tempMethod, (jbyteArray) bytes);
+// Converts a timeval to a number of milliseconds.
+static long toMs(const timeval& tv) {
+    return tv.tv_sec * 1000 + tv.tv_usec / 1000;
 }
 
 /**
@@ -456,11 +449,10 @@ jobject newJavaLangString(JNIEnv * env, jbyteArray bytes) {
  */
 
 static int time_msec_clock() {
-    struct timeval tp;
+    timeval tp;
     struct timezone tzp;
-
     gettimeofday(&tp, &tzp);
-    return (tp.tv_sec * 1000) + (tp.tv_usec / 1000);
+    return toMs(tp);
 }
 
 /**
@@ -871,53 +863,6 @@ static int getSocketAddressFamily(int socket) {
 }
 
 /**
- * A helper method, to set the connect context to a Long object.
- *
- * @param env  pointer to the JNI library
- * @param longclass Java Long Object
- */
-void setConnectContext(JNIEnv *env,jobject longclass,jbyte * context) {
-    jclass descriptorCLS;
-    jfieldID descriptorFID;
-    descriptorCLS = env->FindClass("java/lang/Long");
-    descriptorFID = env->GetFieldID(descriptorCLS, "value", "J");
-    env->SetLongField(longclass, descriptorFID, (jlong)((jint)context));
-};
-
-/**
- * A helper method, to get the connect context.
- *
- * @param env  pointer to the JNI library
- * @param longclass Java Long Object
- */
-jbyte *getConnectContext(JNIEnv *env, jobject longclass) {
-    jclass descriptorCLS;
-    jfieldID descriptorFID;
-    descriptorCLS = env->FindClass("java/lang/Long");
-    descriptorFID = env->GetFieldID(descriptorCLS, "value", "J");
-    return (jbyte*) ((jint)env->GetLongField(longclass, descriptorFID));
-};
-
-// typical ip checksum
-unsigned short ip_checksum(unsigned short* buffer, int size) {
-    register unsigned short * buf = buffer;
-    register int bufleft = size;
-    register unsigned long sum = 0;
-
-    while (bufleft > 1) {
-        sum = sum + (*buf++);
-        bufleft = bufleft - sizeof(unsigned short );
-    }
-    if (bufleft) {
-        sum = sum + (*(unsigned char*)buf);
-    }
-    sum = (sum >> 16) + (sum & 0xffff);
-    sum += (sum >> 16);
-
-    return (unsigned short )(~sum);
-}
-
-/**
  * Converts an IPv4 address to an IPv4-mapped IPv6 address. Performs no error
  * checking.
  *
@@ -1010,9 +955,8 @@ static int doBind(int socket, struct sockaddr_storage *socketAddress) {
  * @return 0, if no errors occurred, otherwise the (negative) error code.
  */
 static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr,
-        unsigned int timeout, unsigned int step, jbyte *ctxt) {
+                                  int timeout, unsigned int step, jbyte *ctxt) {
     int rc = 0;
-    struct timeval passedTimeout;
     int errorVal;
     socklen_t errorValLen = sizeof(int);
     struct selectFDSet *context = NULL;
@@ -1072,13 +1016,13 @@ static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr,
          * set the timeout value to be used. Because on some unix platforms we
          * don't get notified when a socket is closed we only sleep for 100ms
          * at a time
+         * 
+         * TODO: is this relevant for Android?
          */
-        passedTimeout.tv_sec = 0;
         if (timeout > 100) {
-            passedTimeout.tv_usec = 100 * 1000;
-        } else if ((int)timeout >= 0) {
-          passedTimeout.tv_usec = timeout * 1000;
+            timeout = 100;
         }
+        timeval passedTimeout(toTimeval(timeout));
 
         /* initialize the FD sets for the select */
         FD_ZERO(&(context->exceptionSet));
@@ -1092,7 +1036,7 @@ static int sockConnectWithTimeout(int handle, struct sockaddr_storage addr,
                    &(context->readSet),
                    &(context->writeSet),
                    &(context->exceptionSet),
-                   (int)timeout >= 0 ? &passedTimeout : NULL);
+                   timeout >= 0 ? &passedTimeout : NULL);
 
         /* if there is at least one descriptor ready to be checked */
         if (0 < rc) {
@@ -1399,7 +1343,6 @@ static void osNetworkSystem_oneTimeInitializationImpl(JNIEnv* env, jobject obj,
         {&c->integer_class, "java/lang/Integer"},
         {&c->boolean_class, "java/lang/Boolean"},
         {&c->byte_class, "java/lang/Byte"},
-        {&c->string_class, "java/lang/String"},
         {&c->socketimpl_class, "java/net/SocketImpl"},
         {&c->dpack_class, "java/net/DatagramPacket"}
     };
@@ -1421,7 +1364,6 @@ static void osNetworkSystem_oneTimeInitializationImpl(JNIEnv* env, jobject obj,
         {&c->integer_class_init, c->integer_class, "<init>", "(I)V", false},
         {&c->boolean_class_init, c->boolean_class, "<init>", "(Z)V", false},
         {&c->byte_class_init, c->byte_class, "<init>", "(B)V", false},
-        {&c->string_class_init, c->string_class, "<init>", "([B)V", false},
         {&c->iaddr_getbyaddress, c->iaddr_class, "getByAddress",
                     "([B)Ljava/net/InetAddress;", true}
     };
@@ -2474,9 +2416,7 @@ static jint osNetworkSystem_receiveStreamImpl(JNIEnv* env, jclass clazz,
     jbyte* body = env->GetByteArrayElements(data, NULL);
 
     // set timeout
-    struct timeval tv;
-    tv.tv_sec = timeout / 1000;
-    tv.tv_usec = (timeout % 1000) * 1000;
+    timeval tv(toTimeval(timeout));
     setsockopt(handle, SOL_SOCKET, SO_RCVTIMEO, (struct timeval *)&tv,
                sizeof(struct timeval));
 
@@ -2660,110 +2600,87 @@ static jint osNetworkSystem_sendDatagramImpl2(JNIEnv* env, jclass clazz,
     return sent;
 }
 
-static jint osNetworkSystem_selectImpl(JNIEnv* env, jclass clazz,
-        jobjectArray readFDArray, jobjectArray writeFDArray, jint countReadC,
-        jint countWriteC, jintArray outFlags, jlong timeout) {
-    // LOGD("ENTER selectImpl");
-
-    struct timeval timeP;
-    int result = 0;
-    int size = 0;
-    jobject gotFD;
-    fd_set *fdset_read,*fdset_write;
-    int handle;
-    jint *flagArray;
-    int val;
-    unsigned int time_sec = (unsigned int)timeout/1000;
-    unsigned int time_msec = (unsigned int)(timeout%1000);
-
-    fdset_read = (fd_set *)malloc(sizeof(fd_set));
-    fdset_write = (fd_set *)malloc(sizeof(fd_set));
-
-    FD_ZERO(fdset_read);
-    FD_ZERO(fdset_write);
-
-    for (val = 0; val<countReadC; val++) {
-
-        gotFD = env->GetObjectArrayElement(readFDArray,val);
-
-        handle = jniGetFDFromFileDescriptor(env, gotFD);
-
-        FD_SET(handle, fdset_read);
-
-        if (0 > (size - handle)) {
-            size = handle;
+static bool initFdSet(JNIEnv* env, jobjectArray fdArray, jint count, fd_set* fdSet, int* maxFd) {
+    for (int i = 0; i < count; ++i) {
+        jobject fileDescriptor = env->GetObjectArrayElement(fdArray, i);
+        if (fileDescriptor == NULL) {
+            return false;
         }
-    }
-
-    for (val = 0; val<countWriteC; val++) {
-
-        gotFD = env->GetObjectArrayElement(writeFDArray,val);
-
-        handle = jniGetFDFromFileDescriptor(env, gotFD);
-
-        FD_SET(handle, fdset_write);
-
-        if (0 > (size - handle)) {
-            size = handle;
+        
+        const int fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
+        if (fd < 0 || fd > 1024) {
+            LOGE("selectImpl: invalid fd %i", fd);
+            continue;
         }
-    }
-
-    /* the size is the max_fd + 1 */
-    size =size + 1;
-
-    if (0 > size) {
-        result = SOCKERR_FDSET_SIZEBAD;
-    } else {
-      /* only set when timeout >= 0 (non-block)*/
-        if (0 <= timeout) {
-
-            timeP.tv_sec = time_sec;
-            timeP.tv_usec = time_msec*1000;
-
-            result = sockSelect(size, fdset_read, fdset_write, NULL, &timeP);
-
-        } else {
-            result = sockSelect(size, fdset_read, fdset_write, NULL, NULL);
+        
+        FD_SET(fd, fdSet);
+        
+        if (fd > *maxFd) {
+            *maxFd = fd;
         }
     }
+    return true;
+}
 
-    if (0 < result) {
-        /*output the result to a int array*/
-        flagArray = env->GetIntArrayElements(outFlags, NULL);
-
-        for (val=0; val<countReadC; val++) {
-            gotFD = env->GetObjectArrayElement(readFDArray,val);
-
-            handle = jniGetFDFromFileDescriptor(env, gotFD);
-
-            if (FD_ISSET(handle,fdset_read)) {
-                flagArray[val] = SOCKET_OP_READ;
-            } else {
-                flagArray[val] = SOCKET_OP_NONE;
-            }
+static bool translateFdSet(JNIEnv* env, jobjectArray fdArray, jint count, const fd_set& fdSet, jint* flagArray, size_t offset, jint op) {
+    for (int i = 0; i < count; ++i) {
+        jobject fileDescriptor = env->GetObjectArrayElement(fdArray, i);
+        if (fileDescriptor == NULL) {
+            return false;
         }
-
-        for (val=0; val<countWriteC; val++) {
-
-            gotFD = env->GetObjectArrayElement(writeFDArray,val);
-
-            handle = jniGetFDFromFileDescriptor(env, gotFD);
-
-            if (FD_ISSET(handle,fdset_write)) {
-                flagArray[val+countReadC] = SOCKET_OP_WRITE;
-            } else {
-                flagArray[val+countReadC] = SOCKET_OP_NONE;
-            }
+        
+        const int fd = jniGetFDFromFileDescriptor(env, fileDescriptor);
+        const bool valid = fd >= 0 && fd < 1024;
+        
+        if (valid && FD_ISSET(fd, &fdSet)) {
+            flagArray[i + offset] = op;
+        } else {
+            flagArray[i + offset] = SOCKET_OP_NONE;
         }
-
-        env->ReleaseIntArrayElements(outFlags, flagArray, 0);
     }
+    return true;
+}
 
-    free(fdset_write);
-    free(fdset_read);
-
-    /* return both correct and error result, let java handle the exception*/
-    return result;
+static jint osNetworkSystem_selectImpl(JNIEnv* env, jclass clazz,
+        jobjectArray readFDArray, jobjectArray writeFDArray, jint countReadC,
+        jint countWriteC, jintArray outFlags, jlong timeoutMs) {
+    // LOGD("ENTER selectImpl");
+    
+    // Initialize the fd_sets.
+    int maxFd = -1;
+    fd_set readFds;
+    fd_set writeFds;
+    FD_ZERO(&readFds);
+    FD_ZERO(&writeFds);
+    bool initialized = initFdSet(env, readFDArray, countReadC, &readFds, &maxFd) &&
+                       initFdSet(env, writeFDArray, countWriteC, &writeFds, &maxFd);
+    if (!initialized) {
+        return -1;
+    }
+    
+    // Initialize the timeout, if any.
+    timeval tv;
+    timeval* tvp = NULL;
+    if (timeoutMs >= 0) {
+        tv = toTimeval(timeoutMs);
+        tvp = &tv;
+    }
+    
+    // Perform the select.
+    int result = sockSelect(maxFd + 1, &readFds, &writeFds, NULL, tvp);
+    if (result < 0) {
+        return result;
+    }
+    
+    // Translate the result into the int[] we're supposed to fill in.
+    jint* flagArray = env->GetIntArrayElements(outFlags, NULL);
+    if (flagArray == NULL) {
+        return -1;
+    }
+    bool okay = translateFdSet(env, readFDArray, countReadC, readFds, flagArray, 0, SOCKET_OP_READ) &&
+                translateFdSet(env, writeFDArray, countWriteC, writeFds, flagArray, countReadC, SOCKET_OP_WRITE);
+    env->ReleaseIntArrayElements(outFlags, flagArray, 0);
+    return okay ? 0 : -1;
 }
 
 static jobject osNetworkSystem_getSocketLocalAddressImpl(JNIEnv* env,
@@ -2995,7 +2912,7 @@ static jobject osNetworkSystem_getSocketOptionImpl(JNIEnv* env, jclass clazz,
                 throwSocketException(env, convertError(errno));
                 return NULL;
             }
-            return newJavaLangInteger(env, timeout.tv_sec * 1000 + timeout.tv_usec/1000);
+            return newJavaLangInteger(env, toMs(timeout));
         }
         default: {
             throwSocketException(env, SOCKERR_OPTUNSUPP);
@@ -3237,9 +3154,7 @@ static void osNetworkSystem_setSocketOptionImpl(JNIEnv* env, jclass clazz,
         }
 
         case JAVASOCKOPT_SO_RCVTIMEOUT: {
-            struct timeval timeout;
-            timeout.tv_sec = intVal / 1000;
-            timeout.tv_usec = (intVal % 1000) * 1000;
+            timeval timeout(toTimeval(intVal));
             result = setsockopt(handle, SOL_SOCKET, SO_RCVTIMEO, &timeout,
                     sizeof(struct timeval));
             if (0 != result) {