OSDN Git Service

Increase the DNS TTL to 5s to fix netd_test.
[android-x86/system-netd.git] / tests / dns_responder.cpp
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "dns_responder.h"
18
19 #include <arpa/inet.h>
20 #include <fcntl.h>
21 #include <netdb.h>
22 #include <stdarg.h>
23 #include <stdio.h>
24 #include <stdlib.h>
25 #include <string.h>
26 #include <sys/epoll.h>
27 #include <sys/socket.h>
28 #include <sys/types.h>
29 #include <unistd.h>
30
31 #include <iostream>
32 #include <vector>
33
34 #include <log/log.h>
35
36 namespace test {
37
38 std::string errno2str() {
39     char error_msg[512] = { 0 };
40     if (strerror_r(errno, error_msg, sizeof(error_msg)))
41         return std::string();
42     return std::string(error_msg);
43 }
44
45 #define APLOGI(fmt, ...) ALOGI(fmt ": [%d] %s", __VA_ARGS__, errno, errno2str().c_str())
46
47 std::string str2hex(const char* buffer, size_t len) {
48     std::string str(len*2, '\0');
49     for (size_t i = 0 ; i < len ; ++i) {
50         static const char* hex = "0123456789ABCDEF";
51         uint8_t c = buffer[i];
52         str[i*2] = hex[c >> 4];
53         str[i*2 + 1] = hex[c & 0x0F];
54     }
55     return str;
56 }
57
58 std::string addr2str(const sockaddr* sa, socklen_t sa_len) {
59     char host_str[NI_MAXHOST] = { 0 };
60     int rv = getnameinfo(sa, sa_len, host_str, sizeof(host_str), nullptr, 0,
61                          NI_NUMERICHOST);
62     if (rv == 0) return std::string(host_str);
63     return std::string();
64 }
65
66 /* DNS struct helpers */
67
68 const char* dnstype2str(unsigned dnstype) {
69     static std::unordered_map<unsigned, const char*> kTypeStrs = {
70         { ns_type::ns_t_a, "A" },
71         { ns_type::ns_t_ns, "NS" },
72         { ns_type::ns_t_md, "MD" },
73         { ns_type::ns_t_mf, "MF" },
74         { ns_type::ns_t_cname, "CNAME" },
75         { ns_type::ns_t_soa, "SOA" },
76         { ns_type::ns_t_mb, "MB" },
77         { ns_type::ns_t_mb, "MG" },
78         { ns_type::ns_t_mr, "MR" },
79         { ns_type::ns_t_null, "NULL" },
80         { ns_type::ns_t_wks, "WKS" },
81         { ns_type::ns_t_ptr, "PTR" },
82         { ns_type::ns_t_hinfo, "HINFO" },
83         { ns_type::ns_t_minfo, "MINFO" },
84         { ns_type::ns_t_mx, "MX" },
85         { ns_type::ns_t_txt, "TXT" },
86         { ns_type::ns_t_rp, "RP" },
87         { ns_type::ns_t_afsdb, "AFSDB" },
88         { ns_type::ns_t_x25, "X25" },
89         { ns_type::ns_t_isdn, "ISDN" },
90         { ns_type::ns_t_rt, "RT" },
91         { ns_type::ns_t_nsap, "NSAP" },
92         { ns_type::ns_t_nsap_ptr, "NSAP-PTR" },
93         { ns_type::ns_t_sig, "SIG" },
94         { ns_type::ns_t_key, "KEY" },
95         { ns_type::ns_t_px, "PX" },
96         { ns_type::ns_t_gpos, "GPOS" },
97         { ns_type::ns_t_aaaa, "AAAA" },
98         { ns_type::ns_t_loc, "LOC" },
99         { ns_type::ns_t_nxt, "NXT" },
100         { ns_type::ns_t_eid, "EID" },
101         { ns_type::ns_t_nimloc, "NIMLOC" },
102         { ns_type::ns_t_srv, "SRV" },
103         { ns_type::ns_t_naptr, "NAPTR" },
104         { ns_type::ns_t_kx, "KX" },
105         { ns_type::ns_t_cert, "CERT" },
106         { ns_type::ns_t_a6, "A6" },
107         { ns_type::ns_t_dname, "DNAME" },
108         { ns_type::ns_t_sink, "SINK" },
109         { ns_type::ns_t_opt, "OPT" },
110         { ns_type::ns_t_apl, "APL" },
111         { ns_type::ns_t_tkey, "TKEY" },
112         { ns_type::ns_t_tsig, "TSIG" },
113         { ns_type::ns_t_ixfr, "IXFR" },
114         { ns_type::ns_t_axfr, "AXFR" },
115         { ns_type::ns_t_mailb, "MAILB" },
116         { ns_type::ns_t_maila, "MAILA" },
117         { ns_type::ns_t_any, "ANY" },
118         { ns_type::ns_t_zxfr, "ZXFR" },
119     };
120     auto it = kTypeStrs.find(dnstype);
121     static const char* kUnknownStr{ "UNKNOWN" };
122     if (it == kTypeStrs.end()) return kUnknownStr;
123     return it->second;
124 }
125
126 const char* dnsclass2str(unsigned dnsclass) {
127     static std::unordered_map<unsigned, const char*> kClassStrs = {
128         { ns_class::ns_c_in , "Internet" },
129         { 2, "CSNet" },
130         { ns_class::ns_c_chaos, "ChaosNet" },
131         { ns_class::ns_c_hs, "Hesiod" },
132         { ns_class::ns_c_none, "none" },
133         { ns_class::ns_c_any, "any" }
134     };
135     auto it = kClassStrs.find(dnsclass);
136     static const char* kUnknownStr{ "UNKNOWN" };
137     if (it == kClassStrs.end()) return kUnknownStr;
138     return it->second;
139     return "unknown";
140 }
141
142 struct DNSName {
143     std::string name;
144     const char* read(const char* buffer, const char* buffer_end);
145     char* write(char* buffer, const char* buffer_end) const;
146     const char* toString() const;
147 private:
148     const char* parseField(const char* buffer, const char* buffer_end,
149                            bool* last);
150 };
151
152 const char* DNSName::toString() const {
153     return name.c_str();
154 }
155
156 const char* DNSName::read(const char* buffer, const char* buffer_end) {
157     const char* cur = buffer;
158     bool last = false;
159     do {
160         cur = parseField(cur, buffer_end, &last);
161         if (cur == nullptr) {
162             ALOGI("parsing failed at line %d", __LINE__);
163             return nullptr;
164         }
165     } while (!last);
166     return cur;
167 }
168
169 char* DNSName::write(char* buffer, const char* buffer_end) const {
170     char* buffer_cur = buffer;
171     for (size_t pos = 0 ; pos < name.size() ; ) {
172         size_t dot_pos = name.find('.', pos);
173         if (dot_pos == std::string::npos) {
174             // Sanity check, should never happen unless parseField is broken.
175             ALOGI("logic error: all names are expected to end with a '.'");
176             return nullptr;
177         }
178         size_t len = dot_pos - pos;
179         if (len >= 256) {
180             ALOGI("name component '%s' is %zu long, but max is 255",
181                     name.substr(pos, dot_pos - pos).c_str(), len);
182             return nullptr;
183         }
184         if (buffer_cur + sizeof(uint8_t) + len > buffer_end) {
185             ALOGI("buffer overflow at line %d", __LINE__);
186             return nullptr;
187         }
188         *buffer_cur++ = len;
189         buffer_cur = std::copy(std::next(name.begin(), pos),
190                                std::next(name.begin(), dot_pos),
191                                buffer_cur);
192         pos = dot_pos + 1;
193     }
194     // Write final zero.
195     *buffer_cur++ = 0;
196     return buffer_cur;
197 }
198
199 const char* DNSName::parseField(const char* buffer, const char* buffer_end,
200                                 bool* last) {
201     if (buffer + sizeof(uint8_t) > buffer_end) {
202         ALOGI("parsing failed at line %d", __LINE__);
203         return nullptr;
204     }
205     unsigned field_type = *buffer >> 6;
206     unsigned ofs = *buffer & 0x3F;
207     const char* cur = buffer + sizeof(uint8_t);
208     if (field_type == 0) {
209         // length + name component
210         if (ofs == 0) {
211             *last = true;
212             return cur;
213         }
214         if (cur + ofs > buffer_end) {
215             ALOGI("parsing failed at line %d", __LINE__);
216             return nullptr;
217         }
218         name.append(cur, ofs);
219         name.push_back('.');
220         return cur + ofs;
221     } else if (field_type == 3) {
222         ALOGI("name compression not implemented");
223         return nullptr;
224     }
225     ALOGI("invalid name field type");
226     return nullptr;
227 }
228
229 struct DNSQuestion {
230     DNSName qname;
231     unsigned qtype;
232     unsigned qclass;
233     const char* read(const char* buffer, const char* buffer_end);
234     char* write(char* buffer, const char* buffer_end) const;
235     std::string toString() const;
236 };
237
238 const char* DNSQuestion::read(const char* buffer, const char* buffer_end) {
239     const char* cur = qname.read(buffer, buffer_end);
240     if (cur == nullptr) {
241         ALOGI("parsing failed at line %d", __LINE__);
242         return nullptr;
243     }
244     if (cur + 2*sizeof(uint16_t) > buffer_end) {
245         ALOGI("parsing failed at line %d", __LINE__);
246         return nullptr;
247     }
248     qtype = ntohs(*reinterpret_cast<const uint16_t*>(cur));
249     qclass = ntohs(*reinterpret_cast<const uint16_t*>(cur + sizeof(uint16_t)));
250     return cur + 2*sizeof(uint16_t);
251 }
252
253 char* DNSQuestion::write(char* buffer, const char* buffer_end) const {
254     char* buffer_cur = qname.write(buffer, buffer_end);
255     if (buffer_cur == nullptr) return nullptr;
256     if (buffer_cur + 2*sizeof(uint16_t) > buffer_end) {
257         ALOGI("buffer overflow on line %d", __LINE__);
258         return nullptr;
259     }
260     *reinterpret_cast<uint16_t*>(buffer_cur) = htons(qtype);
261     *reinterpret_cast<uint16_t*>(buffer_cur + sizeof(uint16_t)) =
262             htons(qclass);
263     return buffer_cur + 2*sizeof(uint16_t);
264 }
265
266 std::string DNSQuestion::toString() const {
267     char buffer[4096];
268     int len = snprintf(buffer, sizeof(buffer), "Q<%s,%s,%s>", qname.toString(),
269                        dnstype2str(qtype), dnsclass2str(qclass));
270     return std::string(buffer, len);
271 }
272
273 struct DNSRecord {
274     DNSName name;
275     unsigned rtype;
276     unsigned rclass;
277     unsigned ttl;
278     std::vector<char> rdata;
279     const char* read(const char* buffer, const char* buffer_end);
280     char* write(char* buffer, const char* buffer_end) const;
281     std::string toString() const;
282 private:
283     struct IntFields {
284         uint16_t rtype;
285         uint16_t rclass;
286         uint32_t ttl;
287         uint16_t rdlen;
288     } __attribute__((__packed__));
289
290     const char* readIntFields(const char* buffer, const char* buffer_end,
291             unsigned* rdlen);
292     char* writeIntFields(unsigned rdlen, char* buffer,
293                          const char* buffer_end) const;
294 };
295
296 const char* DNSRecord::read(const char* buffer, const char* buffer_end) {
297     const char* cur = name.read(buffer, buffer_end);
298     if (cur == nullptr) {
299         ALOGI("parsing failed at line %d", __LINE__);
300         return nullptr;
301     }
302     unsigned rdlen = 0;
303     cur = readIntFields(cur, buffer_end, &rdlen);
304     if (cur == nullptr) {
305         ALOGI("parsing failed at line %d", __LINE__);
306         return nullptr;
307     }
308     if (cur + rdlen > buffer_end) {
309         ALOGI("parsing failed at line %d", __LINE__);
310         return nullptr;
311     }
312     rdata.assign(cur, cur + rdlen);
313     return cur + rdlen;
314 }
315
316 char* DNSRecord::write(char* buffer, const char* buffer_end) const {
317     char* buffer_cur = name.write(buffer, buffer_end);
318     if (buffer_cur == nullptr) return nullptr;
319     buffer_cur = writeIntFields(rdata.size(), buffer_cur, buffer_end);
320     if (buffer_cur == nullptr) return nullptr;
321     if (buffer_cur + rdata.size() > buffer_end) {
322         ALOGI("buffer overflow on line %d", __LINE__);
323         return nullptr;
324     }
325     return std::copy(rdata.begin(), rdata.end(), buffer_cur);
326 }
327
328 std::string DNSRecord::toString() const {
329     char buffer[4096];
330     int len = snprintf(buffer, sizeof(buffer), "R<%s,%s,%s>", name.toString(),
331                        dnstype2str(rtype), dnsclass2str(rclass));
332     return std::string(buffer, len);
333 }
334
335 const char* DNSRecord::readIntFields(const char* buffer, const char* buffer_end,
336                                      unsigned* rdlen) {
337     if (buffer + sizeof(IntFields) > buffer_end ) {
338         ALOGI("parsing failed at line %d", __LINE__);
339         return nullptr;
340     }
341     const auto& intfields = *reinterpret_cast<const IntFields*>(buffer);
342     rtype = ntohs(intfields.rtype);
343     rclass = ntohs(intfields.rclass);
344     ttl = ntohl(intfields.ttl);
345     *rdlen = ntohs(intfields.rdlen);
346     return buffer + sizeof(IntFields);
347 }
348
349 char* DNSRecord::writeIntFields(unsigned rdlen, char* buffer,
350                                 const char* buffer_end) const {
351     if (buffer + sizeof(IntFields) > buffer_end ) {
352         ALOGI("buffer overflow on line %d", __LINE__);
353         return nullptr;
354     }
355     auto& intfields = *reinterpret_cast<IntFields*>(buffer);
356     intfields.rtype = htons(rtype);
357     intfields.rclass = htons(rclass);
358     intfields.ttl = htonl(ttl);
359     intfields.rdlen = htons(rdlen);
360     return buffer + sizeof(IntFields);
361 }
362
363 struct DNSHeader {
364     unsigned id;
365     bool ra;
366     uint8_t rcode;
367     bool qr;
368     uint8_t opcode;
369     bool aa;
370     bool tr;
371     bool rd;
372     std::vector<DNSQuestion> questions;
373     std::vector<DNSRecord> answers;
374     std::vector<DNSRecord> authorities;
375     std::vector<DNSRecord> additionals;
376     const char* read(const char* buffer, const char* buffer_end);
377     char* write(char* buffer, const char* buffer_end) const;
378     std::string toString() const;
379
380 private:
381     struct Header {
382         uint16_t id;
383         uint8_t flags0;
384         uint8_t flags1;
385         uint16_t qdcount;
386         uint16_t ancount;
387         uint16_t nscount;
388         uint16_t arcount;
389     } __attribute__((__packed__));
390
391     const char* readHeader(const char* buffer, const char* buffer_end,
392                            unsigned* qdcount, unsigned* ancount,
393                            unsigned* nscount, unsigned* arcount);
394 };
395
396 const char* DNSHeader::read(const char* buffer, const char* buffer_end) {
397     unsigned qdcount;
398     unsigned ancount;
399     unsigned nscount;
400     unsigned arcount;
401     const char* cur = readHeader(buffer, buffer_end, &qdcount, &ancount,
402                                  &nscount, &arcount);
403     if (cur == nullptr) {
404         ALOGI("parsing failed at line %d", __LINE__);
405         return nullptr;
406     }
407     if (qdcount) {
408         questions.resize(qdcount);
409         for (unsigned i = 0 ; i < qdcount ; ++i) {
410             cur = questions[i].read(cur, buffer_end);
411             if (cur == nullptr) {
412                 ALOGI("parsing failed at line %d", __LINE__);
413                 return nullptr;
414             }
415         }
416     }
417     if (ancount) {
418         answers.resize(ancount);
419         for (unsigned i = 0 ; i < ancount ; ++i) {
420             cur = answers[i].read(cur, buffer_end);
421             if (cur == nullptr) {
422                 ALOGI("parsing failed at line %d", __LINE__);
423                 return nullptr;
424             }
425         }
426     }
427     if (nscount) {
428         authorities.resize(nscount);
429         for (unsigned i = 0 ; i < nscount ; ++i) {
430             cur = authorities[i].read(cur, buffer_end);
431             if (cur == nullptr) {
432                 ALOGI("parsing failed at line %d", __LINE__);
433                 return nullptr;
434             }
435         }
436     }
437     if (arcount) {
438         additionals.resize(arcount);
439         for (unsigned i = 0 ; i < arcount ; ++i) {
440             cur = additionals[i].read(cur, buffer_end);
441             if (cur == nullptr) {
442                 ALOGI("parsing failed at line %d", __LINE__);
443                 return nullptr;
444             }
445         }
446     }
447     return cur;
448 }
449
450 char* DNSHeader::write(char* buffer, const char* buffer_end) const {
451     if (buffer + sizeof(Header) > buffer_end) {
452         ALOGI("buffer overflow on line %d", __LINE__);
453         return nullptr;
454     }
455     Header& header = *reinterpret_cast<Header*>(buffer);
456     // bytes 0-1
457     header.id = htons(id);
458     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
459     header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
460     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
461     header.flags1 = rcode;
462     // rest of header
463     header.qdcount = htons(questions.size());
464     header.ancount = htons(answers.size());
465     header.nscount = htons(authorities.size());
466     header.arcount = htons(additionals.size());
467     char* buffer_cur = buffer + sizeof(Header);
468     for (const DNSQuestion& question : questions) {
469         buffer_cur = question.write(buffer_cur, buffer_end);
470         if (buffer_cur == nullptr) return nullptr;
471     }
472     for (const DNSRecord& answer : answers) {
473         buffer_cur = answer.write(buffer_cur, buffer_end);
474         if (buffer_cur == nullptr) return nullptr;
475     }
476     for (const DNSRecord& authority : authorities) {
477         buffer_cur = authority.write(buffer_cur, buffer_end);
478         if (buffer_cur == nullptr) return nullptr;
479     }
480     for (const DNSRecord& additional : additionals) {
481         buffer_cur = additional.write(buffer_cur, buffer_end);
482         if (buffer_cur == nullptr) return nullptr;
483     }
484     return buffer_cur;
485 }
486
487 std::string DNSHeader::toString() const {
488     // TODO
489     return std::string();
490 }
491
492 const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
493                                   unsigned* qdcount, unsigned* ancount,
494                                   unsigned* nscount, unsigned* arcount) {
495     if (buffer + sizeof(Header) > buffer_end)
496         return 0;
497     const auto& header = *reinterpret_cast<const Header*>(buffer);
498     // bytes 0-1
499     id = ntohs(header.id);
500     // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
501     qr = header.flags0 >> 7;
502     opcode = (header.flags0 >> 3) & 0x0F;
503     aa = (header.flags0 >> 2) & 1;
504     tr = (header.flags0 >> 1) & 1;
505     rd = header.flags0 & 1;
506     // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
507     ra = header.flags1 >> 7;
508     rcode = header.flags1 & 0xF;
509     // rest of header
510     *qdcount = ntohs(header.qdcount);
511     *ancount = ntohs(header.ancount);
512     *nscount = ntohs(header.nscount);
513     *arcount = ntohs(header.arcount);
514     return buffer + sizeof(Header);
515 }
516
517 /* DNS responder */
518
519 DNSResponder::DNSResponder(std::string listen_address,
520                            std::string listen_service, int poll_timeout_ms,
521                            uint16_t error_rcode, double response_probability) :
522     listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
523     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
524     response_probability_(response_probability),
525     socket_(-1), epoll_fd_(-1), terminate_(false) { }
526
527 DNSResponder::~DNSResponder() {
528     stopServer();
529 }
530
531 void DNSResponder::addMapping(const char* name, ns_type type,
532         const char* addr) {
533     std::lock_guard<std::mutex> lock(mappings_mutex_);
534     auto it = mappings_.find(QueryKey(name, type));
535     if (it != mappings_.end()) {
536         ALOGI("Overwriting mapping for (%s, %s), previous address %s, new "
537             "address %s", name, dnstype2str(type), it->second.c_str(),
538             addr);
539         it->second = addr;
540         return;
541     }
542     mappings_.emplace(std::piecewise_construct,
543                       std::forward_as_tuple(name, type),
544                       std::forward_as_tuple(addr));
545 }
546
547 void DNSResponder::removeMapping(const char* name, ns_type type) {
548     std::lock_guard<std::mutex> lock(mappings_mutex_);
549     auto it = mappings_.find(QueryKey(name, type));
550     if (it != mappings_.end()) {
551         ALOGI("Cannot remove mapping mapping from (%s, %s), not present", name,
552             dnstype2str(type));
553         return;
554     }
555     mappings_.erase(it);
556 }
557
558 void DNSResponder::setResponseProbability(double response_probability) {
559     response_probability_ = response_probability;
560 }
561
562 bool DNSResponder::running() const {
563     return socket_ != -1;
564 }
565
566 bool DNSResponder::startServer() {
567     if (running()) {
568         ALOGI("server already running");
569         return false;
570     }
571     addrinfo ai_hints{
572         .ai_family = AF_UNSPEC,
573         .ai_socktype = SOCK_DGRAM,
574         .ai_flags = AI_PASSIVE
575     };
576     addrinfo* ai_res;
577     int rv = getaddrinfo(listen_address_.c_str(), listen_service_.c_str(),
578                          &ai_hints, &ai_res);
579     if (rv) {
580         ALOGI("getaddrinfo(%s, %s) failed: %s", listen_address_.c_str(),
581             listen_service_.c_str(), gai_strerror(rv));
582         return false;
583     }
584     int s = -1;
585     for (const addrinfo* ai = ai_res ; ai ; ai = ai->ai_next) {
586         s = socket(ai->ai_family, ai->ai_socktype, ai->ai_protocol);
587         if (s < 0) continue;
588         const int one = 1;
589         setsockopt(s, SOL_SOCKET, SO_REUSEPORT, &one, sizeof(one));
590         if (bind(s, ai->ai_addr, ai->ai_addrlen)) {
591             APLOGI("bind failed for socket %d", s);
592             close(s);
593             s = -1;
594             continue;
595         }
596         std::string host_str = addr2str(ai->ai_addr, ai->ai_addrlen);
597         ALOGI("bound to UDP %s:%s", host_str.c_str(), listen_service_.c_str());
598         break;
599     }
600     freeaddrinfo(ai_res);
601     if (s < 0) {
602         ALOGI("bind() failed");
603         return false;
604     }
605
606     int flags = fcntl(s, F_GETFL, 0);
607     if (flags < 0) flags = 0;
608     if (fcntl(s, F_SETFL, flags | O_NONBLOCK) < 0) {
609         APLOGI("fcntl(F_SETFL) failed for socket %d", s);
610         close(s);
611         return false;
612     }
613
614     int ep_fd = epoll_create(1);
615     if (ep_fd < 0) {
616         char error_msg[512] = { 0 };
617         if (strerror_r(errno, error_msg, sizeof(error_msg)))
618             strncpy(error_msg, "UNKNOWN", sizeof(error_msg));
619         APLOGI("epoll_create() failed: %s", error_msg);
620         close(s);
621         return false;
622     }
623     epoll_event ev;
624     ev.events = EPOLLIN;
625     ev.data.fd = s;
626     if (epoll_ctl(ep_fd, EPOLL_CTL_ADD, s, &ev) < 0) {
627         APLOGI("epoll_ctl() failed for socket %d", s);
628         close(ep_fd);
629         close(s);
630         return false;
631     }
632
633     epoll_fd_ = ep_fd;
634     socket_ = s;
635     {
636         std::lock_guard<std::mutex> lock(update_mutex_);
637         handler_thread_ = std::thread(&DNSResponder::requestHandler, this);
638     }
639     ALOGI("server started successfully");
640     return true;
641 }
642
643 bool DNSResponder::stopServer() {
644     std::lock_guard<std::mutex> lock(update_mutex_);
645     if (!running()) {
646         ALOGI("server not running");
647         return false;
648     }
649     if (terminate_) {
650         ALOGI("LOGIC ERROR");
651         return false;
652     }
653     ALOGI("stopping server");
654     terminate_ = true;
655     handler_thread_.join();
656     close(epoll_fd_);
657     close(socket_);
658     terminate_ = false;
659     socket_ = -1;
660     ALOGI("server stopped successfully");
661     return true;
662 }
663
664 std::vector<std::pair<std::string, ns_type >> DNSResponder::queries() const {
665     std::lock_guard<std::mutex> lock(queries_mutex_);
666     return queries_;
667 }
668
669 void DNSResponder::clearQueries() {
670     std::lock_guard<std::mutex> lock(queries_mutex_);
671     queries_.clear();
672 }
673
674 void DNSResponder::requestHandler() {
675     epoll_event evs[1];
676     while (!terminate_) {
677         int n = epoll_wait(epoll_fd_, evs, 1, poll_timeout_ms_);
678         if (n == 0) continue;
679         if (n < 0) {
680             ALOGI("epoll_wait() failed");
681             // TODO(imaipi): terminate on error.
682             return;
683         }
684         char buffer[4096];
685         sockaddr_storage sa;
686         socklen_t sa_len = sizeof(sa);
687         ssize_t len;
688         do {
689             len = recvfrom(socket_, buffer, sizeof(buffer), 0,
690                            (sockaddr*) &sa, &sa_len);
691         } while (len < 0 && (errno == EAGAIN || errno == EINTR));
692         if (len <= 0) {
693             ALOGI("recvfrom() failed");
694             continue;
695         }
696         ALOGI("read %zd bytes", len);
697         char response[4096];
698         size_t response_len = sizeof(response);
699         if (handleDNSRequest(buffer, len, response, &response_len) &&
700             response_len > 0) {
701             len = sendto(socket_, response, response_len, 0,
702                          reinterpret_cast<const sockaddr*>(&sa), sa_len);
703             std::string host_str =
704                 addr2str(reinterpret_cast<const sockaddr*>(&sa), sa_len);
705             if (len > 0) {
706                 ALOGI("sent %zu bytes to %s", len, host_str.c_str());
707             } else {
708                 APLOGI("sendto() failed for %s", host_str.c_str());
709             }
710             // Test that the response is actually a correct DNS message.
711             const char* response_end = response + len;
712             DNSHeader header;
713             const char* cur = header.read(response, response_end);
714             if (cur == nullptr) ALOGI("response is flawed");
715
716         } else {
717             ALOGI("not responding");
718         }
719     }
720 }
721
722 bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
723                                     char* response, size_t* response_len)
724                                     const {
725     ALOGI("request: '%s'", str2hex(buffer, len).c_str());
726     const char* buffer_end = buffer + len;
727     DNSHeader header;
728     const char* cur = header.read(buffer, buffer_end);
729     // TODO(imaipi): for now, unparsable messages are silently dropped, fix.
730     if (cur == nullptr) {
731         ALOGI("failed to parse query");
732         return false;
733     }
734     if (header.qr) {
735         ALOGI("response received instead of a query");
736         return false;
737     }
738     if (header.opcode != ns_opcode::ns_o_query) {
739         ALOGI("unsupported request opcode received");
740         return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
741                                  response_len);
742     }
743     if (header.questions.empty()) {
744         ALOGI("no questions present");
745         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
746                                  response_len);
747     }
748     if (!header.answers.empty()) {
749         ALOGI("already %zu answers present in query", header.answers.size());
750         return makeErrorResponse(&header, ns_rcode::ns_r_formerr, response,
751                                  response_len);
752     }
753     {
754         std::lock_guard<std::mutex> lock(queries_mutex_);
755         for (const DNSQuestion& question : header.questions) {
756             queries_.push_back(make_pair(question.qname.name,
757                                          ns_type(question.qtype)));
758         }
759     }
760
761     // Ignore requests with the preset probability.
762     auto constexpr bound = std::numeric_limits<unsigned>::max();
763     if (arc4random_uniform(bound) > bound*response_probability_) {
764         ALOGI("returning SRVFAIL in accordance with probability distribution");
765         return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
766                                  response_len);
767     }
768
769     for (const DNSQuestion& question : header.questions) {
770         if (question.qclass != ns_class::ns_c_in &&
771             question.qclass != ns_class::ns_c_any) {
772             ALOGI("unsupported question class %u", question.qclass);
773             return makeErrorResponse(&header, ns_rcode::ns_r_notimpl, response,
774                                      response_len);
775         }
776         if (!addAnswerRecords(question, &header.answers)) {
777             return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
778                                      response_len);
779         }
780     }
781     header.qr = true;
782     char* response_cur = header.write(response, response + *response_len);
783     if (response_cur == nullptr) {
784         return false;
785     }
786     *response_len = response_cur - response;
787     return true;
788 }
789
790 bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
791                                     std::vector<DNSRecord>* answers) const {
792     auto it = mappings_.find(QueryKey(question.qname.name, question.qtype));
793     if (it == mappings_.end()) {
794         // TODO(imaipi): handle correctly
795         ALOGI("no mapping found for %s %s, lazily refusing to add an answer",
796             question.qname.name.c_str(), dnstype2str(question.qtype));
797         return true;
798     }
799     ALOGI("mapping found for %s %s: %s", question.qname.name.c_str(),
800         dnstype2str(question.qtype), it->second.c_str());
801     DNSRecord record;
802     record.name = question.qname;
803     record.rtype = question.qtype;
804     record.rclass = ns_class::ns_c_in;
805     record.ttl = 5;  // seconds
806     if (question.qtype == ns_type::ns_t_a) {
807         record.rdata.resize(4);
808         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
809             ALOGI("inet_pton(AF_INET, %s) failed", it->second.c_str());
810             return false;
811         }
812     } else if (question.qtype == ns_type::ns_t_aaaa) {
813         record.rdata.resize(16);
814         if (inet_pton(AF_INET6, it->second.c_str(), record.rdata.data()) != 1) {
815             ALOGI("inet_pton(AF_INET6, %s) failed", it->second.c_str());
816             return false;
817         }
818     } else {
819         ALOGI("unhandled qtype %s", dnstype2str(question.qtype));
820         return false;
821     }
822     answers->push_back(std::move(record));
823     return true;
824 }
825
826 bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
827                                      char* response, size_t* response_len)
828                                      const {
829     header->answers.clear();
830     header->authorities.clear();
831     header->additionals.clear();
832     header->rcode = rcode;
833     header->qr = true;
834     char* response_cur = header->write(response, response + *response_len);
835     if (response_cur == nullptr) return false;
836     *response_len = response_cur - response;
837     return true;
838 }
839
840 }  // namespace test
841