OSDN Git Service

a4e6bd0af0413ae4b6c01b11c9f9ab913271ab18
[hmh/hhml.git] / lib / util_tcp.cc
1 #include "util_tcp.h"
2 #include "config.h"
3 #include "http.h"
4 #include "util_const.h"
5 #include "util_base64.h"
6 #include "util_check.h"
7 #include "util_string.h"
8 #include "ustring.h"
9 #include <boost/regex.hpp>
10 #include <sys/uio.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <sys/socket.h>
14 #include <netdb.h>
15 #include <string.h>
16 #include <unistd.h>
17 #include <iostream>
18 #include <locale>
19
20 #define SSL_NAME_LEN    256
21
22 #if 0
23 static int  case_diffs (const char *s, const char *t) {
24     unsigned char  x;
25     unsigned char  y;
26
27     for (;;) {
28         x = *s++ - 'A';
29         if (x <= 'Z' - 'A')
30             x += 'a';
31         else
32             x += 'A';
33         y = *t++ - 'A';
34         if (y <= 'Z' - 'A')
35             y += 'a';
36         else
37             y += 'A';
38         if (x != y)
39             break;
40         if (!x)
41             break;
42     }
43     return ((int)(unsigned int) x) - ((int)(unsigned int) y);
44 }
45 #endif
46
47 //============================================================
48 bool  TcpBuf::empty () {
49     return start == tail;
50 }
51
52 bool  TcpBuf::fill (TcpClient& tc) {
53     ssize_t  s;
54
55     if (start != buf.begin ()) {
56         if (tail == start) {
57             start = tail = buf.begin ();
58         } else {
59             memmove (&*buf.begin (), &*start, tail - start);
60             tail -= start - buf.begin ();
61             start = buf.begin ();
62         }
63     }
64     s = tc.read (&*tail, buf.end () - tail);
65     if (s <= 0)
66         return false;
67     tail += s;
68
69     return true;
70 }
71
72 bool  TcpBuf::getln (TcpClient& tc, ustring& ans) {
73     boost::match_results<ustring::iterator>  m;
74     static uregex  re_crlf ("\\r\\n");
75
76     if (empty ())
77         fill (tc);
78     if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
79         ans.assign (start, m[0].first);
80         start = m[0].second;
81         return true;
82     } else if (fill (tc)) {
83         if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
84             ans.assign (start, m[0].first);
85             start = m[0].second;
86             return true;
87         }
88     }
89     return false;
90 }
91
92 bool  TcpBuf::getln2 (TcpClient& tc, ustring& ans) {
93     boost::match_results<ustring::iterator>  m;
94     static uregex  re_crlf ("\\r\\n");
95
96     if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
97         ans.assign (start, m[0].first);
98         start = m[0].second;
99         return true;
100     }
101     return false;
102 }
103
104 size_t  TcpBuf::size () {
105     return tail - start;
106 }
107
108 void  TcpBuf::consume () {
109     start = tail = buf.begin ();
110 }
111
112 //============================================================
113 bool  TcpClient::connect (const HostSpec* host) {
114     char pbuf[10];
115     const char*  bindaddr = NULL;
116     struct addrinfo  hints;
117     struct addrinfo*  res;
118     struct addrinfo*  res0;
119     bool  rc;
120
121     switch (host->ipv) {
122     case HostSpec::IPV4:
123         af = AF_INET;
124         break;
125     case HostSpec::IPV6:
126         af = AF_INET6;
127         break;
128     }
129
130     snprintf (pbuf, sizeof(pbuf), "%d", host->port);
131     memset (&hints, 0, sizeof (hints));
132     hints.ai_family = af;
133     hints.ai_socktype = SOCK_STREAM;
134     hints.ai_protocol = 0;
135     if ((err = getaddrinfo (host->host.c_str (), pbuf, &hints, &res0)) != 0) {
136 //      seterr (err);
137         return false;
138     }
139     for (sd = -1, res = res0; res; res = res->ai_next) {
140         if ((sd = socket (res->ai_family, res->ai_socktype, res->ai_protocol)) == -1)
141             continue;
142         if (bindaddr != NULL && *bindaddr != '\0' && ! bind (bindaddr)) {
143 //          failed to bind to 'bindaddr'
144             ::close (sd);
145             sd = -1;
146             continue;
147         }
148         if (::connect (sd, res->ai_addr, res->ai_addrlen) == 0)
149             break;
150         ::close (sd);
151         sd = -1;
152     }
153     freeaddrinfo (res0);
154     if (sd == -1) {
155 //      syserr
156         return false;
157     }
158     
159 #if 0
160     if (! reopen ()) {
161 //      syserr
162         ::close(sd);
163     }
164 #endif
165
166     rc = connect2 ();
167     if (! rc) {
168         close ();
169         return false;
170     }
171
172     noPush ();
173     return true;
174 }
175
176 bool  TcpClient::connect2 () {
177     // none
178     return true;
179 }
180
181 bool  TcpClient::bind (const char* addr) {
182     struct addrinfo  hints;
183     struct addrinfo*  res;
184     struct addrinfo*  res0;
185
186     memset (&hints, 0, sizeof (hints));
187     hints.ai_family = af;
188     hints.ai_socktype = SOCK_STREAM;
189     hints.ai_protocol = 0;
190     if ((err = getaddrinfo (addr, NULL, &hints, &res0)) != 0)
191         return false;
192     for (res = res0; res; res = res->ai_next)
193         if (::bind (sd, res->ai_addr, res->ai_addrlen) == 0)
194             return true;
195     return false;
196 }
197
198 void  TcpClient::noPush () {
199     int  val = 1;
200
201     setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &val, sizeof(val));
202 }
203
204 void  TcpClient::close () {
205     if (sd >= 0) {
206         ::close (sd);
207         sd = -1;
208     }
209 }
210
211 ssize_t  TcpClient::read (void* buf, size_t nbytes) {
212     struct timeval  now;
213     struct timeval  timeout;
214     struct timeval  delta;
215     fd_set rfds;
216     ssize_t rlen, total;
217     int r;
218
219     if (timeLimit) {
220         FD_ZERO (&rfds);
221         gettimeofday (&timeout, NULL);
222         timeout.tv_sec += timeLimit;
223     }
224
225     total = 0;
226     while (nbytes > 0) {
227         while (timeLimit && ! FD_ISSET (sd, &rfds)) {
228             FD_SET (sd, &rfds);
229             gettimeofday (&now, NULL);
230             delta.tv_sec = timeout.tv_sec - now.tv_sec;
231             delta.tv_usec = timeout.tv_usec - now.tv_usec;
232             if (delta.tv_usec < 0) {
233                 delta.tv_usec += 1000000;
234                 delta.tv_sec --;
235             }
236             if (delta.tv_sec < 0) {
237                 errno = ETIMEDOUT;
238 //              syserr();
239                 return -1;
240             }
241             errno = 0;
242             r = select (sd + 1, &rfds, NULL, NULL, &delta);
243             if (r == -1) {
244                 if (errno == EINTR)
245                     continue;
246 //              syserr();
247                 return -1;
248             }
249         }
250         rlen = read2 (buf, nbytes);
251         if (rlen == 0)
252             break;
253         if (rlen < 0) {
254             if (errno == EINTR)
255                 continue;
256             return -1;
257         }
258         nbytes -= rlen;
259         buf = (char*)buf + rlen;
260         total += rlen;
261     }
262     return total;
263 }
264
265 ssize_t  TcpClient::write (const void* buf, size_t nbytes) {
266     struct iovec iov;
267
268     iov.iov_base = __DECONST(char *, buf);
269     iov.iov_len = nbytes;
270     return write (&iov, 1);
271 }
272
273 void  TcpClient::flush_write () {
274     int  val;
275
276     val = 0;
277     setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &val, sizeof(val));
278     val = 1;
279     setsockopt(sd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val));
280 }
281
282 ssize_t  TcpClient::write (struct iovec* iov, int iovcnt) {
283     struct timeval  now;
284     struct timeval  timeout;
285     struct timeval  delta;
286     fd_set writefds;
287     ssize_t wlen, total;
288     int r;
289
290     if (timeLimit) {
291         FD_ZERO (&writefds);
292         gettimeofday (&timeout, NULL);
293         timeout.tv_sec += timeLimit;
294     }
295
296     total = 0;
297     while (iovcnt > 0) {
298         while (timeLimit && ! FD_ISSET (sd, &writefds)) {
299             FD_SET (sd, &writefds);
300             gettimeofday (&now, NULL);
301             delta.tv_sec = timeout.tv_sec - now.tv_sec;
302             delta.tv_usec = timeout.tv_usec - now.tv_usec;
303             if (delta.tv_usec < 0) {
304                 delta.tv_usec += 1000000;
305                 delta.tv_sec --;
306             }
307             if (delta.tv_sec < 0) {
308                 errno = ETIMEDOUT;
309 //              syserr();
310                 return -1;
311             }
312             errno = 0;
313             r = select (sd + 1, NULL, &writefds, NULL, &delta);
314             if (r == -1) {
315                 if (errno == EINTR)
316                     continue;
317                 return -1;
318             }
319         }
320         errno = 0;
321         wlen = write2 (iov, iovcnt);
322         if (wlen == 0) {
323             /* we consider a short write a failure */
324             errno = EPIPE;
325 //          syserr();
326             return -1;
327         }
328         if (wlen < 0) {
329             if (errno == EINTR)
330                 continue;
331             return -1;
332         }
333         total += wlen;
334         while (iovcnt > 0 && wlen >= (ssize_t)iov->iov_len) {
335             wlen -= iov->iov_len;
336             iov++;
337             iovcnt--;
338         }
339         if (iovcnt > 0) {
340             iov->iov_len -= wlen;
341             iov->iov_base = __DECONST(char *, iov->iov_base) + wlen;
342         }
343     }
344     return total;
345 }
346
347 ssize_t  TcpClient::write2 (struct iovec* iov, int iovcnt) {
348     return writev (sd, iov, iovcnt);
349 }
350
351 ssize_t  TcpClient::read2 (void* buf, size_t nbytes) {
352     return ::read (sd, buf, nbytes);
353 }
354
355 //============================================================
356 bool  SslClient::connect (const HostSpec* conhost, const HostSpec* _ephost) {
357     bool  rc;
358
359     ephost = _ephost;
360     rc = TcpClient::connect (conhost);
361     return rc;
362 }
363
364 bool  SslClient::connect2 () {
365     return sslOpen ();
366 }
367
368 bool  SslClient::sslOpen () {
369     if (!SSL_library_init ()){
370         close ();
371         throw (ustring (CharConst ("SSL library init failed\n")));
372         return false;
373     }
374
375     SSL_load_error_strings ();
376
377     ssl_meth = SSLv23_client_method ();
378     ssl_ctx.reset (SSL_CTX_new (ssl_meth));
379     SSL_CTX_set_mode (ssl_ctx.get (), SSL_MODE_AUTO_RETRY);
380     SSL_CTX_set_options (ssl_ctx.get (), SSL_OP_ALL | SSL_OP_NO_TICKET | SSL_OP_NO_SSLv2);
381     if (! setupPeerVerification (kCERTFILE)) {
382         close ();
383         throw (ustring (CharConst ("SSL certificate error\n")));
384         return false;
385     }
386
387     ssl.reset (SSL_new (ssl_ctx.get ()));
388     if (ssl.get () == NULL){
389         close ();
390         throw (ustring (CharConst ("SSL context creation failed\n")));
391         return false;
392     }
393     SSL_set_fd (ssl.get (), sd);
394 //    if (SSL_connect (ssl.get ()) == -1){
395 #if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
396     if (! SSL_set_tlsext_host_name (ssl.get (), ephost->host.c_str ())) {
397         fprintf (stderr, "TLS server name indication extension failed for host %s\n", ephost->host.c_str ());
398         close ();
399         return false;
400     }
401 #endif
402
403     int  rc;
404     while ((rc = SSL_connect (ssl.get ())) == -1) {
405         int  ssl_err = SSL_get_error (ssl.get (), rc);
406         if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) {
407             ERR_print_errors_fp (stderr);
408             close ();
409             return false;
410         }
411     }
412     if (fnoverify || verifyCA ()) {
413         sslmode = true;
414         return true;
415     } else {
416         close ();
417         return false;
418     }
419 }
420
421 void  SslClient::loadCAFile (const char* certfile, int depth) {
422     if (!SSL_CTX_load_verify_locations (ssl_ctx.get (), certfile, NULL))
423         return;
424     SSL_CTX_set_verify_depth (ssl_ctx.get (), depth);
425 }
426
427 /*
428  * Callback for SSL certificate verification, this is called on server
429  * cert verification. It takes no decision, but informs the user in case
430  * verification failed.
431  */
432 static int  fetch_ssl_cb_verify_crt (int verified, X509_STORE_CTX *ctx) {
433     X509*  crt;
434     X509_NAME*  name;
435     char*  str;
436
437     str = NULL;
438     if (! verified) {
439         if ((crt = X509_STORE_CTX_get_current_cert (ctx)) != NULL &&
440             (name = X509_get_subject_name (crt)) != NULL)
441             str = X509_NAME_oneline (name, 0, 0);
442         fprintf (stderr, "Certificate verification failed for %s\n",
443                  str != NULL ? str : "no relevant certificate");
444         OPENSSL_free (str);
445     }
446     return verified;
447 }
448
449 bool  SslClient::setupPeerVerification (const char* certfile, int depth) {
450     if (! fnoverify) {
451         SSL_CTX_set_verify (ssl_ctx.get (), SSL_VERIFY_PEER, fetch_ssl_cb_verify_crt);
452         SSL_CTX_load_verify_locations (ssl_ctx.get (), certfile, NULL);
453     }
454     return true;
455 }
456
457 bool  SslClient::verifyCA () {
458     X509_autoptr  cert;
459     long  rc;
460     ustring  lhost;
461     bool  ans;
462
463     if ((rc = SSL_get_verify_result (ssl.get ())) != X509_V_OK) {
464         switch (rc) {
465         case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY:
466             throw (ustring (CharConst ("unable to get issuer cert locally.")));
467         default:
468             std::cerr << rc << ": X509 failed\n";
469         }
470         return false;
471     }
472     cert.reset (SSL_get_peer_certificate (ssl.get ()));
473     if (!cert.get ())
474         return false;
475
476     lhost.resize (ephost->host.length ());
477     std::transform (ephost->host.begin (), ephost->host.end (), lhost.begin (), ::tolower);
478     ans = verifyCN (lhost, cert.get ()) || verifyAltName (lhost, cert.get ());
479     if (! ans) {
480         std::cerr << lhost << ": hostname not match\n";
481     }
482     return ans;
483 }
484
485 bool  SslClient::verifyCN (const ustring& lhost, X509* cert) {
486     char  buf[SSL_NAME_LEN];
487     long  rc;
488     int  len;
489     ustring  cn;
490
491     len = X509_NAME_get_text_by_NID (X509_get_subject_name (cert), NID_commonName, buf, sizeof (buf));
492
493     if (len > 0 && lhost.length () > 0) {
494         if (len >= sizeof (buf)) {
495             std::cerr << "Too long common name of the server certificate\n";
496             return false;
497         }
498         cn.assign (buf, len);
499         std::transform (cn.begin (), cn.end (), cn.begin (), ::tolower);
500         return globMatch (lhost, cn);
501     }
502     return false;
503 }
504
505 bool  SslClient::verifyAltName (const ustring& lhost, X509* cert) {
506     int  pos;
507     X509_EXTENSION*  ext = NULL;
508     GENERAL_NAMES*  names;
509     GENERAL_NAME*  name;
510     int  i, n;
511     unsigned char*  dns;
512
513     pos = X509_get_ext_by_NID (cert, NID_subject_alt_name, -1);
514     if (pos >= 0) {
515         ext = X509_get_ext (cert, pos);
516         if (ext) {
517             names = (GENERAL_NAMES*)X509V3_EXT_d2i (ext);
518             n = sk_GENERAL_NAME_num (names);
519             for (i = 0; i < n; ++ i) {
520                 name = sk_GENERAL_NAME_value (names, i);
521                 if (name->type == GEN_DNS) {
522                     ASN1_STRING_to_UTF8 (&dns, name->d.dNSName);
523 //                  std::cerr << "dns:" << dns << "\n";
524                     ustring  name (char_type (dns));
525                     std::transform (name.begin (), name.end (), name.begin (), ::tolower);
526                     if (globMatch (lhost, ustring (char_type (dns)))) {
527                         OPENSSL_free (dns);
528                         return true;
529                     } else {
530                         OPENSSL_free (dns);
531                     }
532                 }
533             }
534         }
535     }
536     return false;
537 }
538
539 bool  SslClient::globMatch (const ustring& lhost, const ustring& name) {
540     if (matchHead (name, CharConst ("*."))) {
541         // XXX RFC2459(= HTTP over TLS)では、*は一つのサブドメインのみにマッチする。
542         ustring::size_type  p = lhost.find ('.');
543         if (p != ustring::npos) {
544             return ustring (lhost.begin () + p, lhost.end ()) == ustring (name.begin () + 1, name.end ());
545         } else {
546             return false;
547         }
548     } else {
549         return lhost == name;
550     }
551 }
552
553 ssize_t  SslClient::write2 (struct iovec* iov, int iovcnt) {
554     if (sslmode) {
555         return SSL_write(ssl.get (), iov->iov_base, iov->iov_len);
556     } else {
557         return TcpClient::write2 (iov, iovcnt);
558     }
559 }
560
561 ssize_t  SslClient::read2 (void* buf, size_t nbytes) {
562     if (sslmode) {
563         return SSL_read (ssl.get (), buf, nbytes);
564     } else {
565         return TcpClient::read2 (buf, nbytes);
566     }
567 }
568
569 //============================================================
570 bool  ProxySslClient::connect2 () {
571     ustring  msg;
572
573     if (checkHostname (ephost->host)) {
574         msg.assign (CharConst ("CONNECT "));
575         msg.append (ephost->host).append (uColon).append (to_ustring (ephost->port)).append (CharConst (" HTTP/1.0" kCRLF));
576         msg.append (CharConst ("Host: ")).append (ephost->host).append (uCRLF);
577         // no User-Agent:
578         if (proxyid.length () > 0) {
579             ustring idpw;
580             idpw.assign (proxyid).append (uColon).append (proxypw);
581             msg.append (CharConst ("Proxy-Authorization: Basic ")).append (base64Encode (idpw.begin (), idpw.end ())).append (uCRLF);
582         }
583         msg.append (uCRLF);
584         write (&*msg.begin (), msg.length ()); // CONNECT HOST...
585
586         int  rc = readReplyHead ();
587         if (rc == 200) {
588             return SslClient::connect2 ();
589         } else {
590         }
591     }
592     return false;
593 }
594
595 int  ProxySslClient::readReplyHead () {
596     TcpBuf  buf;
597     ustring  line;
598     int  responseCode = 0;
599     umatch  m;
600     static uregex  re_crlf ("\\r\\n");
601
602     buf.tail += read3 (&*buf.start, buf.buf.length ());
603     if (buf.getln2 (*this, line)) {
604     } else {
605         return 0;
606     }
607
608     if (! match (line.substr (0, 5), CharConst ("HTTP/"))) {
609         return 0;               // bad protocol
610     } else {
611         ustring::size_type  p1 = line.find (' ', 0);
612         ustring::size_type  p2 = line.find (' ', p1 + 1);
613         responseCode = strtoul (line.substr (p1 + 1, p2));
614         while (buf.getln2 (*this, line)) {
615             // XXX
616             if (line.empty ())
617                 break;
618         }
619     }
620     return responseCode;
621 }
622
623 ssize_t  ProxySslClient::read3 (void* buf, size_t nbytes) {
624     struct timeval  delta;
625     fd_set rfds;
626     ssize_t rlen, total;
627     int r;
628
629     FD_ZERO (&rfds);
630     total = 0;
631
632     FD_SET (sd, &rfds);
633     delta.tv_sec = timeLimit;
634     delta.tv_usec = 0;
635     errno = 0;
636     r = select (sd + 1, &rfds, NULL, NULL, &delta);
637     if (r <= 0)
638         return total;
639     FD_CLR (sd, &rfds);
640     rlen = read2 (buf, nbytes);
641     if (rlen <= 0)
642         return total;
643     nbytes -= rlen;
644     buf = (char*)buf + rlen;
645     total += rlen;
646
647     while (nbytes > 0) {
648         FD_SET (sd, &rfds);
649         delta.tv_sec = 0;
650         delta.tv_usec = 0;
651         errno = 0;
652         r = select (sd + 1, &rfds, NULL, NULL, &delta);
653         if (r <= 0)
654             return total;
655         FD_CLR (sd, &rfds);
656         rlen = read2 (buf, nbytes);
657         if (rlen <= 0)
658             return total;
659         nbytes -= rlen;
660         buf = (char*)buf + rlen;
661         total += rlen;
662     }
663     return total;
664 }
665