BandwidthController.cpp BandwidthControllerTest.cpp \
FirewallControllerTest.cpp FirewallController.cpp \
SockDiagTest.cpp SockDiag.cpp \
- StrictController.cpp StrictControllerTest.cpp
+ StrictController.cpp StrictControllerTest.cpp \
+ UidRanges.cpp \
+
LOCAL_MODULE_TAGS := tests
LOCAL_SHARED_LIBRARIES := liblog libbase libcutils liblogwrap
include $(BUILD_NATIVE_TEST)
ALOGE("cannot add users to non-virtual network with netId %u", netId);
return -EINVAL;
}
- if (int ret = static_cast<VirtualNetwork*>(network)->addUsers(uidRanges)) {
+ if (int ret = static_cast<VirtualNetwork*>(network)->addUsers(uidRanges, mProtectableUsers)) {
return ret;
}
return 0;
ALOGE("cannot remove users from non-virtual network with netId %u", netId);
return -EINVAL;
}
- if (int ret = static_cast<VirtualNetwork*>(network)->removeUsers(uidRanges)) {
+ if (int ret = static_cast<VirtualNetwork*>(network)->removeUsers(uidRanges,
+ mProtectableUsers)) {
return ret;
}
return 0;
#define LOG_TAG "Netd"
+#include <android-base/strings.h>
#include <cutils/log.h>
#include "NetdConstants.h"
}
default:
inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
- callback(proto, msg);
+ if (callback(proto, msg)) {
+ sockDestroy(proto, msg);
+ }
}
}
} while (bytesread > 0);
return ret;
}
- auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
- return this->sockDestroy(proto, msg);
- };
+ auto destroyAll = [] (uint8_t, const inet_diag_msg*) { return true; };
- return readDiagMsg(proto, destroy);
+ return readDiagMsg(proto, destroyAll);
}
int SockDiag::destroySockets(const char *addrstr) {
return mSocketsDestroyed;
}
+int SockDiag::destroyLiveSockets(DumpCallback destroyFilter) {
+ int proto = IPPROTO_TCP;
+
+ for (const int family : {AF_INET, AF_INET6}) {
+ const char *familyName = (family == AF_INET) ? "IPv4" : "IPv6";
+ uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+ if (int ret = sendDumpRequest(proto, family, states)) {
+ ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+ return ret;
+ }
+ if (int ret = readDiagMsg(proto, destroyFilter)) {
+ ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+ return ret;
+ }
+ }
+
+ return 0;
+}
+
int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
mSocketsDestroyed = 0;
Stopwatch s;
- auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
- if (msg != nullptr && msg->idiag_uid == uid) {
- return this->sockDestroy(proto, msg);
- } else {
- return 0;
- }
+ auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
+ return (msg != nullptr && msg->idiag_uid == uid);
};
for (const int family : {AF_INET, AF_INET6}) {
ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
return ret;
}
- if (int ret = readDiagMsg(proto, destroy)) {
+ if (int ret = readDiagMsg(proto, shouldDestroy)) {
ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
return ret;
}
return 0;
}
+
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+ mSocketsDestroyed = 0;
+ Stopwatch s;
+
+ auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
+ return msg != nullptr &&
+ uidRanges.hasUid(msg->idiag_uid) &&
+ skipUids.find(msg->idiag_uid) == skipUids.end();
+ };
+
+ if (int ret = destroyLiveSockets(shouldDestroy)) {
+ return ret;
+ }
+
+ std::vector<uid_t> skipUidStrings;
+ for (uid_t uid : skipUids) {
+ skipUidStrings.push_back(uid);
+ }
+ std::sort(skipUidStrings.begin(), skipUidStrings.end());
+
+ if (mSocketsDestroyed > 0) {
+ ALOGI("Destroyed %d sockets for %s skip={%s} in %.1f ms",
+ mSocketsDestroyed, uidRanges.toString().c_str(),
+ android::base::Join(skipUidStrings, " ").c_str(), s.timeTaken());
+ }
+
+ return 0;
+}
* limitations under the License.
*/
-#include <functional>
+#include <unistd.h>
+#include <sys/socket.h>
#include <linux/netlink.h>
#include <linux/sock_diag.h>
#include <linux/inet_diag.h>
+#include <functional>
+#include <set>
+
+#include "UidRanges.h"
+
struct inet_diag_msg;
class SockDiagTest;
public:
static const int kBufferSize = 4096;
- typedef std::function<int(uint8_t proto, const inet_diag_msg *)> DumpCallback;
+
+ // Callback function that is called once for every socket in the dump. A return value of true
+ // means destroy the socket.
+ typedef std::function<bool(uint8_t proto, const inet_diag_msg *)> DumpCallback;
struct DestroyRequest {
nlmsghdr nlh;
int sockDestroy(uint8_t proto, const inet_diag_msg *);
int destroySockets(const char *addrstr);
int destroySockets(uint8_t proto, uid_t uid);
+ int destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids);
private:
int mSock;
int mSocketsDestroyed;
int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
int destroySockets(uint8_t proto, int family, const char *addrstr);
+ int destroyLiveSockets(DumpCallback destroy);
bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
};
#include "NetdConstants.h"
#include "SockDiag.h"
-
-
-#define NUM_SOCKETS 500
-#define START_UID 8000 // START_UID + NUM_SOCKETS must be <= 9999.
-#define CLOSE_UID (START_UID + NUM_SOCKETS - 42) // Close to the end
-
+#include "UidRanges.h"
class SockDiagTest : public ::testing::Test {
};
if (msg == nullptr) {
EXPECT_FALSE(seenNull);
seenNull = true;
- return 0;
+ return false;
}
EXPECT_EQ(htonl(INADDR_LOOPBACK), msg->id.idiag_src[0]);
v4SocketsSeen++;
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
- return 0;
+ return false;
};
int v6SocketsSeen = 0;
if (msg == nullptr) {
EXPECT_FALSE(seenNull);
seenNull = true;
- return 0;
+ return false;
}
struct in6_addr *saddr = (struct in6_addr *) msg->id.idiag_src;
EXPECT_TRUE(
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
- return 0;
+ return false;
};
SockDiag sd;
enum MicroBenchmarkTestType {
ADDRESS,
UID,
+ UIDRANGE,
};
const char *testTypeName(MicroBenchmarkTestType mode) {
switch((mode)) {
TO_STRING_TYPE(ADDRESS);
TO_STRING_TYPE(UID);
+ TO_STRING_TYPE(UIDRANGE);
}
#undef TO_STRING_TYPE
}
protected:
SockDiag mSd;
+ constexpr static int MAX_SOCKETS = 500;
+ constexpr static int ADDRESS_SOCKETS = 500;
+ constexpr static int UID_SOCKETS = 100;
+ constexpr static uid_t START_UID = 8000; // START_UID + number of sockets must be <= 9999.
+ constexpr static int CLOSE_UID = START_UID + UID_SOCKETS - 42; // Close to the end
+ static_assert(START_UID + MAX_SOCKETS < 9999, "Too many sockets");
+
+ int howManySockets() {
+ MicroBenchmarkTestType mode = GetParam();
+ switch (mode) {
+ case ADDRESS:
+ return 500;
+ case UID:
+ case UIDRANGE:
+ return 50;
+ }
+ }
+
int destroySockets() {
MicroBenchmarkTestType mode = GetParam();
int ret;
- if (mode == ADDRESS) {
- ret = mSd.destroySockets("::1");
- EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
- } else {
- ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
- EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
- strerror(-ret);
+ switch (mode) {
+ case ADDRESS:
+ ret = mSd.destroySockets("::1");
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+ break;
+ case UID:
+ ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
+ strerror(-ret);
+ break;
+ case UIDRANGE: {
+ const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
+ std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
+ UidRanges uidRanges;
+ uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
+ ret = mSd.destroySockets(uidRanges, skipUids);
+ }
}
return ret;
}
bool shouldHaveClosedSocket(int i) {
MicroBenchmarkTestType mode = GetParam();
switch (mode) {
- case ADDRESS:
- return true;
- case UID:
- return i == CLOSE_UID - START_UID;
+ case ADDRESS:
+ return true;
+ case UID:
+ return i == CLOSE_UID - START_UID;
+ case UIDRANGE: {
+ uid_t uid = i + START_UID;
+ // Skip UIDs in skipUids.
+ if (uid == 8007 || uid == 8043 || uid == 8098 || uid == 8099) {
+ return false;
+ }
+ // Include UIDs in uidRanges.
+ if ((8005 <= uid && uid <= 8012) || uid == 8042 || (8090 <= uid && uid <= 8099)) {
+ return true;
+ }
+ return false;
+ }
}
}
TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
MicroBenchmarkTestType mode = GetParam();
+ int numSockets = howManySockets();
+
fprintf(stderr, "Benchmarking closing %d sockets based on %s\n",
- NUM_SOCKETS, testTypeName(mode));
+ numSockets, testTypeName(mode));
int listensocket = socket(AF_INET6, SOCK_STREAM, 0);
ASSERT_NE(-1, listensocket) << "Failed to open listen socket";
using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
- int clientsockets[NUM_SOCKETS], serversockets[NUM_SOCKETS];
- uint16_t clientports[NUM_SOCKETS];
+ int clientsockets[MAX_SOCKETS], serversockets[MAX_SOCKETS];
+ uint16_t clientports[MAX_SOCKETS];
sockaddr_in6 client;
socklen_t clientlen;
auto start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
int s = socket(AF_INET6, SOCK_STREAM, 0);
uid_t uid = START_UID + i;
ASSERT_EQ(0, fchown(s, uid, -1));
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
checkSocketState(i, clientsockets[i], "Client socket");
checkSocketState(i, serversockets[i], "Server socket");
}
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count());
start = std::chrono::steady_clock::now();
- for (int i = 0; i < NUM_SOCKETS; i++) {
+ for (int i = 0; i < numSockets; i++) {
close(clientsockets[i]);
close(serversockets[i]);
}
close(listensocket);
}
-INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest, testing::Values(ADDRESS, UID));
+// "SockDiagTest.cpp:232: error: undefined reference to 'SockDiagMicroBenchmarkTest::CLOSE_UID'".
+constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
+
+INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
+ testing::Values(ADDRESS, UID, UIDRANGE));
#include <stdlib.h>
+#include <android-base/stringprintf.h>
+
+using android::base::StringAppendF;
+
bool UidRanges::hasUid(uid_t uid) const {
auto iter = std::lower_bound(mRanges.begin(), mRanges.end(), Range(uid, uid));
return (iter != mRanges.end() && iter->first == uid) ||
other.mRanges.end(), mRanges.begin());
mRanges.erase(end, mRanges.end());
}
+
+std::string UidRanges::toString() const {
+ std::string s("UidRanges{ ");
+ for (Range range : mRanges) {
+ if (range.first != range.second) {
+ StringAppendF(&s, "%u-%u ", range.first, range.second);
+ } else {
+ StringAppendF(&s, "%u ", range.first);
+ }
+ }
+ StringAppendF(&s, "}");
+ return s;
+}
const std::vector<Range>& getRanges() const;
bool parseFrom(int argc, char* argv[]);
+ std::string toString() const;
void add(const UidRanges& other);
void remove(const UidRanges& other);
* limitations under the License.
*/
+#include <set>
#include "VirtualNetwork.h"
+#include "SockDiag.h"
#include "RouteController.h"
#define LOG_TAG "Netd"
return mUidRanges.hasUid(uid);
}
-int VirtualNetwork::addUsers(const UidRanges& uidRanges) {
+
+int VirtualNetwork::maybeCloseSockets(bool add, const UidRanges& uidRanges,
+ const std::set<uid_t>& protectableUsers) {
+ if (!mSecure) {
+ return 0;
+ }
+
+ SockDiag sd;
+ if (!sd.open()) {
+ return -EBADFD;
+ }
+
+ if (int ret = sd.destroySockets(uidRanges, protectableUsers)) {
+ ALOGE("Failed to close sockets while %s %s to network %d: %s",
+ add ? "adding" : "removing", uidRanges.toString().c_str(), mNetId, strerror(-ret));
+ return ret;
+ }
+
+ return 0;
+}
+
+int VirtualNetwork::addUsers(const UidRanges& uidRanges, const std::set<uid_t>& protectableUsers) {
+ maybeCloseSockets(true, uidRanges, protectableUsers);
+
for (const std::string& interface : mInterfaces) {
if (int ret = RouteController::addUsersToVirtualNetwork(mNetId, interface.c_str(), mSecure,
uidRanges)) {
return 0;
}
-int VirtualNetwork::removeUsers(const UidRanges& uidRanges) {
+int VirtualNetwork::removeUsers(const UidRanges& uidRanges,
+ const std::set<uid_t>& protectableUsers) {
+ maybeCloseSockets(false, uidRanges, protectableUsers);
+
for (const std::string& interface : mInterfaces) {
if (int ret = RouteController::removeUsersFromVirtualNetwork(mNetId, interface.c_str(),
mSecure, uidRanges)) {
#ifndef NETD_SERVER_VIRTUAL_NETWORK_H
#define NETD_SERVER_VIRTUAL_NETWORK_H
+#include <set>
+
#include "Network.h"
#include "UidRanges.h"
bool isSecure() const;
bool appliesToUser(uid_t uid) const;
- int addUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
- int removeUsers(const UidRanges& uidRanges) WARN_UNUSED_RESULT;
+ int addUsers(const UidRanges& uidRanges,
+ const std::set<uid_t>& protectableUsers) WARN_UNUSED_RESULT;
+ int removeUsers(const UidRanges& uidRanges,
+ const std::set<uid_t>& protectableUsers) WARN_UNUSED_RESULT;
private:
Type getType() const override;
int addInterface(const std::string& interface) override WARN_UNUSED_RESULT;
int removeInterface(const std::string& interface) override WARN_UNUSED_RESULT;
+ int maybeCloseSockets(bool add, const UidRanges& uidRanges,
+ const std::set<uid_t>& protectableUsers);
const bool mHasDns;
const bool mSecure;