OSDN Git Service

Add full support for UIDs in VPNs.
authorSreeram Ramachandran <sreeram@google.com>
Sun, 6 Jul 2014 00:15:14 +0000 (17:15 -0700)
committerSreeram Ramachandran <sreeram@google.com>
Mon, 7 Jul 2014 23:20:18 +0000 (16:20 -0700)
Major:
+ Implement the functions mentioned in http://go/android-multinetwork-routing
  correctly, including handling accept(), connect(), setNetworkForSocket()
  and protect() and supporting functions like canUserSelectNetwork().
+ Eliminate the old code path of getting/setting UID ranges through
  SecondaryTableController (which is currently unused) and mUidMap.

Minor:
+ Rename some methods/variables for clarity and consistency.
+ Moved some methods in .cpp files to match declaration order in the .h files.

Bug: 15409918
Change-Id: Ic6ce3646c58cf645db0d9a53cbeefdd7ffafff93

19 files changed:
server/ClatdController.cpp
server/CommandListener.cpp
server/DnsProxyListener.cpp
server/DnsProxyListener.h
server/FwmarkServer.cpp
server/FwmarkServer.h
server/NatController.cpp
server/Network.cpp
server/Network.h
server/NetworkController.cpp
server/NetworkController.h
server/PhysicalNetwork.h
server/RouteController.cpp
server/SecondaryTableController.cpp
server/SecondaryTableController.h
server/UidRanges.cpp
server/UidRanges.h
server/VirtualNetwork.cpp
server/VirtualNetwork.h

