typedef std::unique_ptr<addrinfo, AddrinfoDeleter> ScopedAddrinfo;
+class Stopwatch {
+public:
+ Stopwatch(): mStart(std::chrono::steady_clock::now()) {}
+ float timeTaken() {
+ using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
+ return (std::chrono::duration_cast<ms>(
+ std::chrono::steady_clock::now() - mStart)).count();
+ }
+
+private:
+ std::chrono::time_point<std::chrono::steady_clock> mStart;
+ std::string mName;
+};
+
int checkError(int fd) {
struct {
nlmsghdr h;
return true;
}
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states,
+ iovec *iov, int iovcnt) {
+ struct {
+ nlmsghdr nlh;
+ inet_diag_req_v2 req;
+ } __attribute__((__packed__)) request = {
+ .nlh = {
+ .nlmsg_type = SOCK_DIAG_BY_FAMILY,
+ .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
+ },
+ .req = {
+ .sdiag_family = family,
+ .sdiag_protocol = proto,
+ .idiag_states = states,
+ },
+ };
+
+ size_t len = 0;
+ iov[0].iov_base = &request;
+ iov[0].iov_len = sizeof(request);
+ for (int i = 0; i < iovcnt; i++) {
+ len += iov[i].iov_len;
+ }
+ request.nlh.nlmsg_len = len;
+
+ if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
+ return -errno;
+ }
+
+ return checkError(mSock);
+}
+
+int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
+ iovec iov[] = {
+ { nullptr, 0 },
+ };
+ return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
+}
+
int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
addrinfo hints = { .ai_flags = AI_NUMERICHOST };
addrinfo *res;
uint8_t prefixlen = addrlen * 8;
uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
uint8_t nojump = yesjump + 4;
- uint32_t states = ~(1 << TCP_TIME_WAIT);
struct {
- nlmsghdr nlh;
- inet_diag_req_v2 req;
nlattr nla;
inet_diag_bc_op op;
inet_diag_hostcond cond;
- } __attribute__((__packed__)) request = {
- .nlh = {
- .nlmsg_type = SOCK_DIAG_BY_FAMILY,
- .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
- },
- .req = {
- .sdiag_family = family,
- .sdiag_protocol = proto,
- .idiag_states = states,
- },
+ } __attribute__((__packed__)) attrs = {
.nla = {
.nla_type = INET_DIAG_REQ_BYTECODE,
},
},
};
- request.nlh.nlmsg_len = sizeof(request) + addrlen;
- request.nla.nla_len = sizeof(request.nla) + sizeof(request.op) + sizeof(request.cond) + addrlen;
+ attrs.nla.nla_len = sizeof(attrs) + addrlen;
- struct iovec iov[] = {
- { &request, sizeof(request) },
+ iovec iov[] = {
+ { nullptr, 0 },
+ { &attrs, sizeof(attrs) },
{ addr, addrlen },
};
- if (writev(mSock, iov, ARRAY_SIZE(iov)) != (int) request.nlh.nlmsg_len) {
- return -errno;
- }
-
- return checkError(mSock);
+ uint32_t states = ~(1 << TCP_TIME_WAIT);
+ return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
}
int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
}
int SockDiag::destroySockets(const char *addrstr) {
- using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
-
+ Stopwatch s;
mSocketsDestroyed = 0;
- const auto start = std::chrono::steady_clock::now();
+
if (!strchr(addrstr, ':')) {
if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
return ret;
}
- auto elapsed = std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start);
if (mSocketsDestroyed > 0) {
- ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, elapsed.count());
+ ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
}
return mSocketsDestroyed;
}
+
+int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
+ mSocketsDestroyed = 0;
+ Stopwatch s;
+
+ auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
+ if (msg != nullptr && msg->idiag_uid == uid) {
+ return this->sockDestroy(proto, msg);
+ } else {
+ return 0;
+ }
+ };
+
+ for (const int family : {AF_INET, AF_INET6}) {
+ const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
+ uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
+ if (int ret = sendDumpRequest(proto, family, states)) {
+ ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
+ return ret;
+ }
+ if (int ret = readDiagMsg(proto, destroy)) {
+ ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
+ return ret;
+ }
+ }
+
+ if (mSocketsDestroyed > 0) {
+ ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
+ }
+
+ return 0;
+}
+/*
+ * Copyright (C) 2016 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#include <functional>
#include <linux/netlink.h>
bool open();
virtual ~SockDiag() { closeSocks(); }
+ int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states);
int sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr);
int readDiagMsg(uint8_t proto, DumpCallback callback);
int sockDestroy(uint8_t proto, const inet_diag_msg *);
int destroySockets(const char *addrstr);
+ int destroySockets(uint8_t proto, uid_t uid);
private:
int mSock;
int mWriteSock;
int mSocketsDestroyed;
+ int sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states, iovec *iov, int iovcnt);
int destroySockets(uint8_t proto, int family, const char *addrstr);
bool hasSocks() { return mSock != -1 && mWriteSock != -1; }
void closeSocks() { close(mSock); close(mWriteSock); mSock = mWriteSock = -1; }
#define NUM_SOCKETS 500
+#define START_UID 8000 // START_UID + NUM_SOCKETS must be <= 9999.
+#define CLOSE_UID (START_UID + NUM_SOCKETS - 42) // Close to the end
class SockDiagTest : public ::testing::Test {
close(accepted6);
}
+enum MicroBenchmarkTestType {
+ ADDRESS,
+ UID,
+};
-class SockDiagMicroBenchmarkTest : public ::testing::Test {
+const char *testTypeName(MicroBenchmarkTestType mode) {
+#define TO_STRING_TYPE(x) case ((x)): return #x;
+ switch((mode)) {
+ TO_STRING_TYPE(ADDRESS);
+ TO_STRING_TYPE(UID);
+ }
+#undef TO_STRING_TYPE
+}
+
+class SockDiagMicroBenchmarkTest : public ::testing::TestWithParam<MicroBenchmarkTestType> {
public:
void SetUp() {
SockDiag mSd;
int destroySockets() {
- const int ret = mSd.destroySockets("::1");
- EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+ MicroBenchmarkTestType mode = GetParam();
+ int ret;
+ if (mode == ADDRESS) {
+ ret = mSd.destroySockets("::1");
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
+ } else {
+ ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+ EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
+ strerror(-ret);
+ }
return ret;
}
- bool shouldHaveClosedSocket(int) {
- return true;
+ bool shouldHaveClosedSocket(int i) {
+ MicroBenchmarkTestType mode = GetParam();
+ switch (mode) {
+ case ADDRESS:
+ return true;
+ case UID:
+ return i == CLOSE_UID - START_UID;
+ }
}
void checkSocketState(int i, int sock, const char *msg) {
}
};
-TEST_F(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
- fprintf(stderr, "Benchmarking closing %d sockets\n", NUM_SOCKETS);
+TEST_P(SockDiagMicroBenchmarkTest, TestMicroBenchmark) {
+ MicroBenchmarkTestType mode = GetParam();
+
+ fprintf(stderr, "Benchmarking closing %d sockets based on %s\n",
+ NUM_SOCKETS, testTypeName(mode));
int listensocket = socket(AF_INET6, SOCK_STREAM, 0);
ASSERT_NE(-1, listensocket) << "Failed to open listen socket";
auto start = std::chrono::steady_clock::now();
for (int i = 0; i < NUM_SOCKETS; i++) {
int s = socket(AF_INET6, SOCK_STREAM, 0);
+ uid_t uid = START_UID + i;
+ ASSERT_EQ(0, fchown(s, uid, -1));
clientlen = sizeof(client);
ASSERT_EQ(0, connect(s, (sockaddr *) &server, sizeof(server)))
<< "Connecting socket " << i << " failed " << strerror(errno);
close(listensocket);
}
+
+INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest, testing::Values(ADDRESS, UID));