OSDN Git Service

Support destroying sockets for UIDs.
[android-x86/system-netd.git] / server / SockDiag.cpp
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include <errno.h>
18 #include <netdb.h>
19 #include <string.h>
20 #include <netinet/in.h>
21 #include <netinet/tcp.h>
22 #include <sys/socket.h>
23 #include <sys/uio.h>
24
25 #include <linux/netlink.h>
26 #include <linux/sock_diag.h>
27 #include <linux/inet_diag.h>
28
29 #define LOG_TAG "Netd"
30
31 #include <cutils/log.h>
32
33 #include "NetdConstants.h"
34 #include "SockDiag.h"
35
36 #include <chrono>
37
38 #ifndef SOCK_DESTROY
39 #define SOCK_DESTROY 21
40 #endif
41
42 namespace {
43
44 struct AddrinfoDeleter {
45   void operator()(addrinfo *a) { if (a) freeaddrinfo(a); }
46 };
47
48 typedef std::unique_ptr<addrinfo, AddrinfoDeleter> ScopedAddrinfo;
49
50 class Stopwatch {
51 public:
52     Stopwatch(): mStart(std::chrono::steady_clock::now()) {}
53     float timeTaken() {
54         using ms = std::chrono::duration<float, std::ratio<1, 1000>>;
55         return (std::chrono::duration_cast<ms>(
56                 std::chrono::steady_clock::now() - mStart)).count();
57     }
58
59 private:
60     std::chrono::time_point<std::chrono::steady_clock> mStart;
61     std::string mName;
62 };
63
64 int checkError(int fd) {
65     struct {
66         nlmsghdr h;
67         nlmsgerr err;
68     } __attribute__((__packed__)) ack;
69     ssize_t bytesread = recv(fd, &ack, sizeof(ack), MSG_DONTWAIT | MSG_PEEK);
70     if (bytesread == -1) {
71        // Read failed (error), or nothing to read (good).
72        return (errno == EAGAIN) ? 0 : -errno;
73     } else if (bytesread == (ssize_t) sizeof(ack) && ack.h.nlmsg_type == NLMSG_ERROR) {
74         // We got an error. Consume it.
75         recv(fd, &ack, sizeof(ack), 0);
76         return ack.err.error;
77     } else {
78         // The kernel replied with something. Leave it to the caller.
79         return 0;
80     }
81 }
82
83 }  // namespace
84
85 bool SockDiag::open() {
86     if (hasSocks()) {
87         return false;
88     }
89
90     mSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
91     mWriteSock = socket(PF_NETLINK, SOCK_DGRAM, NETLINK_INET_DIAG);
92     if (!hasSocks()) {
93         closeSocks();
94         return false;
95     }
96
97     sockaddr_nl nl = { .nl_family = AF_NETLINK };
98     if ((connect(mSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1) ||
99         (connect(mWriteSock, reinterpret_cast<sockaddr *>(&nl), sizeof(nl)) == -1)) {
100         closeSocks();
101         return false;
102     }
103
104     return true;
105 }
106
107 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states,
108                               iovec *iov, int iovcnt) {
109     struct {
110         nlmsghdr nlh;
111         inet_diag_req_v2 req;
112     } __attribute__((__packed__)) request = {
113         .nlh = {
114             .nlmsg_type = SOCK_DIAG_BY_FAMILY,
115             .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
116         },
117         .req = {
118             .sdiag_family = family,
119             .sdiag_protocol = proto,
120             .idiag_states = states,
121         },
122     };
123
124     size_t len = 0;
125     iov[0].iov_base = &request;
126     iov[0].iov_len = sizeof(request);
127     for (int i = 0; i < iovcnt; i++) {
128         len += iov[i].iov_len;
129     }
130     request.nlh.nlmsg_len = len;
131
132     if (writev(mSock, iov, iovcnt) != (ssize_t) len) {
133         return -errno;
134     }
135
136     return checkError(mSock);
137 }
138
139 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, uint32_t states) {
140     iovec iov[] = {
141         { nullptr, 0 },
142     };
143     return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
144 }
145
146 int SockDiag::sendDumpRequest(uint8_t proto, uint8_t family, const char *addrstr) {
147     addrinfo hints = { .ai_flags = AI_NUMERICHOST };
148     addrinfo *res;
149     in6_addr mapped = { .s6_addr32 = { 0, 0, htonl(0xffff), 0 } };
150     int ret;
151
152     // TODO: refactor the netlink parsing code out of system/core, bring it into netd, and stop
153     // doing string conversions when they're not necessary.
154     if ((ret = getaddrinfo(addrstr, nullptr, &hints, &res)) != 0) {
155         return -EINVAL;
156     }
157
158     // So we don't have to call freeaddrinfo on every failure path.
159     ScopedAddrinfo resP(res);
160
161     void *addr;
162     uint8_t addrlen;
163     if (res->ai_family == AF_INET && family == AF_INET) {
164         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
165         addr = &ina;
166         addrlen = sizeof(ina);
167     } else if (res->ai_family == AF_INET && family == AF_INET6) {
168         in_addr& ina = reinterpret_cast<sockaddr_in*>(res->ai_addr)->sin_addr;
169         mapped.s6_addr32[3] = ina.s_addr;
170         addr = &mapped;
171         addrlen = sizeof(mapped);
172     } else if (res->ai_family == AF_INET6 && family == AF_INET6) {
173         in6_addr& in6a = reinterpret_cast<sockaddr_in6*>(res->ai_addr)->sin6_addr;
174         addr = &in6a;
175         addrlen = sizeof(in6a);
176     } else {
177         return -EAFNOSUPPORT;
178     }
179
180     uint8_t prefixlen = addrlen * 8;
181     uint8_t yesjump = sizeof(inet_diag_bc_op) + sizeof(inet_diag_hostcond) + addrlen;
182     uint8_t nojump = yesjump + 4;
183
184     struct {
185         nlattr nla;
186         inet_diag_bc_op op;
187         inet_diag_hostcond cond;
188     } __attribute__((__packed__)) attrs = {
189         .nla = {
190             .nla_type = INET_DIAG_REQ_BYTECODE,
191         },
192         .op = {
193             INET_DIAG_BC_S_COND,
194             yesjump,
195             nojump,
196         },
197         .cond = {
198             family,
199             prefixlen,
200             -1,
201             {}
202         },
203     };
204
205     attrs.nla.nla_len = sizeof(attrs) + addrlen;
206
207     iovec iov[] = {
208         { nullptr, 0 },
209         { &attrs, sizeof(attrs) },
210         { addr, addrlen },
211     };
212
213     uint32_t states = ~(1 << TCP_TIME_WAIT);
214     return sendDumpRequest(proto, family, states, iov, ARRAY_SIZE(iov));
215 }
216
217 int SockDiag::readDiagMsg(uint8_t proto, SockDiag::DumpCallback callback) {
218     char buf[kBufferSize];
219
220     ssize_t bytesread;
221     do {
222         bytesread = read(mSock, buf, sizeof(buf));
223
224         if (bytesread < 0) {
225             return -errno;
226         }
227
228         uint32_t len = bytesread;
229         for (nlmsghdr *nlh = reinterpret_cast<nlmsghdr *>(buf);
230              NLMSG_OK(nlh, len);
231              nlh = NLMSG_NEXT(nlh, len)) {
232             switch (nlh->nlmsg_type) {
233               case NLMSG_DONE:
234                 callback(proto, NULL);
235                 return 0;
236               case NLMSG_ERROR: {
237                 nlmsgerr *err = reinterpret_cast<nlmsgerr *>(NLMSG_DATA(nlh));
238                 return err->error;
239               }
240               default:
241                 inet_diag_msg *msg = reinterpret_cast<inet_diag_msg *>(NLMSG_DATA(nlh));
242                 callback(proto, msg);
243             }
244         }
245     } while (bytesread > 0);
246
247     return 0;
248 }
249
250 int SockDiag::sockDestroy(uint8_t proto, const inet_diag_msg *msg) {
251     if (msg == nullptr) {
252        return 0;
253     }
254
255     DestroyRequest request = {
256         .nlh = {
257             .nlmsg_type = SOCK_DESTROY,
258             .nlmsg_flags = NLM_F_REQUEST,
259         },
260         .req = {
261             .sdiag_family = msg->idiag_family,
262             .sdiag_protocol = proto,
263             .idiag_states = (uint32_t) (1 << msg->idiag_state),
264             .id = msg->id,
265         },
266     };
267     request.nlh.nlmsg_len = sizeof(request);
268
269     if (write(mWriteSock, &request, sizeof(request)) < (ssize_t) sizeof(request)) {
270         return -errno;
271     }
272
273     int ret = checkError(mWriteSock);
274     if (!ret) mSocketsDestroyed++;
275     return ret;
276 }
277
278 int SockDiag::destroySockets(uint8_t proto, int family, const char *addrstr) {
279     if (!hasSocks()) {
280         return -EBADFD;
281     }
282
283     if (int ret = sendDumpRequest(proto, family, addrstr)) {
284         return ret;
285     }
286
287     auto destroy = [this] (uint8_t proto, const inet_diag_msg *msg) {
288         return this->sockDestroy(proto, msg);
289     };
290
291     return readDiagMsg(proto, destroy);
292 }
293
294 int SockDiag::destroySockets(const char *addrstr) {
295     Stopwatch s;
296     mSocketsDestroyed = 0;
297
298     if (!strchr(addrstr, ':')) {
299         if (int ret = destroySockets(IPPROTO_TCP, AF_INET, addrstr)) {
300             ALOGE("Failed to destroy IPv4 sockets on %s: %s", addrstr, strerror(-ret));
301             return ret;
302         }
303     }
304     if (int ret = destroySockets(IPPROTO_TCP, AF_INET6, addrstr)) {
305         ALOGE("Failed to destroy IPv6 sockets on %s: %s", addrstr, strerror(-ret));
306         return ret;
307     }
308
309     if (mSocketsDestroyed > 0) {
310         ALOGI("Destroyed %d sockets on %s in %.1f ms", mSocketsDestroyed, addrstr, s.timeTaken());
311     }
312
313     return mSocketsDestroyed;
314 }
315
316 int SockDiag::destroySockets(uint8_t proto, const uid_t uid) {
317     mSocketsDestroyed = 0;
318     Stopwatch s;
319
320     auto destroy = [this, uid] (uint8_t proto, const inet_diag_msg *msg) {
321         if (msg != nullptr && msg->idiag_uid == uid) {
322             return this->sockDestroy(proto, msg);
323         } else {
324             return 0;
325         }
326     };
327
328     for (const int family : {AF_INET, AF_INET6}) {
329         const char *familyName = family == AF_INET ? "IPv4" : "IPv6";
330         uint32_t states = (1 << TCP_ESTABLISHED) | (1 << TCP_SYN_SENT) | (1 << TCP_SYN_RECV);
331         if (int ret = sendDumpRequest(proto, family, states)) {
332             ALOGE("Failed to dump %s sockets for UID: %s", familyName, strerror(-ret));
333             return ret;
334         }
335         if (int ret = readDiagMsg(proto, destroy)) {
336             ALOGE("Failed to destroy %s sockets for UID: %s", familyName, strerror(-ret));
337             return ret;
338         }
339     }
340
341     if (mSocketsDestroyed > 0) {
342         ALOGI("Destroyed %d sockets for UID in %.1f ms", mSocketsDestroyed, s.timeTaken());
343     }
344
345     return 0;
346 }