OSDN Git Service

Don't close loopback sockets when a VPN connects or entering doze.
authorLorenzo Colitti <lorenzo@google.com>
Tue, 26 Jul 2016 08:53:50 +0000 (17:53 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Thu, 28 Jul 2016 09:43:22 +0000 (18:43 +0900)
Bug: 30186506
Change-Id: I8bae7b004c3bb9f6e9e0db99774a6ff6505578b4

server/NetdNativeService.cpp
server/SockDiag.cpp
server/SockDiag.h
server/SockDiagTest.cpp
server/VirtualNetwork.cpp

index 10629ef..5e5b8fd 100644 (file)
@@ -164,7 +164,8 @@ binder::Status NetdNativeService::socketDestroy(const std::vector<UidRange>& uid
     }
 
     UidRanges uidRanges(uids);
-    int err = sd.destroySockets(uidRanges, std::set<uid_t>(skipUids.begin(), skipUids.end()));
+    int err = sd.destroySockets(uidRanges, std::set<uid_t>(skipUids.begin(), skipUids.end()),
+                                true /* excludeLoopback */);
 
     if (err) {
         return binder::Status::fromServiceSpecificError(-err,
index 48b8eae..41e92c2 100644 (file)
@@ -236,6 +236,28 @@ int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
     return 0;
 }
 
+// Determines whether a socket is a loopback socket. Does not check socket state.
+bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
+    switch (msg->idiag_family) {
+        case AF_INET:
+            // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
+            return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
+                   IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
+                   msg->id.idiag_src[0] == msg->id.idiag_dst[0];
+
+        case AF_INET6: {
+            const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
+            const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
+            return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
+                   (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
+                   IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
+                   !memcmp(src, dst, sizeof(*src));
+        }
+        default:
+            return false;
+    }
+}
+
 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
     if (msg == nullptr) {
        return 0;
@@ -319,12 +341,14 @@ int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
     return 0;
 }
 
-int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
+int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
     mSocketsDestroyed = 0;
     Stopwatch s;
 
-    auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
-        return (msg != nullptr && msg->idiag_uid == uid);
+    auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr &&
+               msg->idiag_uid == uid &&
+               !(excludeLoopback && isLoopbackSocket(msg));
     };
 
     for (const int family : {AF_INET, AF_INET6}) {
@@ -347,14 +371,16 @@ int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
     return 0;
 }
 
-int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
+                             bool excludeLoopback) {
     mSocketsDestroyed = 0;
     Stopwatch s;
 
     auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
         return msg != nullptr &&
                uidRanges.hasUid(msg->idiag_uid) &&
-               skipUids.find(msg->idiag_uid) == skipUids.end();
+               skipUids.find(msg->idiag_uid) == skipUids.end() &&
+               !(excludeLoopback && isLoopbackSocket(msg));
     };
 
     if (int ret = destroyLiveSockets(shouldDestroy)) {
index 6a4c703..5dc77c1 100644 (file)
@@ -51,11 +51,16 @@ class SockDiag {
     int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
     int readDiagMsg(uint8_t proto, DumpCallback callback);
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
+    // Destroys all sockets on the given IPv4 or IPv6 address.
     int destroySockets(const char *addrstr);
-    int destroySockets(uint8_t proto, uid_t uid);
-    int destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids);
+    // Destroys all sockets for the given protocol and UID.
+    int destroySockets(uint8_t proto, uid_t uid, bool excludeLoopback);
+    // Destroys all "live" (CONNECTED, SYN_SENT, SYN_RECV) TCP sockets for the given UID ranges.
+    int destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
+                       bool excludeLoopback);
 
   private:
+    friend class SockDiagTest;
     int mSock;
     int mWriteSock;
     int mSocketsDestroyed;
@@ -64,4 +69,5 @@ class SockDiag {
     int destroyLiveSockets(DumpCallback destroy);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
+    static bool isLoopbackSocket(const inet_diag_msg *msg);
 };