index ca6908c..bcb01ba 100644 (file)
@@ -57,7 +57,7 @@ int ClatdController::startClatd(char *interface) {
 
     if (!pid) {
         // Pass in the interface, a netid to use for DNS lookups, and a fwmark for outgoing packets.
-        unsigned netId = mNetCtrl->getNetworkId(interface);
+        unsigned netId = mNetCtrl->getNetworkForInterface(interface);
         char netIdString[UINT32_STRLEN];
         snprintf(netIdString, sizeof(netIdString), "%u", netId);
 
index df2df46..b7723d4 100644 (file)
@@ -299,7 +299,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli,
         // interface route add/remove iface default/secondary dest    prefix gateway
         // interface fwmark  rule  add/remove    iface
         // interface fwmark  route add/remove    iface        dest    prefix
-        // interface fwmark  uid   add/remove    iface      uid_start uid_end forward_dns
         // interface fwmark exempt add/remove    dest
         // interface fwmark  get     protect
         // interface fwmark  get     mark        uid
@@ -357,33 +356,6 @@ int CommandListener::InterfaceCmd::runCommand(SocketClient *cli,
                             false);
                 }
                 return 0;
-
-            } else if (!strcmp(argv[2], "uid")) {
-                if (argc < 8) {
-                    cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false);
-                    return 0;
-                }
-                if (!strcmp(argv[3], "add")) {
-                    if (!sSecondaryTableCtrl->addUidRule(argv[4], atoi(argv[5]), atoi(argv[6]),
-                            atoi(argv[7]))) {
-                        cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully added",
-                                false);
-                    } else {
-                        cli->sendMsg(ResponseCode::OperationFailed, "Failed to add uid rule", true);
-                    }
-                } else if (!strcmp(argv[3], "remove")) {
-                    if (!sSecondaryTableCtrl->removeUidRule(argv[4],
-                                atoi(argv[5]), atoi(argv[6]))) {
-                        cli->sendMsg(ResponseCode::CommandOkay, "uid rule successfully removed",
-                                false);
-                    } else {
-                        cli->sendMsg(ResponseCode::OperationFailed, "Failed to remove uid rule",
-                                true);
-                    }
-                } else {
-                    cli->sendMsg(ResponseCode::CommandSyntaxError, "Unknown uid cmd", false);
-                }
-                return 0;
             } else if (!strcmp(argv[2], "exempt")) {
                 if (argc < 5) {
                     cli->sendMsg(ResponseCode::CommandSyntaxError, "Missing argument", false);
@@ -1640,17 +1612,18 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc,
     //    0      1       2         3
     // network create <netId> [permission]
     //
-    //    0      1       2     3
-    // network create <netId> vpn
+    //    0      1       2     3     4
+    // network create <netId> vpn <hasDns>
     if (!strcmp(argv[1], "create")) {
         if (argc < 3) {
             return syntaxError(client, "Missing argument");
         }
         // strtoul() returns 0 on errors, which is fine because 0 is an invalid netId.
         unsigned netId = strtoul(argv[2], NULL, 0);
-        if (argc == 4 && !strcmp(argv[3], "vpn")) {
-            if (int ret = sNetCtrl->createVpn(netId)) {
-                return operationError(client, "createVpn() failed", ret);
+        if (argc == 5 && !strcmp(argv[3], "vpn")) {
+            bool hasDns = atoi(argv[4]);
+            if (int ret = sNetCtrl->createVirtualNetwork(netId, hasDns)) {
+                return operationError(client, "createVirtualNetwork() failed", ret);
             }
         } else if (argc > 4) {
             return syntaxError(client, "Unknown trailing argument(s)");
@@ -1662,8 +1635,8 @@ int CommandListener::NetworkCommand::runCommand(SocketClient* client, int argc,
                     return syntaxError(client, "Unknown permission");
                 }
             }
-            if (int ret = sNetCtrl->createNetwork(netId, permission)) {
-                return operationError(client, "createNetwork() failed", ret);
+            if (int ret = sNetCtrl->createPhysicalNetwork(netId, permission)) {
+                return operationError(client, "createPhysicalNetwork() failed", ret);
             }
         }
         return success(client);
index 3fcb5bd..c88e788 100644 (file)
@@ -48,13 +48,11 @@ DnsProxyListener::DnsProxyListener(const NetworkController* netCtrl) :
     registerCmd(new GetHostByNameCmd(this));
 }
 
-uint32_t DnsProxyListener::calcMark(SocketClient *c, unsigned netId) const {
+uint32_t DnsProxyListener::calcMark(unsigned netId) const {
     Fwmark fwmark;
     fwmark.netId = netId;
-    // If netd's UID is forced into a VPN that isn't the intended network,
-    // use VPN protect bit to force it into the desired network.
-    fwmark.protectedFromVpn = mNetCtrl->getNetwork(getuid(), netId, true) != netId;
-    fwmark.permission = mNetCtrl->getPermissionForUser(c->getUid());
+    fwmark.protectedFromVpn = true;
+    fwmark.permission = PERMISSION_SYSTEM;
     return fwmark.intValue;
 }
 
@@ -204,8 +202,8 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
     unsigned netId = strtoul(argv[7], NULL, 10);
     uid_t uid = cli->getUid();
 
-    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
-    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+    netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(netId);
 
     if (ai_flags != -1 || ai_family != -1 ||
         ai_socktype != -1 || ai_protocol != -1) {
@@ -273,8 +271,8 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
         name = strdup(name);
     }
 
-    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
-    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+    netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(netId);
 
     cli->incRef();
     DnsProxyListener::GetHostByNameHandler* handler =
@@ -389,8 +387,8 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
         return -1;
     }
 
-    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
-    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
+    netId = mDnsProxyListener->mNetCtrl->getNetworkForUser(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(netId);
 
     cli->incRef();
     DnsProxyListener::GetHostByAddrHandler* handler =
index f5624e8..5862ac7 100644 (file)
@@ -126,7 +126,7 @@ private:
     };
 
     // Calculate the socket mark to use for a DNS resolution.
-    uint32_t calcMark(SocketClient *c, unsigned netId) const;
+    uint32_t calcMark(unsigned netId) const;
 };
 
 #endif
index e2d2079..3a540bd 100644 (file)
@@ -29,10 +29,10 @@ FwmarkServer::FwmarkServer(NetworkController* networkController) :
 }
 
 bool FwmarkServer::onDataAvailable(SocketClient* client) {
-    int fd = -1;
-    int error = processClient(client, &fd);
-    if (fd >= 0) {
-        close(fd);
+    int socketFd = -1;
+    int error = processClient(client, &socketFd);
+    if (socketFd >= 0) {
+        close(socketFd);
     }
 
     // Always send a response even if there were connection errors or read errors, so that we don't
@@ -45,7 +45,7 @@ bool FwmarkServer::onDataAvailable(SocketClient* client) {
     return false;
 }
 
-int FwmarkServer::processClient(SocketClient* client, int* fd) {
+int FwmarkServer::processClient(SocketClient* client, int* socketFd) {
     FwmarkCommand command;
 
     iovec iov;
@@ -59,7 +59,7 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
 
     union {
         cmsghdr cmh;
-        char cmsg[CMSG_SPACE(sizeof(*fd))];
+        char cmsg[CMSG_SPACE(sizeof(*socketFd))];
     } cmsgu;
 
     memset(cmsgu.cmsg, 0, sizeof(cmsgu.cmsg));
@@ -77,17 +77,17 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
 
     cmsghdr* const cmsgh = CMSG_FIRSTHDR(&message);
     if (cmsgh && cmsgh->cmsg_level == SOL_SOCKET && cmsgh->cmsg_type == SCM_RIGHTS &&
-        cmsgh->cmsg_len == CMSG_LEN(sizeof(*fd))) {
-        memcpy(fd, CMSG_DATA(cmsgh), sizeof(*fd));
+        cmsgh->cmsg_len == CMSG_LEN(sizeof(*socketFd))) {
+        memcpy(socketFd, CMSG_DATA(cmsgh), sizeof(*socketFd));
     }
 
-    if (*fd < 0) {
+    if (*socketFd < 0) {
         return -EBADF;
     }
 
     Fwmark fwmark;
     socklen_t fwmarkLen = sizeof(fwmark.intValue);
-    if (getsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
+    if (getsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue, &fwmarkLen) == -1) {
         return -errno;
     }
 
@@ -114,27 +114,23 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
             fwmark.netId = command.netId;
             if (command.netId == NETID_UNSET) {
                 fwmark.explicitlySelected = false;
-            } else {
+                fwmark.protectedFromVpn = false;
+                permission = PERMISSION_NONE;
+            } else if (mNetworkController->canUserSelectNetwork(client->getUid(), command.netId)) {
                 fwmark.explicitlySelected = true;
-                // If the socket already has the protectedFromVpn bit set, don't reset it, because
-                // non-system apps (e.g.: VpnService) may also protect sockets.
-                if ((permission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) {
-                    fwmark.protectedFromVpn = true;
-                }
-                if (!mNetworkController->isValidNetwork(command.netId)) {
-                    return -ENONET;
-                }
-                if (!mNetworkController->isUserPermittedOnNetwork(client->getUid(),
-                                                                  command.netId)) {
-                    return -EPERM;
-                }
+                fwmark.protectedFromVpn = mNetworkController->canProtect(client->getUid());
+            } else {
+                return -EPERM;
             }
             break;
         }
 
         case FwmarkCommand::PROTECT_FROM_VPN: {
-            // set vpn protect
-            // TODO
+            if (!mNetworkController->canProtect(client->getUid())) {
+                return -EPERM;
+            }
+            fwmark.protectedFromVpn = true;
+            permission = static_cast<Permission>(permission | fwmark.permission);
             break;
         }
 
@@ -146,7 +142,8 @@ int FwmarkServer::processClient(SocketClient* client, int* fd) {
 
     fwmark.permission = permission;
 
-    if (setsockopt(*fd, SOL_SOCKET, SO_MARK, &fwmark.intValue, sizeof(fwmark.intValue)) == -1) {
+    if (setsockopt(*socketFd, SOL_SOCKET, SO_MARK, &fwmark.intValue,
+                   sizeof(fwmark.intValue)) == -1) {
         return -errno;
     }
 
index 54cbc74..12096be 100644 (file)
@@ -17,7 +17,7 @@
 #ifndef NETD_SERVER_FWMARK_SERVER_H
 #define NETD_SERVER_FWMARK_SERVER_H
 
-#include <sysutils/SocketListener.h>
+#include "sysutils/SocketListener.h"
 
 class NetworkController;
 
@@ -30,7 +30,7 @@ private:
     bool onDataAvailable(SocketClient* client);
 
     // Returns 0 on success or a negative errno value on failure.
-    int processClient(SocketClient* client, int* fd);
+    int processClient(SocketClient* client, int* socketFd);
 
     NetworkController* const mNetworkController;
 };
index 44b8b4a..6c066f8 100644 (file)
@@ -135,7 +135,7 @@ int NatController::setDefaults() {
 }
 
 int NatController::routesOp(bool add, const char *intIface, const char *extIface, char **argv, int addrCount) {
-    unsigned netId = mNetCtrl->getNetworkId(extIface);
+    unsigned netId = mNetCtrl->getNetworkForInterface(extIface);
     int ret = 0;
 
     for (int i = 0; i < addrCount; i++) {
index d22f42d..5104de2 100644 (file)
@@ -25,6 +25,10 @@ Network::~Network() {
     }
 }
 
+unsigned Network::getNetId() const {
+    return mNetId;
+}
+
 bool Network::hasInterface(const std::string& interface) const {
     return mInterfaces.find(interface) != mInterfaces.end();
 }
index b10cb17..f72cebb 100644 (file)
@@ -36,6 +36,7 @@ public:
     virtual ~Network();
 
     virtual Type getType() const = 0;
+    unsigned getNetId() const;
 
     bool hasInterface(const std::string& interface) const;
 
index 03c22be..1487b72 100644 (file)
@@ -90,57 +90,17 @@ int NetworkController::setDefaultNetwork(unsigned netId) {
     return 0;
 }
 
-bool NetworkController::setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId,
-                                              bool forwardDns) {
-    if (uidStart > uidEnd || !isValidNetwork(netId)) {
-        errno = EINVAL;
-        return false;
-    }
-
-    android::RWLock::AutoWLock lock(mRWLock);
-    for (UidEntry& entry : mUidMap) {
-        if (entry.uidStart == uidStart && entry.uidEnd == uidEnd && entry.netId == netId) {
-            entry.forwardDns = forwardDns;
-            return true;
-        }
-    }
-
-    mUidMap.push_front(UidEntry(uidStart, uidEnd, netId, forwardDns));
-    return true;
-}
-
-bool NetworkController::clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId) {
-    if (uidStart > uidEnd || !isValidNetwork(netId)) {
-        errno = EINVAL;
-        return false;
-    }
-
-    android::RWLock::AutoWLock lock(mRWLock);
-    for (auto iter = mUidMap.begin(); iter != mUidMap.end(); ++iter) {
-        if (iter->uidStart == uidStart && iter->uidEnd == uidEnd && iter->netId == netId) {
-            mUidMap.erase(iter);
-            return true;
-        }
-    }
-
-    errno = ENOENT;
-    return false;
-}
-
-unsigned NetworkController::getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const {
+unsigned NetworkController::getNetworkForUser(uid_t uid, unsigned requestedNetId,
+                                              bool forDns) const {
     android::RWLock::AutoRLock lock(mRWLock);
-    for (const UidEntry& entry : mUidMap) {
-        if (entry.uidStart <= uid && uid <= entry.uidEnd) {
-            if (forDns && !entry.forwardDns) {
-                break;
-            }
-            return entry.netId;
-        }
+    VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid);
+    if (virtualNetwork && (!forDns || virtualNetwork->getHasDns())) {
+        return virtualNetwork->getNetId();
     }
     return getNetworkLocked(requestedNetId) ? requestedNetId : mDefaultNetId;
 }
 
-unsigned NetworkController::getNetworkId(const char* interface) const {
+unsigned NetworkController::getNetworkForInterface(const char* interface) const {
     android::RWLock::AutoRLock lock(mRWLock);
     for (const auto& entry : mNetworks) {
         if (entry.second->hasInterface(interface)) {
@@ -150,12 +110,7 @@ unsigned NetworkController::getNetworkId(const char* interface) const {
     return NETID_UNSET;
 }
 
-bool NetworkController::isValidNetwork(unsigned netId) const {
-    android::RWLock::AutoRLock lock(mRWLock);
-    return getNetworkLocked(netId);
-}
-
-int NetworkController::createNetwork(unsigned netId, Permission permission) {
+int NetworkController::createPhysicalNetwork(unsigned netId, Permission permission) {
     if (netId < MIN_NET_ID || netId > MAX_NET_ID) {
         ALOGE("invalid netId %u", netId);
         return -EINVAL;
@@ -178,7 +133,7 @@ int NetworkController::createNetwork(unsigned netId, Permission permission) {
     return 0;
 }
 
-int NetworkController::createVpn(unsigned netId) {
+int NetworkController::createVirtualNetwork(unsigned netId, bool hasDns) {
     if (netId < MIN_NET_ID || netId > MAX_NET_ID) {
         ALOGE("invalid netId %u", netId);
         return -EINVAL;
@@ -190,7 +145,7 @@ int NetworkController::createVpn(unsigned netId) {
     }
 
     android::RWLock::AutoWLock lock(mRWLock);
-    mNetworks[netId] = new VirtualNetwork(netId);
+    mNetworks[netId] = new VirtualNetwork(netId, hasDns);
     return 0;
 }
 
@@ -226,7 +181,7 @@ int NetworkController::addInterfaceToNetwork(unsigned netId, const char* interfa
         return -EINVAL;
     }
 
-    unsigned existingNetId = getNetworkId(interface);
+    unsigned existingNetId = getNetworkForInterface(interface);
     if (existingNetId != NETID_UNSET && existingNetId != netId) {
         ALOGE("interface %s already assigned to netId %u", interface, existingNetId);
         return -EBUSY;
@@ -259,18 +214,23 @@ void NetworkController::setPermissionForUsers(Permission permission,
     }
 }
 
-// TODO: Handle VPNs.
-bool NetworkController::isUserPermittedOnNetwork(uid_t uid, unsigned netId) const {
-    if (uid == INVALID_UID || netId == NETID_UNSET) {
-        return false;
-    }
-
+bool NetworkController::canUserSelectNetwork(uid_t uid, unsigned netId) const {
     android::RWLock::AutoRLock lock(mRWLock);
     Network* network = getNetworkLocked(netId);
-    if (!network || network->getType() != Network::PHYSICAL) {
+    if (!network || uid == INVALID_UID) {
         return false;
     }
     Permission userPermission = getPermissionForUserLocked(uid);
+    if ((userPermission & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) {
+        return true;
+    }
+    if (network->getType() == Network::VIRTUAL) {
+        return static_cast<VirtualNetwork*>(network)->appliesToUser(uid);
+    }
+    VirtualNetwork* virtualNetwork = getVirtualNetworkForUserLocked(uid);
+    if (virtualNetwork && mProtectableUsers.find(uid) == mProtectableUsers.end()) {
+        return false;
+    }
     Permission networkPermission = static_cast<PhysicalNetwork*>(network)->getPermission();
     return (userPermission & networkPermission) == networkPermission;
 }
@@ -330,6 +290,12 @@ int NetworkController::removeRoute(unsigned netId, const char* interface, const
     return modifyRoute(netId, interface, destination, nexthop, false, legacy, uid);
 }
 
+bool NetworkController::canProtect(uid_t uid) const {
+    android::RWLock::AutoRLock lock(mRWLock);
+    return ((getPermissionForUserLocked(uid) & PERMISSION_SYSTEM) == PERMISSION_SYSTEM) ||
+           mProtectableUsers.find(uid) != mProtectableUsers.end();
+}
+
 void NetworkController::allowProtect(const std::vector<uid_t>& uids) {
     android::RWLock::AutoWLock lock(mRWLock);
     mProtectableUsers.insert(uids.begin(), uids.end());
@@ -342,11 +308,28 @@ void NetworkController::denyProtect(const std::vector<uid_t>& uids) {
     }
 }
 
+bool NetworkController::isValidNetwork(unsigned netId) const {
+    android::RWLock::AutoRLock lock(mRWLock);
+    return getNetworkLocked(netId);
+}
+
 Network* NetworkController::getNetworkLocked(unsigned netId) const {
     auto iter = mNetworks.find(netId);
     return iter == mNetworks.end() ? NULL : iter->second;
 }
 
+VirtualNetwork* NetworkController::getVirtualNetworkForUserLocked(uid_t uid) const {
+    for (const auto& entry : mNetworks) {
+        if (entry.second->getType() == Network::VIRTUAL) {
+            VirtualNetwork* virtualNetwork = static_cast<VirtualNetwork*>(entry.second);
+            if (virtualNetwork->appliesToUser(uid)) {
+                return virtualNetwork;
+            }
+        }
+    }
+    return NULL;
+}
+
 Permission NetworkController::getPermissionForUserLocked(uid_t uid) const {
     auto iter = mUsers.find(uid);
     if (iter != mUsers.end()) {
@@ -357,7 +340,7 @@ Permission NetworkController::getPermissionForUserLocked(uid_t uid) const {
 
 int NetworkController::modifyRoute(unsigned netId, const char* interface, const char* destination,
                                    const char* nexthop, bool add, bool legacy, uid_t uid) {
-    unsigned existingNetId = getNetworkId(interface);
+    unsigned existingNetId = getNetworkForInterface(interface);
     if (netId == NETID_UNSET || existingNetId != netId) {
         ALOGE("interface %s assigned to netId %u, not %u", interface, existingNetId, netId);
         return -ENOENT;
@@ -377,8 +360,3 @@ int NetworkController::modifyRoute(unsigned netId, const char* interface, const
     return add ? RouteController::addRoute(interface, destination, nexthop, tableType) :
                  RouteController::removeRoute(interface, destination, nexthop, tableType);
 }
-
-NetworkController::UidEntry::UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId,
-                                      bool forwardDns) :
-        uidStart(uidStart), uidEnd(uidEnd), netId(netId), forwardDns(forwardDns) {
-}
index 0418f96..217dfbc 100644 (file)
@@ -30,6 +30,7 @@
 
 class Network;
 class UidRanges;
+class VirtualNetwork;
 
 /*
  * Keeps track of default, per-pid, and per-uid-range network selection, as
@@ -44,19 +45,15 @@ public:
     unsigned getDefaultNetwork() const;
     int setDefaultNetwork(unsigned netId) WARN_UNUSED_RESULT;
 
-    bool setNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns);
-    bool clearNetworkForUidRange(uid_t uidStart, uid_t uidEnd, unsigned netId);
-
     // Order of preference: UID-specific, requestedNetId, default.
     // Specify NETID_UNSET for requestedNetId if the default network is preferred.
     // forDns indicates if we're querying the netId for a DNS request. This avoids sending DNS
     // requests to VPNs without DNS servers.
-    unsigned getNetwork(uid_t uid, unsigned requestedNetId, bool forDns) const;
-    unsigned getNetworkId(const char* interface) const;
-    bool isValidNetwork(unsigned netId) const;
+    unsigned getNetworkForUser(uid_t uid, unsigned requestedNetId, bool forDns) const;
+    unsigned getNetworkForInterface(const char* interface) const;
 
-    int createNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT;
-    int createVpn(unsigned netId) WARN_UNUSED_RESULT;
+    int createPhysicalNetwork(unsigned netId, Permission permission) WARN_UNUSED_RESULT;
+    int createVirtualNetwork(unsigned netId, bool hasDns) WARN_UNUSED_RESULT;
     int destroyNetwork(unsigned netId) WARN_UNUSED_RESULT;
 
     int addInterfaceToNetwork(unsigned netId, const char* interface) WARN_UNUSED_RESULT;
@@ -64,7 +61,7 @@ public:
 
     Permission getPermissionForUser(uid_t uid) const;
     void setPermissionForUsers(Permission permission, const std::vector<uid_t>& uids);
-    bool isUserPermittedOnNetwork(uid_t uid, unsigned netId) const;
+    bool canUserSelectNetwork(uid_t uid, unsigned netId) const;
     int setPermissionForNetworks(Permission permission,
                                  const std::vector<unsigned>& netIds) WARN_UNUSED_RESULT;
 
@@ -78,29 +75,21 @@ public:
     int removeRoute(unsigned netId, const char* interface, const char* destination,
                     const char* nexthop, bool legacy, uid_t uid) WARN_UNUSED_RESULT;
 
+    bool canProtect(uid_t uid) const;
     void allowProtect(const std::vector<uid_t>& uids);
     void denyProtect(const std::vector<uid_t>& uids);
 
 private:
+    bool isValidNetwork(unsigned netId) const;
     Network* getNetworkLocked(unsigned netId) const;
+    VirtualNetwork* getVirtualNetworkForUserLocked(uid_t uid) const;
     Permission getPermissionForUserLocked(uid_t uid) const;
 
     int modifyRoute(unsigned netId, const char* interface, const char* destination,
                     const char* nexthop, bool add, bool legacy, uid_t uid) WARN_UNUSED_RESULT;
 
-    struct UidEntry {
-        const uid_t uidStart;
-        const uid_t uidEnd;
-        const unsigned netId;
-        bool forwardDns;
-
-        UidEntry(uid_t uidStart, uid_t uidEnd, unsigned netId, bool forwardDns);
-    };
-
-    // mRWLock guards all accesses to mUidMap, mDefaultNetId, mNetworks, mUsers and
-    // mProtectableUsers.
+    // mRWLock guards all accesses to mDefaultNetId, mNetworks, mUsers and mProtectableUsers.
     mutable android::RWLock mRWLock;
-    std::list<UidEntry> mUidMap;
     unsigned mDefaultNetId;
     std::map<unsigned, Network*> mNetworks;  // Map keys are NetIds.
     std::map<uid_t, Permission> mUsers;
index 3bfb61a..6ee118b 100644 (file)
@@ -32,9 +32,8 @@ public:
     int addAsDefault() WARN_UNUSED_RESULT;
     int removeAsDefault() WARN_UNUSED_RESULT;
 
-    Type getType() const override;
-
 private:
+    Type getType() const override;
     int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
     int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
 
index d090bef..bc50dc4 100644 (file)
@@ -524,7 +524,7 @@ WARN_UNUSED_RESULT int modifyVirtualNetwork(unsigned netId, const char* interfac
         return -ESRCH;
     }
 
-    for (const std::pair<uid_t, uid_t>& range : uidRanges.getRanges()) {
+    for (const UidRanges::Range& range : uidRanges.getRanges()) {
         if (int ret = modifyExplicitNetworkRule(netId, table, PERMISSION_NONE, range.first,
                                                 range.second, add)) {
             return ret;
index 398edd1..87fa4fe 100644 (file)
@@ -89,7 +89,8 @@ int SecondaryTableController::setupIptablesHooks() {
 
 int SecondaryTableController::addRoute(SocketClient *cli, char *iface, char *dest, int prefix,
         char *gateway) {
-    return modifyRoute(cli, ADD, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
+    return modifyRoute(cli, ADD, iface, dest, prefix, gateway,
+                       mNetCtrl->getNetworkForInterface(iface));
 }
 
 int SecondaryTableController::modifyRoute(SocketClient *cli, const char *action, char *iface,
@@ -175,7 +176,8 @@ IptablesTarget SecondaryTableController::getIptablesTarget(const char *addr) {
 
 int SecondaryTableController::removeRoute(SocketClient *cli, char *iface, char *dest, int prefix,
         char *gateway) {
-    return modifyRoute(cli, DEL, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
+    return modifyRoute(cli, DEL, iface, dest, prefix, gateway,
+                       mNetCtrl->getNetworkForInterface(iface));
 }
 
 int SecondaryTableController::modifyFromRule(unsigned netId, const char *action,
@@ -234,7 +236,7 @@ int SecondaryTableController::setFwmarkRule(const char *iface, bool add) {
         return -1;
     }
 
-    unsigned netId = mNetCtrl->getNetworkId(iface);
+    unsigned netId = mNetCtrl->getNetworkForInterface(iface);
 
     // Fail fast if any rules already exist for this interface
     if (mNetIdRuleCount.count(netId) > 0) {
@@ -396,7 +398,7 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest
         return -1;
     }
 
-    unsigned netId = mNetCtrl->getNetworkId(iface);
+    unsigned netId = mNetCtrl->getNetworkForInterface(iface);
     char mark_str[11] = {0};
     char dest_str[44]; // enough to store an IPv6 address + 3 character bitmask
 
@@ -419,50 +421,6 @@ int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest
     return runCmd(ARRAY_SIZE(rule_cmd), rule_cmd);
 }
 
-int SecondaryTableController::addUidRule(const char *iface, int uid_start, int uid_end,
-        bool forward_dns) {
-    return setUidRule(iface, uid_start, uid_end, true, forward_dns);
-}
-
-int SecondaryTableController::removeUidRule(const char *iface, int uid_start, int uid_end) {
-    return setUidRule(iface, uid_start, uid_end, false, false);
-}
-
-int SecondaryTableController::setUidRule(const char *iface, int uid_start, int uid_end, bool add,
-        bool forward_dns) {
-    unsigned netId = mNetCtrl->getNetworkId(iface);
-    if (add) {
-        if (!mNetCtrl->setNetworkForUidRange(uid_start, uid_end, netId, forward_dns)) {
-            // errno is set by setNetworkForUidRange.
-            return -1;
-        }
-    } else {
-        if (!mNetCtrl->clearNetworkForUidRange(uid_start, uid_end, netId)) {
-            // errno is set by clearNetworkForUidRange.
-            return -1;
-        }
-    }
-
-    char uid_str[24] = {0};
-    snprintf(uid_str, sizeof(uid_str), "%d-%d", uid_start, uid_end);
-    char mark_str[11] = {0};
-    snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
-    return execIptables(V4V6,
-            "-t",
-            "mangle",
-            add ? "-A" : "-D",
-            LOCAL_MANGLE_OUTPUT,
-            "-m",
-            "owner",
-            "--uid-owner",
-            uid_str,
-            "-j",
-            "MARK",
-            "--set-mark",
-            mark_str,
-            NULL);
-}
-
 int SecondaryTableController::addHostExemption(const char *host) {
     return setHostExemption(host, true);
 }
@@ -488,7 +446,7 @@ int SecondaryTableController::setHostExemption(const char *host, bool add) {
 }
 
 void SecondaryTableController::getUidMark(SocketClient *cli, int uid) {
-    unsigned netId = mNetCtrl->getNetwork(uid, NETID_UNSET, false);
+    unsigned netId = mNetCtrl->getNetworkForUser(uid, NETID_UNSET, false);
     char mark_str[11];
     snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
     cli->sendMsg(ResponseCode::GetMarkResult, mark_str, false);
index 9278bb3..b2cc36a 100644 (file)
@@ -48,11 +48,6 @@ public:
     int modifyFromRule(unsigned netId, const char *action, const char *addr);
     int modifyLocalRoute(unsigned netId, const char *action, const char *iface, const char *addr);
 
-    // Add/remove rules to force packets in a particular range of UIDs over a particular interface.
-    // This is accomplished with a rule specifying these UIDs use the interface's routing chain.
-    int addUidRule(const char *iface, int uid_start, int uid_end, bool forward_dns);
-    int removeUidRule(const char *iface, int uid_start, int uid_end);
-
     // Add/remove rules and chains so packets intended for a particular interface use that
     // interface.
     int addFwmarkRule(const char *iface);
@@ -85,7 +80,6 @@ public:
 private:
     NetworkController *mNetCtrl;
 
-    int setUidRule(const char* iface, int uid_start, int uid_end, bool add, bool foward_dns);
     int setFwmarkRule(const char *iface, bool add);
     int setFwmarkRoute(const char* iface, const char *dest, int prefix, bool add);
     int setHostExemption(const char *host, bool add);
index d752cbf..10e445a 100644 (file)
 
 #include <stdlib.h>
 
-const std::vector<std::pair<uid_t, uid_t>>& UidRanges::getRanges() const {
+bool UidRanges::hasUid(uid_t uid) const {
+    auto iter = std::lower_bound(mRanges.begin(), mRanges.end(), Range(uid, uid));
+    return (iter != mRanges.end() && iter->first == uid) ||
+           (iter != mRanges.begin() && (--iter)->second >= uid);
+}
+
+const std::vector<UidRanges::Range>& UidRanges::getRanges() const {
     return mRanges;
 }
 
@@ -59,7 +65,7 @@ bool UidRanges::parseFrom(int argc, char* argv[]) {
             // Invalid UIDs.
             return false;
         }
-        mRanges.push_back(std::pair<uid_t, uid_t>(uidStart, uidEnd));
+        mRanges.push_back(Range(uidStart, uidEnd));
     }
     std::sort(mRanges.begin(), mRanges.end());
     return true;
index 88685b4..044a8f9 100644 (file)
 
 class UidRanges {
 public:
-    const std::vector<std::pair<uid_t, uid_t>>& getRanges() const;
+    typedef std::pair<uid_t, uid_t> Range;
+
+    bool hasUid(uid_t uid) const;
+    const std::vector<Range>& getRanges() const;
 
     bool parseFrom(int argc, char* argv[]);
 
@@ -31,7 +34,7 @@ public:
     void remove(const UidRanges& other);
 
 private:
-    std::vector<std::pair<uid_t, uid_t>> mRanges;
+    std::vector<Range> mRanges;
 };
 
 #endif  // NETD_SERVER_UID_RANGES_H
index 024d2cf..565bd55 100644 (file)
 #define LOG_TAG "Netd"
 #include "log/log.h"
 
-VirtualNetwork::VirtualNetwork(unsigned netId): Network(netId) {
+VirtualNetwork::VirtualNetwork(unsigned netId, bool hasDns): Network(netId), mHasDns(hasDns) {
 }
 
 VirtualNetwork::~VirtualNetwork() {
 }
 
-int VirtualNetwork::addInterface(const std::string& interface) {
-    if (hasInterface(interface)) {
-        return 0;
-    }
-    if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(),
-                                                                mUidRanges)) {
-        ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId);
-        return ret;
-    }
-    mInterfaces.insert(interface);
-    return 0;
-}
-
-int VirtualNetwork::removeInterface(const std::string& interface) {
-    if (!hasInterface(interface)) {
-        return 0;
-    }
-    if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(),
-                                                                     mUidRanges)) {
-        ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId);
-        return ret;
-    }
-    mInterfaces.erase(interface);
-    return 0;
+bool VirtualNetwork::getHasDns() const {
+    return mHasDns;
 }
 
-Network::Type VirtualNetwork::getType() const {
-    return VIRTUAL;
+bool VirtualNetwork::appliesToUser(uid_t uid) const {
+    return mUidRanges.hasUid(uid);
 }
 
 int VirtualNetwork::addUsers(const UidRanges& uidRanges) {
@@ -80,3 +58,33 @@ int VirtualNetwork::removeUsers(const UidRanges& uidRanges) {
     mUidRanges.remove(uidRanges);
     return 0;
 }
+
+Network::Type VirtualNetwork::getType() const {
+    return VIRTUAL;
+}
+
+int VirtualNetwork::addInterface(const std::string& interface) {
+    if (hasInterface(interface)) {
+        return 0;
+    }
+    if (int ret = RouteController::addInterfaceToVirtualNetwork(mNetId, interface.c_str(),
+                                                                mUidRanges)) {
+        ALOGE("failed to add interface %s to VPN netId %u", interface.c_str(), mNetId);
+        return ret;
+    }
+    mInterfaces.insert(interface);
+    return 0;
+}
+
+int VirtualNetwork::removeInterface(const std::string& interface) {
+    if (!hasInterface(interface)) {
+        return 0;
+    }
+    if (int ret = RouteController::removeInterfaceFromVirtualNetwork(mNetId, interface.c_str(),
+                                                                     mUidRanges)) {
+        ALOGE("failed to remove interface %s from VPN netId %u", interface.c_str(), mNetId);
+        return ret;
+    }
+    mInterfaces.erase(interface);
+    return 0;
+}
index 54b4926..92a1b0e 100644 (file)
 
 class VirtualNetwork : public Network {
 public:
-    explicit VirtualNetwork(unsigned netId);
+    VirtualNetwork(unsigned netId, bool hasDns);
     virtual ~VirtualNetwork();
 
+    bool getHasDns() const;
+    bool appliesToUser(uid_t uid) const;
+
     int addUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
     int removeUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
 
-    Type getType() const override;
-
 private:
+    Type getType() const override;
     int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
     int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
 
+    const bool mHasDns;
     UidRanges mUidRanges;
 };