OSDN Git Service

netd: Replace iface with opaque netid in resolver.
authorSzymon Jakubczak <szym@google.com>
Fri, 14 Feb 2014 22:09:43 +0000 (17:09 -0500)
committerPaul Jensen <pauljensen@google.com>
Thu, 27 Mar 2014 19:36:40 +0000 (15:36 -0400)
Also ensure that UID mapping (for VPN) cannot be overridden by
android_getaddrinfofornet or per-PID preference.

Change-Id: I9ccfda2902cc0943e87c9bc346ad9a2578accdab

16 files changed:
Android.mk
CommandListener.cpp
CommandListener.h
DnsProxyListener.cpp
DnsProxyListener.h
NatController.cpp
NatController.h
NetworkController.cpp [new file with mode: 0644]
NetworkController.h [new file with mode: 0644]
ResolverController.cpp
ResolverController.h
SecondaryTableController.cpp
SecondaryTableController.h
UidMarkMap.cpp [deleted file]
UidMarkMap.h [deleted file]
main.cpp

index 359831a..0c6210c 100644 (file)
@@ -16,13 +16,13 @@ LOCAL_SRC_FILES:=                                      \
                   NetdConstants.cpp                    \
                   NetlinkHandler.cpp                   \
                   NetlinkManager.cpp                   \
+                  NetworkController.cpp                \
                   PppController.cpp                    \
                   ResolverController.cpp               \
                   SecondaryTableController.cpp         \
                   SoftapController.cpp                 \
                   TetherController.cpp                 \
                   oem_iptables_hook.cpp                \
-                  UidMarkMap.cpp                       \
                   main.cpp                             \
 
 
index 7ec3860..c89c836 100644 (file)
@@ -25,6 +25,7 @@
 #include <errno.h>
 #include <string.h>
 #include <linux/if.h>
+#include <resolv_netid.h>
 
 #define __STDC_FORMAT_MACROS 1
 #include <inttypes.h>
@@ -44,6 +45,7 @@
 #include "NetdConstants.h"
 #include "FirewallController.h"
 
+NetworkController *CommandListener::sNetCtrl = NULL;
 TetherController *CommandListener::sTetherCtrl = NULL;
 NatController *CommandListener::sNatCtrl = NULL;
 PppController *CommandListener::sPppCtrl = NULL;
@@ -133,7 +135,7 @@ static void createChildChains(IptablesTarget target, const char* table, const ch
     } while (*(++childChain) != NULL);
 }
 
