OSDN Git Service

Kill sockets when a VPN comes up.
authorLorenzo Colitti <lorenzo@google.com>
Wed, 13 Apr 2016 15:56:01 +0000 (00:56 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Thu, 14 Apr 2016 07:56:11 +0000 (16:56 +0900)
1. Change the SockDiag callback function to be a filter that
   returns a bool instead of a function that optionally kills a
   socket. All existing callbacks basically only existed to kill
   sockets under certain conditions, and making them return a
   boolean allows reusing the same callback function signature
   to filter sockets as well.
2. Add a new SockDiag method to kill sockets based on a UidRanges
   object (which contains a number of UID ranges) and a list of
   users to skip.
3. Add a new UIDRANGE mode to SockDiagTest to test the above.
4. When UID ranges are added or removed from the VPN, kill
   sockets in those UID ranges unless the socket UIDs are in
   mProtectableUsers and thus their creator might have set the
   protect bit on their mark.  Short of actually being
   able to see the socket mark on each socket and basing our
   decision on that, this is the best we can do.

Bug: 26976388
Change-Id: I53a30df3feb63254a6451a29fa6041c9b679f9bb

server/Android.mk
server/NetworkController.cpp
server/SockDiag.cpp
server/SockDiag.h
server/SockDiagTest.cpp
server/UidRanges.cpp
server/UidRanges.h
server/VirtualNetwork.cpp
server/VirtualNetwork.h

index 1135c5a..f09355a 100644 (file)
@@ -129,7 +129,9 @@ LOCAL_SRC_FILES := \
         BandwidthController.cpp BandwidthControllerTest.cpp \
         FirewallControllerTest.cpp FirewallController.cpp \
         SockDiagTest.cpp SockDiag.cpp \
-        StrictController.cpp StrictControllerTest.cpp
+        StrictController.cpp StrictControllerTest.cpp \
+        UidRanges.cpp \
+
 LOCAL_MODULE_TAGS := tests
 LOCAL_SHARED_LIBRARIES := liblog libbase libcutils liblogwrap
 include $(BUILD_NATIVE_TEST)
index 8ae0324..7c2a826 100644 (file)
@@ -459,7 +459,7 @@ int NetworkController::addUsersToNetwork(unsigned netId, const UidRanges& uidRan
         ALOGE("cannot add users to non-virtual network with netId %u", netId);
         return -EINVAL;
     }
-    if (int ret = static_cast<VirtualNetwork*>(network)->addUsers(uidRanges)) {
+    if (int ret = static_cast<VirtualNetwork*>(network)->addUsers(uidRanges, mProtectableUsers)) {
         return ret;
     }
     return 0;
@@ -476,7 +476,8 @@ int NetworkController::removeUsersFromNetwork(unsigned netId, const UidRanges& u
         ALOGE("cannot remove users from non-virtual network with netId %u", netId);
         return -EINVAL;
     }
-    if (int ret = static_cast<VirtualNetwork*>(network)->removeUsers(uidRanges)) {
+    if (int ret = static_cast<VirtualNetwork*>(network)->removeUsers(uidRanges,
+                                                                     mProtectableUsers)) {
         return ret;
     }
     return 0;
index 57ba19c..6a39997 100644 (file)
@@ -28,6 +28,7 @@
 
 #define LOG_TAG "Netd"
 
+#include <android-base/strings.h>
 #include <cutils/log.h>
 
 #include "NetdConstants.h"
@@ -239,7 +240,9 @@ int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
               }
               default:
                 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
-                callback(proto, msg);
+                if (callback(proto, msg)) {
+                    sockDestroy(proto, msg);
+                }
             }
         }
     } while (bytesread > 0);
@@ -284,11 +287,9 @@ int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
         return ret;
     }
 
-    auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
-        return this->sockDestroy(proto, msg);
-    };
+    auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
 
-    return readDiagMsg(proto, destroy);
+    return readDiagMsg(proto, destroyAll);
 }
 
 int SockDiag::destroySockets(const char *addrstr) {
@@ -313,16 +314,31 @@ int SockDiag::destroySockets(const char *addrstr) {
     return mSocketsDestroyed;
 }
 
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+    int proto = IPPROTO_TCP;
+
+    for (const int family : {AF_INET, AF_INET6}) {
+        const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
+        uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+        if (int ret = sendDumpRequest(proto, family, states)) {
+            ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+        if (int ret = readDiagMsg(proto, destroyFilter)) {
+            ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+            return ret;
+        }
+    }
+
+    return 0;
+}
+
 int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
     mSocketsDestroyed = 0;
     Stopwatch s;
 
-    auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
-        if (msg != nullptr && msg->idiag_uid == uid) {
-            return this->sockDestroy(proto, msg);
-        } else {
-            return 0;
-        }
+    auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
+        return (msg != nullptr && msg->idiag_uid == uid);
     };
 
     for (const int family : {AF_INET, AF_INET6}) {
@@ -332,7 +348,7 @@ int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
             return ret;
         }
