OSDN Git Service

Make SSL network I/O interruptible
authorBrian Carlstrom <bdc@google.com>
Tue, 21 Sep 2010 09:09:29 +0000 (02:09 -0700)
committerBrian Carlstrom <bdc@google.com>
Tue, 21 Sep 2010 19:13:43 +0000 (12:13 -0700)
- Changed NativeCrypto code to hold onto java.io.FileDescriptor so it
  can see observe when another thread calls Socket.close and sets the
  FileDescriptor's fd to -1. Changed AppData::setEnv to check
  NetFd::isClosed, it was already being used before each SSL I/O
  operation.

- Changed sslSelect to no longer take an int fd, it now uses the
  AppData to get access the FileDescriptor. Within sslSelect, the
  select call is now protected with AsynchronousSocketCloseMonitor.
  The select call is now retried on EINTR, checking for socket close
  similar to NET_FAILURE_RETRY. sslSelect now returns
  THROWN_SOCKETEXCEPTION to indicate that NetFd::isClosed has already
  thrown.

- sslRead and sslWrite now similarly returns THROWN_SOCKETEXCEPTION to
  indicate that Net::isClosed detected a closed FileDescriptor.

luni/src/main/native/NativeCrypto.cpp

Moved NetFd from OSNetworkSystem.cpp to new NetFd.h for reuse by NativeCrypto

luni/src/main/native/NetFd.h
luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp

Added test of 4 Socket/SSLSocket interrupt cases

    1.) read    Socket / close    Socket (redundant with AsynchronousCloseExceptionTest)
    2.) read    Socket / close SSLSocket
    3.) read SSLSocket / close    Socket
    4.) read SSLSocket / close SSLSocket

luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java

Bug: 2973020
Change-Id: I9037738dd1d1c09c03c99e3403e086366aa25109

luni/src/main/native/NativeCrypto.cpp
luni/src/main/native/NetFd.h [new file with mode: 0644]
luni/src/main/native/org_apache_harmony_luni_platform_OSNetworkSystem.cpp
luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java

index 6b278f5..8c8cda7 100644 (file)
 #include <openssl/rsa.h>
 #include <openssl/ssl.h>
 
+#include "AsynchronousSocketCloseMonitor.h"
 #include "JNIHelp.h"
 #include "JniConstants.h"
 #include "JniException.h"
 #include "LocalArray.h"
+#include "NetFd.h"
 #include "NetworkUtilities.h"
 #include "ScopedLocalRef.h"
 #include "ScopedPrimitiveArray.h"
@@ -189,7 +191,6 @@ static int throwExceptionIfNecessary(JNIEnv* env, const char* location  __attrib
     return result;
 }
 
-
 /**
  * Throws an SocketTimeoutException with the given string as a message.
  */
