OSDN Git Service

Support destroying sockets for UIDs.
authorLorenzo Colitti <lorenzo@google.com>
Thu, 24 Mar 2016 07:47:12 +0000 (16:47 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Fri, 25 Mar 2016 05:17:11 +0000 (14:17 +0900)
Bug: 27824851
Change-Id: Iab5ebfd1c3d463d60d3dbd3a271737c8bc824298

server/SockDiag.cpp
server/SockDiag.h
server/SockDiagTest.cpp

index b9f69cd..57ba19c 100644 (file)
@@ -47,6 +47,20 @@ struct AddrinfoDeleter {
 
 typedef std::unique_ptr<addrinfo, AddrinfoDeleter> ScopedAddrinfo;
 
+class Stopwatch {
+public:
+    Stopwatch(): mStart(std::chrono::steady_clock::now()) {}
+    float timeTaken() {
+        using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
+        return (std::chrono::duration_cast<ms>(
+                std::chrono::steady_clock::now() - mStart)).count();
+    }
+
+private:
+    std::chrono::time_point<std::chrono::steady_clock> mStart;
+    std::string mName;
+};
+
 int checkError(int fd) {
     struct {
         nlmsghdr h;
@@ -90,6 +104,45 @@ bool SockDiag::open() {
     return true;
 }
 
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states,
+                              iovec *iov, int iovcnt) {
+    struct {
+        nlmsghdr nlh;
+        inet_diag_req_v2 req;
+    } __attribute__((__packed__)) request = {
+        .nlh = {
+            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
+            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
+        },
+        .req = {
+            .sdiag_family = family,
+            .sdiag_protocol = proto,
+            .idiag_states = states,
+        },
+    };
+
+    size_t len = 0;
+    iov[0].iov_base = &request;
+    iov[0].iov_len = sizeof(request);
+    for (int i = 0; i < iovcnt; i++) {
+        len += iov[i].iov_len;
+    }
+    request.nlh.nlmsg_len = len;
+
+    if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
+        return -errno;
+    }
+
+    return checkError(mSock);
+}
+
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
+    iovec iov[] = {
+        { nullptr, 0 },
+    };
+    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
+}
+
 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
     addrinfo *res;
@@ -127,24 +180,12 @@ int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr
     uint8_t prefixlen = addrlen * 8;
     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
     uint8_t nojump = yesjump + 4;
-    uint32_t states = ~(1 << TCP_TIME_WAIT);
 
     struct {
-        nlmsghdr nlh;
-        inet_diag_req_v2 req;
         nlattr nla;
         inet_diag_bc_op op;
         inet_diag_hostcond cond;
-    } __attribute__((__packed__)) request = {
-        .nlh = {
-            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
-            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
-        },
-        .req = {
-            .sdiag_family = family,
-            .sdiag_protocol = proto,
-            .idiag_states = states,
-        },
+    } __attribute__((__packed__)) attrs = {
         .nla = {
             .nla_type = INET_DIAG_REQ_BYTECODE,
         },
@@ -161,19 +202,16 @@ int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr
         },
     };
 
-    request.nlh.nlmsg_len = sizeof(request) + addrlen;
-    request.nla.nla_len = sizeof(request.nla) + sizeof(request.op) + sizeof(request.cond) + addrlen;
+    attrs.nla.nla_len = sizeof(attrs) + addrlen;
 
-    struct iovec iov[] = {
-        { &request, sizeof(request) },
+    iovec iov[] = {
+        { nullptr, 0 },
+        { &attrs, sizeof(attrs) },
         { addr, addrlen },
     };
 
-    if (writev(mSock, iov, ARRAY_SIZE(iov)) != (int) request.nlh.nlmsg_len) {
-        return -errno;
-    }
-
-    return checkError(mSock);
+    uint32_t states = ~(1 << TCP_TIME_WAIT);
+    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
 }
 
 int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
