OSDN Git Service

Listen to xt_quota2 kobject event for quota
[android-x86/system-netd.git] / server / SockDiag.cpp
index b9f69cd..48b8eae 100644 (file)
@@ -28,6 +28,7 @@
 
 #define LOG_TAG "Netd"
 
+#include <android-base/strings.h>
 #include <cutils/log.h>
 
 #include "NetdConstants.h"
@@ -90,6 +91,45 @@ bool SockDiag::open() {
     return true;
 }
 
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states,
+                              iovec *iov, int iovcnt) {
+    struct {
+        nlmsghdr nlh;
+        inet_diag_req_v2 req;
+    } __attribute__((__packed__)) request = {
+        .nlh = {
+            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
+            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
+        },
+        .req = {
+            .sdiag_family = family,
+            .sdiag_protocol = proto,
+            .idiag_states = states,
+        },
+    };
+
+    size_t len = 0;
+    iov[0].iov_base = &request;
+    iov[0].iov_len = sizeof(request);
+    for (int i = 0; i < iovcnt; i++) {
+        len += iov[i].iov_len;
+    }
+    request.nlh.nlmsg_len = len;
+
+    if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
+        return -errno;
+    }
+
+    return checkError(mSock);
+}
+
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
+    iovec iov[] = {
+        { nullptr, 0 },
+    };
+    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
+}
+
 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
     addrinfo *res;
@@ -127,24 +167,12 @@ int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr
     uint8_t prefixlen = addrlen * 8;
     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
     uint8_t nojump = yesjump + 4;
-    uint32_t states = ~(1 << TCP_TIME_WAIT);
 
     struct {
-        nlmsghdr nlh;
-        inet_diag_req_v2 req;
         nlattr nla;
         inet_diag_bc_op op;
         inet_diag_hostcond cond;
-    } __attribute__((__packed__)) request = {
-        .nlh = {
-            .nlmsg_type = SOCK_DIAG_BY_FAMILY,
-            .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
-        },
-        .req = {
-            .sdiag_family = family,
-            .sdiag_protocol = proto,
-            .idiag_states = states,
-        },
+    } __attribute__((__packed__)) attrs = {
         .nla = {
             .nla_type = INET_DIAG_REQ_BYTECODE,
         },
@@ -161,19 +189,16 @@ int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr
         },
     };
 
-    request.nlh.nlmsg_len = sizeof(request) + addrlen;
-    request.nla.nla_len = sizeof(request.nla) + sizeof(request.op) + sizeof(request.cond) + addrlen;
+    attrs.nla.nla_len = sizeof(attrs) + addrlen;
 
-    struct iovec iov[] = {
-        { &request, sizeof(request) },
+    iovec iov[] = {
+        { nullptr, 0 },
+        { &attrs, sizeof(attrs) },
         { addr, addrlen },
     };
 
-    if (writev(mSock, iov, ARRAY_SIZE(iov)) != (int) request.nlh.nlmsg_len) {
-        return -errno;
-    }
-
-    return checkError(mSock);
+    uint32_t states = ~(1 << TCP_TIME_WAIT);
+    return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
 }
 
 int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
@@ -201,7 +226,9 @@ int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
               }
               default:
                 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
-                callback(proto, msg);
+                if (callback(proto, msg)) {
+                    sockDestroy(proto, msg);
+                }
             }
         }
     } while (bytesread > 0);
@@ -246,18 +273,15 @@ int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
         return ret;
     }
 
-    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
-        return this->sockDestroy(proto, msg);
-    };
+    auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
 
-    return readDiagMsg(proto, destroy);
+    return readDiagMsg(proto, destroyAll);
 }
 
 int SockDiag::destroySockets(const char *addrstr) {
-    using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
-
+    Stopwatch s;
     mSocketsDestroyed = 0;
-    const auto start = std::chrono::steady_clock::now();
+
     if (!strchr(addrstr, ':')) {
         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
@@ -268,11 +292,86 @@ int SockDiag::destroySockets(const char *addrstr) {
         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
         return ret;
     }
-    auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
 
     if (mSocketsDestroyed > 0) {
-        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+        ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
     }
 
     return mSocketsDestroyed;
 }
+
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+    int proto = IPPROTO_TCP;
+
+    for (const int family : {AF_INET, AF_INET6}) {
+        const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
+        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+        if (int ret = sendDumpRequest(proto, family, states)) {
+            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+        if (int ret = readDiagMsg(proto, destroyFilter)) {
+            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+    }
+
+    return 0;
+}
+
+int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
+        return (msg != nullptr && msg->idiag_uid == uid);
+    };
+
+    for (const int family : {AF_INET, AF_INET6}) {
+        const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
+        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+        if (int ret = sendDumpRequest(proto, family, states)) {
+            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+        if (int ret = readDiagMsg(proto, shouldDestroy)) {
+            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+    }
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
+    }
+
+    return 0;
+}
+
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr &&
+               uidRanges.hasUid(msg->idiag_uid) &&
+               skipUids.find(msg->idiag_uid) == skipUids.end();
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy)) {
+        return ret;
+    }
+
+    std::vector<uid_t> skipUidStrings;
+    for (uid_t uid : skipUids) {
+        skipUidStrings.push_back(uid);
+    }
+    std::sort(skipUidStrings.begin(), skipUidStrings.end());
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
+              mSocketsDestroyed, uidRanges.toString().c_str(),
+              android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
+    }
+
+    return 0;
+}