-CommandListener::CommandListener(UidMarkMap *map) :
+CommandListener::CommandListener() :
                  FrameworkListener("netd", true) {
     registerCmd(new InterfaceCmd());
     registerCmd(new IpFwdCmd());
@@ -148,12 +150,14 @@ CommandListener::CommandListener(UidMarkMap *map) :
     registerCmd(new FirewallCmd());
     registerCmd(new ClatdCmd());
 
+    if (!sNetCtrl)
+        sNetCtrl = new NetworkController();
     if (!sSecondaryTableCtrl)
-        sSecondaryTableCtrl = new SecondaryTableController(map);
+        sSecondaryTableCtrl = new SecondaryTableController(sNetCtrl);
     if (!sTetherCtrl)
         sTetherCtrl = new TetherController();
     if (!sNatCtrl)
-        sNatCtrl = new NatController(sSecondaryTableCtrl);
+        sNatCtrl = new NatController(sSecondaryTableCtrl, sNetCtrl);
     if (!sPppCtrl)
         sPppCtrl = new PppController();
     if (!sSoftapCtrl)
@@ -953,7 +957,8 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
 
     if (!strcmp(argv[1], "setdefaultif")) { // "resolver setdefaultif <iface>"
         if (argc == 3) {
-            rc = sResolverCtrl->setDefaultInterface(argv[2]);
+            unsigned netId = sNetCtrl->getNetworkId(argv[2]);
+            sNetCtrl->setDefaultNetwork(netId);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver setdefaultif", false);
@@ -962,25 +967,16 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
     } else if (!strcmp(argv[1], "setifdns")) {
         // "resolver setifdns <iface> <domains> <dns1> <dns2> ..."
         if (argc >= 5) {
-            rc = sResolverCtrl->setInterfaceDnsServers(argv[2], argv[3], &argv[4], argc - 4);
+            unsigned netId = sNetCtrl->getNetworkId(argv[2]);
+            rc = sResolverCtrl->setDnsServers(netId, argv[3], &argv[4], argc - 4);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver setifdns", false);
             return 0;
         }
-
-        // set the address of the interface to which the name servers
-        // are bound. Required in order to bind to right interface when
-        // doing the dns query.
-        if (!rc) {
-            ifc_init();
-            ifc_get_info(argv[2], &addr.s_addr, NULL, 0);
-
-            rc = sResolverCtrl->setInterfaceAddress(argv[2], &addr);
-        }
     } else if (!strcmp(argv[1], "flushdefaultif")) { // "resolver flushdefaultif"
         if (argc == 2) {
-            rc = sResolverCtrl->flushDefaultDnsCache();
+            rc = sResolverCtrl->flushDnsCache(sNetCtrl->getDefaultNetwork());
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver flushdefaultif", false);
@@ -988,7 +984,8 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
         }
     } else if (!strcmp(argv[1], "flushif")) { // "resolver flushif <iface>"
         if (argc == 3) {
-            rc = sResolverCtrl->flushInterfaceDnsCache(argv[2]);
+            unsigned netId = sNetCtrl->getNetworkId(argv[2]);
+            rc = sResolverCtrl->flushDnsCache(netId);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver setdefaultif", false);
@@ -996,7 +993,8 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
         }
     } else if (!strcmp(argv[1], "setifaceforpid")) { // resolver setifaceforpid <iface> <pid>
         if (argc == 4) {
-            rc = sResolverCtrl->setDnsInterfaceForPid(argv[2], atoi(argv[3]));
+            unsigned netId = sNetCtrl->getNetworkId(argv[2]);
+            sNetCtrl->setNetworkForPid(atoi(argv[3]), netId);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver setifaceforpid", false);
@@ -1004,15 +1002,17 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
         }
     } else if (!strcmp(argv[1], "clearifaceforpid")) { // resolver clearifaceforpid <pid>
         if (argc == 3) {
-            rc = sResolverCtrl->clearDnsInterfaceForPid(atoi(argv[2]));
+            sNetCtrl->setNetworkForPid(atoi(argv[2]), 0);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver clearifaceforpid", false);
             return 0;
         }
     } else if (!strcmp(argv[1], "setifaceforuidrange")) { // resolver setifaceforuid <iface> <l> <h>
+        // TODO: Merge this command with "interface fwmark uid add/remove iface uid_start uid_end
         if (argc == 5) {
-            rc = sResolverCtrl->setDnsInterfaceForUidRange(argv[2], atoi(argv[3]), atoi(argv[4]));
+            unsigned netId = sNetCtrl->getNetworkId(argv[2]);
+            rc = !sNetCtrl->setNetworkForUidRange(atoi(argv[3]), atoi(argv[4]), netId, true);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver setifaceforuid", false);
@@ -1020,7 +1020,7 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
         }
     } else if (!strcmp(argv[1], "clearifaceforuidrange")) { // resolver clearifaceforuid <l> <h>
         if (argc == 4) {
-            rc = sResolverCtrl->clearDnsInterfaceForUidRange(atoi(argv[2]), atoi(argv[3]));
+            rc = !sNetCtrl->setNetworkForUidRange(atoi(argv[2]), atoi(argv[3]), NETID_UNSET, false);
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arguments to resolver clearifaceforuid", false);
@@ -1028,7 +1028,7 @@ int CommandListener::ResolverCmd::runCommand(SocketClient *cli, int argc, char *
         }
     } else if (!strcmp(argv[1], "clearifacemapping")) {
         if (argc == 2) {
-            rc = sResolverCtrl->clearDnsInterfaceMappings();
+            sNetCtrl->clearNetworkPreference();
         } else {
             cli->sendMsg(ResponseCode::CommandSyntaxError,
                     "Wrong number of arugments to resolver clearifacemapping", false);
index 23b8dd1..d737270 100644 (file)
@@ -20,6 +20,7 @@
 #include <sysutils/FrameworkListener.h>
 
 #include "NetdCommand.h"
+#include "NetworkController.h"
 #include "TetherController.h"
 #include "NatController.h"
 #include "PppController.h"
@@ -31,7 +32,6 @@
 #include "SecondaryTableController.h"
 #include "FirewallController.h"
 #include "ClatdController.h"
-#include "UidMarkMap.h"
 
 class CommandListener : public FrameworkListener {
     static TetherController *sTetherCtrl;
@@ -47,7 +47,9 @@ class CommandListener : public FrameworkListener {
     static ClatdController *sClatdCtrl;
 
 public:
-    CommandListener(UidMarkMap *map);
+    static NetworkController *sNetCtrl;
+
+    CommandListener();
     virtual ~CommandListener() {}
 
 private:
index a544259..9b5c283 100644 (file)
@@ -25,7 +25,7 @@
 #include <sys/types.h>
 #include <string.h>
 #include <pthread.h>
-#include <resolv_iface.h>
+#include <resolv_netid.h>
 #include <net/if.h>
 
 #define LOG_TAG "DnsProxyListener"
 #include "DnsProxyListener.h"
 #include "ResponseCode.h"
 
-DnsProxyListener::DnsProxyListener(UidMarkMap *map) :
-                 FrameworkListener("dnsproxyd") {
-    registerCmd(new GetAddrInfoCmd(map));
-    registerCmd(new GetHostByAddrCmd(map));
-    registerCmd(new GetHostByNameCmd(map));
-    mUidMarkMap = map;
+DnsProxyListener::DnsProxyListener(const NetworkController* controller) :
+                 FrameworkListener("dnsproxyd"),
+                 mNetCtrl(controller) {
+    registerCmd(new GetAddrInfoCmd(controller));
+    registerCmd(new GetHostByAddrCmd(controller));
+    registerCmd(new GetHostByNameCmd(controller));
 }
 
 DnsProxyListener::GetAddrInfoHandler::GetAddrInfoHandler(SocketClient *c,
                                                          char* host,
                                                          char* service,
                                                          struct addrinfo* hints,
-                                                         char* iface,
-                                                         pid_t pid,
-                                                         uid_t uid,
-                                                         int mark)
+                                                         unsigned netId)
         : mClient(c),
           mHost(host),
           mService(service),
           mHints(hints),
-          mIface(iface),
-          mPid(pid),
-          mUid(uid),
-          mMark(mark) {
+          mNetId(netId) {
 }
 
 DnsProxyListener::GetAddrInfoHandler::~GetAddrInfoHandler() {
     free(mHost);
     free(mService);
     free(mHints);
-    free(mIface);
 }
 
 void DnsProxyListener::GetAddrInfoHandler::start() {
@@ -125,21 +118,11 @@ static bool sendhostent(SocketClient *c, struct hostent *hp) {
 
 void DnsProxyListener::GetAddrInfoHandler::run() {
     if (DBG) {
-        ALOGD("GetAddrInfoHandler, now for %s / %s / %s", mHost, mService, mIface);
-    }
-
-    char tmp[IF_NAMESIZE + 1];
-    int mark = mMark;
-    if (mIface == NULL) {
-        //fall back to the per uid interface if no per pid interface exists
-        if(!_resolv_get_pids_associated_interface(mPid, tmp, sizeof(tmp)))
-            if(!_resolv_get_uids_associated_interface(mUid, tmp, sizeof(tmp)))
-                mark = -1; // if we don't have a targeted iface don't use a mark
+        ALOGD("GetAddrInfoHandler, now for %s / %s / %u", mHost, mService, mNetId);
     }
 
     struct addrinfo* result = NULL;
-    uint32_t rv = android_getaddrinfoforiface(mHost, mService, mHints, mIface ? mIface : tmp,
-            mark, &result);
+    uint32_t rv = android_getaddrinfofornet(mHost, mService, mHints, mNetId, 0, &result);
     if (rv) {
         // getaddrinfo failed
         mClient->sendBinaryMsg(ResponseCode::DnsProxyOperationFailed, &rv, sizeof(rv));
@@ -165,9 +148,9 @@ void DnsProxyListener::GetAddrInfoHandler::run() {
     mClient->decRef();
 }
 
-DnsProxyListener::GetAddrInfoCmd::GetAddrInfoCmd(UidMarkMap *uidMarkMap) :
-    NetdCommand("getaddrinfo") {
-        mUidMarkMap = uidMarkMap;
+DnsProxyListener::GetAddrInfoCmd::GetAddrInfoCmd(const NetworkController* controller) :
+    NetdCommand("getaddrinfo"),
+    mNetCtrl(controller) {
 }
 
 int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
@@ -200,18 +183,12 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
         service = strdup(service);
     }
 
-    char* iface = argv[7];
-    if (strcmp(iface, "^") == 0) {
-        iface = NULL;
-    } else {
-        iface = strdup(iface);
-    }
-
     struct addrinfo* hints = NULL;
     int ai_flags = atoi(argv[3]);
     int ai_family = atoi(argv[4]);
     int ai_socktype = atoi(argv[5]);
     int ai_protocol = atoi(argv[6]);
+    unsigned netId = strtoul(argv[7], NULL, 10);
     pid_t pid = cli->getPid();
     uid_t uid = cli->getUid();
 
@@ -222,20 +199,26 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
         hints->ai_family = ai_family;
         hints->ai_socktype = ai_socktype;
         hints->ai_protocol = ai_protocol;
+
+        // Only implement AI_ADDRCONFIG if application is using default network since our
+        // implementation only works on the default network.
+        if ((hints->ai_flags & AI_ADDRCONFIG) && netId != mNetCtrl->getDefaultNetwork()) {
+            hints->ai_flags &= ~AI_ADDRCONFIG;
+        }
     }
 
     if (DBG) {
-        ALOGD("GetAddrInfoHandler for %s / %s / %s / %d / %d",
+        ALOGD("GetAddrInfoHandler for %s / %s / %u / %d / %d",
              name ? name : "[nullhost]",
              service ? service : "[nullservice]",
-             iface ? iface : "[nulliface]",
-             pid, uid);
+             netId, pid, uid);
     }
 
+    netId = mNetCtrl->getNetwork(uid, netId, pid, true);
+
     cli->incRef();
     DnsProxyListener::GetAddrInfoHandler* handler =
-        new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, iface, pid, uid,
-                                    mUidMarkMap->getMark(uid));
+            new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, netId);
     handler->start();
 
     return 0;
@@ -244,9 +227,9 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
 /*******************************************************
  *                  GetHostByName                      *
  *******************************************************/
-DnsProxyListener::GetHostByNameCmd::GetHostByNameCmd(UidMarkMap *uidMarkMap) :
-        NetdCommand("gethostbyname") {
-            mUidMarkMap = uidMarkMap;
+DnsProxyListener::GetHostByNameCmd::GetHostByNameCmd(const NetworkController* controller) :
+      NetdCommand("gethostbyname"),
+      mNetCtrl(controller) {
 }
 
 int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
@@ -267,49 +250,37 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
 
     pid_t pid = cli->getPid();
     uid_t uid = cli->getUid();
-    char* iface = argv[1];
+    unsigned netId = strtoul(argv[1], NULL, 10);
     char* name = argv[2];
     int af = atoi(argv[3]);
 
-    if (strcmp(iface, "^") == 0) {
-        iface = NULL;
-    } else {
-        iface = strdup(iface);
-    }
-
     if (strcmp(name, "^") == 0) {
         name = NULL;
     } else {
         name = strdup(name);
     }
 
+    netId = mNetCtrl->getNetwork(uid, netId, pid, true);
+
     cli->incRef();
     DnsProxyListener::GetHostByNameHandler* handler =
-            new DnsProxyListener::GetHostByNameHandler(cli, pid, uid, iface, name, af,
-                    mUidMarkMap->getMark(uid));
+            new DnsProxyListener::GetHostByNameHandler(cli, name, af, netId);
     handler->start();
 
     return 0;
 }
 
 DnsProxyListener::GetHostByNameHandler::GetHostByNameHandler(SocketClient* c,
-                                                             pid_t pid,
-                                                             uid_t uid,
-                                                             char* iface,
                                                              char* name,
                                                              int af,
-                                                             int mark)
+                                                             unsigned netId)
         : mClient(c),
-          mPid(pid),
-          mUid(uid),
-          mIface(iface),
           mName(name),
           mAf(af),
-          mMark(mark) {
+          mNetId(netId) {
 }
 
 DnsProxyListener::GetHostByNameHandler::~GetHostByNameHandler() {
-    free(mIface);
     free(mName);
 }
 
@@ -333,16 +304,9 @@ void DnsProxyListener::GetHostByNameHandler::run() {
         ALOGD("DnsProxyListener::GetHostByNameHandler::run\n");
     }
 
-    char iface[IF_NAMESIZE + 1];
-    if (mIface == NULL) {
-        //fall back to the per uid interface if no per pid interface exists
-        if(!_resolv_get_pids_associated_interface(mPid, iface, sizeof(iface)))
-            _resolv_get_uids_associated_interface(mUid, iface, sizeof(iface));
-    }
-
     struct hostent* hp;
 
-    hp = android_gethostbynameforiface(mName, mAf, mIface ? mIface : iface, mMark);
+    hp = android_gethostbynamefornet(mName, mAf, mNetId, 0);
 
     if (DBG) {
         ALOGD("GetHostByNameHandler::run gethostbyname errno: %s hp->h_name = %s, name_len = %zu\n",
@@ -369,9 +333,9 @@ void DnsProxyListener::GetHostByNameHandler::run() {
 /*******************************************************
  *                  GetHostByAddr                      *
  *******************************************************/
-DnsProxyListener::GetHostByAddrCmd::GetHostByAddrCmd(UidMarkMap *uidMarkMap) :
-        NetdCommand("gethostbyaddr") {
-        mUidMarkMap = uidMarkMap;
+DnsProxyListener::GetHostByAddrCmd::GetHostByAddrCmd(const NetworkController* controller) :
+        NetdCommand("gethostbyaddr"),
+        mNetCtrl(controller) {
 }
 
 int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
@@ -395,13 +359,7 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
     int addrFamily = atoi(argv[3]);
     pid_t pid = cli->getPid();
     uid_t uid = cli->getUid();
-    char* iface = argv[4];
-
-    if (strcmp(iface, "^") == 0) {
-        iface = NULL;
-    } else {
-        iface = strdup(iface);
-    }
+    unsigned netId = strtoul(argv[4], NULL, 10);
 
     void* addr = malloc(sizeof(struct in6_addr));
     errno = 0;
@@ -416,10 +374,11 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
         return -1;
     }
 
+    netId = mNetCtrl->getNetwork(uid, netId, pid, true);
+
     cli->incRef();
     DnsProxyListener::GetHostByAddrHandler* handler =
-            new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, iface, pid,
-                    uid, mUidMarkMap->getMark(uid));
+            new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netId);
     handler->start();
 
     return 0;
@@ -429,23 +388,16 @@ DnsProxyListener::GetHostByAddrHandler::GetHostByAddrHandler(SocketClient* c,
                                                              void* address,
                                                              int   addressLen,
                                                              int   addressFamily,
-                                                             char* iface,
-                                                             pid_t pid,
-                                                             uid_t uid,
-                                                             int mark)
+                                                             unsigned netId)
         : mClient(c),
           mAddress(address),
           mAddressLen(addressLen),
           mAddressFamily(addressFamily),
-          mIface(iface),
-          mPid(pid),
-          mUid(uid),
-          mMark(mark) {
+          mNetId(netId) {
 }
 
 DnsProxyListener::GetHostByAddrHandler::~GetHostByAddrHandler() {
     free(mAddress);
-    free(mIface);
 }
 
 void DnsProxyListener::GetHostByAddrHandler::start() {
@@ -467,20 +419,10 @@ void DnsProxyListener::GetHostByAddrHandler::run() {
     if (DBG) {
         ALOGD("DnsProxyListener::GetHostByAddrHandler::run\n");
     }
-
-    char tmp[IF_NAMESIZE + 1];
-    int mark = mMark;
-    if (mIface == NULL) {
-        //fall back to the per uid interface if no per pid interface exists
-        if(!_resolv_get_pids_associated_interface(mPid, tmp, sizeof(tmp)))
-            if(!_resolv_get_uids_associated_interface(mUid, tmp, sizeof(tmp)))
-                mark = -1;
-    }
     struct hostent* hp;
 
     // NOTE gethostbyaddr should take a void* but bionic thinks it should be char*
-    hp = android_gethostbyaddrforiface((char*)mAddress, mAddressLen, mAddressFamily,
-            mIface ? mIface : tmp, mark);
+    hp = android_gethostbyaddrfornet((char*)mAddress, mAddressLen, mAddressFamily, mNetId, 0);
 
     if (DBG) {
         ALOGD("GetHostByAddrHandler::run gethostbyaddr errno: %s hp->h_name = %s, name_len = %zu\n",
index 2061d71..345928f 100644 (file)
 #include <sysutils/FrameworkListener.h>
 
 #include "NetdCommand.h"
-#include "UidMarkMap.h"
+#include "NetworkController.h"
 
 class DnsProxyListener : public FrameworkListener {
 public:
-    DnsProxyListener(UidMarkMap *map);
+    DnsProxyListener(const NetworkController* controller);
     virtual ~DnsProxyListener() {}
 
 private:
-    UidMarkMap *mUidMarkMap;
+    const NetworkController *mNetCtrl;
     class GetAddrInfoCmd : public NetdCommand {
     public:
-        GetAddrInfoCmd(UidMarkMap *uidMarkMap);
+        GetAddrInfoCmd(const NetworkController* controller);
         virtual ~GetAddrInfoCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        UidMarkMap *mUidMarkMap;
+        const NetworkController* mNetCtrl;
     };
 
     class GetAddrInfoHandler {
@@ -45,10 +45,7 @@ private:
                            char* host,
                            char* service,
                            struct addrinfo* hints,
-                           char* iface,
-                           pid_t pid,
-                           uid_t uid,
-                           int mark);
+                           unsigned netId);
         ~GetAddrInfoHandler();
 
         static void* threadStart(void* handler);
@@ -60,65 +57,53 @@ private:
         char* mHost;    // owned
         char* mService; // owned
         struct addrinfo* mHints;  // owned
-        char* mIface; // owned
-        pid_t mPid;
-        uid_t mUid;
-        int mMark;
+        unsigned mNetId;
     };
 
     /* ------ gethostbyname ------*/
     class GetHostByNameCmd : public NetdCommand {
     public:
-        GetHostByNameCmd(UidMarkMap *uidMarkMap);
+        GetHostByNameCmd(const NetworkController* controller);
         virtual ~GetHostByNameCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        UidMarkMap *mUidMarkMap;
+        const NetworkController* mNetCtrl;
     };
 
     class GetHostByNameHandler {
     public:
         GetHostByNameHandler(SocketClient *c,
-                            pid_t pid,
-                            uid_t uid,
-                            char *iface,
                             char *name,
                             int af,
-                            int mark);
+                            unsigned netId);
         ~GetHostByNameHandler();
         static void* threadStart(void* handler);
         void start();
     private:
         void run();
         SocketClient* mClient; //ref counted
-        pid_t mPid;
-        uid_t mUid;
-        char* mIface; // owned
         char* mName; // owned
         int mAf;
-        int mMark;
+        unsigned mNetId;
     };
 
     /* ------ gethostbyaddr ------*/
     class GetHostByAddrCmd : public NetdCommand {
     public:
-        GetHostByAddrCmd(UidMarkMap *uidMarkMap);
+        GetHostByAddrCmd(const NetworkController* controller);
         virtual ~GetHostByAddrCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        UidMarkMap *mUidMarkMap;
+        const NetworkController* mNetCtrl;
     };
 
     class GetHostByAddrHandler {
     public:
         GetHostByAddrHandler(SocketClient *c,
                             void* address,
-                            int   addressLen,
-                            int   addressFamily,
-                            char* iface,
-                            pid_t pid,
-                            uid_t uid,
-                            int mark);
+                            int addressLen,
+                            int addressFamily,
+                            unsigned netId);
         ~GetHostByAddrHandler();
 
         static void* threadStart(void* handler);
@@ -128,12 +113,9 @@ private:
         void run();
         SocketClient* mClient;  // ref counted
         void* mAddress;    // address to lookup; owned
-        int   mAddressLen; // length of address to look up
-        int   mAddressFamily;  // address family
-        char* mIface; // owned
-        pid_t mPid;
-        uid_t mUid;
-        int   mMark;
+        int mAddressLen; // length of address to look up
+        int mAddressFamily;  // address family
+        unsigned mNetId;
     };
 };
 
index b2a0e64..dd5316a 100644 (file)
@@ -32,6 +32,7 @@
 #include <logwrap/logwrap.h>
 
 #include "NatController.h"
+#include "NetworkController.h"
 #include "SecondaryTableController.h"
 #include "NetdConstants.h"
 
@@ -39,8 +40,8 @@ const char* NatController::LOCAL_FORWARD = "natctrl_FORWARD";
 const char* NatController::LOCAL_NAT_POSTROUTING = "natctrl_nat_POSTROUTING";
 const char* NatController::LOCAL_TETHER_COUNTERS_CHAIN = "natctrl_tether_counters";
 
-NatController::NatController(SecondaryTableController *ctrl) {
-    secondaryTableCtrl = ctrl;
+NatController::NatController(SecondaryTableController *table_ctrl, NetworkController* net_ctrl) :
+        mSecondaryTableCtrl(table_ctrl), mNetCtrl(net_ctrl) {
 }
 
 NatController::~NatController() {
@@ -138,27 +139,25 @@ bool NatController::checkInterface(const char *iface) {
 }
 
 int NatController::routesOp(bool add, const char *intIface, const char *extIface, char **argv, int addrCount) {
-    int tableNumber = secondaryTableCtrl->findTableNumber(extIface);
+    unsigned netId = mNetCtrl->getNetworkId(extIface);
     int ret = 0;
 
-    if (tableNumber != -1) {
-        for (int i = 0; i < addrCount; i++) {
-            if (add) {
-                ret |= secondaryTableCtrl->modifyFromRule(tableNumber, ADD, argv[5+i]);
-                ret |= secondaryTableCtrl->modifyLocalRoute(tableNumber, ADD, intIface, argv[5+i]);
-            } else {
-                ret |= secondaryTableCtrl->modifyLocalRoute(tableNumber, DEL, intIface, argv[5+i]);
-                ret |= secondaryTableCtrl->modifyFromRule(tableNumber, DEL, argv[5+i]);
-            }
+    for (int i = 0; i < addrCount; i++) {
+        if (add) {
+            ret |= mSecondaryTableCtrl->modifyFromRule(netId, ADD, argv[5+i]);
+            ret |= mSecondaryTableCtrl->modifyLocalRoute(netId, ADD, intIface, argv[5+i]);
+        } else {
+            ret |= mSecondaryTableCtrl->modifyLocalRoute(netId, DEL, intIface, argv[5+i]);
+            ret |= mSecondaryTableCtrl->modifyFromRule(netId, DEL, argv[5+i]);
         }
-        const char *cmd[] = {
-                IP_PATH,
-                "route",
-                "flush",
-                "cache"
-        };
-        runCmd(ARRAY_SIZE(cmd), cmd);
     }
+    const char *cmd[] = {
+            IP_PATH,
+            "route",
+            "flush",
+            "cache"
+    };
+    runCmd(ARRAY_SIZE(cmd), cmd);
     return ret;
 }
 
index 525ca02..5f45376 100644 (file)
 
 #include <linux/in.h>
 
-#include "SecondaryTableController.h"
+class NetworkController;
+class SecondaryTableController;
 
 class NatController {
 
 public:
-    NatController(SecondaryTableController *ctrl);
+    NatController(SecondaryTableController *table_ctrl, NetworkController* net_ctrl);
     virtual ~NatController();
 
     int enableNat(const int argc, char **argv);
@@ -37,7 +38,8 @@ public:
 
 private:
     int natCount;
-    SecondaryTableController *secondaryTableCtrl;
+    SecondaryTableController *mSecondaryTableCtrl;
+    NetworkController *mNetCtrl;
 
     int setDefaults();
     int runCmd(int argc, const char **argv);
diff --git a/NetworkController.cpp b/NetworkController.cpp
new file mode 100644 (file)
index 0000000..a1f1535
--- /dev/null
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <resolv_netid.h>
+
+#define LOG_TAG "NetworkController"
+#include <cutils/log.h>
+
+#include "NetworkController.h"
+
+// Mark 1 is reserved for SecondaryTableController::PROTECT_MARK.
+NetworkController::NetworkController() : mNextFreeNetId(10) {}
+
+void NetworkController::clearNetworkPreference() {
+    android::RWLock::AutoWLock lock(mRWLock);
+    mUidMap.clear();
+    mPidMap.clear();
+}
+
+unsigned NetworkController::getDefaultNetwork() const {
+    return mDefaultNetId;
+}
+
+void NetworkController::setDefaultNetwork(unsigned netId) {
+    android::RWLock::AutoWLock lock(mRWLock);
+    mDefaultNetId = netId;
+}
+
+void NetworkController::setNetworkForPid(int pid, unsigned netId) {
+    android::RWLock::AutoWLock lock(mRWLock);
+    if (netId == 0) {
+        mPidMap.erase(pid);
+    } else {
+        mPidMap[pid] = netId;
+    }
+}
+
+bool NetworkController::setNetworkForUidRange(int uid_start, int uid_end, unsigned netId,
+        bool forward_dns) {
+    android::RWLock::AutoWLock lock(mRWLock);
+    if (uid_start > uid_end)
+        return false;
+
+    for (std::list<UidEntry>::iterator it = mUidMap.begin(); it != mUidMap.end(); ++it) {
+        if (it->uid_start > uid_end || uid_start > it->uid_end)
+            continue;
+        /* Overlapping or identical range. */
+        if (it->uid_start != uid_start || it->uid_end != uid_end) {
+            ALOGE("Overlapping but not identical uid range detected.");
+            return false;
+        }
+
+        if (netId == NETID_UNSET) {
+            mUidMap.erase(it);
+        } else {
+            it->netId = netId;
+            it->forward_dns = forward_dns;
+        }
+        return true;
+    }
+
+    mUidMap.push_back(UidEntry(uid_start, uid_end, netId, forward_dns));
+    return true;
+}
+
+unsigned NetworkController::getNetwork(int uid, unsigned requested_netId, int pid,
+        bool for_dns) const {
+    android::RWLock::AutoRLock lock(mRWLock);
+    for (std::list<UidEntry>::const_iterator it = mUidMap.begin(); it != mUidMap.end(); ++it) {
+        if (uid < it->uid_start || it->uid_end < uid)
+            continue;
+        if (for_dns && !it->forward_dns)
+            break;
+        return it->netId;
+    }
+    if (requested_netId != NETID_UNSET)
+        return requested_netId;
+    if (pid != PID_UNSPECIFIED) {
+        std::map<int, unsigned>::const_iterator it = mPidMap.find(pid);
+        if (it != mPidMap.end())
+            return it->second;
+    }
+    return mDefaultNetId;
+}
+
+unsigned NetworkController::getNetworkId(const char* interface) {
+    std::map<std::string, unsigned>::const_iterator it = mIfaceNetidMap.find(interface);
+    if (it != mIfaceNetidMap.end())
+        return it->second;
+
+    unsigned netId = mNextFreeNetId++;
+    mIfaceNetidMap[interface] = netId;
+    return netId;
+}
+
+NetworkController::UidEntry::UidEntry(
+    int start, int end, unsigned netId, bool forward_dns)
+      : uid_start(start),
+        uid_end(end),
+        netId(netId),
+        forward_dns(forward_dns) {
+}
diff --git a/NetworkController.h b/NetworkController.h
new file mode 100644 (file)
index 0000000..52ab7f4
--- /dev/null
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2014 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _NETD_NETWORKCONTROLLER_H
+#define _NETD_NETWORKCONTROLLER_H
+
+#include <list>
+#include <map>
+#include <string>
+
+#include <stddef.h>
+#include <stdint.h>
+#include <utils/RWLock.h>
+
+/*
+ * Keeps track of default, per-pid, and per-uid-range network selection, as
+ * well as the mark associated with each network. Networks are identified
+ * by netid. In all set* commands netid == 0 means "unspecified" and is
+ * equivalent to clearing the mapping.
+ */
+class NetworkController {
+public:
+    enum {
+        // For use with getNetwork().
+        PID_UNSPECIFIED = 0,
+    };
+
+    NetworkController();
+
+    void clearNetworkPreference();
+    unsigned getDefaultNetwork() const;
+    void setDefaultNetwork(unsigned netId);
+    void setNetworkForPid(int pid, unsigned netId);
+    // Returns false if a partially overlapping range exists.
+    // Specify NETID_UNSET for netId to clear a mapping.
+    bool setNetworkForUidRange(int uid_start, int uid_end, unsigned netId, bool forward_dns);
+
+    // Order of preference: UID-specific, requested_netId, PID-specific, default.
+    // Specify NETID_UNSET for requested_netId if the default network is preferred.
+    // Specify PID_UNSPECIFIED for pid to ignore PID-specific overrides.
+    // for_dns indicates if we're querrying the netId for a DNS request.  This avoids sending DNS
+    // requests to VPNs without DNS servers.
+    unsigned getNetwork(int uid, unsigned requested_netId, int pid, bool for_dns) const;
+
+    unsigned getNetworkId(const char* interface);
+
+private:
+    struct UidEntry {
+        int uid_start;
+        int uid_end;
+        unsigned netId;
+        bool forward_dns;
+        UidEntry(int uid_start, int uid_end, unsigned netId, bool forward_dns);
+    };
+
+    mutable android::RWLock mRWLock;
+    std::list<UidEntry> mUidMap;
+    std::map<int, unsigned> mPidMap;
+    unsigned mDefaultNetId;
+
+    std::map<std::string, unsigned> mIfaceNetidMap;
+    unsigned mNextFreeNetId;
+};
+
+#endif
index e61fae7..db00da6 100644 (file)
 
 #include <net/if.h>
 
-// NOTE: <resolv_iface.h> is a private C library header that provides
-//       declarations for _resolv_set_default_iface() and others.
-#include <resolv_iface.h>
+// NOTE: <resolv_netid.h> is a private C library header that provides
+//       declarations for _resolv_set_nameservers_for_net and
+//       _resolv_flush_cache_for_net
+#include <resolv_netid.h>
 
 #include "ResolverController.h"
 
-int ResolverController::setDefaultInterface(const char* iface) {
-    if (DBG) {
-        ALOGD("setDefaultInterface iface = %s\n", iface);
-    }
-
-    _resolv_set_default_iface(iface);
-
-    return 0;
-}
-
-int ResolverController::setInterfaceDnsServers(const char* iface, const char* domains,
+int ResolverController::setDnsServers(unsigned netId, const char* domains,
         const char** servers, int numservers) {
     if (DBG) {
-        ALOGD("setInterfaceDnsServers iface = %s\n", iface);
-    }
-    _resolv_set_nameservers_for_iface(iface, servers, numservers, domains);
-
-    return 0;
-}
-
-int ResolverController::setInterfaceAddress(const char* iface, struct in_addr* addr) {
-    if (DBG) {
-        ALOGD("setInterfaceAddress iface = %s\n", iface);
-    }
-
-    _resolv_set_addr_of_iface(iface, addr);
-
-    return 0;
-}
-
-int ResolverController::flushDefaultDnsCache() {
-    if (DBG) {
-        ALOGD("flushDefaultDnsCache\n");
-    }
-
-    _resolv_flush_cache_for_default_iface();
-
-    return 0;
-}
-
-int ResolverController::flushInterfaceDnsCache(const char* iface) {
-    if (DBG) {
-        ALOGD("flushInterfaceDnsCache iface = %s\n", iface);
+        ALOGD("setDnsServers netId = %u\n", netId);
     }
-
-    _resolv_flush_cache_for_iface(iface);
+    _resolv_set_nameservers_for_net(netId, servers, numservers, domains);
 
     return 0;
 }
 
-int ResolverController::setDnsInterfaceForPid(const char* iface, int pid) {
+int ResolverController::flushDnsCache(unsigned netId) {
     if (DBG) {
-        ALOGD("setDnsIfaceForPid iface = %s, pid = %d\n", iface, pid);
+        ALOGD("flushDnsCache netId = %u\n", netId);
     }
 
-    _resolv_set_iface_for_pid(iface, pid);
+    _resolv_flush_cache_for_net(netId);
 
     return 0;
 }
 
-int ResolverController::clearDnsInterfaceForPid(int pid) {
-    if (DBG) {
-        ALOGD("clearDnsIfaceForPid pid = %d\n", pid);
-    }
-
-    _resolv_clear_iface_for_pid(pid);
-
-    return 0;
-}
-
-int ResolverController::setDnsInterfaceForUidRange(const char* iface, int uid_start, int uid_end) {
-    if (DBG) {
-        ALOGD("setDnsIfaceForUidRange iface = %s, range = [%d,%d]\n", iface, uid_start, uid_end);
-    }
-
-    return _resolv_set_iface_for_uid_range(iface, uid_start, uid_end);
-}
-
-int ResolverController::clearDnsInterfaceForUidRange(int uid_start, int uid_end) {
-    if (DBG) {
-        ALOGD("clearDnsIfaceForUidRange range = [%d,%d]\n", uid_start, uid_end);
-    }
-
-    return _resolv_clear_iface_for_uid_range(uid_start, uid_end);
-}
-
-int ResolverController::clearDnsInterfaceMappings()
-{
-    if (DBG) {
-        ALOGD("clearInterfaceMappings\n");
-    }
-    _resolv_clear_iface_uid_range_mapping();
-    _resolv_clear_iface_pid_mapping();
-
-    return 0;
-}
index e705c8f..0c245d7 100644 (file)
@@ -25,17 +25,10 @@ public:
     ResolverController() {};
     virtual ~ResolverController() {};
 
-    int setDefaultInterface(const char* iface);
-    int setInterfaceDnsServers(const char* iface, const char * domains, const char** servers,
+    int setDnsServers(unsigned netid, const char * domains, const char** servers,
             int numservers);
-    int setInterfaceAddress(const char* iface, struct in_addr* addr);
-    int flushDefaultDnsCache();
-    int flushInterfaceDnsCache(const char* iface);
-    int setDnsInterfaceForPid(const char* iface, int pid);
-    int clearDnsInterfaceForPid(int pid);
-    int setDnsInterfaceForUidRange(const char* iface, int uid_start, int uid_end);
-    int clearDnsInterfaceForUidRange(int uid_start, int uid_end);
-    int clearDnsInterfaceMappings();
+    int flushDnsCache(unsigned netid);
+    // TODO: Add deleteDnsCache(unsigned netId)
 };
 
 #endif /* _RESOLVER_CONTROLLER_H_ */
index d12f4c8..dba801e 100644 (file)
@@ -26,6 +26,7 @@
 
 #include <netinet/in.h>
 #include <arpa/inet.h>
+#include <resolv_netid.h>
 
 #define LOG_TAG "SecondaryTablController"
 #include <cutils/log.h>
@@ -42,13 +43,8 @@ const char* SecondaryTableController::LOCAL_MANGLE_IFACE_FORMAT = "st_mangle_%s_
 const char* SecondaryTableController::LOCAL_NAT_POSTROUTING = "st_nat_POSTROUTING";
 const char* SecondaryTableController::LOCAL_FILTER_OUTPUT = "st_filter_OUTPUT";
 
-SecondaryTableController::SecondaryTableController(UidMarkMap *map) : mUidMarkMap(map) {
-    int i;
-    for (i=0; i < INTERFACES_TRACKED; i++) {
-        mInterfaceTable[i][0] = 0;
-        // TODO - use a hashtable or other prebuilt container class
-        mInterfaceRuleCount[i] = 0;
-    }
+SecondaryTableController::SecondaryTableController(NetworkController* controller) :
+        mNetCtrl(controller) {
 }
 
 SecondaryTableController::~SecondaryTableController() {
@@ -100,45 +96,20 @@ int SecondaryTableController::setupIptablesHooks() {
     return res;
 }
 
-int SecondaryTableController::findTableNumber(const char *iface) {
-    int i;
-    for (i = 0; i < INTERFACES_TRACKED; i++) {
-        // compare through the final null, hence +1
-        if (strncmp(iface, mInterfaceTable[i], IFNAMSIZ + 1) == 0) {
-            return i;
-        }
-    }
-    return -1;
-}
-
 int SecondaryTableController::addRoute(SocketClient *cli, char *iface, char *dest, int prefix,
         char *gateway) {
-    int tableIndex = findTableNumber(iface);
-    if (tableIndex == -1) {
-        tableIndex = findTableNumber(""); // look for an empty slot
-        if (tableIndex == -1) {
-            ALOGE("Max number of NATed interfaces reached");
-            errno = ENODEV;
-            cli->sendMsg(ResponseCode::OperationFailed, "Max number NATed", true);
-            return -1;
-        }
-        strncpy(mInterfaceTable[tableIndex], iface, IFNAMSIZ);
-        // Ensure null termination even if truncation happened
-        mInterfaceTable[tableIndex][IFNAMSIZ] = 0;
-    }
-
-    return modifyRoute(cli, ADD, iface, dest, prefix, gateway, tableIndex);
+    return modifyRoute(cli, ADD, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
 }
 
 int SecondaryTableController::modifyRoute(SocketClient *cli, const char *action, char *iface,
-        char *dest, int prefix, char *gateway, int tableIndex) {
+        char *dest, int prefix, char *gateway, unsigned netId) {
     char dest_str[44]; // enough to store an IPv6 address + 3 character bitmask
     char tableIndex_str[11];
     int ret;
 
     //  IP tool doesn't like "::" - the equiv of 0.0.0.0 that it accepts for ipv4
     snprintf(dest_str, sizeof(dest_str), "%s/%d", dest, prefix);
-    snprintf(tableIndex_str, sizeof(tableIndex_str), "%d", tableIndex + BASE_TABLE_NUMBER);
+    snprintf(tableIndex_str, sizeof(tableIndex_str), "%u", netId + BASE_TABLE_NUMBER);
 
     if (strcmp("::", gateway) == 0) {
         const char *cmd[] = {
@@ -169,47 +140,32 @@ int SecondaryTableController::modifyRoute(SocketClient *cli, const char *action,
     }
 
     if (ret) {
-        ALOGE("ip route %s failed: %s route %s %s/%d via %s dev %s table %d", action,
-                IP_PATH, action, dest, prefix, gateway, iface, tableIndex+BASE_TABLE_NUMBER);
+        ALOGE("ip route %s failed: %s route %s %s/%d via %s dev %s table %u", action,
+                IP_PATH, action, dest, prefix, gateway, iface, netId + BASE_TABLE_NUMBER);
         errno = ENODEV;
         cli->sendMsg(ResponseCode::OperationFailed, "ip route modification failed", true);
         return -1;
     }
 
-    if (strcmp(action, ADD) == 0) {
-        mInterfaceRuleCount[tableIndex]++;
-    } else {
-        if (--mInterfaceRuleCount[tableIndex] < 1) {
-            mInterfaceRuleCount[tableIndex] = 0;
-            mInterfaceTable[tableIndex][0] = 0;
-        }
-    }
-    modifyRuleCount(tableIndex, action);
+    modifyRuleCount(netId, action);
     cli->sendMsg(ResponseCode::CommandOkay, "Route modified", false);
     return 0;
 }
 
-void SecondaryTableController::modifyRuleCount(int tableIndex, const char *action) {
+void SecondaryTableController::modifyRuleCount(unsigned netId, const char *action) {
     if (strcmp(action, ADD) == 0) {
-        mInterfaceRuleCount[tableIndex]++;
+        if (mNetIdRuleCount.count(netId) == 0)
+            mNetIdRuleCount[netId] = 0;
+        mNetIdRuleCount[netId]++;
     } else {
-        if (--mInterfaceRuleCount[tableIndex] < 1) {
-            mInterfaceRuleCount[tableIndex] = 0;
-            mInterfaceTable[tableIndex][0] = 0;
+        if (mNetIdRuleCount.count(netId) > 0) {
+            if (--mNetIdRuleCount[netId] < 1) {
+                mNetIdRuleCount.erase(mNetIdRuleCount.find(netId));
+            }
         }
     }
 }
 
-int SecondaryTableController::verifyTableIndex(int tableIndex) {
-    if ((tableIndex < 0) ||
-            (tableIndex >= INTERFACES_TRACKED) ||
-            (mInterfaceTable[tableIndex][0] == 0)) {
-        return -1;
-    } else {
-        return 0;
-    }
-}
-
 const char *SecondaryTableController::getVersion(const char *addr) {
     if (strchr(addr, ':') != NULL) {
         return "-6";
@@ -228,27 +184,14 @@ IptablesTarget SecondaryTableController::getIptablesTarget(const char *addr) {
 
 int SecondaryTableController::removeRoute(SocketClient *cli, char *iface, char *dest, int prefix,
         char *gateway) {
-    int tableIndex = findTableNumber(iface);
-    if (tableIndex == -1) {
-        ALOGE("Interface not found");
-        errno = ENODEV;
-        cli->sendMsg(ResponseCode::OperationFailed, "Interface not found", true);
-        return -1;
-    }
-
-    return modifyRoute(cli, DEL, iface, dest, prefix, gateway, tableIndex);
+    return modifyRoute(cli, DEL, iface, dest, prefix, gateway, mNetCtrl->getNetworkId(iface));
 }
 
-int SecondaryTableController::modifyFromRule(int tableIndex, const char *action,
+int SecondaryTableController::modifyFromRule(unsigned netId, const char *action,
         const char *addr) {
     char tableIndex_str[11];
 
-    if (verifyTableIndex(tableIndex)) {
-        return -1;
-    }
-
-    snprintf(tableIndex_str, sizeof(tableIndex_str), "%d", tableIndex +
-            BASE_TABLE_NUMBER);
+    snprintf(tableIndex_str, sizeof(tableIndex_str), "%u", netId + BASE_TABLE_NUMBER);
     const char *cmd[] = {
             IP_PATH,
             getVersion(addr),
@@ -263,22 +206,16 @@ int SecondaryTableController::modifyFromRule(int tableIndex, const char *action,
         return -1;
     }
 
-    modifyRuleCount(tableIndex, action);
+    modifyRuleCount(netId, action);
     return 0;
 }
 
-int SecondaryTableController::modifyLocalRoute(int tableIndex, const char *action,
+int SecondaryTableController::modifyLocalRoute(unsigned netId, const char *action,
         const char *iface, const char *addr) {
     char tableIndex_str[11];
 
-    if (verifyTableIndex(tableIndex)) {
-        return -1;
-    }
-
-    modifyRuleCount(tableIndex, action); // some del's will fail as the iface is already gone.
-
-    snprintf(tableIndex_str, sizeof(tableIndex_str), "%d", tableIndex +
-            BASE_TABLE_NUMBER);
+    modifyRuleCount(netId, action); // some del's will fail as the iface is already gone.
+    snprintf(tableIndex_str, sizeof(tableIndex_str), "%u", netId + BASE_TABLE_NUMBER);
     const char *cmd[] = {
             IP_PATH,
             "route",
@@ -301,29 +238,17 @@ int SecondaryTableController::removeFwmarkRule(const char *iface) {
 }
 
 int SecondaryTableController::setFwmarkRule(const char *iface, bool add) {
-    int tableIndex = findTableNumber(iface);
-    if (tableIndex == -1) {
-        tableIndex = findTableNumber(""); // look for an empty slot
-        if (tableIndex == -1) {
-            ALOGE("Max number of NATed interfaces reached");
-            errno = ENODEV;
-            return -1;
-        }
-        strncpy(mInterfaceTable[tableIndex], iface, IFNAMSIZ);
-        // Ensure null termination even if truncation happened
-        mInterfaceTable[tableIndex][IFNAMSIZ] = 0;
-    }
-    int mark = tableIndex + BASE_TABLE_NUMBER;
-    char mark_str[11];
-    int ret;
+    unsigned netId = mNetCtrl->getNetworkId(iface);
 
-    //fail fast if any rules already exist for this interface
-    if (mUidMarkMap->anyRulesForMark(mark)) {
+    // Fail fast if any rules already exist for this interface
+    if (mNetIdRuleCount.count(netId) > 0) {
         errno = EBUSY;
         return -1;
     }
 
-    snprintf(mark_str, sizeof(mark_str), "%d", mark);
+    int ret;
+    char mark_str[11];
+    snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
     //add the catch all route to the tun. Route rules will make sure the right packets hit the table
     const char *route_cmd[] = {
         IP_PATH,
@@ -510,22 +435,17 @@ int SecondaryTableController::addFwmarkRoute(const char* iface, const char *dest
 }
 
 int SecondaryTableController::removeFwmarkRoute(const char* iface, const char *dest, int prefix) {
-    return setFwmarkRoute(iface, dest, prefix, true);
+    return setFwmarkRoute(iface, dest, prefix, false);
 }
 
 int SecondaryTableController::setFwmarkRoute(const char* iface, const char *dest, int prefix,
                                              bool add) {
-    int tableIndex = findTableNumber(iface);
-    if (tableIndex == -1) {
-        errno = EINVAL;
-        return -1;
-    }
-    int mark = tableIndex + BASE_TABLE_NUMBER;
+    unsigned netId = mNetCtrl->getNetworkId(iface);
     char mark_str[11] = {0};
     char chain_str[IFNAMSIZ + 18];
     char dest_str[44]; // enough to store an IPv6 address + 3 character bitmask
 
-    snprintf(mark_str, sizeof(mark_str), "%d", mark);
+    snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
     snprintf(chain_str, sizeof(chain_str), LOCAL_MANGLE_IFACE_FORMAT, iface);
     snprintf(dest_str, sizeof(dest_str), "%s/%d", dest, prefix);
     return execIptables(getIptablesTarget(dest),
@@ -551,23 +471,12 @@ int SecondaryTableController::removeUidRule(const char *iface, int uid_start, in
 }
 
 int SecondaryTableController::setUidRule(const char *iface, int uid_start, int uid_end, bool add) {
-    int tableIndex = findTableNumber(iface);
-    if (tableIndex == -1) {
+    unsigned netId = mNetCtrl->getNetworkId(iface);
+    if (!mNetCtrl->setNetworkForUidRange(uid_start, uid_end, add ? netId : 0, false)) {
         errno = EINVAL;
         return -1;
     }
-    int mark = tableIndex + BASE_TABLE_NUMBER;
-    if (add) {
-        if (!mUidMarkMap->add(uid_start, uid_end, mark)) {
-            errno = EINVAL;
-            return -1;
-        }
-    } else {
-        if (!mUidMarkMap->remove(uid_start, uid_end, mark)) {
-            errno = EINVAL;
-            return -1;
-        }
-    }
+
     char uid_str[24] = {0};
     char chain_str[IFNAMSIZ + 18];
     snprintf(uid_str, sizeof(uid_str), "%d-%d", uid_start, uid_end);
@@ -627,9 +536,10 @@ int SecondaryTableController::setHostExemption(const char *host, bool add) {
 }
 
 void SecondaryTableController::getUidMark(SocketClient *cli, int uid) {
-    int mark = mUidMarkMap->getMark(uid);
+    unsigned netId = mNetCtrl->getNetwork(uid, NETID_UNSET, NetworkController::PID_UNSPECIFIED,
+            false);
     char mark_str[11];
-    snprintf(mark_str, sizeof(mark_str), "%d", mark);
+    snprintf(mark_str, sizeof(mark_str), "%u", netId + BASE_TABLE_NUMBER);
     cli->sendMsg(ResponseCode::GetMarkResult, mark_str, false);
 }
 
index 81bb863..091f95e 100644 (file)
 #ifndef _SECONDARY_TABLE_CONTROLLER_H
 #define _SECONDARY_TABLE_CONTROLLER_H
 
+#include <map>
+
 #include <sysutils/FrameworkListener.h>
 
 #include <net/if.h>
-#include "UidMarkMap.h"
 #include "NetdConstants.h"
+#include "NetworkController.h"
 
 #ifndef IFNAMSIZ
 #define IFNAMSIZ 16
 #endif
 
-static const int INTERFACES_TRACKED = 10;
 static const int BASE_TABLE_NUMBER = 60;
-static int MAX_TABLE_NUMBER = BASE_TABLE_NUMBER + INTERFACES_TRACKED;
 static const int PROTECT_MARK = 0x1;
 static const char *EXEMPT_PRIO = "99";
 static const char *RULE_PRIO = "100";
 
+// SecondaryTableController is responsible for maintaining the "secondary" routing tables, where
+// "secondary" means not the main table.  The "secondary" tables are used for VPNs.
 class SecondaryTableController {
 
 public:
-    SecondaryTableController(UidMarkMap *map);
+    SecondaryTableController(NetworkController* controller);
     virtual ~SecondaryTableController();
 
+    // Add/remove a particular route in a particular interface's table.
     int addRoute(SocketClient *cli, char *iface, char *dest, int prefixLen, char *gateway);
     int removeRoute(SocketClient *cli, char *iface, char *dest, int prefixLen, char *gateway);
-    int findTableNumber(const char *iface);
-    int modifyFromRule(int tableIndex, const char *action, const char *addr);
-    int modifyLocalRoute(int tableIndex, const char *action, const char *iface, const char *addr);
+
+    int modifyFromRule(unsigned netId, const char *action, const char *addr);
+    int modifyLocalRoute(unsigned netId, const char *action, const char *iface, const char *addr);
+
+    // Add/remove rules to force packets in a particular range of UIDs over a particular interface.
+    // This is accomplished with a rule specifying these UIDs use the interface's routing chain.
     int addUidRule(const char *iface, int uid_start, int uid_end);
     int removeUidRule(const char *iface, int uid_start, int uid_end);
+
+    // Add/remove rules and chains so packets intended for a particular interface use that
+    // interface.
     int addFwmarkRule(const char *iface);
     int removeFwmarkRule(const char *iface);
+
+    // Add/remove rules so packets going to a particular range of IPs use a particular interface.
+    // This is accomplished by adding/removeing a rule to/from an interface’s chain to mark packets
+    // destined for the IP address range with the mark for the interface’s table.
     int addFwmarkRoute(const char* iface, const char *dest, int prefix);
     int removeFwmarkRoute(const char* iface, const char *dest, int prefix);
+
+    // Add/remove rules so packets going to a particular IP address use the main table (i.e. not
+    // the VPN tables).  This is used in conjunction with adding a specific route to the main
+    // table.  This is to support requestRouteToHost().
+    // This is accomplished by marking these packets with the protect mark and adding a rule to
+    // use the main table.
     int addHostExemption(const char *host);
     int removeHostExemption(const char *host);
+
     void getUidMark(SocketClient *cli, int uid);
     void getProtectMark(SocketClient *cli);
 
@@ -66,19 +86,17 @@ public:
 
 
 private:
-    UidMarkMap *mUidMarkMap;
+    NetworkController *mNetCtrl;
 
     int setUidRule(const char* iface, int uid_start, int uid_end, bool add);
     int setFwmarkRule(const char *iface, bool add);
     int setFwmarkRoute(const char* iface, const char *dest, int prefix, bool add);
     int setHostExemption(const char *host, bool add);
     int modifyRoute(SocketClient *cli, const char *action, char *iface, char *dest, int prefix,
-            char *gateway, int tableIndex);
+            char *gateway, unsigned netId);
 
-    char mInterfaceTable[INTERFACES_TRACKED][IFNAMSIZ + 1];
-    int mInterfaceRuleCount[INTERFACES_TRACKED];
-    void modifyRuleCount(int tableIndex, const char *action);
-    int verifyTableIndex(int tableIndex);
+    std::map<unsigned, int> mNetIdRuleCount;
+    void modifyRuleCount(unsigned netId, const char *action);
     const char *getVersion(const char *addr);
     IptablesTarget getIptablesTarget(const char *addr);
 
diff --git a/UidMarkMap.cpp b/UidMarkMap.cpp
deleted file mode 100644 (file)
index d30ac53..0000000
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Copyright (C) 2013 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "UidMarkMap.h"
-
-UidMarkMap::UidMarkEntry::UidMarkEntry(int start, int end, int new_mark) :
-                                            uid_start(start),
-                                            uid_end(end),
-                                            mark(new_mark) {
-};
-
-bool UidMarkMap::add(int uid_start, int uid_end, int mark) {
-    android::RWLock::AutoWLock lock(mRWLock);
-    if (uid_start > uid_end) {
-        return false;
-    }
-    android::netd::List<UidMarkEntry*>::iterator it;
-    for (it = mMap.begin(); it != mMap.end(); it++) {
-        UidMarkEntry *entry = *it;
-        if (entry->uid_start <= uid_end && uid_start <= entry->uid_end) {
-            return false;
-        }
-    }
-
-    UidMarkEntry *e = new UidMarkEntry(uid_start, uid_end, mark);
-    mMap.push_back(e);
-    return true;
-};
-
-bool UidMarkMap::remove(int uid_start, int uid_end, int mark) {
-    android::RWLock::AutoWLock lock(mRWLock);
-    android::netd::List<UidMarkEntry*>::iterator it;
-    for (it = mMap.begin(); it != mMap.end(); it++) {
-        UidMarkEntry *entry = *it;
-        if (entry->uid_start == uid_start && entry->uid_end == uid_end && entry->mark == mark) {
-            mMap.erase(it);
-            delete entry;
-            return true;
-        }
-    }
-    return false;
-};
-
-int UidMarkMap::getMark(int uid) {
-    android::RWLock::AutoRLock lock(mRWLock);
-    android::netd::List<UidMarkEntry*>::iterator it;
-    for (it = mMap.begin(); it != mMap.end(); it++) {
-        UidMarkEntry *entry = *it;
-        if (entry->uid_start <= uid && entry->uid_end >= uid) {
-            return entry->mark;
-        }
-    }
-    return -1;
-};
-
-bool UidMarkMap::anyRulesForMark(int mark) {
-    android::RWLock::AutoRLock lock(mRWLock);
-    android::netd::List<UidMarkEntry*>::iterator it;
-    for (it = mMap.begin(); it != mMap.end(); it++) {
-        UidMarkEntry *entry = *it;
-        if (entry->mark == mark) {
-            return true;
-        }
-    }
-    return false;
-}
diff --git a/UidMarkMap.h b/UidMarkMap.h
deleted file mode 100644 (file)
index 43881be..0000000
+++ /dev/null
@@ -1,43 +0,0 @@
-/*
- * Copyright (C) 2013 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef _NETD_UIDMARKMAP_H
-#define _NETD_UIDMARKMAP_H
-
-#include <stddef.h>
-#include <stdint.h>
-#include <List.h>
-#include <utils/RWLock.h>
-
-class UidMarkMap {
-public:
-    bool add(int uid_start, int uid_end, int mark);
-    bool remove(int uid_start, int uid_end, int mark);
-    int getMark(int uid);
-    bool anyRulesForMark(int mark);
-
-private:
-    struct UidMarkEntry {
-        int uid_start;
-        int uid_end;
-        int mark;
-        UidMarkEntry(int uid_start, int uid_end, int mark);
-    };
-
-    android::RWLock mRWLock;
-    android::netd::List<UidMarkEntry*> mMap;
-};
-#endif
index 104ebe1..3a356d4 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -34,7 +34,6 @@
 #include "NetlinkManager.h"
 #include "DnsProxyListener.h"
 #include "MDnsSdListener.h"
-#include "UidMarkMap.h"
 
 static void coldboot(const char *path);
 static void sigchld_handler(int sig);
@@ -57,9 +56,7 @@ int main() {
         exit(1);
     };
 
-    UidMarkMap *rangeMap = new UidMarkMap();
-
-    cl = new CommandListener(rangeMap);
+    cl = new CommandListener();
     nm->setBroadcaster((SocketListener *) cl);
 
     if (nm->start()) {
@@ -70,7 +67,7 @@ int main() {
     // Set local DNS mode, to prevent bionic from proxying
     // back to this service, recursively.
     setenv("ANDROID_DNS_MODE", "local", 1);
-    dpl = new DnsProxyListener(rangeMap);
+    dpl = new DnsProxyListener(CommandListener::sNetCtrl);
     if (dpl->startListener()) {
         ALOGE("Unable to start DnsProxyListener (%s)", strerror(errno));
         exit(1);