@@ -407,6 +408,7 @@ static BIGNUM* arrayToBignum(JNIEnv* env, jbyteArray source) {
 #define THREAD_ID pthread_self()
 #define THROW_EXCEPTION (-2)
 #define THROW_SOCKETTIMEOUTEXCEPTION (-3)
+#define THROWN_SOCKETEXCEPTION (-4)
 
 static MUTEX_TYPE* mutex_buf = NULL;
 
@@ -1232,7 +1234,7 @@ static jobjectArray getPrincipalBytes(JNIEnv* env, const STACK_OF(X509_NAME)* na
  * accesses to that field inside a lock/unlock sequence of our mutex, but
  * currently this seems a bit like overkill. Marking volatile at the very least.
  *
- * During handshaking, two additional fields are used to up-call into
+ * During handshaking, additional fields are used to up-call into
  * Java to perform certificate verification and handshake
  * completion. These are also used in any renegotiation.
  *
@@ -1240,6 +1242,8 @@ static jobjectArray getPrincipalBytes(JNIEnv* env, const STACK_OF(X509_NAME)* na
  *
  * (6) a NativeCrypto.SSLHandshakeCallbacks instance for callbacks from native to Java
  *
+ * (7) a java.io.FileDescriptor wrapper to check for socket close
+ *
  * Because renegotiation can be requested by the peer at any time,
  * care should be taken to maintain an appropriate JNIEnv on any
  * downcall to openssl since it could result in an upcall to Java. The
@@ -1249,7 +1253,7 @@ static jobjectArray getPrincipalBytes(JNIEnv* env, const STACK_OF(X509_NAME)* na
  *
  * Finally, we have one other piece of state setup by OpenSSL callbacks:
  *
- * (7) a set of ephemeral RSA keys that is lazily generated if a peer
+ * (8) a set of ephemeral RSA keys that is lazily generated if a peer
  * wants to use an exportable RSA cipher suite.
  *
  */
@@ -1261,6 +1265,7 @@ class AppData {
     MUTEX_TYPE mutex;
     JNIEnv* env;
     jobject sslHandshakeCallbacks;
+    jobject fileDescriptor;
     Unique_RSA ephemeralRsa;
 
     /**
@@ -1268,13 +1273,18 @@ class AppData {
      *
      * @param env The JNIEnv
      * @param shc The SSLHandshakeCallbacks
+     * @param fd The FileDescriptor
      */
   public:
     static AppData* create(JNIEnv* env,
-                           jobject shc) {
+                           jobject shc,
+                           jobject fd) {
         if (shc == NULL) {
             return NULL;
         }
+        if (fd == NULL) {
+            return NULL;
+        }
         AppData* appData = new AppData(env);
         if (pipe(appData->fdsEmergency) == -1) {
             destroy(env, appData);
@@ -1289,6 +1299,11 @@ class AppData {
             destroy(env, appData);
             return NULL;
         }
+        appData->fileDescriptor = env->NewGlobalRef(fd);
+        if (appData->fileDescriptor == NULL) {
+            destroy(env, appData);
+            return NULL;
+        }
         return appData;
     }
 
@@ -1301,13 +1316,12 @@ class AppData {
     }
 
   private:
-    AppData(JNIEnv* env) :
+    AppData(JNIEnv* e) :
             aliveAndKicking(1),
             waitingThreads(0),
-            env(NULL),
+            env(e),
             sslHandshakeCallbacks(NULL),
             ephemeralRsa(NULL) {
-        setEnv(env);
         fdsEmergency[0] = -1;
         fdsEmergency[1] = -1;
     }
@@ -1331,17 +1345,27 @@ class AppData {
             env->DeleteGlobalRef(sslHandshakeCallbacks);
             sslHandshakeCallbacks = NULL;
         }
+        if (fileDescriptor != NULL) {
+            env->DeleteGlobalRef(fileDescriptor);
+            fileDescriptor = NULL;
+        }
         clearEnv();
     }
 
   public:
-    void setEnv(JNIEnv* e) {
+    bool setEnv(JNIEnv* e) {
+        NetFd fd(e, fileDescriptor);
+        if (fd.isClosed()) {
+            return false;
+        }
         env = e;
+        return true;
     }
 
     void clearEnv() {
         env = NULL;
     }
+
 };
 
 /**
@@ -1354,60 +1378,88 @@ class AppData {
  * to be passed either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE, since we
  * only need to wait in case one of these problems occurs.
  *
+ * @param env
  * @param type Either SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE
- * @param fd The file descriptor to wait for (the underlying socket)
- * @param data The application data structure with mutex info etc.
+ * @param appData The application data structure with mutex info etc.
  * @param timeout The timeout value for select call, with the special value
  *                0 meaning no timeout at all (wait indefinitely). Note: This is
  *                the Java semantics of the timeout value, not the usual
  *                select() semantics.
- * @return The result of the inner select() call, -1 on additional errors
- */
-static int sslSelect(int type, int fd, AppData* appData, int timeout) {
+ * @return The result of the inner select() call,
+ * THROW_SOCKETEXCEPTION if a SocketException was thrown, -1 on
+ * additional errors
+ */
+static int sslSelect(JNIEnv* env, int type, AppData* appData, int timeout) {
+    // This loop is an expanded version of the NET_FAILURE_RETRY
+    // macro. It cannot simply be used in this case because select
+    // cannot be restarted without recreating the fd_sets and timeout
+    // structure.
+    int result;
     fd_set rfds;
     fd_set wfds;
+    do {
+        NetFd fd(env, appData->fileDescriptor);
+        if (fd.isClosed()) {
+            result = THROWN_SOCKETEXCEPTION;
+            break;
+        }
+        int intFd = fd.get();
+        JNI_TRACE("sslSelect type=%s fd=%d appData=%p timeout=%d",
+                  (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE", intFd, appData, timeout);
 
-    FD_ZERO(&rfds);
-    FD_ZERO(&wfds);
+        FD_ZERO(&rfds);
+        FD_ZERO(&wfds);
 
-    if (type == SSL_ERROR_WANT_READ) {
-        FD_SET(fd, &rfds);
-    } else {
-        FD_SET(fd, &wfds);
-    }
+        if (type == SSL_ERROR_WANT_READ) {
+            FD_SET(intFd, &rfds);
+        } else {
+            FD_SET(intFd, &wfds);
+        }
 
-    FD_SET(appData->fdsEmergency[0], &rfds);
+        FD_SET(appData->fdsEmergency[0], &rfds);
 
-    int max = fd > appData->fdsEmergency[0] ? fd : appData->fdsEmergency[0];
+        int max = intFd > appData->fdsEmergency[0] ? intFd : appData->fdsEmergency[0];
 
-    // Build a struct for the timeout data if we actually want a timeout.
-    timeval tv;
-    timeval* ptv;
-    if (timeout > 0) {
-        tv.tv_sec = timeout / 1000;
-        tv.tv_usec = 0;
-        ptv = &tv;
-    } else {
-        ptv = NULL;
-    }
+        // Build a struct for the timeout data if we actually want a timeout.
+        timeval tv;
+        timeval* ptv;
+        if (timeout > 0) {
+            tv.tv_sec = timeout / 1000;
+            tv.tv_usec = 0;
+            ptv = &tv;
+        } else {
+            ptv = NULL;
+        }
 
-    // LOGD("Doing select() for SSL_ERROR_WANT_%s...",
-    //      type == SSL_ERROR_WANT_READ ? "READ" : "WRITE");
-    int result = select(max + 1, &rfds, &wfds, NULL, ptv);
-    // LOGD("Returned from select(), result is %d", result);
+        {
+            AsynchronousSocketCloseMonitor monitor(intFd);
+            result = select(max + 1, &rfds, &wfds, NULL, ptv);
+            if (result == -1) {
+                if (fd.isClosed()) {
+                    result = THROWN_SOCKETEXCEPTION;
+                    break;
+                }
+                if (errno != EINTR) {
+                    break;
+                }
+            }
+        }
+    } while (result == -1);
 
     // Lock
     if (MUTEX_LOCK(appData->mutex) == -1) {
         return -1;
     }
 
-    // If we have been woken up by the emergency pipe, there must be a token in
-    // it. Thus we can safely read it (even in a blocking way).
-    if (FD_ISSET(appData->fdsEmergency[0], &rfds)) {
-        char token;
-        do {
-            read(appData->fdsEmergency[0], &token, 1);
-        } while (errno == EINTR);
+    if (result > 0) {
+        // If we have been woken up by the emergency pipe, there must be a token in
+        // it. Thus we can safely read it (even in a blocking way).
+        if (FD_ISSET(appData->fdsEmergency[0], &rfds)) {
+            char token;
+            do {
+                read(appData->fdsEmergency[0], &token, 1);
+            } while (errno == EINTR);
+        }
     }
 
     // Tell the world that there is now one thread less waiting for the
@@ -1416,7 +1468,9 @@ static int sslSelect(int type, int fd, AppData* appData, int timeout) {
 
     // Unlock
     MUTEX_UNLOCK(appData->mutex);
-    // LOGD("leave sslSelect");
+
+    JNI_TRACE("sslSelect %s fd=%d appData=%p timeout=%d => %d",
+              (type == SSL_ERROR_WANT_READ) ? "READ" : "WRITE", intFd, appData, timeout, result);
     return result;
 }
 
@@ -2360,15 +2414,15 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
         return 0;
     }
 
-    int fd = jniGetFDFromFileDescriptor(env, fdObject);
-    if (fd == -1) {
-        throwSSLExceptionStr(env, "Invalid file descriptor");
+    NetFd fd(env, fdObject);
+    if (fd.isClosed()) {
+        // SocketException thrown by NetFd.isClosed
         SSL_clear(ssl);
         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
         return 0;
     }
 
-    int ret = SSL_set_fd(ssl, fd);
+    int ret = SSL_set_fd(ssl, fd.get());
     JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake s=%d", ssl, fd);
 
     if (ret != 1) {
@@ -2383,7 +2437,7 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
      * Make socket non-blocking, so SSL_connect SSL_read() and SSL_write() don't hang
      * forever and we can use select() to find out if the socket is ready.
      */
-    if (!setBlocking(fd, false)) {
+    if (!setBlocking(fd.get(), false)) {
         throwSSLExceptionStr(env, "Unable to make socket non blocking");
         SSL_clear(ssl);
         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
@@ -2393,7 +2447,7 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
     /*
      * Create our special application data.
      */
-    AppData* appData = AppData::create(env, shc);
+    AppData* appData = AppData::create(env, shc, fdObject);
     if (appData == NULL) {
         throwSSLExceptionStr(env, "Unable to create application data");
         SSL_clear(ssl);
@@ -2409,16 +2463,22 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
         SSL_set_accept_state(ssl);
     }
 
+    ret = 0;
     while (appData->aliveAndKicking) {
         errno = 0;
-        appData->setEnv(env);
+        if (!appData->setEnv(env)) {
+            // SocketException thrown by NetFd.isClosed
+            SSL_clear(ssl);
+            JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
+            return 0;
+        }
         ret = SSL_do_handshake(ssl);
         appData->clearEnv();
         // cert_verify_callback threw exception
         if (env->ExceptionCheck()) {
-          SSL_clear(ssl);
-          JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
-          return 0;
+            SSL_clear(ssl);
+            JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
+            return 0;
         }
         // success case
         if (ret == 1) {
@@ -2429,11 +2489,12 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
             continue;
         }
         // error case
-        // LOGD("SSL_connect: result %d, errno %d, timeout %d", ret, errno, timeout);
         int sslError = SSL_get_error(ssl, ret);
+        JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake ret=%d errno=%d sslError=%d timeout=%d",
+                  ssl, ret, errno, sslError, timeout);
 
         /*
-         * If SSL_connect doesn't succeed due to the socket being
+         * If SSL_do_handshake doesn't succeed due to the socket being
          * either unreadable or unwritable, we use sslSelect to
          * wait for it to become ready. If that doesn't happen
          * before the specified timeout or an error occurs, we
@@ -2442,8 +2503,14 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
          */
         if (sslError == SSL_ERROR_WANT_READ || sslError == SSL_ERROR_WANT_WRITE) {
             appData->waitingThreads++;
-            int selectResult = sslSelect(sslError, fd, appData, timeout);
+            int selectResult = sslSelect(env, sslError, appData, timeout);
 
+            if (selectResult == THROWN_SOCKETEXCEPTION) {
+                // SocketException thrown by NetFd.isClosed
+                SSL_clear(ssl);
+                JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
+                return 0;
+            }
             if (selectResult == -1) {
                 throwSSLExceptionWithSslErrors(env, ssl, SSL_ERROR_SYSCALL, "handshake error");
                 SSL_clear(ssl);
@@ -2474,7 +2541,7 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
         if (sslError == SSL_ERROR_NONE || (sslError == SSL_ERROR_SYSCALL && errno == 0)) {
             throwSSLExceptionStr(env, "Connection closed by peer");
         } else {
-            throwSSLExceptionWithSslErrors(env, ssl, sslError, "Trouble with SSL handshake");
+            throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL handshake terminated");
         }
         SSL_clear(ssl);
         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
@@ -2488,7 +2555,7 @@ static jint NativeCrypto_SSL_do_handshake(JNIEnv* env, jclass,
          * at this point.
          */
         int sslError = SSL_get_error(ssl, ret);
-        throwSSLExceptionWithSslErrors(env, ssl, sslError, "Trouble with SSL handshake");
+        throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL handshake aborted");
         SSL_clear(ssl);
         JNI_TRACE("ssl=%p NativeCrypto_SSL_do_handshake => 0", ssl);
         return 0;
@@ -2621,7 +2688,6 @@ static int sslRead(JNIEnv* env, SSL* ssl, char* buf, jint len, int* sslReturnCod
         return 0;
     }
 
-    int fd = SSL_get_fd(ssl);
     BIO* bio = SSL_get_rbio(ssl);
 
     AppData* appData = toAppData(ssl);
@@ -2640,7 +2706,10 @@ static int sslRead(JNIEnv* env, SSL* ssl, char* buf, jint len, int* sslReturnCod
         unsigned int bytesMoved = BIO_number_read(bio) + BIO_number_written(bio);
 
         // LOGD("Doing SSL_Read()");
-        appData->setEnv(env);
+        if (!appData->setEnv(env)) {
+            MUTEX_UNLOCK(appData->mutex);
+            return THROWN_SOCKETEXCEPTION;
+        }
         int result = SSL_read(ssl, buf, len);
         appData->clearEnv();
         int sslError = SSL_ERROR_NONE;
@@ -2681,7 +2750,10 @@ static int sslRead(JNIEnv* env, SSL* ssl, char* buf, jint len, int* sslReturnCod
             // Need to wait for availability of underlying layer, then retry.
             case SSL_ERROR_WANT_READ:
             case SSL_ERROR_WANT_WRITE: {
-                int selectResult = sslSelect(sslError, fd, appData, timeout);
+                int selectResult = sslSelect(env, sslError, appData, timeout);
+                if (selectResult == THROWN_SOCKETEXCEPTION) {
+                    return THROWN_SOCKETEXCEPTION;
+                }
                 if (selectResult == -1) {
                     *sslReturnCode = -1;
                     *sslErrorCode = sslError;
@@ -2753,6 +2825,10 @@ static jint NativeCrypto_SSL_read_byte(JNIEnv* env, jclass, jint ssl_address, ji
             throwSocketTimeoutException(env, "Read timed out");
             result = -1;
             break;
+        case THROWN_SOCKETEXCEPTION:
+            // SocketException thrown by NetFd.isClosed
+            result = -1;
+            break;
         case -1:
             // Propagate EOF upwards.
             result = -1;
@@ -2792,15 +2868,23 @@ static jint NativeCrypto_SSL_read(JNIEnv* env, jclass, jint
                       &returnCode, &sslErrorCode, timeout);
 
     int result;
-    if (ret == THROW_EXCEPTION) {
-        // See sslRead() regarding improper failure to handle normal cases.
-        throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Read error");
-        result = -1;
-    } else if (ret == THROW_SOCKETTIMEOUTEXCEPTION) {
-        throwSocketTimeoutException(env, "Read timed out");
-        result = -1;
-    } else {
-        result = ret;
+    switch (ret) {
+        case THROW_EXCEPTION:
+            // See sslRead() regarding improper failure to handle normal cases.
+            throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Read error");
+            result = -1;
+            break;
+        case THROW_SOCKETTIMEOUTEXCEPTION:
+            throwSocketTimeoutException(env, "Read timed out");
+            result = -1;
+            break;
+        case THROWN_SOCKETEXCEPTION:
+            // SocketException thrown by NetFd.isClosed
+            result = -1;
+            break;
+        default:
+            result = ret;
+            break;
     }
 
     JNI_TRACE("ssl=%p NativeCrypto_SSL_read => %d", ssl, result);
@@ -2829,7 +2913,6 @@ static int sslWrite(JNIEnv* env, SSL* ssl, const char* buf, jint len, int* sslRe
         return 0;
     }
 
-    int fd = SSL_get_fd(ssl);
     BIO* bio = SSL_get_wbio(ssl);
 
     AppData* appData = toAppData(ssl);
@@ -2848,7 +2931,10 @@ static int sslWrite(JNIEnv* env, SSL* ssl, const char* buf, jint len, int* sslRe
         unsigned int bytesMoved = BIO_number_read(bio) + BIO_number_written(bio);
 
         // LOGD("Doing SSL_write() with %d bytes to go", len);
-        appData->setEnv(env);
+        if (!appData->setEnv(env)) {
+            MUTEX_UNLOCK(appData->mutex);
+            return THROWN_SOCKETEXCEPTION;
+        }
         int result = SSL_write(ssl, buf, len);
         appData->clearEnv();
         int sslError = SSL_ERROR_NONE;
@@ -2892,7 +2978,10 @@ static int sslWrite(JNIEnv* env, SSL* ssl, const char* buf, jint len, int* sslRe
             // it's also not standard Java behavior, so we wait forever here.
             case SSL_ERROR_WANT_READ:
             case SSL_ERROR_WANT_WRITE: {
-                int selectResult = sslSelect(sslError, fd, appData, 0);
+                int selectResult = sslSelect(env, sslError, appData, 0);
+                if (selectResult == THROWN_SOCKETEXCEPTION) {
+                    return THROWN_SOCKETEXCEPTION;
+                }
                 if (selectResult == -1) {
                     *sslReturnCode = -1;
                     *sslErrorCode = sslError;
@@ -2952,11 +3041,19 @@ static void NativeCrypto_SSL_write_byte(JNIEnv* env, jclass, jint ssl_address, j
     char buf[1] = { (char) b };
     int ret = sslWrite(env, ssl, buf, 1, &returnCode, &sslErrorCode);
 
-    if (ret == THROW_EXCEPTION) {
-        // See sslWrite() regarding improper failure to handle normal cases.
-        throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
-    } else if (ret == THROW_SOCKETTIMEOUTEXCEPTION) {
-        throwSocketTimeoutException(env, "Write timed out");
+    switch (ret) {
+        case THROW_EXCEPTION:
+            // See sslWrite() regarding improper failure to handle normal cases.
+            throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
+            break;
+        case THROW_SOCKETTIMEOUTEXCEPTION:
+            throwSocketTimeoutException(env, "Write timed out");
+            break;
+        case THROWN_SOCKETEXCEPTION:
+            // SocketException thrown by NetFd.isClosed
+            break;
+        default:
+            break;
     }
 }
 
@@ -2982,11 +3079,19 @@ static void NativeCrypto_SSL_write(JNIEnv* env, jclass,
     int ret = sslWrite(env, ssl, reinterpret_cast<const char*>(bytes.get() + offset), len,
             &returnCode, &sslErrorCode);
 
-    if (ret == THROW_EXCEPTION) {
-        // See sslWrite() regarding improper failure to handle normal cases.
-        throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
-    } else if (ret == THROW_SOCKETTIMEOUTEXCEPTION) {
-        throwSocketTimeoutException(env, "Write timed out");
+    switch (ret) {
+        case THROW_EXCEPTION:
+            // See sslWrite() regarding improper failure to handle normal cases.
+            throwSSLExceptionWithSslErrors(env, ssl, sslErrorCode, "Write error");
+            break;
+        case THROW_SOCKETTIMEOUTEXCEPTION:
+            throwSocketTimeoutException(env, "Write timed out");
+            break;
+        case THROWN_SOCKETEXCEPTION:
+            // SocketException thrown by NetFd.isClosed
+            break;
+        default:
+            break;
     }
 }
 
@@ -3024,50 +3129,55 @@ static void NativeCrypto_SSL_shutdown(JNIEnv* env, jclass, jint ssl_address) {
     if (ssl == NULL) {
         return;
     }
-    /*
-     * Try to make socket blocking again. OpenSSL literature recommends this.
-     */
-    int fd = SSL_get_fd(ssl);
-    JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown s=%d", ssl, fd);
-    if (fd != -1) {
-        setBlocking(fd, true);
-    }
-
     AppData* appData = toAppData(ssl);
     if (appData != NULL) {
-        appData->setEnv(env);
-    }
-    int ret = SSL_shutdown(ssl);
-    if (appData != NULL) {
+        if (!appData->setEnv(env)) {
+            // SocketException thrown by NetFd.isClosed
+            SSL_clear(ssl);
+            freeSslErrorState();
+            return;
+        }
+
+        /*
+         * Try to make socket blocking again. OpenSSL literature recommends this.
+         */
+        int fd = SSL_get_fd(ssl);
+        JNI_TRACE("ssl=%p NativeCrypto_SSL_shutdown s=%d", ssl, fd);
+        if (fd != -1) {
+            setBlocking(fd, true);
+        }
+
+        int ret = SSL_shutdown(ssl);
+        switch (ret) {
+            case 0:
+                /*
+                 * Shutdown was not successful (yet), but there also
+                 * is no error. Since we can't know whether the remote
+                 * server is actually still there, and we don't want to
+                 * get stuck forever in a second SSL_shutdown() call, we
+                 * simply return. This is not security a problem as long
+                 * as we close the underlying socket, which we actually
+                 * do, because that's where we are just coming from.
+                 */
+                break;
+            case 1:
+                /*
+                 * Shutdown was successful. We can safely return. Hooray!
+                 */
+                break;
+            default:
+                /*
+                 * Everything else is a real error condition. We should
+                 * let the Java layer know about this by throwing an
+                 * exception.
+                 */
+                int sslError = SSL_get_error(ssl, ret);
+                throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL shutdown failed");
+                break;
+        }
         appData->clearEnv();
     }
-    switch (ret) {
-        case 0:
-            /*
-             * Shutdown was not successful (yet), but there also
-             * is no error. Since we can't know whether the remote
-             * server is actually still there, and we don't want to
-             * get stuck forever in a second SSL_shutdown() call, we
-             * simply return. This is not security a problem as long
-             * as we close the underlying socket, which we actually
-             * do, because that's where we are just coming from.
-             */
-            break;
-        case 1:
-            /*
-             * Shutdown was successful. We can safely return. Hooray!
-             */
-            break;
-        default:
-            /*
-             * Everything else is a real error condition. We should
-             * let the Java layer know about this by throwing an
-             * exception.
-             */
-            int sslError = SSL_get_error(ssl, ret);
-            throwSSLExceptionWithSslErrors(env, ssl, sslError, "SSL shutdown failed");
-            break;
-    }
+
     SSL_clear(ssl);
     freeSslErrorState();
 }
diff --git a/luni/src/main/native/NetFd.h b/luni/src/main/native/NetFd.h
new file mode 100644 (file)
index 0000000..235b057
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2010 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef NET_FD_H_included
+#define NET_FD_H_included
+
+/**
+ * Wraps access to the int inside a java.io.FileDescriptor, taking care of throwing exceptions.
+ */
+class NetFd {
+public:
+    NetFd(JNIEnv* env, jobject fileDescriptor)
+        : mEnv(env), mFileDescriptor(fileDescriptor), mFd(-1)
+    {
+    }
+
+    bool isClosed() {
+        mFd = jniGetFDFromFileDescriptor(mEnv, mFileDescriptor);
+        bool closed = (mFd == -1);
+        if (closed) {
+            jniThrowException(mEnv, "java/net/SocketException", "Socket closed");
+        }
+        return closed;
+    }
+
+    int get() const {
+        return mFd;
+    }
+
+private:
+    JNIEnv* mEnv;
+    jobject mFileDescriptor;
+    int mFd;
+
+    // Disallow copy and assignment.
+    NetFd(const NetFd&);
+    void operator=(const NetFd&);
+};
+
+/**
+ * Used to retry syscalls that can return EINTR. This differs from TEMP_FAILURE_RETRY in that
+ * it also considers the case where the reason for failure is that another thread called
+ * Socket.close.
+ */
+#define NET_FAILURE_RETRY(fd, exp) ({               \
+    typeof (exp) _rc;                               \
+    do {                                            \
+        _rc = (exp);                                \
+        if (_rc == -1) {                            \
+            if (fd.isClosed() || errno != EINTR) {  \
+                break;                              \
+            }                                       \
+        }                                           \
+    } while (_rc == -1);                            \
+    _rc; })
+
+#endif // NET_FD_H_included
index b22678f..73bc402 100644 (file)
@@ -21,6 +21,7 @@
 #include "JniConstants.h"
 #include "JniException.h"
 #include "LocalArray.h"
+#include "NetFd.h"
 #include "NetworkUtilities.h"
 #include "ScopedPrimitiveArray.h"
 #include "jni.h"
@@ -93,56 +94,6 @@ static struct CachedFields {
 } gCachedFields;
 
 /**
- * Wraps access to the int inside a java.io.FileDescriptor, taking care of throwing exceptions.
- */
-class NetFd {
-public:
-    NetFd(JNIEnv* env, jobject fileDescriptor)
-        : mEnv(env), mFileDescriptor(fileDescriptor), mFd(-1)
-    {
-    }
-
-    bool isClosed() {
-        mFd = jniGetFDFromFileDescriptor(mEnv, mFileDescriptor);
-        bool closed = (mFd == -1);
-        if (closed) {
-            jniThrowException(mEnv, "java/net/SocketException", "Socket closed");
-        }
-        return closed;
-    }
-
-    int get() const {
-        return mFd;
-    }
-
-private:
-    JNIEnv* mEnv;
-    jobject mFileDescriptor;
-    int mFd;
-
-    // Disallow copy and assignment.
-    NetFd(const NetFd&);
-    void operator=(const NetFd&);
-};
-
-/**
- * Used to retry syscalls that can return EINTR. This differs from TEMP_FAILURE_RETRY in that
- * it also considers the case where the reason for failure is that another thread called
- * Socket.close.
- */
-#define NET_FAILURE_RETRY(fd, exp) ({               \
-    typeof (exp) _rc;                               \
-    do {                                            \
-        _rc = (exp);                                \
-        if (_rc == -1) {                            \
-            if (fd.isClosed() || errno != EINTR) {  \
-                break;                              \
-            }                                       \
-        }                                           \
-    } while (_rc == -1);                            \
-    _rc; })
-
-/**
  * Returns the port number in a sockaddr_storage structure.
  *
  * @param address the sockaddr_storage structure to get the port from
index dbd9a2a..5e9567a 100644 (file)
@@ -816,6 +816,62 @@ public class SSLSocketTest extends TestCase {
         listening.close();
     }
 
+    public void test_SSLSocket_interrupt() throws Exception {
+        ServerSocket listening = new ServerSocket(0);
+
+        for (int i = 0; i < 3; i++) {
+            Socket underlying = new Socket(listening.getInetAddress(), listening.getLocalPort());
+            Socket server = listening.accept();
+
+            SSLSocketFactory sf = (SSLSocketFactory) SSLSocketFactory.getDefault();
+            Socket clientWrapping = sf.createSocket(underlying, null, -1, true);
+
+            switch (i) {
+                case 0:
+                    test_SSLSocket_interrupt_case(underlying, underlying);
+                    break;
+                case 1:
+                    test_SSLSocket_interrupt_case(underlying, clientWrapping);
+                    break;
+                case 2:
+                    test_SSLSocket_interrupt_case(clientWrapping, underlying);
+                    break;
+                case 3:
+                    test_SSLSocket_interrupt_case(clientWrapping, clientWrapping);
+                    break;
+                default:
+                    fail();
+            }
+
+            server.close();
+            underlying.close();
+        }
+        listening.close();
+    }
+
+    private void test_SSLSocket_interrupt_case(Socket toRead, final Socket toClose)
+            throws Exception {
+        new Thread() {
+            @Override
+            public void run() {
+                try {
+                    Thread.sleep(1 * 1000);
+                    toClose.close();
+                } catch (Exception e) {
+                    throw new RuntimeException(e);
+                }
+            }
+        }.start();
+        try {
+            toRead.setSoTimeout(5 * 1000);
+            toRead.getInputStream().read();
+            fail();
+        } catch (SocketTimeoutException e) {
+            throw e;
+        } catch (SocketException expected) {
+        }
+    }
+
     public void test_TestSSLSocketPair_create() {
         TestSSLSocketPair test = TestSSLSocketPair.create();
         assertNotNull(test.c);