OSDN Git Service

Don't close loopback sockets when a VPN connects or entering doze.
[android-x86/system-netd.git] / server / SockDiagTest.cpp
index 2b1bf02..f9353f3 100644 (file)
  * sock_diag_test.cpp - unit tests for SockDiag.cpp
  */
 
+#include <sys/socket.h>
+#include <netdb.h>
 #include <arpa/inet.h>
 #include <netinet/in.h>
+#include <netinet/tcp.h>
 #include <linux/inet_diag.h>
 
 #include <gtest/gtest.h>
 #include "UidRanges.h"
 
 class SockDiagTest : public ::testing::Test {
+protected:
+    static bool isLoopbackSocket(const inet_diag_msg *msg) {
+        return SockDiag::isLoopbackSocket(msg);
+    };
 };
 
 uint16_t bindAndListen(int s) {
@@ -110,6 +117,9 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
+        if (msg->idiag_state == TCP_ESTABLISHED) {
+            EXPECT_TRUE(isLoopbackSocket(msg));
+        }
         return false;
     };
 
@@ -136,6 +146,9 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
+        if (msg->idiag_state == TCP_ESTABLISHED) {
+            EXPECT_TRUE(isLoopbackSocket(msg));
+        }
         return false;
     };
 
@@ -175,10 +188,96 @@ TEST_F(SockDiagTest, TestDump) {
     close(accepted6);
 }
 
+bool fillDiagAddr(__be32 addr[4], const sockaddr *sa) {
+    switch (sa->sa_family) {
+        case AF_INET: {
+            sockaddr_in *sin = (sockaddr_in *) sa;
+            memcpy(addr, &sin->sin_addr, sizeof(sin->sin_addr));
+            return true;
+        }
+        case AF_INET6: {
+            sockaddr_in6 *sin6 = (sockaddr_in6 *) sa;
+            memcpy(addr, &sin6->sin6_addr, sizeof(sin6->sin6_addr));
+            return true;
+        }
+        default:
+            return false;
+    }
+}
+
+inet_diag_msg makeDiagMessage(__u8 family,  const sockaddr *src, const sockaddr *dst) {
+    inet_diag_msg msg = {
+        .idiag_family = family,
+        .idiag_state = TCP_ESTABLISHED,
+        .idiag_uid = AID_APP + 123,
+        .idiag_inode = 123456789,
+        .id = {
+            .idiag_sport = 1234,
+            .idiag_dport = 4321,
+        }
+    };
+    EXPECT_TRUE(fillDiagAddr(msg.id.idiag_src, src));
+    EXPECT_TRUE(fillDiagAddr(msg.id.idiag_dst, dst));
+    return msg;
+}
+
+inet_diag_msg makeDiagMessage(const char* srcstr, const char* dststr) {
+    addrinfo hints = { .ai_flags = AI_NUMERICHOST }, *src, *dst;
+    EXPECT_EQ(0, getaddrinfo(srcstr, NULL, &hints, &src));
+    EXPECT_EQ(0, getaddrinfo(dststr, NULL, &hints, &dst));
+    EXPECT_EQ(src->ai_addr->sa_family, dst->ai_addr->sa_family);
+    inet_diag_msg msg = makeDiagMessage(src->ai_addr->sa_family, src->ai_addr, dst->ai_addr);
+    freeaddrinfo(src);
+    freeaddrinfo(dst);
+    return msg;
+}
+
+TEST_F(SockDiagTest, TestIsLoopbackSocket) {
+    inet_diag_msg msg;
+
+    msg = makeDiagMessage("127.0.0.1", "127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::ffff:127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "192.0.2.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "8.8.8.8");
+    EXPECT_FALSE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("192.0.2.1", "127.0.0.1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("2001:db8::1", "2001:db8::1");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("2001:db8::1", "2001:4860:4860::6464");
+    EXPECT_FALSE(isLoopbackSocket(&msg));
+
+    // While isLoopbackSocket returns true on these sockets, we usually don't want to close them
+    // because they aren't specific to any particular network and thus don't become unusable when
+    // an app's routing changes or its network access is removed.
+    //
+    // This isn't a problem, as anything that calls destroyLiveSockets will skip them because
+    // destroyLiveSockets only enumerates ESTABLISHED, SYN_SENT, and SYN_RECV sockets.
+    msg = makeDiagMessage("127.0.0.1", "0.0.0.0");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+
+    msg = makeDiagMessage("::1", "::");
+    EXPECT_TRUE(isLoopbackSocket(&msg));
+}
+
 enum MicroBenchmarkTestType {
     ADDRESS,
     UID,
+    UID_EXCLUDE_LOOPBACK,
     UIDRANGE,
+    UIDRANGE_EXCLUDE_LOOPBACK,
 };
 
 const char *testTypeName(MicroBenchmarkTestType mode) {
@@ -186,7 +285,9 @@ const char *testTypeName(MicroBenchmarkTestType mode) {
     switch((mode)) {
         TO_STRING_TYPE(ADDRESS);
         TO_STRING_TYPE(UID);
+        TO_STRING_TYPE(UID_EXCLUDE_LOOPBACK);
         TO_STRING_TYPE(UIDRANGE);
+        TO_STRING_TYPE(UIDRANGE_EXCLUDE_LOOPBACK);
     }
 #undef TO_STRING_TYPE
 }
@@ -214,7 +315,9 @@ protected:
         case ADDRESS:
             return ADDRESS_SOCKETS;
         case UID:
+        case UID_EXCLUDE_LOOPBACK:
         case UIDRANGE:
+        case UIDRANGE_EXCLUDE_LOOPBACK:
             return UID_SOCKETS;
         }
     }
@@ -228,16 +331,21 @@ protected:
                 EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
                 break;
             case UID:
-                ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+            case UID_EXCLUDE_LOOPBACK: {
+                bool excludeLoopback = (mode == UID_EXCLUDE_LOOPBACK);
+                ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID, excludeLoopback);
                 EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
                         strerror(-ret);
                 break;
-            case UIDRANGE: {
+            }
+            case UIDRANGE:
+            case UIDRANGE_EXCLUDE_LOOPBACK: {
+                bool excludeLoopback = (mode == UIDRANGE_EXCLUDE_LOOPBACK);
                 const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
                 std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
                 UidRanges uidRanges;
                 uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
-                ret = mSd.destroySockets(uidRanges, skipUids);
+                ret = mSd.destroySockets(uidRanges, skipUids, excludeLoopback);
             }
         }
         return ret;
@@ -262,6 +370,9 @@ protected:
                 }
                 return false;
             }
+            case UID_EXCLUDE_LOOPBACK:
+            case UIDRANGE_EXCLUDE_LOOPBACK:
+                return false;
         }
     }
 
@@ -341,7 +452,9 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
     fprintf(stderr, "   Verifying: %6.1f ms (%d sockets destroyed)\n",
             std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count(),
             socketsClosed);
-    EXPECT_GT(socketsClosed, 0);  // Just in case there's a bug in the test.
+    if (strstr(testTypeName(mode), "_EXCLUDE_LOOPBACK") == nullptr) {
+        EXPECT_GT(socketsClosed, 0);  // Just in case there's a bug in the test.
+    }
 
     start = std::chrono::steady_clock::now();
     for (int i = 0; i < numSockets; i++) {
@@ -358,4 +471,5 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
 constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
 
 INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
-                        testing::Values(ADDRESS, UID, UIDRANGE));
+                        testing::Values(ADDRESS, UID, UIDRANGE,
+                                        UID_EXCLUDE_LOOPBACK, UIDRANGE_EXCLUDE_LOOPBACK));