-        if (int ret = readDiagMsg(proto, destroy)) {
+        if (int ret = readDiagMsg(proto, shouldDestroy)) {
             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
             return ret;
         }
@@ -344,3 +360,32 @@ int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
 
     return 0;
 }
+
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+    mSocketsDestroyed = 0;
+    Stopwatch s;
+
+    auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+        return msg != nullptr &&
+               uidRanges.hasUid(msg->idiag_uid) &&
+               skipUids.find(msg->idiag_uid) == skipUids.end();
+    };
+
+    if (int ret = destroyLiveSockets(shouldDestroy)) {
+        return ret;
+    }
+
+    std::vector<uid_t> skipUidStrings;
+    for (uid_t uid : skipUids) {
+        skipUidStrings.push_back(uid);
+    }
+    std::sort(skipUidStrings.begin(), skipUidStrings.end());
+
+    if (mSocketsDestroyed > 0) {
+        ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
+              mSocketsDestroyed, uidRanges.toString().c_str(),
+              android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
+    }
+
+    return 0;
+}
index 059a11c..6a4c703 100644 (file)
  * limitations under the License.
  */
 
-#include <functional>
+#include <unistd.h>
+#include <sys/socket.h>
 
 #include <linux/netlink.h>
 #include <linux/sock_diag.h>
 #include <linux/inet_diag.h>
 
+#include <functional>
+#include <set>
+
+#include "UidRanges.h"
+
 struct inet_diag_msg;
 class SockDiagTest;
 
