OSDN Git Service

Add per user VPN support
authorChad Brubaker <cbrubaker@google.com>
Fri, 14 Jun 2013 18:16:51 +0000 (11:16 -0700)
committerChad Brubaker <cbrubaker@google.com>
Sat, 13 Jul 2013 03:51:03 +0000 (20:51 -0700)
VPNs are now per user instead of global. A VPN set by user A routes only
user A's traffic and no other user can access it.

Change-Id: Ia66463637b6bd088b05768076a1db897fe95c46c

core/java/android/net/VpnService.java
core/java/com/android/internal/net/VpnConfig.java
services/java/com/android/server/ConnectivityService.java
services/java/com/android/server/connectivity/Vpn.java
services/java/com/android/server/net/LockdownVpnTracker.java

index 733de94..d7dc7f5 100644 (file)
@@ -36,6 +36,7 @@ import java.net.Inet6Address;
 import java.net.InetAddress;
 import java.net.Socket;
 import java.util.ArrayList;
+import java.util.List;
 
 /**
  * VpnService is a base class for applications to extend and build their
@@ -253,8 +254,8 @@ public class VpnService extends Service {
     public class Builder {
 
         private final VpnConfig mConfig = new VpnConfig();
-        private final StringBuilder mAddresses = new StringBuilder();
-        private final StringBuilder mRoutes = new StringBuilder();
+        private final List<LinkAddress> mAddresses = new ArrayList<LinkAddress>();
+        private final List<RouteInfo> mRoutes = new ArrayList<RouteInfo>();
 
         public Builder() {
             mConfig.user = VpnService.this.getClass().getName();
@@ -328,9 +329,7 @@ public class VpnService extends Service {
             if (address.isAnyLocalAddress()) {
                 throw new IllegalArgumentException("Bad address");
             }
-
-            mAddresses.append(' ')
-                    .append(address.getHostAddress()).append('/').append(prefixLength);
+            mAddresses.add(new LinkAddress(address, prefixLength));
             return this;
         }
 
@@ -364,8 +363,7 @@ public class VpnService extends Service {
                     }
                 }
             }
-
-            mRoutes.append(' ').append(address.getHostAddress()).append('/').append(prefixLength);
+            mRoutes.add(new RouteInfo(new LinkAddress(address, prefixLength), null));
             return this;
         }
 
@@ -466,8 +464,8 @@ public class VpnService extends Service {
          * @see VpnService
          */
         public ParcelFileDescriptor establish() {
-            mConfig.addresses = mAddresses.toString();
-            mConfig.routes = mRoutes.toString();
+            mConfig.addresses = mAddresses;
+            mConfig.routes = mRoutes;
 
             try {
                 return getService().establishVpn(mConfig);
index 956653b..abf99a3 100644 (file)
@@ -21,10 +21,14 @@ import android.content.Context;
 import android.content.Intent;
 import android.os.Parcel;
 import android.os.Parcelable;
+import android.net.RouteInfo;
+import android.net.LinkAddress;
 
 import com.android.internal.util.Preconditions;
 
+import java.net.InetAddress;
 import java.util.List;
+import java.util.ArrayList;
 
 /**
  * A simple container used to carry information in VpnBuilder, VpnDialogs,
@@ -61,14 +65,42 @@ public class VpnConfig implements Parcelable {
     public String interfaze;
     public String session;
     public int mtu = -1;
-    public String addresses;
-    public String routes;
+    public List<LinkAddress> addresses = new ArrayList<LinkAddress>();
+    public List<RouteInfo> routes = new ArrayList<RouteInfo>();
     public List<String> dnsServers;
     public List<String> searchDomains;
     public PendingIntent configureIntent;
     public long startTime = -1;
     public boolean legacy;
 
+    public void addLegacyRoutes(String routesStr) {
+        if (routesStr.trim().equals("")) {
+            return;
+        }
+        String[] routes = routesStr.trim().split(" ");
+        for (String route : routes) {
+            //each route is ip/prefix
+            String[] split = route.split("/");
+            RouteInfo info = new RouteInfo(new LinkAddress
+                    (InetAddress.parseNumericAddress(split[0]), Integer.parseInt(split[1])), null);
+            this.routes.add(info);
+        }
+    }
+
+    public void addLegacyAddresses(String addressesStr) {
+        if (addressesStr.trim().equals("")) {
+            return;
+        }
+        String[] addresses = addressesStr.trim().split(" ");
+        for (String address : addresses) {
+            //each address is ip/prefix
+            String[] split = address.split("/");
+            LinkAddress addr = new LinkAddress(InetAddress.parseNumericAddress(split[0]),
+                    Integer.parseInt(split[1]));
+            this.addresses.add(addr);
+        }
+    }
+
     @Override
     public int describeContents() {
         return 0;
@@ -80,8 +112,8 @@ public class VpnConfig implements Parcelable {
         out.writeString(interfaze);
         out.writeString(session);
         out.writeInt(mtu);
-        out.writeString(addresses);
-        out.writeString(routes);
+        out.writeTypedList(addresses);
+        out.writeTypedList(routes);
         out.writeStringList(dnsServers);
         out.writeStringList(searchDomains);
         out.writeParcelable(configureIntent, flags);
@@ -98,8 +130,8 @@ public class VpnConfig implements Parcelable {
             config.interfaze = in.readString();
             config.session = in.readString();
             config.mtu = in.readInt();
-            config.addresses = in.readString();
-            config.routes = in.readString();
+            in.readTypedList(config.addresses, LinkAddress.CREATOR);
+            in.readTypedList(config.routes, RouteInfo.CREATOR);
             config.dnsServers = in.createStringArrayList();
             config.searchDomains = in.createStringArrayList();
             config.configureIntent = in.readParcelable(null);
index b148b91..a6344ca 100644 (file)
@@ -97,6 +97,7 @@ import android.telephony.TelephonyManager;
 import android.text.TextUtils;
 import android.util.Slog;
 import android.util.SparseIntArray;
+import android.util.SparseArray;
 
 import com.android.internal.R;
 import com.android.internal.net.LegacyVpnInfo;
@@ -116,6 +117,8 @@ import com.android.server.net.LockdownVpnTracker;
 import com.google.android.collect.Lists;
 import com.google.android.collect.Sets;
 
+import com.android.internal.annotations.GuardedBy;
+
 import dalvik.system.DexClassLoader;
 
 import java.io.FileDescriptor;
@@ -171,7 +174,8 @@ public class ConnectivityService extends IConnectivityManager.Stub {
 
     private KeyStore mKeyStore;
 
-    private Vpn mVpn;
+    @GuardedBy("mVpns")
+    private final SparseArray<Vpn> mVpns = new SparseArray<Vpn>();
     private VpnCallback mVpnCallback = new VpnCallback();
 
     private boolean mLockdownEnabled;
@@ -583,10 +587,13 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                                   mTethering.getTetherableWifiRegexs().length != 0 ||
                                   mTethering.getTetherableBluetoothRegexs().length != 0) &&
                                  mTethering.getUpstreamIfaceTypes().length != 0);
+        //set up the listener for user state for creating user VPNs
 
-        mVpn = new Vpn(mContext, mVpnCallback, mNetd, this);
-        mVpn.startMonitoring(mContext, mTrackerHandler);
-
+        IntentFilter intentFilter = new IntentFilter();
+        intentFilter.addAction(Intent.ACTION_USER_STARTING);
+        intentFilter.addAction(Intent.ACTION_USER_STOPPING);
+        mContext.registerReceiverAsUser(
+                mUserIntentReceiver, UserHandle.ALL, intentFilter, null, null);
         mClat = new Nat464Xlat(mContext, mNetd, this, mTrackerHandler);
 
         try {
@@ -2313,7 +2320,11 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                             // Tell VPN the interface is down. It is a temporary
                             // but effective fix to make VPN aware of the change.
                             if ((resetMask & NetworkUtils.RESET_IPV4_ADDRESSES) != 0) {
-                                mVpn.interfaceStatusChanged(iface, false);
+                                synchronized(mVpns) {
+                                    for (int i = 0; i < mVpns.size(); i++) {
+                                        mVpns.valueAt(i).interfaceStatusChanged(iface, false);
+                                    }
+                                }
                             }
                         }
                         if (resetDns) {
@@ -2570,7 +2581,6 @@ public class ConnectivityService extends IConnectivityManager.Stub {
 
         try {
             mNetd.setDnsServersForInterface(iface, NetworkUtils.makeStrings(dnses), domains);
-            mNetd.setDefaultInterfaceForDns(iface);
             for (InetAddress dns : dnses) {
                 ++last;
                 String key = "net.dns" + last;
@@ -3305,8 +3315,12 @@ public class ConnectivityService extends IConnectivityManager.Stub {
         throwIfLockdownEnabled();
         try {
             int type = mActiveDefaultNetwork;
+            int user = UserHandle.getUserId(Binder.getCallingUid());
             if (ConnectivityManager.isNetworkTypeValid(type) && mNetTrackers[type] != null) {
-                mVpn.protect(socket, mNetTrackers[type].getLinkProperties().getInterfaceName());
+                synchronized(mVpns) {
+                    mVpns.get(user).protect(socket,
+                            mNetTrackers[type].getLinkProperties().getInterfaceName());
+                }
                 return true;
             }
         } catch (Exception e) {
@@ -3330,7 +3344,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
     @Override
     public boolean prepareVpn(String oldPackage, String newPackage) {
         throwIfLockdownEnabled();
-        return mVpn.prepare(oldPackage, newPackage);
+        int user = UserHandle.getUserId(Binder.getCallingUid());
+        synchronized(mVpns) {
+            return mVpns.get(user).prepare(oldPackage, newPackage);
+        }
     }
 
     /**
@@ -3343,7 +3360,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
     @Override
     public ParcelFileDescriptor establishVpn(VpnConfig config) {
         throwIfLockdownEnabled();
-        return mVpn.establish(config);
+        int user = UserHandle.getUserId(Binder.getCallingUid());
+        synchronized(mVpns) {
+            return mVpns.get(user).establish(config);
+        }
     }
 
     /**
@@ -3357,7 +3377,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
         if (egress == null) {
             throw new IllegalStateException("Missing active network connection");
         }
-        mVpn.startLegacyVpn(profile, mKeyStore, egress);
+        int user = UserHandle.getUserId(Binder.getCallingUid());
+        synchronized(mVpns) {
+            mVpns.get(user).startLegacyVpn(profile, mKeyStore, egress);
+        }
     }
 
     /**
@@ -3369,7 +3392,10 @@ public class ConnectivityService extends IConnectivityManager.Stub {
     @Override
     public LegacyVpnInfo getLegacyVpnInfo() {
         throwIfLockdownEnabled();
-        return mVpn.getLegacyVpnInfo();
+        int user = UserHandle.getUserId(Binder.getCallingUid());
+        synchronized(mVpns) {
+            return mVpns.get(user).getLegacyVpnInfo();
+        }
     }
 
     /**
@@ -3390,7 +3416,7 @@ public class ConnectivityService extends IConnectivityManager.Stub {
             mHandler.obtainMessage(EVENT_VPN_STATE_CHANGED, info).sendToTarget();
         }
 
-        public void override(List<String> dnsServers, List<String> searchDomains) {
+        public void override(String iface, List<String> dnsServers, List<String> searchDomains) {
             if (dnsServers == null) {
                 restore();
                 return;
@@ -3422,7 +3448,7 @@ public class ConnectivityService extends IConnectivityManager.Stub {
 
             // Apply DNS changes.
             synchronized (mDnsLock) {
-                updateDnsLocked("VPN", "VPN", addresses, domains);
+                updateDnsLocked("VPN", iface, addresses, domains);
                 mDnsOverridden = true;
             }
 
@@ -3451,6 +3477,67 @@ public class ConnectivityService extends IConnectivityManager.Stub {
                 }
             }
         }
+
+        public void protect(ParcelFileDescriptor socket) {
+            try {
+                final int mark = mNetd.getMarkForProtect();
+                NetworkUtils.markSocket(socket.getFd(), mark);
+            } catch (RemoteException e) {
+            }
+        }
+
+        public void setRoutes(String interfaze, List<RouteInfo> routes) {
+            for (RouteInfo route : routes) {
+                try {
+                    mNetd.setMarkedForwardingRoute(interfaze, route);
+                } catch (RemoteException e) {
+                }
+            }
+        }
+
+        public void setMarkedForwarding(String interfaze) {
+            try {
+                mNetd.setMarkedForwarding(interfaze);
+            } catch (RemoteException e) {
+            }
+        }
+
+        public void clearMarkedForwarding(String interfaze) {
+            try {
+                mNetd.clearMarkedForwarding(interfaze);
+            } catch (RemoteException e) {
+            }
+        }
+
+        public void addUserForwarding(String interfaze, int uid) {
+            int uidStart = uid * UserHandle.PER_USER_RANGE;
+            int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1;
+            addUidForwarding(interfaze, uidStart, uidEnd);
+        }
+
+        public void clearUserForwarding(String interfaze, int uid) {
+            int uidStart = uid * UserHandle.PER_USER_RANGE;
+            int uidEnd = uidStart + UserHandle.PER_USER_RANGE - 1;
+            clearUidForwarding(interfaze, uidStart, uidEnd);
+        }
+
+        public void addUidForwarding(String interfaze, int uidStart, int uidEnd) {
+            try {
+                mNetd.setUidRangeRoute(interfaze,uidStart, uidEnd);
+                mNetd.setDnsInterfaceForUidRange(interfaze, uidStart, uidEnd);
+            } catch (RemoteException e) {
+            }
+
+        }
+
+        public void clearUidForwarding(String interfaze, int uidStart, int uidEnd) {
+            try {
+                mNetd.clearUidRangeRoute(interfaze, uidStart, uidEnd);
+                mNetd.clearDnsInterfaceForUidRange(uidStart, uidEnd);
+            } catch (RemoteException e) {
+            }
+
+        }
     }
 
     @Override
@@ -3471,7 +3558,11 @@ public class ConnectivityService extends IConnectivityManager.Stub {
             final String profileName = new String(mKeyStore.get(Credentials.LOCKDOWN_VPN));
             final VpnProfile profile = VpnProfile.decode(
                     profileName, mKeyStore.get(Credentials.VPN + profileName));
-            setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpn, profile));
+            int user = UserHandle.getUserId(Binder.getCallingUid());
+            synchronized(mVpns) {
+                setLockdownTracker(new LockdownVpnTracker(mContext, mNetd, this, mVpns.get(user),
+                            profile));
+            }
         } else {
             setLockdownTracker(null);
         }
@@ -4002,4 +4093,43 @@ public class ConnectivityService extends IConnectivityManager.Stub {
 
         return url;
     }
+
+    private void onUserStart(int userId) {
+        synchronized(mVpns) {
+            Vpn userVpn = mVpns.get(userId);
+            if (userVpn != null) {
+                loge("Starting user already has a VPN");
+                return;
+            }
+            userVpn = new Vpn(mContext, mVpnCallback, mNetd, this, userId);
+            mVpns.put(userId, userVpn);
+            userVpn.startMonitoring(mContext, mTrackerHandler);
+        }
+    }
+
+    private void onUserStop(int userId) {
+        synchronized(mVpns) {
+            Vpn userVpn = mVpns.get(userId);
+            if (userVpn == null) {
+                loge("Stopping user has no VPN");
+                return;
+            }
+            mVpns.delete(userId);
+        }
+    }
+
+    private BroadcastReceiver mUserIntentReceiver = new BroadcastReceiver() {
+        @Override
+        public void onReceive(Context context, Intent intent) {
+            final String action = intent.getAction();
+            final int userId = intent.getIntExtra(Intent.EXTRA_USER_HANDLE, UserHandle.USER_NULL);
+            if (userId == UserHandle.USER_NULL) return;
+
+            if (Intent.ACTION_USER_STARTING.equals(action)) {
+                onUserStart(userId);
+            } else if (Intent.ACTION_USER_STOPPING.equals(action)) {
+                onUserStop(userId);
+            }
+        }
+    };
 }
index 63d3958..efa1c8e 100644 (file)
@@ -18,6 +18,7 @@ package com.android.server.connectivity;
 
 import static android.Manifest.permission.BIND_VPN_SERVICE;
 
+import android.app.AppGlobals;
 import android.app.Notification;
 import android.app.NotificationManager;
 import android.app.PendingIntent;
@@ -28,6 +29,7 @@ import android.content.Intent;
 import android.content.IntentFilter;
 import android.content.ServiceConnection;
 import android.content.pm.ApplicationInfo;
+import android.content.pm.IPackageManager;
 import android.content.pm.PackageManager;
 import android.content.pm.ResolveInfo;
 import android.graphics.Bitmap;
@@ -37,6 +39,7 @@ import android.net.BaseNetworkStateTracker;
 import android.net.ConnectivityManager;
 import android.net.IConnectivityManager;
 import android.net.INetworkManagementEventObserver;
+import android.net.LinkAddress;
 import android.net.LinkProperties;
 import android.net.LocalSocket;
 import android.net.LocalSocketAddress;
@@ -54,6 +57,7 @@ import android.os.RemoteException;
 import android.os.SystemClock;
 import android.os.SystemService;
 import android.os.UserHandle;
+import android.os.RemoteException;
 import android.security.Credentials;
 import android.security.KeyStore;
 import android.util.Log;
@@ -74,6 +78,7 @@ import java.net.Inet4Address;
 import java.net.InetAddress;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
+import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import libcore.io.IoUtils;
@@ -98,14 +103,16 @@ public class Vpn extends BaseNetworkStateTracker {
     private volatile boolean mEnableNotif = true;
     private volatile boolean mEnableTeardown = true;
     private final IConnectivityManager mConnService;
+    private final int mUserId;
 
     public Vpn(Context context, VpnCallback callback, INetworkManagementService netService,
-            IConnectivityManager connService) {
+            IConnectivityManager connService, int userId) {
         // TODO: create dedicated TYPE_VPN network type
         super(ConnectivityManager.TYPE_DUMMY);
         mContext = context;
         mCallback = callback;
         mConnService = connService;
+        mUserId = userId;
 
         try {
             netService.registerObserver(mObserver);
@@ -197,14 +204,17 @@ public class Vpn extends BaseNetworkStateTracker {
 
         // Reset the interface and hide the notification.
         if (mInterface != null) {
-            jniReset(mInterface);
             final long token = Binder.clearCallingIdentity();
             try {
                 mCallback.restore();
+                mCallback.clearUserForwarding(mInterface, mUserId);
+
+                mCallback.clearMarkedForwarding(mInterface);
                 hideNotification();
             } finally {
                 Binder.restoreCallingIdentity(token);
             }
+            jniReset(mInterface);
             mInterface = null;
         }
 
@@ -237,12 +247,22 @@ public class Vpn extends BaseNetworkStateTracker {
      * @param interfaze The name of the interface.
      */
     public void protect(ParcelFileDescriptor socket, String interfaze) throws Exception {
+
         PackageManager pm = mContext.getPackageManager();
-        ApplicationInfo app = pm.getApplicationInfo(mPackage, 0);
-        if (Binder.getCallingUid() != app.uid) {
+        int appUid = pm.getPackageUid(mPackage, mUserId);
+        if (Binder.getCallingUid() != appUid) {
             throw new SecurityException("Unauthorized Caller");
         }
+        //protect the socket from routing rules
+        final long token = Binder.clearCallingIdentity();
+        try {
+            mCallback.protect(socket);
+        } finally {
+            Binder.restoreCallingIdentity(token);
+        }
+        //bind the socket to the interface
         jniProtect(socket.getFd(), interfaze);
+
     }
 
     /**
@@ -258,23 +278,31 @@ public class Vpn extends BaseNetworkStateTracker {
         PackageManager pm = mContext.getPackageManager();
         ApplicationInfo app = null;
         try {
-            app = pm.getApplicationInfo(mPackage, 0);
+            app = AppGlobals.getPackageManager().getApplicationInfo(mPackage, 0, mUserId);
+            if (Binder.getCallingUid() != app.uid) {
+                return null;
+            }
         } catch (Exception e) {
             return null;
         }
-        if (Binder.getCallingUid() != app.uid) {
-            return null;
-        }
 
         // Check if the service is properly declared.
         Intent intent = new Intent(VpnConfig.SERVICE_INTERFACE);
         intent.setClassName(mPackage, config.user);
-        ResolveInfo info = pm.resolveService(intent, 0);
-        if (info == null) {
-            throw new SecurityException("Cannot find " + config.user);
-        }
-        if (!BIND_VPN_SERVICE.equals(info.serviceInfo.permission)) {
-            throw new SecurityException(config.user + " does not require " + BIND_VPN_SERVICE);
+        long token = Binder.clearCallingIdentity();
+        try {
+            ResolveInfo info = AppGlobals.getPackageManager().resolveService(intent,
+                                                                        null, 0, mUserId);
+            if (info == null) {
+                throw new SecurityException("Cannot find " + config.user);
+            }
+            if (!BIND_VPN_SERVICE.equals(info.serviceInfo.permission)) {
+                throw new SecurityException(config.user + " does not require " + BIND_VPN_SERVICE);
+            }
+        } catch (RemoteException e) {
+                throw new SecurityException("Cannot find " + config.user);
+        } finally {
+            Binder.restoreCallingIdentity(token);
         }
 
         // Load the label.
@@ -300,14 +328,18 @@ public class Vpn extends BaseNetworkStateTracker {
         try {
             updateState(DetailedState.CONNECTING, "establish");
             String interfaze = jniGetName(tun.getFd());
-            if (jniSetAddresses(interfaze, config.addresses) < 1) {
-                throw new IllegalArgumentException("At least one address must be specified");
+
+            //TEMP use the old jni calls until there is support for netd address setting
+            StringBuilder builder = new StringBuilder();
+            for (LinkAddress address : config.addresses) {
+                builder.append(" " + address);
             }
-            if (config.routes != null) {
-                jniSetRoutes(interfaze, config.routes);
+            if (jniSetAddresses(interfaze, builder.toString()) < 1) {
+                throw new IllegalArgumentException("At least one address must be specified");
             }
             Connection connection = new Connection();
-            if (!mContext.bindService(intent, connection, Context.BIND_AUTO_CREATE)) {
+            if (!mContext.bindServiceAsUser(intent, connection, Context.BIND_AUTO_CREATE,
+                        new UserHandle(mUserId))) {
                 throw new IllegalStateException("Cannot bind " + config.user);
             }
             if (mConnection != null) {
@@ -318,25 +350,37 @@ public class Vpn extends BaseNetworkStateTracker {
             }
             mConnection = connection;
             mInterface = interfaze;
+
+            // Fill more values.
+            config.user = mPackage;
+            config.interfaze = mInterface;
+            // Set up forwarding and DNS rules.
+            token = Binder.clearCallingIdentity();
+            try {
+                mCallback.setMarkedForwarding(mInterface);
+                mCallback.setRoutes(interfaze, config.routes);
+                mCallback.override(mInterface, config.dnsServers, config.searchDomains);
+                mCallback.addUserForwarding(mInterface, mUserId);
+                showNotification(config, label, bitmap);
+
+            } finally {
+                Binder.restoreCallingIdentity(token);
+            }
+
         } catch (RuntimeException e) {
             updateState(DetailedState.FAILED, "establish");
             IoUtils.closeQuietly(tun);
+            //make sure marked forwarding is cleared if it was set
+            try {
+                mCallback.clearMarkedForwarding(mInterface);
+            } catch (Exception ingored) {
+                //ignored
+            }
             throw e;
         }
         Log.i(TAG, "Established by " + config.user + " on " + mInterface);
 
-        // Fill more values.
-        config.user = mPackage;
-        config.interfaze = mInterface;
 
-        // Override DNS servers and show the notification.
-        final long token = Binder.clearCallingIdentity();
-        try {
-            mCallback.override(config.dnsServers, config.searchDomains);
-            showNotification(config, label, bitmap);
-        } finally {
-            Binder.restoreCallingIdentity(token);
-        }
         // TODO: ensure that contract class eventually marks as connected
         updateState(DetailedState.AUTHENTICATING, "establish");
         return tun;
@@ -367,6 +411,9 @@ public class Vpn extends BaseNetworkStateTracker {
                 if (interfaze.equals(mInterface) && jniCheck(interfaze) == 0) {
                     final long token = Binder.clearCallingIdentity();
                     try {
+                        mCallback.clearUserForwarding(mInterface, mUserId);
+                        mCallback.clearMarkedForwarding(mInterface);
+
                         mCallback.restore();
                         hideNotification();
                     } finally {
@@ -391,16 +438,19 @@ public class Vpn extends BaseNetworkStateTracker {
         if (Binder.getCallingUid() == Process.SYSTEM_UID) {
             return;
         }
-
+        int appId = UserHandle.getAppId(Binder.getCallingUid());
+        final long token = Binder.clearCallingIdentity();
         try {
             // System dialogs are also allowed to control VPN.
             PackageManager pm = mContext.getPackageManager();
             ApplicationInfo app = pm.getApplicationInfo(VpnConfig.DIALOGS_PACKAGE, 0);
-            if (Binder.getCallingUid() == app.uid) {
+            if (appId == app.uid) {
                 return;
             }
         } catch (Exception e) {
             // ignore
+        } finally {
+            Binder.restoreCallingIdentity(token);
         }
 
         throw new SecurityException("Unauthorized Caller");
@@ -443,7 +493,7 @@ public class Vpn extends BaseNetworkStateTracker {
                     .setDefaults(0)
                     .setOngoing(true)
                     .build();
-            nm.notifyAsUser(null, R.drawable.vpn_connected, notification, UserHandle.ALL);
+            nm.notifyAsUser(null, R.drawable.vpn_connected, notification,new UserHandle(mUserId));
         }
     }
 
@@ -455,7 +505,7 @@ public class Vpn extends BaseNetworkStateTracker {
                 mContext.getSystemService(Context.NOTIFICATION_SERVICE);
 
         if (nm != null) {
-            nm.cancelAsUser(null, R.drawable.vpn_connected, UserHandle.ALL);
+            nm.cancelAsUser(null, R.drawable.vpn_connected, new UserHandle(mUserId));
         }
     }
 
@@ -577,7 +627,8 @@ public class Vpn extends BaseNetworkStateTracker {
         config.user = profile.key;
         config.interfaze = iface;
         config.session = profile.name;
-        config.routes = profile.routes;
+
+        config.addLegacyRoutes(profile.routes);
         if (!profile.dnsServers.isEmpty()) {
             config.dnsServers = Arrays.asList(profile.dnsServers.split(" +"));
         }
@@ -691,7 +742,7 @@ public class Vpn extends BaseNetworkStateTracker {
             // mConfig.interfaze will change to point to OUR
             // internal interface soon. TODO - add inner/outer to mconfig
             // TODO - we have a race - if the outer iface goes away/disconnects before we hit this
-            // we will leave the VPN up.  We should check that it's still there/connected after 
+            // we will leave the VPN up.  We should check that it's still there/connected after
             // registering
             mOuterInterface = mConfig.interfaze;
 
@@ -867,11 +918,11 @@ public class Vpn extends BaseNetworkStateTracker {
 
                 // Set the interface and the addresses in the config.
                 mConfig.interfaze = parameters[0].trim();
-                mConfig.addresses = parameters[1].trim();
 
+                mConfig.addLegacyAddresses(parameters[1]);
                 // Set the routes if they are not set in the config.
                 if (mConfig.routes == null || mConfig.routes.isEmpty()) {
-                    mConfig.routes = parameters[2].trim();
+                    mConfig.addLegacyRoutes(parameters[2]);
                 }
 
                 // Set the DNS servers if they are not set in the config.
@@ -891,7 +942,13 @@ public class Vpn extends BaseNetworkStateTracker {
                 }
 
                 // Set the routes.
-                jniSetRoutes(mConfig.interfaze, mConfig.routes);
+                long token = Binder.clearCallingIdentity();
+                try {
+                    mCallback.setMarkedForwarding(mConfig.interfaze);
+                    mCallback.setRoutes(mConfig.interfaze, mConfig.routes);
+                } finally {
+                    Binder.restoreCallingIdentity(token);
+                }
 
                 // Here is the last step and it must be done synchronously.
                 synchronized (Vpn.this) {
@@ -905,14 +962,26 @@ public class Vpn extends BaseNetworkStateTracker {
 
                     // Now INetworkManagementEventObserver is watching our back.
                     mInterface = mConfig.interfaze;
-                    mCallback.override(mConfig.dnsServers, mConfig.searchDomains);
-                    showNotification(mConfig, null, null);
+
+                    token = Binder.clearCallingIdentity();
+                    try {
+                        mCallback.override(mInterface, mConfig.dnsServers, mConfig.searchDomains);
+                        mCallback.addUserForwarding(mInterface, mUserId);
+                        showNotification(mConfig, null, null);
+                    } finally {
+                        Binder.restoreCallingIdentity(token);
+                    }
 
                     Log.i(TAG, "Connected!");
                     updateState(DetailedState.CONNECTED, "execute");
                 }
             } catch (Exception e) {
                 Log.i(TAG, "Aborting", e);
+                //make sure the routing is cleared
+                try {
+                    mCallback.clearMarkedForwarding(mConfig.interfaze);
+                } catch (Exception ignored) {
+                }
                 exit();
             } finally {
                 // Kill the daemons if they fail to stop.
index 13e400f..a2e9d67 100644 (file)
@@ -26,6 +26,7 @@ import android.content.Context;
 import android.content.Intent;
 import android.content.IntentFilter;
 import android.net.LinkProperties;
+import android.net.LinkAddress;
 import android.net.NetworkInfo;
 import android.net.NetworkInfo.DetailedState;
 import android.net.NetworkInfo.State;
@@ -44,6 +45,8 @@ import com.android.server.ConnectivityService;
 import com.android.server.EventLogTags;
 import com.android.server.connectivity.Vpn;
 
+import java.util.List;
+
 /**
  * State tracker for lockdown mode. Watches for normal {@link NetworkInfo} to be
  * connected and kicks off VPN connection, managing any required {@code netd}
@@ -73,7 +76,7 @@ public class LockdownVpnTracker {
 
     private String mAcceptedEgressIface;
     private String mAcceptedIface;
-    private String mAcceptedSourceAddr;
+    private List<LinkAddress> mAcceptedSourceAddr;
 
     private int mErrorCount;
 
@@ -162,14 +165,15 @@ public class LockdownVpnTracker {
 
         } else if (vpnInfo.isConnected() && vpnConfig != null) {
             final String iface = vpnConfig.interfaze;
-            final String sourceAddr = vpnConfig.addresses;
+            final List<LinkAddress> sourceAddrs = vpnConfig.addresses;
 
             if (TextUtils.equals(iface, mAcceptedIface)
-                    && TextUtils.equals(sourceAddr, mAcceptedSourceAddr)) {
+                  && sourceAddrs.equals(mAcceptedSourceAddr)) {
                 return;
             }
 
-            Slog.d(TAG, "VPN connected using iface=" + iface + ", sourceAddr=" + sourceAddr);
+            Slog.d(TAG, "VPN connected using iface=" + iface +
+                    ", sourceAddr=" + sourceAddrs.toString());
             EventLogTags.writeLockdownVpnConnected(egressType);
             showNotification(R.string.vpn_lockdown_connected, R.drawable.vpn_connected);
 
@@ -177,11 +181,13 @@ public class LockdownVpnTracker {
                 clearSourceRulesLocked();
 
                 mNetService.setFirewallInterfaceRule(iface, true);
-                mNetService.setFirewallEgressSourceRule(sourceAddr, true);
+                for (LinkAddress addr : sourceAddrs) {
+                    mNetService.setFirewallEgressSourceRule(addr.toString(), true);
+                }
 
                 mErrorCount = 0;
                 mAcceptedIface = iface;
-                mAcceptedSourceAddr = sourceAddr;
+                mAcceptedSourceAddr = sourceAddrs;
             } catch (RemoteException e) {
                 throw new RuntimeException("Problem setting firewall rules", e);
             }
@@ -263,7 +269,9 @@ public class LockdownVpnTracker {
                 mAcceptedIface = null;
             }
             if (mAcceptedSourceAddr != null) {
-                mNetService.setFirewallEgressSourceRule(mAcceptedSourceAddr, false);
+                for (LinkAddress addr : mAcceptedSourceAddr) {
+                    mNetService.setFirewallEgressSourceRule(addr.toString(), false);
+                }
                 mAcceptedSourceAddr = null;
             }
         } catch (RemoteException e) {