return 0;
}
+// Determines whether a socket is a loopback socket. Does not check socket state.
+bool SockDiag::isLoopbackSocket(const inet_diag_msg *msg) {
+ switch (msg->idiag_family) {
+ case AF_INET:
+ // Old kernels only copy the IPv4 address and leave the other 12 bytes uninitialized.
+ return IN_LOOPBACK(htonl(msg->id.idiag_src[0])) ||
+ IN_LOOPBACK(htonl(msg->id.idiag_dst[0])) ||
+ msg->id.idiag_src[0] == msg->id.idiag_dst[0];
+
+ case AF_INET6: {
+ const struct in6_addr *src = (const struct in6_addr *) &msg->id.idiag_src;
+ const struct in6_addr *dst = (const struct in6_addr *) &msg->id.idiag_dst;
+ return (IN6_IS_ADDR_V4MAPPED(src) && IN_LOOPBACK(src->s6_addr32[3])) ||
+ (IN6_IS_ADDR_V4MAPPED(dst) && IN_LOOPBACK(dst->s6_addr32[3])) ||
+ IN6_IS_ADDR_LOOPBACK(src) || IN6_IS_ADDR_LOOPBACK(dst) ||
+ !memcmp(src, dst, sizeof(*src));
+ }
+ default:
+ return false;
+ }
+}
+
int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
if (msg == nullptr) {
return 0;
return 0;
}
-int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
+int SockDiag::destroySockets(uint8_t proto, const uid_t uid, bool excludeLoopback) {
mSocketsDestroyed = 0;
Stopwatch s;
- auto shouldDestroy = [uid] (uint8_t, const inet_diag_msg *msg) {
- return (msg != nullptr && msg->idiag_uid == uid);
+ auto shouldDestroy = [uid, excludeLoopback] (uint8_t, const inet_diag_msg *msg) {
+ return msg != nullptr &&
+ msg->idiag_uid == uid &&
+ !(excludeLoopback && isLoopbackSocket(msg));
};
for (const int family : {AF_INET, AF_INET6}) {
return 0;
}
-int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids) {
+int SockDiag::destroySockets(const UidRanges& uidRanges, const std::set<uid_t>& skipUids,
+ bool excludeLoopback) {
mSocketsDestroyed = 0;
Stopwatch s;
auto shouldDestroy = [&] (uint8_t, const inet_diag_msg *msg) {
return msg != nullptr &&
uidRanges.hasUid(msg->idiag_uid) &&
- skipUids.find(msg->idiag_uid) == skipUids.end();
+ skipUids.find(msg->idiag_uid) == skipUids.end() &&
+ !(excludeLoopback && isLoopbackSocket(msg));
};
if (int ret = destroyLiveSockets(shouldDestroy)) {
* sock_diag_test.cpp - unit tests for SockDiag.cpp
*/
+#include <sys/socket.h>
+#include <netdb.h>
#include <arpa/inet.h>
#include <netinet/in.h>
+#include <netinet/tcp.h>
#include <linux/inet_diag.h>
#include <gtest/gtest.h>
#include "UidRanges.h"
class SockDiagTest : public ::testing::Test {
+protected:
+ static bool isLoopbackSocket(const inet_diag_msg *msg) {
+ return SockDiag::isLoopbackSocket(msg);
+ };
};
uint16_t bindAndListen(int s) {
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
+ if (msg->idiag_state == TCP_ESTABLISHED) {
+ EXPECT_TRUE(isLoopbackSocket(msg));
+ }
return false;
};
src, htons(msg->id.idiag_sport),
dst, htons(msg->id.idiag_dport),
tcpStateName(msg->idiag_state));
+ if (msg->idiag_state == TCP_ESTABLISHED) {
+ EXPECT_TRUE(isLoopbackSocket(msg));
+ }
return false;
};
close(accepted6);
}
+bool fillDiagAddr(__be32 addr[4], const sockaddr *sa) {
+ switch (sa->sa_family) {
+ case AF_INET: {
+ sockaddr_in *sin = (sockaddr_in *) sa;
+ memcpy(addr, &sin->sin_addr, sizeof(sin->sin_addr));
+ return true;
+ }
+ case AF_INET6: {
+ sockaddr_in6 *sin6 = (sockaddr_in6 *) sa;
+ memcpy(addr, &sin6->sin6_addr, sizeof(sin6->sin6_addr));
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
+inet_diag_msg makeDiagMessage(__u8 family, const sockaddr *src, const sockaddr *dst) {
+ inet_diag_msg msg = {
+ .idiag_family = family,
+ .idiag_state = TCP_ESTABLISHED,
+ .idiag_uid = AID_APP + 123,
+ .idiag_inode = 123456789,
+ .id = {
+ .idiag_sport = 1234,
+ .idiag_dport = 4321,
+ }
+ };
+ EXPECT_TRUE(fillDiagAddr(msg.id.idiag_src, src));
+ EXPECT_TRUE(fillDiagAddr(msg.id.idiag_dst, dst));
+ return msg;
+}
+
+inet_diag_msg makeDiagMessage(const char* srcstr, const char* dststr) {
+ addrinfo hints = { .ai_flags = AI_NUMERICHOST }, *src, *dst;
+ EXPECT_EQ(0, getaddrinfo(srcstr, NULL, &hints, &src));
+ EXPECT_EQ(0, getaddrinfo(dststr, NULL, &hints, &dst));
+ EXPECT_EQ(src->ai_addr->sa_family, dst->ai_addr->sa_family);
+ inet_diag_msg msg = makeDiagMessage(src->ai_addr->sa_family, src->ai_addr, dst->ai_addr);
+ freeaddrinfo(src);
+ freeaddrinfo(dst);
+ return msg;
+}
+
+TEST_F(SockDiagTest, TestIsLoopbackSocket) {
+ inet_diag_msg msg;
+
+ msg = makeDiagMessage("127.0.0.1", "127.0.0.1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("::1", "::1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("::1", "::ffff:127.0.0.1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("192.0.2.1", "192.0.2.1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("192.0.2.1", "8.8.8.8");
+ EXPECT_FALSE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("192.0.2.1", "127.0.0.1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("2001:db8::1", "2001:db8::1");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("2001:db8::1", "2001:4860:4860::6464");
+ EXPECT_FALSE(isLoopbackSocket(&msg));
+
+ // While isLoopbackSocket returns true on these sockets, we usually don't want to close them
+ // because they aren't specific to any particular network and thus don't become unusable when
+ // an app's routing changes or its network access is removed.
+ //
+ // This isn't a problem, as anything that calls destroyLiveSockets will skip them because
+ // destroyLiveSockets only enumerates ESTABLISHED, SYN_SENT, and SYN_RECV sockets.
+ msg = makeDiagMessage("127.0.0.1", "0.0.0.0");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+
+ msg = makeDiagMessage("::1", "::");
+ EXPECT_TRUE(isLoopbackSocket(&msg));
+}
+
enum MicroBenchmarkTestType {
ADDRESS,
UID,
+ UID_EXCLUDE_LOOPBACK,
UIDRANGE,
+ UIDRANGE_EXCLUDE_LOOPBACK,
};
const char *testTypeName(MicroBenchmarkTestType mode) {
switch((mode)) {
TO_STRING_TYPE(ADDRESS);
TO_STRING_TYPE(UID);
+ TO_STRING_TYPE(UID_EXCLUDE_LOOPBACK);
TO_STRING_TYPE(UIDRANGE);
+ TO_STRING_TYPE(UIDRANGE_EXCLUDE_LOOPBACK);
}
#undef TO_STRING_TYPE
}
case ADDRESS:
return ADDRESS_SOCKETS;
case UID:
+ case UID_EXCLUDE_LOOPBACK:
case UIDRANGE:
+ case UIDRANGE_EXCLUDE_LOOPBACK:
return UID_SOCKETS;
}
}
EXPECT_LE(0, ret) << ": Failed to destroy sockets on ::1: " << strerror(-ret);
break;
case UID:
- ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID);
+ case UID_EXCLUDE_LOOPBACK: {
+ bool excludeLoopback = (mode == UID_EXCLUDE_LOOPBACK);
+ ret = mSd.destroySockets(IPPROTO_TCP, CLOSE_UID, excludeLoopback);
EXPECT_LE(0, ret) << ": Failed to destroy sockets for UID " << CLOSE_UID << ": " <<
strerror(-ret);
break;
- case UIDRANGE: {
+ }
+ case UIDRANGE:
+ case UIDRANGE_EXCLUDE_LOOPBACK: {
+ bool excludeLoopback = (mode == UIDRANGE_EXCLUDE_LOOPBACK);
const char *uidRangeStrings[] = { "8005-8012", "8042", "8043", "8090-8099" };
std::set<uid_t> skipUids { 8007, 8043, 8098, 8099 };
UidRanges uidRanges;
uidRanges.parseFrom(ARRAY_SIZE(uidRangeStrings), (char **) uidRangeStrings);
- ret = mSd.destroySockets(uidRanges, skipUids);
+ ret = mSd.destroySockets(uidRanges, skipUids, excludeLoopback);
}
}
return ret;
}
return false;
}
+ case UID_EXCLUDE_LOOPBACK:
+ case UIDRANGE_EXCLUDE_LOOPBACK:
+ return false;
}
}
fprintf(stderr, " Verifying: %6.1f ms (%d sockets destroyed)\n",
std::chrono::duration_cast<ms>(std::chrono::steady_clock::now() - start).count(),
socketsClosed);
- EXPECT_GT(socketsClosed, 0); // Just in case there's a bug in the test.
+ if (strstr(testTypeName(mode), "_EXCLUDE_LOOPBACK") == nullptr) {
+ EXPECT_GT(socketsClosed, 0); // Just in case there's a bug in the test.
+ }
start = std::chrono::steady_clock::now();
for (int i = 0; i < numSockets; i++) {
constexpr int SockDiagMicroBenchmarkTest::CLOSE_UID;
INSTANTIATE_TEST_CASE_P(Address, SockDiagMicroBenchmarkTest,
- testing::Values(ADDRESS, UID, UIDRANGE));
+ testing::Values(ADDRESS, UID, UIDRANGE,
+ UID_EXCLUDE_LOOPBACK, UIDRANGE_EXCLUDE_LOOPBACK));