@@ -27,7 +33,10 @@ class SockDiag {
 
   public:
     static const int kBufferSize = 4096;
-    typedef std::function<int(uint8_t proto, const inet_diag_msg *)> DumpCallback;
+
+    // Callback function that is called once for every socket in the dump. A return value of true
+    // means destroy the socket.
+    typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DumpCallback;
 
     struct DestroyRequest {
         nlmsghdr nlh;
@@ -44,6 +53,7 @@ class SockDiag {
     int sockDestroy(uint8_t proto, const inet_diag_msg *);
     int destroySockets(const char *addrstr);
     int destroySockets(uint8_t proto, uid_t uid);
+    int destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids);
 
   private:
     int mSock;
@@ -51,6 +61,7 @@ class SockDiag {
     int mSocketsDestroyed;
     int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
     int destroySockets(uint8_t proto, int family, const char *addrstr);
+    int destroyLiveSockets(DumpCallback destroy);
     bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
     void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
 };
index 6425c67..2061a3b 100644 (file)
 
 #include "NetdConstants.h"
 #include "SockDiag.h"
-
-
-#define NUM_SOCKETS 500
-#define START_UID 8000  // START_UID + NUM_SOCKETS must be <= 9999.
-#define CLOSE_UID (START_UID + NUM_SOCKETS - 42) // Close to the end
-
+#include "UidRanges.h"
 
 class SockDiagTest : public ::testing::Test {
 };
@@ -104,7 +99,7 @@ TEST_F(SockDiagTest, TestDump) {
         if (msg == nullptr) {
             EXPECT_FALSE(seenNull);
             seenNull = true;
-            return 0;
+            return false;
         }
         EXPECT_EQ(htonl(INADDR_LOOPBACK), msg->id.idiag_src[0]);
         v4SocketsSeen++;
@@ -115,7 +110,7 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
-        return 0;
+        return false;
     };
 
     int v6SocketsSeen = 0;
@@ -125,7 +120,7 @@ TEST_F(SockDiagTest, TestDump) {
         if (msg == nullptr) {
             EXPECT_FALSE(seenNull);
             seenNull = true;
-            return 0;
+            return false;
         }
         struct in6_addr *saddr = (struct in6_addr *) msg->id.idiag_src;
         EXPECT_TRUE(
@@ -141,7 +136,7 @@ TEST_F(SockDiagTest, TestDump) {
                 src, htons(msg->id.idiag_sport),
                 dst, htons(msg->id.idiag_dport),
                 tcpStateName(msg->idiag_state));
-        return 0;
+        return false;
     };
 
     SockDiag sd;
@@ -183,6 +178,7 @@ TEST_F(SockDiagTest, TestDump) {
 enum MicroBenchmarkTestType {
     ADDRESS,
     UID,
+    UIDRANGE,
 };
 
 const char *testTypeName(MicroBenchmarkTestType mode) {
@@ -190,6 +186,7 @@ const char *testTypeName(MicroBenchmarkTestType mode) {
     switch((mode)) {
         TO_STRING_TYPE(ADDRESS);
         TO_STRING_TYPE(UID);
+        TO_STRING_TYPE(UIDRANGE);
     }
 #undef TO_STRING_TYPE
 }
@@ -204,16 +201,44 @@ public:
 protected:
     SockDiag mSd;
 
+    constexpr static int MAX_SOCKETS = 500;
+    constexpr static int ADDRESS_SOCKETS = 500;
+    constexpr static int UID_SOCKETS = 100;
+    constexpr static uid_t START_UID = 8000;  // START_UID + number of sockets must be <= 9999.
+    constexpr static int CLOSE_UID = START_UID + UID_SOCKETS - 42;  // Close to the end
+    static_assert(START_UID + MAX_SOCKETS < 9999, "Too many sockets");
+
+    int howManySockets() {
+        MicroBenchmarkTestType mode = GetParam();
+        switch (mode) {
+        case ADDRESS:
+            return 500;
+        case UID:
+        case UIDRANGE:
+            return 50;
+        }
+    }
+
     int destroySockets() {
         MicroBenchmarkTestType mode = GetParam();
         int ret;
-        if (mode == ADDRESS) {
-            ret = mSd.destroySockets("::1");
-            EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
-        } else {
-            ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
-            EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
-                    strerror(-ret);
+        switch (mode) {
+            case ADDRESS:
+                ret = mSd.destroySockets("::1");
+                EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+                break;
+            case UID:
+                ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+                EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
+                        strerror(-ret);
+                break;
+            case UIDRANGE: {
+                const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
+                std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
+                UidRanges uidRanges;
+                uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
+                ret = mSd.destroySockets(uidRanges, skipUids);
+            }
         }
         return ret;
     }
@@ -221,10 +246,22 @@ protected:
     bool shouldHaveClosedSocket(int i) {
         MicroBenchmarkTestType mode = GetParam();
         switch (mode) {
-        case ADDRESS:
-            return true;
-        case UID:
-            return i == CLOSE_UID - START_UID;
+            case ADDRESS:
+                return true;
+            case UID:
+                return i == CLOSE_UID - START_UID;
+            case UIDRANGE: {
+                uid_t uid = i + START_UID;
+                // Skip UIDs in skipUids.
+                if (uid == 8007 || uid == 8043 || uid == 8098 || uid == 8099) {
+                    return false;
+                }
+                // Include UIDs in uidRanges.
+                if ((8005 <= uid && uid <= 8012) || uid == 8042 || (8090 <= uid && uid <= 8099)) {
+                    return true;
+                }
+                return false;
+            }
         }
     }
 
@@ -251,8 +288,10 @@ protected:
 TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
     MicroBenchmarkTestType mode = GetParam();
 
+    int numSockets = howManySockets();
+
     fprintf(stderr, "Benchmarking closing %d sockets based on %s\n",
-            NUM_SOCKETS, testTypeName(mode));
+            numSockets, testTypeName(mode));
 
     int listensocket = socket(AF_INET6, SOCK_STREAM, 0);
     ASSERT_NE(-1, listensocket) << "Failed to open listen socket";
@@ -263,13 +302,13 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
 
     using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
 
-    int clientsockets[NUM_SOCKETS], serversockets[NUM_SOCKETS];
-    uint16_t clientports[NUM_SOCKETS];
+    int clientsockets[MAX_SOCKETS], serversockets[MAX_SOCKETS];
+    uint16_t clientports[MAX_SOCKETS];
     sockaddr_in6 client;
     socklen_t clientlen;
 
     auto start = std::chrono::steady_clock::now();
-    for (int i = 0; i < NUM_SOCKETS; i++) {
+    for (int i = 0; i < numSockets; i++) {
         int s = socket(AF_INET6, SOCK_STREAM, 0);
         uid_t uid = START_UID + i;
         ASSERT_EQ(0, fchown(s, uid, -1));
@@ -291,7 +330,7 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
             std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
 
     start = std::chrono::steady_clock::now();
-    for (int i = 0; i < NUM_SOCKETS; i++) {
+    for (int i = 0; i < numSockets; i++) {
         checkSocketState(i, clientsockets[i], "Client socket");
         checkSocketState(i, serversockets[i], "Server socket");
     }
@@ -299,7 +338,7 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
             std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
 
     start = std::chrono::steady_clock::now();
-    for (int i = 0; i < NUM_SOCKETS; i++) {
+    for (int i = 0; i < numSockets; i++) {
         close(clientsockets[i]);
         close(serversockets[i]);
     }
@@ -309,4 +348,8 @@ TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
     close(listensocket);
 }
 
-INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest, testing::Values(ADDRESS, UID));
+// "SockDiagTest.cpp:232: error: undefined reference to 'SockDiagMicroBenchmarkTest::CLOSE_UID'".
+constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
+
+INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
+                        testing::Values(ADDRESS, UID, UIDRANGE));
index 10e445a..64c1b45 100644 (file)
 
 #include <stdlib.h>
 
+#include <android-base/stringprintf.h>
+
+using android::base::StringAppendF;
+
 bool UidRanges::hasUid(uid_t uid) const {
     auto iter = std::lower_bound(mRanges.begin(), mRanges.end(), Range(uid, uid));
     return (iter != mRanges.end() && iter->first == uid) ||
@@ -81,3 +85,16 @@ void UidRanges::remove(const UidRanges& other) {
                                    other.mRanges.end(), mRanges.begin());
     mRanges.erase(end, mRanges.end());
 }
+
+std::string UidRanges::toString() const {
+    std::string s("UidRanges{ ");
+    for (Range range : mRanges) {
+        if (range.first != range.second) {
+            StringAppendF(&s, "%u-%u ", range.first, range.second);
+        } else {
+            StringAppendF(&s, "%u ", range.first);
+        }
+    }
+    StringAppendF(&s, "}");
+    return s;
+}
index 044a8f9..2a39953 100644 (file)
@@ -29,6 +29,7 @@ public:
     const std::vector<Range>& getRanges() const;
 
     bool parseFrom(int argc, char* argv[]);
+    std::string toString() const;
 
     void add(const UidRanges& other);
     void remove(const UidRanges& other);
index 5db3645..3d83703 100644 (file)
  * limitations under the License.
  */
 
+#include <set>
 #include "VirtualNetwork.h"
 
+#include "SockDiag.h"
 #include "RouteController.h"
 
 #define LOG_TAG "Netd"
@@ -40,7 +42,30 @@ bool VirtualNetwork::appliesToUser(uid_t uid) const {
     return mUidRanges.hasUid(uid);
 }
 
-int VirtualNetwork::addUsers(const UidRanges& uidRanges) {
+
+int VirtualNetwork::maybeCloseSockets(bool add, const UidRanges& uidRanges,
+                                      const std::set<uid_t>& protectableUsers) {
+    if (!mSecure) {
+        return 0;
+    }
+
+    SockDiag sd;
+    if (!sd.open()) {
+        return -EBADFD;
+    }
+
+    if (int ret = sd.destroySockets(uidRanges, protectableUsers)) {
+        ALOGE("Failed to close sockets while %s %s to network %d: %s",
+              add ? "adding" : "removing", uidRanges.toString().c_str(), mNetId, strerror(-ret));
+        return ret;
+    }
+
+    return 0;
+}
+
+int VirtualNetwork::addUsers(const UidRanges& uidRanges, const std::set<uid_t>& protectableUsers) {
+    maybeCloseSockets(true, uidRanges, protectableUsers);
+
     for (const std::string& interface : mInterfaces) {
         if (int ret = RouteController::addUsersToVirtualNetwork(mNetId, interface.c_str(), mSecure,
                                                                 uidRanges)) {
@@ -52,7 +77,10 @@ int VirtualNetwork::addUsers(const UidRanges& uidRanges) {
     return 0;
 }
 
-int VirtualNetwork::removeUsers(const UidRanges& uidRanges) {
+int VirtualNetwork::removeUsers(const UidRanges& uidRanges,
+                                const std::set<uid_t>& protectableUsers) {
+    maybeCloseSockets(false, uidRanges, protectableUsers);
+
     for (const std::string& interface : mInterfaces) {
         if (int ret = RouteController::removeUsersFromVirtualNetwork(mNetId, interface.c_str(),
                                                                      mSecure, uidRanges)) {
index d315f97..1a6a136 100644 (file)
@@ -17,6 +17,8 @@
 #ifndef NETD_SERVER_VIRTUAL_NETWORK_H
 #define NETD_SERVER_VIRTUAL_NETWORK_H
 
+#include <set>
+
 #include "Network.h"
 #include "UidRanges.h"
 
@@ -36,13 +38,17 @@ public:
     bool isSecure() const;
     bool appliesToUser(uid_t uid) const;
 
-    int addUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
-    int removeUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
+    int addUsers(const UidRanges& uidRanges,
+                 const std::set<uid_t>& protectableUsers) WARN_UNUSED_RESULT;
+    int removeUsers(const UidRanges& uidRanges,
+                    const std::set<uid_t>& protectableUsers) WARN_UNUSED_RESULT;
 
 private:
     Type getType() const override;
     int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
     int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
+    int maybeCloseSockets(bool add, const UidRanges& uidRanges,
+                          const std::set<uid_t>& protectableUsers);
 
     const bool mHasDns;
     const bool mSecure;