index 2b1bf02..f9353f3 100644 (file)
  * sock_diag_test.cpp - unit tests for SockDiag.cpp
  */
 
+#include <sys/socket.h>
+#include <netdb.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
+#include <netinet/tcp.h>
 #include <linux/inet_diag.h>
 
 #include <gtest/gtest.h>
 #include "UidRanges.h"
 
 class SockDiagTest : public ::testing::Test {
+protected:
+    static bool isLoopbackSocket(const inet_diag_msg *msg) {
+        return SockDiag::isLoopbackSocket(msg);
+    };
 };
 
 uint16_t bindAndListen(int s) {
@@ -110,6 +117,9 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
+        if (msg->idiag_state == TCP_ESTABLISHED) {
+            EXPECT_TRUE(isLoopbackSocket(msg));
+        }
         return false;
     };
 
@@ -136,6 +146,9 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
+        if (msg->idiag_state == TCP_ESTABLISHED) {
+            EXPECT_TRUE(isLoopbackSocket(msg));
+        }
         return false;
     };
 
@@ -175,10 +188,96 @@ TEST_F(SockDiagTest, TestDump) {
     close(accepted6);
 }
 
+bool fillDiagAddr(__be32 addr[4], const sockaddr *sa) {
+    switch (sa->sa_family) {
+        case AF_INET: {
+            sockaddr_in *sin = (sockaddr_in *) sa;
+            memcpy(addr, &sin->sin_addr, sizeof(sin->sin_addr));
+            return true;
+        }
+        case AF_INET6: {
+            sockaddr_in6 *sin6 = (sockaddr_in6 *) sa;
+            memcpy(addr, &sin6->sin6_addr, sizeof(sin6->sin6_addr));
+            return true;
+        }
+        default:
+            return false;
+    }
+}
+
+inet_diag_msg makeDiagMessage(__u8 family,  const sockaddr *src, const sockaddr *dst) {
+    inet_diag_msg msg = {
+        .idiag_family = family,
+        .idiag_state = TCP_ESTABLISHED,
+        .idiag_uid = AID_APP + 123,
+        .idiag_inode = 123456789,
+        .id = {
+            .idiag_sport = 1234,
+            .idiag_dport = 4321,
+        }
+    };
+    EXPECT_TRUE(fillDiagAddr(msg.id.idiag_src, src));
+    EXPECT_TRUE(fillDiagAddr(msg.id.idiag_dst, dst));
+    return msg;
+}
+
+inet_diag_msg makeDiagMessage(const char* srcstr, const char* dststr) {
+    addrinfo hints = { .ai_flags = AI_NUMERICHOST }, *src, *dst;
+    EXPECT_EQ(0, getaddrinfo(srcstr, NULL, &hints, &src));
+    EXPECT_EQ(0, getaddrinfo(dststr, NULL, &hints, &dst));
+    EXPECT_EQ(src->ai_addr->sa_family, dst->ai_addr->sa_family);
+    inet_diag_msg msg = makeDiagMessage(src->ai_addr->sa_family, src->ai_addr, dst->ai_addr);
+    freeaddrinfo(src);
+    freeaddrinfo(dst);
+    return msg;
+}
+
+TEST_F(SockDiagTest, TestIsLoopbackSocket) {
+    inet_diag_msg msg;
+
+    msg = makeDiagMessage("127.0.0.1", "127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::ffff:127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "192.0.2.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "8.8.8.8");
+    EXPECT_FALSE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("2001:db8::1", "2001:db8::1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("2001:db8::1", "2001:4860:4860::6464");
+    EXPECT_FALSE(isLoopbackSocket(&msg));
+
+    // While isLoopbackSocket returns true on these sockets, we usually don't want to close them
+    // because they aren't specific to any particular network and thus don't become unusable when
+    // an app's routing changes or its network access is removed.
+    //
+    // This isn't a problem, as anything that calls destroyLiveSockets will skip them because
+    // destroyLiveSockets only enumerates ESTABLISHED, SYN_SENT, and SYN_RECV sockets.
+    msg = makeDiagMessage("127.0.0.1", "0.0.0.0");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+}
+
 enum MicroBenchmarkTestType {
     ADDRESS,
     UID,
+    UID_EXCLUDE_LOOPBACK,
     UIDRANGE,
+    UIDRANGE_EXCLUDE_LOOPBACK,
 };
 
 const char *testTypeName(MicroBenchmarkTestType mode) {
@@ -186,7 +285,9 @@ const char *testTypeName(MicroBenchmarkTestType mode) {
     switch((mode)) {
         TO_STRING_TYPE(ADDRESS);
         TO_STRING_TYPE(UID);
+        TO_STRING_TYPE(UID_EXCLUDE_LOOPBACK);
         TO_STRING_TYPE(UIDRANGE);
+        TO_STRING_TYPE(UIDRANGE_EXCLUDE_LOOPBACK);
     }
 #undef TO_STRING_TYPE
 }
@@ -214,7 +315,9 @@ protected:
         case ADDRESS:
             return ADDRESS_SOCKETS;
         case UID:
+        case UID_EXCLUDE_LOOPBACK:
         case UIDRANGE:
+        case UIDRANGE_EXCLUDE_LOOPBACK:
             return UID_SOCKETS;
         }
     }
