OSDN Git Service

Use proper socket mark for DNS resolution.
authorPaul Jensen <pauljensen@google.com>
Thu, 17 Apr 2014 21:25:43 +0000 (17:25 -0400)
committerLorenzo Colitti <lorenzo@google.com>
Wed, 14 May 2014 09:42:00 +0000 (02:42 -0700)
Change-Id: I531ff0fbce6e7172b82bc2d4c7406a324603eb8a

CommandListener.h
DnsProxyListener.cpp
DnsProxyListener.h
main.cpp

index ad24d87..ada79eb 100644 (file)
@@ -48,11 +48,11 @@ class CommandListener : public FrameworkListener {
     static SecondaryTableController *sSecondaryTableCtrl;
     static FirewallController *sFirewallCtrl;
     static ClatdController *sClatdCtrl;
-    static PermissionsController* sPermissionsController;
     static RouteController* sRouteController;
 
 public:
     static NetworkController *sNetCtrl;
+    static PermissionsController* sPermissionsController;
 
     CommandListener();
     virtual ~CommandListener() {}
index 9f97d29..7379788 100644 (file)
 #include <cutils/log.h>
 #include <sysutils/SocketClient.h>
 
-#include "NetdConstants.h"
+#include "Fwmark.h"
 #include "DnsProxyListener.h"
+#include "NetdConstants.h"
+#include "NetworkController.h"
+#include "PermissionsController.h"
 #include "ResponseCode.h"
 
-DnsProxyListener::DnsProxyListener(const NetworkController* controller) :
+DnsProxyListener::DnsProxyListener(const NetworkController* netCtrl,
+        const PermissionsController* permCtrl) :
                  FrameworkListener("dnsproxyd"),
-                 mNetCtrl(controller) {
-    registerCmd(new GetAddrInfoCmd(controller));
-    registerCmd(new GetHostByAddrCmd(controller));
-    registerCmd(new GetHostByNameCmd(controller));
+                 mNetCtrl(netCtrl),
+                 mPermCtrl(permCtrl) {
+    registerCmd(new GetAddrInfoCmd(this));
+    registerCmd(new GetHostByAddrCmd(this));
+    registerCmd(new GetHostByNameCmd(this));
+}
+
+uint32_t DnsProxyListener::calcMark(SocketClient *c, unsigned netId) const {
+    // If netd's UID is forced into a VPN that isn't the intended network,
+    // use VPN protect bit to force it into the desired network.
+    bool vpnProtect = mNetCtrl->getNetwork(getuid(), netId, true) != netId;
+    return getFwmark(netId, false, vpnProtect, mPermCtrl->getPermissionForUser(c->getUid()));
 }
 
 DnsProxyListener::GetAddrInfoHandler::GetAddrInfoHandler(SocketClient *c,
                                                          char* host,
                                                          char* service,
                                                          struct addrinfo* hints,
-                                                         unsigned netId)
+                                                         unsigned netId,
+                                                         uint32_t mark)
         : mClient(c),
           mHost(host),
           mService(service),
           mHints(hints),
-          mNetId(netId) {
+          mNetId(netId),
+          mMark(mark) {
 }
 
 DnsProxyListener::GetAddrInfoHandler::~GetAddrInfoHandler() {
@@ -118,11 +132,11 @@ static bool sendhostent(SocketClient *c, struct hostent *hp) {
 
 void DnsProxyListener::GetAddrInfoHandler::run() {
     if (DBG) {
-        ALOGD("GetAddrInfoHandler, now for %s / %s / %u", mHost, mService, mNetId);
+        ALOGD("GetAddrInfoHandler, now for %s / %s / %u / %u", mHost, mService, mNetId, mMark);
     }
 
     struct addrinfo* result = NULL;
-    uint32_t rv = android_getaddrinfofornet(mHost, mService, mHints, mNetId, 0, &result);
+    uint32_t rv = android_getaddrinfofornet(mHost, mService, mHints, mNetId, mMark, &result);
     if (rv) {
         // getaddrinfo failed
         mClient->sendBinaryMsg(ResponseCode::DnsProxyOperationFailed, &rv, sizeof(rv));
@@ -148,9 +162,9 @@ void DnsProxyListener::GetAddrInfoHandler::run() {
     mClient->decRef();
 }
 
-DnsProxyListener::GetAddrInfoCmd::GetAddrInfoCmd(const NetworkController* controller) :
+DnsProxyListener::GetAddrInfoCmd::GetAddrInfoCmd(const DnsProxyListener* dnsProxyListener) :
     NetdCommand("getaddrinfo"),
-    mNetCtrl(controller) {
+    mDnsProxyListener(dnsProxyListener) {
 }
 
 int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
@@ -191,7 +205,8 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
     unsigned netId = strtoul(argv[7], NULL, 10);
     uid_t uid = cli->getUid();
 
-    netId = mNetCtrl->getNetwork(uid, netId, true);
+    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
 
     if (ai_flags != -1 || ai_family != -1 ||
         ai_socktype != -1 || ai_protocol != -1) {
@@ -203,21 +218,22 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
 
         // Only implement AI_ADDRCONFIG if application is using default network since our
         // implementation only works on the default network.
-        if ((hints->ai_flags & AI_ADDRCONFIG) && netId != mNetCtrl->getDefaultNetwork()) {
+        if ((hints->ai_flags & AI_ADDRCONFIG) &&
+                netId != mDnsProxyListener->mNetCtrl->getDefaultNetwork()) {
             hints->ai_flags &= ~AI_ADDRCONFIG;
         }
     }
 
     if (DBG) {
-        ALOGD("GetAddrInfoHandler for %s / %s / %u / %d",
+        ALOGD("GetAddrInfoHandler for %s / %s / %u / %d / %u",
              name ? name : "[nullhost]",
              service ? service : "[nullservice]",
-             netId, uid);
+             netId, uid, mark);
     }
 
     cli->incRef();
     DnsProxyListener::GetAddrInfoHandler* handler =
-            new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, netId);
+            new DnsProxyListener::GetAddrInfoHandler(cli, name, service, hints, netId, mark);
     handler->start();
 
     return 0;
@@ -226,9 +242,9 @@ int DnsProxyListener::GetAddrInfoCmd::runCommand(SocketClient *cli,
 /*******************************************************
  *                  GetHostByName                      *
  *******************************************************/
-DnsProxyListener::GetHostByNameCmd::GetHostByNameCmd(const NetworkController* controller) :
+DnsProxyListener::GetHostByNameCmd::GetHostByNameCmd(const DnsProxyListener* dnsProxyListener) :
       NetdCommand("gethostbyname"),
-      mNetCtrl(controller) {
+      mDnsProxyListener(dnsProxyListener) {
 }
 
 int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
@@ -258,11 +274,12 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
         name = strdup(name);
     }
 
-    netId = mNetCtrl->getNetwork(uid, netId, true);
+    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
 
     cli->incRef();
     DnsProxyListener::GetHostByNameHandler* handler =
-            new DnsProxyListener::GetHostByNameHandler(cli, name, af, netId);
+            new DnsProxyListener::GetHostByNameHandler(cli, name, af, netId, mark);
     handler->start();
 
     return 0;
@@ -271,11 +288,13 @@ int DnsProxyListener::GetHostByNameCmd::runCommand(SocketClient *cli,
 DnsProxyListener::GetHostByNameHandler::GetHostByNameHandler(SocketClient* c,
                                                              char* name,
                                                              int af,
-                                                             unsigned netId)
+                                                             unsigned netId,
+                                                             uint32_t mark)
         : mClient(c),
           mName(name),
           mAf(af),
-          mNetId(netId) {
+          mNetId(netId),
+          mMark(mark) {
 }
 
 DnsProxyListener::GetHostByNameHandler::~GetHostByNameHandler() {
@@ -304,7 +323,7 @@ void DnsProxyListener::GetHostByNameHandler::run() {
 
     struct hostent* hp;
 
-    hp = android_gethostbynamefornet(mName, mAf, mNetId, 0);
+    hp = android_gethostbynamefornet(mName, mAf, mNetId, mMark);
 
     if (DBG) {
         ALOGD("GetHostByNameHandler::run gethostbyname errno: %s hp->h_name = %s, name_len = %zu\n",
@@ -331,9 +350,9 @@ void DnsProxyListener::GetHostByNameHandler::run() {
 /*******************************************************
  *                  GetHostByAddr                      *
  *******************************************************/
-DnsProxyListener::GetHostByAddrCmd::GetHostByAddrCmd(const NetworkController* controller) :
+DnsProxyListener::GetHostByAddrCmd::GetHostByAddrCmd(const DnsProxyListener* dnsProxyListener) :
         NetdCommand("gethostbyaddr"),
-        mNetCtrl(controller) {
+        mDnsProxyListener(dnsProxyListener) {
 }
 
 int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
@@ -371,11 +390,12 @@ int DnsProxyListener::GetHostByAddrCmd::runCommand(SocketClient *cli,
         return -1;
     }
 
-    netId = mNetCtrl->getNetwork(uid, netId, true);
+    netId = mDnsProxyListener->mNetCtrl->getNetwork(uid, netId, true);
+    uint32_t mark = mDnsProxyListener->calcMark(cli, netId);
 
     cli->incRef();
     DnsProxyListener::GetHostByAddrHandler* handler =
-            new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netId);
+            new DnsProxyListener::GetHostByAddrHandler(cli, addr, addrLen, addrFamily, netId, mark);
     handler->start();
 
     return 0;
@@ -385,12 +405,14 @@ DnsProxyListener::GetHostByAddrHandler::GetHostByAddrHandler(SocketClient* c,
                                                              void* address,
                                                              int   addressLen,
                                                              int   addressFamily,
-                                                             unsigned netId)
+                                                             unsigned netId,
+                                                             uint32_t mark)
         : mClient(c),
           mAddress(address),
           mAddressLen(addressLen),
           mAddressFamily(addressFamily),
-          mNetId(netId) {
+          mNetId(netId),
+          mMark(mark) {
 }
 
 DnsProxyListener::GetHostByAddrHandler::~GetHostByAddrHandler() {
@@ -419,7 +441,7 @@ void DnsProxyListener::GetHostByAddrHandler::run() {
     struct hostent* hp;
 
     // NOTE gethostbyaddr should take a void* but bionic thinks it should be char*
-    hp = android_gethostbyaddrfornet((char*)mAddress, mAddressLen, mAddressFamily, mNetId, 0);
+    hp = android_gethostbyaddrfornet((char*)mAddress, mAddressLen, mAddressFamily, mNetId, mMark);
 
     if (DBG) {
         ALOGD("GetHostByAddrHandler::run gethostbyaddr errno: %s hp->h_name = %s, name_len = %zu\n",
index 345928f..936d6aa 100644 (file)
 #include <sysutils/FrameworkListener.h>
 
 #include "NetdCommand.h"
-#include "NetworkController.h"
+
+class NetworkController;
+class PermissionsController;
 
 class DnsProxyListener : public FrameworkListener {
 public:
-    DnsProxyListener(const NetworkController* controller);
+    DnsProxyListener(const NetworkController* netCtrl, const PermissionsController* permCtrl);
     virtual ~DnsProxyListener() {}
 
 private:
     const NetworkController *mNetCtrl;
+    const PermissionsController *mPermCtrl;
     class GetAddrInfoCmd : public NetdCommand {
     public:
-        GetAddrInfoCmd(const NetworkController* controller);
+        GetAddrInfoCmd(const DnsProxyListener* dnsProxyListener);
         virtual ~GetAddrInfoCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        const NetworkController* mNetCtrl;
+        const DnsProxyListener* mDnsProxyListener;
     };
 
     class GetAddrInfoHandler {
@@ -45,7 +48,8 @@ private:
                            char* host,
                            char* service,
                            struct addrinfo* hints,
-                           unsigned netId);
+                           unsigned netId,
+                           uint32_t mark);
         ~GetAddrInfoHandler();
 
         static void* threadStart(void* handler);
@@ -58,16 +62,17 @@ private:
         char* mService; // owned
         struct addrinfo* mHints;  // owned
         unsigned mNetId;
+        uint32_t mMark;
     };
 
     /* ------ gethostbyname ------*/
     class GetHostByNameCmd : public NetdCommand {
     public:
-        GetHostByNameCmd(const NetworkController* controller);
+        GetHostByNameCmd(const DnsProxyListener* dnsProxyListener);
         virtual ~GetHostByNameCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        const NetworkController* mNetCtrl;
+        const DnsProxyListener* mDnsProxyListener;
     };
 
     class GetHostByNameHandler {
@@ -75,7 +80,8 @@ private:
         GetHostByNameHandler(SocketClient *c,
                             char *name,
                             int af,
-                            unsigned netId);
+                            unsigned netId,
+                            uint32_t mark);
         ~GetHostByNameHandler();
         static void* threadStart(void* handler);
         void start();
@@ -85,16 +91,17 @@ private:
         char* mName; // owned
         int mAf;
         unsigned mNetId;
+        uint32_t mMark;
     };
 
     /* ------ gethostbyaddr ------*/
     class GetHostByAddrCmd : public NetdCommand {
     public:
-        GetHostByAddrCmd(const NetworkController* controller);
+        GetHostByAddrCmd(const DnsProxyListener* dnsProxyListener);
         virtual ~GetHostByAddrCmd() {}
         int runCommand(SocketClient *c, int argc, char** argv);
     private:
-        const NetworkController* mNetCtrl;
+        const DnsProxyListener* mDnsProxyListener;
     };
 
     class GetHostByAddrHandler {
@@ -103,7 +110,8 @@ private:
                             void* address,
                             int addressLen,
                             int addressFamily,
-                            unsigned netId);
+                            unsigned netId,
+                            uint32_t mark);
         ~GetHostByAddrHandler();
 
         static void* threadStart(void* handler);
@@ -116,7 +124,11 @@ private:
         int mAddressLen; // length of address to look up
         int mAddressFamily;  // address family
         unsigned mNetId;
+        uint32_t mMark;
     };
+
+    // Calculate the socket mark to use for a DNS resolution.
+    uint32_t calcMark(SocketClient *c, unsigned netId) const;
 };
 
 #endif
index 3a356d4..90f1e6f 100644 (file)
--- a/main.cpp
+++ b/main.cpp
@@ -67,7 +67,7 @@ int main() {
     // Set local DNS mode, to prevent bionic from proxying
     // back to this service, recursively.
     setenv("ANDROID_DNS_MODE", "local", 1);
-    dpl = new DnsProxyListener(CommandListener::sNetCtrl);
+    dpl = new DnsProxyListener(CommandListener::sNetCtrl, CommandListener::sPermissionsController);
     if (dpl->startListener()) {
         ALOGE("Unable to start DnsProxyListener (%s)", strerror(errno));
         exit(1);