OSDN Git Service

Increase the DNS TTL to 5s to fix netd_test.
[android-x86/system-netd.git] / tests / dns_responder.cpp
index e7baeca..6094ca1 100644 (file)
 #include <netdb.h>
 #include <stdarg.h>
 #include <stdio.h>
+#include <stdlib.h>
 #include <string.h>
 #include <sys/epoll.h>
 #include <sys/socket.h>
 #include <sys/types.h>
 #include <unistd.h>
 
+#include <iostream>
 #include <vector>
 
 #include <log/log.h>
@@ -365,7 +367,7 @@ struct DNSHeader {
     bool qr;
     uint8_t opcode;
     bool aa;
-    bool tc;
+    bool tr;
     bool rd;
     std::vector<DNSQuestion> questions;
     std::vector<DNSRecord> answers;
@@ -378,8 +380,8 @@ struct DNSHeader {
 private:
     struct Header {
         uint16_t id;
-        uint8_t rcode;
-        uint8_t op;
+        uint8_t flags0;
+        uint8_t flags1;
         uint16_t qdcount;
         uint16_t ancount;
         uint16_t nscount;
@@ -451,9 +453,13 @@ char* DNSHeader::write(char* buffer, const char* buffer_end) const {
         return nullptr;
     }
     Header& header = *reinterpret_cast<Header*>(buffer);
+    // bytes 0-1
     header.id = htons(id);
-    header.rcode = (rcode << 4) | ra;
-    header.op = (rd << 7) | (tc << 6) | (aa << 5) | (opcode << 1) | qr;
+    // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
+    header.flags0 = (qr << 7) | (opcode << 3) | (aa << 2) | (tr << 1) | rd;
+    // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
+    header.flags1 = rcode;
+    // rest of header
     header.qdcount = htons(questions.size());
     header.ancount = htons(answers.size());
     header.nscount = htons(authorities.size());
@@ -489,14 +495,18 @@ const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
     if (buffer + sizeof(Header) > buffer_end)
         return 0;
     const auto& header = *reinterpret_cast<const Header*>(buffer);
+    // bytes 0-1
     id = ntohs(header.id);
-    ra = header.rcode & 1;
-    rcode = header.rcode >> 4;
-    qr = header.op & 1;
-    opcode = (header.op >> 1) & 0x0F;
-    aa = (header.op >> 5) & 1;
-    tc = (header.op >> 6) & 1;
-    rd = header.op >> 7;
+    // byte 2: 7:qr, 3-6:opcode, 2:aa, 1:tr, 0:rd
+    qr = header.flags0 >> 7;
+    opcode = (header.flags0 >> 3) & 0x0F;
+    aa = (header.flags0 >> 2) & 1;
+    tr = (header.flags0 >> 1) & 1;
+    rd = header.flags0 & 1;
+    // byte 3: 7:ra, 6:zero, 5:ad, 4:cd, 0-3:rcode
+    ra = header.flags1 >> 7;
+    rcode = header.flags1 & 0xF;
+    // rest of header
     *qdcount = ntohs(header.qdcount);
     *ancount = ntohs(header.ancount);
     *nscount = ntohs(header.nscount);
@@ -506,11 +516,12 @@ const char* DNSHeader::readHeader(const char* buffer, const char* buffer_end,
 
 /* DNS responder */
 
-DNSResponder::DNSResponder(const char* listen_address,
-                           const char* listen_service, int poll_timeout_ms,
-                           uint16_t error_rcode) :
-    listen_address_(listen_address), listen_service_(listen_service),
+DNSResponder::DNSResponder(std::string listen_address,
+                           std::string listen_service, int poll_timeout_ms,
+                           uint16_t error_rcode, double response_probability) :
+    listen_address_(std::move(listen_address)), listen_service_(std::move(listen_service)),
     poll_timeout_ms_(poll_timeout_ms), error_rcode_(error_rcode),
+    response_probability_(response_probability),
     socket_(-1), epoll_fd_(-1), terminate_(false) { }
 
 DNSResponder::~DNSResponder() {
@@ -544,6 +555,10 @@ void DNSResponder::removeMapping(const char* name, ns_type type) {
     mappings_.erase(it);
 }
 
+void DNSResponder::setResponseProbability(double response_probability) {
+    response_probability_ = response_probability;
+}
+
 bool DNSResponder::running() const {
     return socket_ != -1;
 }
@@ -742,6 +757,15 @@ bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
                                          ns_type(question.qtype)));
         }
     }
+
+    // Ignore requests with the preset probability.
+    auto constexpr bound = std::numeric_limits<unsigned>::max();
+    if (arc4random_uniform(bound) > bound*response_probability_) {
+        ALOGI("returning SRVFAIL in accordance with probability distribution");
+        return makeErrorResponse(&header, ns_rcode::ns_r_servfail, response,
+                                 response_len);
+    }
+
     for (const DNSQuestion& question : header.questions) {
         if (question.qclass != ns_class::ns_c_in &&
             question.qclass != ns_class::ns_c_any) {
@@ -754,6 +778,7 @@ bool DNSResponder::handleDNSRequest(const char* buffer, ssize_t len,
                                      response_len);
         }
     }
+    header.qr = true;
     char* response_cur = header.write(response, response + *response_len);
     if (response_cur == nullptr) {
         return false;
@@ -777,7 +802,7 @@ bool DNSResponder::addAnswerRecords(const DNSQuestion& question,
     record.name = question.qname;
     record.rtype = question.qtype;
     record.rclass = ns_class::ns_c_in;
-    record.ttl = 1;
+    record.ttl = 5;  // seconds
     if (question.qtype == ns_type::ns_t_a) {
         record.rdata.resize(4);
         if (inet_pton(AF_INET, it->second.c_str(), record.rdata.data()) != 1) {
@@ -805,6 +830,7 @@ bool DNSResponder::makeErrorResponse(DNSHeader* header, ns_rcode rcode,
     header->authorities.clear();
     header->additionals.clear();
     header->rcode = rcode;
+    header->qr = true;
     char* response_cur = header->write(response, response + *response_len);
     if (response_cur == nullptr) return false;
     *response_len = response_cur - response;