@@ -228,16 +331,21 @@ protected:
                 EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
                 break;
             case UID:
-                ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+            case UID_EXCLUDE_LOOPBACK: {
+                bool excludeLoopback = (mode == UID_EXCLUDE_LOOPBACK);
+                ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID, excludeLoopback);
                 EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
                         strerror(-ret);
                 break;
-            case UIDRANGE: {
+            }
+            case UIDRANGE:
+            case UIDRANGE_EXCLUDE_LOOPBACK: {
+                bool excludeLoopback = (mode == UIDRANGE_EXCLUDE_LOOPBACK);
                 const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
                 std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
                 UidRanges uidRanges;
                 uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
-                ret = mSd.destroySockets(uidRanges, skipUids);
+                ret = mSd.destroySockets(uidRanges, skipUids, excludeLoopback);
             }
         }
         return ret;
@@ -262,6 +370,9 @@ protected:
                 }
                 return false;
             }
+            case UID_EXCLUDE_LOOPBACK:
+            case UIDRANGE_EXCLUDE_LOOPBACK:
+                return false;
         }
     }
 
@@ -341,7 +452,9 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
     fprintf(stderr, "   Verifying: %6.1f ms (%d sockets destroyed)\n",
             std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count(),
             socketsClosed);
-    EXPECT_GT(socketsClosed, 0);  // Just in case there's a bug in the test.
+    if (strstr(testTypeName(mode), "_EXCLUDE_LOOPBACK") == nullptr) {
+        EXPECT_GT(socketsClosed, 0);  // Just in case there's a bug in the test.
+    }
 
     start = std::chrono::steady_clock::now();
     for (int i = 0; i < numSockets; i++) {
@@ -358,4 +471,5 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
 constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
 
 INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
-                        testing::Values(ADDRESS, UID, UIDRANGE));
+                        testing::Values(ADDRESS, UID, UIDRANGE,
+                                        UID_EXCLUDE_LOOPBACK, UIDRANGE_EXCLUDE_LOOPBACK));
index 3d83703..6daa50d 100644 (file)
@@ -54,7 +54,7 @@ int VirtualNetwork::maybeCloseSockets(bool add, const UidRanges& uidRanges,
         return -EBADFD;
     }
 
-    if (int ret = sd.destroySockets(uidRanges, protectableUsers)) {
+    if (int ret = sd.destroySockets(uidRanges, protectableUsers, true /* excludeLoopback */)) {
         ALOGE("Failed to close sockets while %s %s to network %d: %s",
               add ? "adding" : "removing", uidRanges.toString().c_str(), mNetId, strerror(-ret));
         return ret;