@@ -254,10 +292,9 @@ int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
 }
 
 int SockDiag::destroySockets(const char *addrstr) {
-    using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
-
+    Stopwatch s;
     mSocketsDestroyed = 0;
-    const auto start = std::chrono::steady_clock::now();
+
     if (!strchr(addrstr, ':')) {
         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
@@ -268,11 +305,42 @@ int SockDiag::destroySockets(const char *addrstr) {
         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
         return ret;
     }
-    auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
 
     if (mSocketsDestroyed > 0) {
-        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
     }
 
     return mSocketsDestroyed;
 }
+
+int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
+        if (msg != nullptr && msg->idiag_uid == uid) {
+            return this->sockDestroy(proto, msg);
+        } else {
+            return 0;
+        }
+    };
+
+    for (const int family : {AF_INET, AF_INET6}) {
+        const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
+        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+        if (int ret = sendDumpRequest(proto, family, states)) {
+            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+        if (int ret = readDiagMsg(proto, destroy)) {
+            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+    }
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
+    }
+
+    return 0;
+}
index 56acbdb..059a11c 100644 (file)
@@ -1,3 +1,19 @@
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 #include <functional>
 
 #include <linux/netlink.h>
@@ -22,15 +38,18 @@ class SockDiag {
     bool open();
     virtual ~SockDiag() { closeSocks(); }
 
+    int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
     int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
     int readDiagMsg(uint8_t proto, DumpCallback callback);
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
     int destroySockets(const char *addrstr);
+    int destroySockets(uint8_t proto, uid_t uid);
 
   private:
     int mSock;
     int mWriteSock;
     int mSocketsDestroyed;
+    int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
     int destroySockets(uint8_t proto, int family, const char *addrstr);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
index 70e7bcf..6425c67 100644 (file)
@@ -27,6 +27,8 @@
 
 
 #define NUM_SOCKETS 500
+#define START_UID 8000  // START_UID + NUM_SOCKETS must be <= 9999.
+#define CLOSE_UID (START_UID + NUM_SOCKETS - 42) // Close to the end
 
 
 class SockDiagTest : public ::testing::Test {
@@ -178,8 +180,21 @@ TEST_F(SockDiagTest, TestDump) {
     close(accepted6);
 }
 
+enum MicroBenchmarkTestType {
+    ADDRESS,
+    UID,
+};
 
-class SockDiagMicroBenchmarkTest : public ::testing::Test {
+const char *testTypeName(MicroBenchmarkTestType mode) {
+#define TO_STRING_TYPE(x) case ((x)): return #x;
+    switch((mode)) {
+        TO_STRING_TYPE(ADDRESS);
+        TO_STRING_TYPE(UID);
+    }
+#undef TO_STRING_TYPE
+}
+
+class SockDiagMicroBenchmarkTest : public ::testing::TestWithParam<MicroBenchmarkTestType> {
 
 public:
     void SetUp() {
@@ -190,13 +205,27 @@ protected:
     SockDiag mSd;
 
     int destroySockets() {
-        const int ret = mSd.destroySockets("::1");
-        EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+        MicroBenchmarkTestType mode = GetParam();
+        int ret;
+        if (mode == ADDRESS) {
+            ret = mSd.destroySockets("::1");
+            EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+        } else {
+            ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+            EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
+                    strerror(-ret);
+        }
         return ret;
     }
 
-    bool shouldHaveClosedSocket(int) {
-        return true;
+    bool shouldHaveClosedSocket(int i) {
+        MicroBenchmarkTestType mode = GetParam();
+        switch (mode) {
+        case ADDRESS:
+            return true;
+        case UID:
+            return i == CLOSE_UID - START_UID;
+        }
     }
 
     void checkSocketState(int i, int sock, const char *msg) {
@@ -219,8 +248,11 @@ protected:
     }
 };
 
-TEST_F(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
-    fprintf(stderr, "Benchmarking closing %d sockets\n", NUM_SOCKETS);
+TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
+    MicroBenchmarkTestType mode = GetParam();
+
+    fprintf(stderr, "Benchmarking closing %d sockets based on %s\n",
+            NUM_SOCKETS, testTypeName(mode));
 
     int listensocket = socket(AF_INET6, SOCK_STREAM, 0);
     ASSERT_NE(-1, listensocket) << "Failed to open listen socket";
@@ -239,6 +271,8 @@ TEST_F(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
     auto start = std::chrono::steady_clock::now();
     for (int i = 0; i < NUM_SOCKETS; i++) {
         int s = socket(AF_INET6, SOCK_STREAM, 0);
+        uid_t uid = START_UID + i;
+        ASSERT_EQ(0, fchown(s, uid, -1));
         clientlen = sizeof(client);
         ASSERT_EQ(0, connect(s, (sockaddr *) &server, sizeof(server)))
             << "Connecting socket " << i << " failed " << strerror(errno);
@@ -274,3 +308,5 @@ TEST_F(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
 
     close(listensocket);
 }
+
+INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest, testing::Values(ADDRESS, UID));