4 #include "util_const.h"
5 #include "util_base64.h"
6 #include "util_check.h"
7 #include "util_string.h"
9 #include <boost/regex.hpp>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 #include <sys/socket.h>
20 #define SSL_NAME_LEN 256
23 static int case_diffs (const char *s, const char *t) {
43 return ((int)(unsigned int) x) - ((int)(unsigned int) y);
47 //============================================================
48 bool TcpBuf::empty () {
52 bool TcpBuf::fill (TcpClient& tc) {
55 if (start != buf.begin ()) {
57 start = tail = buf.begin ();
59 memmove (&*buf.begin (), &*start, tail - start);
60 tail -= start - buf.begin ();
64 s = tc.read (&*tail, buf.end () - tail);
72 bool TcpBuf::getln (TcpClient& tc, ustring& ans) {
73 boost::match_results<ustring::iterator> m;
74 static uregex re_crlf ("\\r\\n");
78 if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
79 ans.assign (start, m[0].first);
82 } else if (fill (tc)) {
83 if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
84 ans.assign (start, m[0].first);
92 bool TcpBuf::getln2 (TcpClient& tc, ustring& ans) {
93 boost::match_results<ustring::iterator> m;
94 static uregex re_crlf ("\\r\\n");
96 if (regex_search (start, tail, m, re_crlf, boost::regex_constants::match_single_line)) {
97 ans.assign (start, m[0].first);
104 size_t TcpBuf::size () {
108 void TcpBuf::consume () {
109 start = tail = buf.begin ();
112 //============================================================
113 bool TcpClient::connect (const HostSpec* host) {
115 const char* bindaddr = NULL;
116 struct addrinfo hints;
117 struct addrinfo* res;
118 struct addrinfo* res0;
130 snprintf (pbuf, sizeof(pbuf), "%d", host->port);
131 memset (&hints, 0, sizeof (hints));
132 hints.ai_family = af;
133 hints.ai_socktype = SOCK_STREAM;
134 hints.ai_protocol = 0;
135 if ((err = getaddrinfo (host->host.c_str (), pbuf, &hints, &res0)) != 0) {
139 for (sd = -1, res = res0; res; res = res->ai_next) {
140 if ((sd = socket (res->ai_family, res->ai_socktype, res->ai_protocol)) == -1)
142 if (bindaddr != NULL && *bindaddr != '\0' && ! bind (bindaddr)) {
143 // failed to bind to 'bindaddr'
148 if (::connect (sd, res->ai_addr, res->ai_addrlen) == 0)
176 bool TcpClient::connect2 () {
181 bool TcpClient::bind (const char* addr) {
182 struct addrinfo hints;
183 struct addrinfo* res;
184 struct addrinfo* res0;
186 memset (&hints, 0, sizeof (hints));
187 hints.ai_family = af;
188 hints.ai_socktype = SOCK_STREAM;
189 hints.ai_protocol = 0;
190 if ((err = getaddrinfo (addr, NULL, &hints, &res0)) != 0)
192 for (res = res0; res; res = res->ai_next)
193 if (::bind (sd, res->ai_addr, res->ai_addrlen) == 0)
198 void TcpClient::noPush () {
201 setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &val, sizeof(val));
204 void TcpClient::close () {
211 ssize_t TcpClient::read (void* buf, size_t nbytes) {
213 struct timeval timeout;
214 struct timeval delta;
221 gettimeofday (&timeout, NULL);
222 timeout.tv_sec += timeLimit;
227 while (timeLimit && ! FD_ISSET (sd, &rfds)) {
229 gettimeofday (&now, NULL);
230 delta.tv_sec = timeout.tv_sec - now.tv_sec;
231 delta.tv_usec = timeout.tv_usec - now.tv_usec;
232 if (delta.tv_usec < 0) {
233 delta.tv_usec += 1000000;
236 if (delta.tv_sec < 0) {
242 r = select (sd + 1, &rfds, NULL, NULL, &delta);
250 rlen = read2 (buf, nbytes);
259 buf = (char*)buf + rlen;
265 ssize_t TcpClient::write (const void* buf, size_t nbytes) {
268 iov.iov_base = __DECONST(char *, buf);
269 iov.iov_len = nbytes;
270 return write (&iov, 1);
273 void TcpClient::flush_write () {
277 setsockopt(sd, IPPROTO_TCP, TCP_NOPUSH, &val, sizeof(val));
279 setsockopt(sd, IPPROTO_TCP, TCP_NODELAY, &val, sizeof(val));
282 ssize_t TcpClient::write (struct iovec* iov, int iovcnt) {
284 struct timeval timeout;
285 struct timeval delta;
292 gettimeofday (&timeout, NULL);
293 timeout.tv_sec += timeLimit;
298 while (timeLimit && ! FD_ISSET (sd, &writefds)) {
299 FD_SET (sd, &writefds);
300 gettimeofday (&now, NULL);
301 delta.tv_sec = timeout.tv_sec - now.tv_sec;
302 delta.tv_usec = timeout.tv_usec - now.tv_usec;
303 if (delta.tv_usec < 0) {
304 delta.tv_usec += 1000000;
307 if (delta.tv_sec < 0) {
313 r = select (sd + 1, NULL, &writefds, NULL, &delta);
321 wlen = write2 (iov, iovcnt);
323 /* we consider a short write a failure */
334 while (iovcnt > 0 && wlen >= (ssize_t)iov->iov_len) {
335 wlen -= iov->iov_len;
340 iov->iov_len -= wlen;
341 iov->iov_base = __DECONST(char *, iov->iov_base) + wlen;
347 ssize_t TcpClient::write2 (struct iovec* iov, int iovcnt) {
348 return writev (sd, iov, iovcnt);
351 ssize_t TcpClient::read2 (void* buf, size_t nbytes) {
352 return ::read (sd, buf, nbytes);
355 //============================================================
356 bool SslClient::connect (const HostSpec* conhost, const HostSpec* _ephost) {
360 rc = TcpClient::connect (conhost);
364 bool SslClient::connect2 () {
368 bool SslClient::sslOpen () {
369 if (!SSL_library_init ()){
371 throw (ustring (CharConst ("SSL library init failed\n")));
375 SSL_load_error_strings ();
377 ssl_meth = SSLv23_client_method ();
378 ssl_ctx.reset (SSL_CTX_new (ssl_meth));
379 SSL_CTX_set_mode (ssl_ctx.get (), SSL_MODE_AUTO_RETRY);
380 SSL_CTX_set_options (ssl_ctx.get (), SSL_OP_ALL | SSL_OP_NO_TICKET | SSL_OP_NO_SSLv2);
381 if (! setupPeerVerification (kCERTFILE)) {
383 throw (ustring (CharConst ("SSL certificate error\n")));
387 ssl.reset (SSL_new (ssl_ctx.get ()));
388 if (ssl.get () == NULL){
390 throw (ustring (CharConst ("SSL context creation failed\n")));
393 SSL_set_fd (ssl.get (), sd);
394 // if (SSL_connect (ssl.get ()) == -1){
395 #if OPENSSL_VERSION_NUMBER >= 0x0090806fL && !defined(OPENSSL_NO_TLSEXT)
396 if (! SSL_set_tlsext_host_name (ssl.get (), ephost->host.c_str ())) {
397 fprintf (stderr, "TLS server name indication extension failed for host %s\n", ephost->host.c_str ());
404 while ((rc = SSL_connect (ssl.get ())) == -1) {
405 int ssl_err = SSL_get_error (ssl.get (), rc);
406 if (ssl_err != SSL_ERROR_WANT_READ && ssl_err != SSL_ERROR_WANT_WRITE) {
407 ERR_print_errors_fp (stderr);
412 if (fnoverify || verifyCA ()) {
421 void SslClient::loadCAFile (const char* certfile, int depth) {
422 if (!SSL_CTX_load_verify_locations (ssl_ctx.get (), certfile, NULL))
424 SSL_CTX_set_verify_depth (ssl_ctx.get (), depth);
428 * Callback for SSL certificate verification, this is called on server
429 * cert verification. It takes no decision, but informs the user in case
430 * verification failed.
432 static int fetch_ssl_cb_verify_crt (int verified, X509_STORE_CTX *ctx) {
439 if ((crt = X509_STORE_CTX_get_current_cert (ctx)) != NULL &&
440 (name = X509_get_subject_name (crt)) != NULL)
441 str = X509_NAME_oneline (name, 0, 0);
442 fprintf (stderr, "Certificate verification failed for %s\n",
443 str != NULL ? str : "no relevant certificate");
449 bool SslClient::setupPeerVerification (const char* certfile, int depth) {
451 SSL_CTX_set_verify (ssl_ctx.get (), SSL_VERIFY_PEER, fetch_ssl_cb_verify_crt);
452 SSL_CTX_load_verify_locations (ssl_ctx.get (), certfile, NULL);
457 bool SslClient::verifyCA () {
463 if ((rc = SSL_get_verify_result (ssl.get ())) != X509_V_OK) {
465 case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY:
466 throw (ustring (CharConst ("unable to get issuer cert locally.")));
468 std::cerr << rc << ": X509 failed\n";
472 cert.reset (SSL_get_peer_certificate (ssl.get ()));
476 lhost.resize (ephost->host.length ());
477 std::transform (ephost->host.begin (), ephost->host.end (), lhost.begin (), ::tolower);
478 ans = verifyCN (lhost, cert.get ()) || verifyAltName (lhost, cert.get ());
480 std::cerr << lhost << ": hostname not match\n";
485 bool SslClient::verifyCN (const ustring& lhost, X509* cert) {
486 char buf[SSL_NAME_LEN];
491 len = X509_NAME_get_text_by_NID (X509_get_subject_name (cert), NID_commonName, buf, sizeof (buf));
493 if (len > 0 && lhost.length () > 0) {
494 if (len >= sizeof (buf)) {
495 std::cerr << "Too long common name of the server certificate\n";
498 cn.assign (buf, len);
499 std::transform (cn.begin (), cn.end (), cn.begin (), ::tolower);
500 return globMatch (lhost, cn);
505 bool SslClient::verifyAltName (const ustring& lhost, X509* cert) {
507 X509_EXTENSION* ext = NULL;
508 GENERAL_NAMES* names;
513 pos = X509_get_ext_by_NID (cert, NID_subject_alt_name, -1);
515 ext = X509_get_ext (cert, pos);
517 names = (GENERAL_NAMES*)X509V3_EXT_d2i (ext);
518 n = sk_GENERAL_NAME_num (names);
519 for (i = 0; i < n; ++ i) {
520 name = sk_GENERAL_NAME_value (names, i);
521 if (name->type == GEN_DNS) {
522 ASN1_STRING_to_UTF8 (&dns, name->d.dNSName);
523 // std::cerr << "dns:" << dns << "\n";
524 ustring name (char_type (dns));
525 std::transform (name.begin (), name.end (), name.begin (), ::tolower);
526 if (globMatch (lhost, ustring (char_type (dns)))) {
539 bool SslClient::globMatch (const ustring& lhost, const ustring& name) {
540 if (matchHead (name, CharConst ("*."))) {
541 // XXX RFC2459(= HTTP over TLS)では、*は一つのサブドメインのみにマッチする。
542 ustring::size_type p = lhost.find ('.');
543 if (p != ustring::npos) {
544 return ustring (lhost.begin () + p, lhost.end ()) == ustring (name.begin () + 1, name.end ());
549 return lhost == name;
553 ssize_t SslClient::write2 (struct iovec* iov, int iovcnt) {
555 return SSL_write(ssl.get (), iov->iov_base, iov->iov_len);
557 return TcpClient::write2 (iov, iovcnt);
561 ssize_t SslClient::read2 (void* buf, size_t nbytes) {
563 return SSL_read (ssl.get (), buf, nbytes);
565 return TcpClient::read2 (buf, nbytes);
569 //============================================================
570 bool ProxySslClient::connect2 () {
573 if (checkHostname (ephost->host)) {
574 msg.assign (CharConst ("CONNECT "));
575 msg.append (ephost->host).append (uColon).append (to_ustring (ephost->port)).append (CharConst (" HTTP/1.0" kCRLF));
576 msg.append (CharConst ("Host: ")).append (ephost->host).append (uCRLF);
578 if (proxyid.length () > 0) {
580 idpw.assign (proxyid).append (uColon).append (proxypw);
581 msg.append (CharConst ("Proxy-Authorization: Basic ")).append (base64Encode (idpw.begin (), idpw.end ())).append (uCRLF);
584 write (&*msg.begin (), msg.length ()); // CONNECT HOST...
586 int rc = readReplyHead ();
588 return SslClient::connect2 ();
595 int ProxySslClient::readReplyHead () {
598 int responseCode = 0;
600 static uregex re_crlf ("\\r\\n");
602 buf.tail += read3 (&*buf.start, buf.buf.length ());
603 if (buf.getln2 (*this, line)) {
608 if (! match (line.substr (0, 5), CharConst ("HTTP/"))) {
609 return 0; // bad protocol
611 ustring::size_type p1 = line.find (' ', 0);
612 ustring::size_type p2 = line.find (' ', p1 + 1);
613 responseCode = strtoul (line.substr (p1 + 1, p2));
614 while (buf.getln2 (*this, line)) {
623 ssize_t ProxySslClient::read3 (void* buf, size_t nbytes) {
624 struct timeval delta;
633 delta.tv_sec = timeLimit;
636 r = select (sd + 1, &rfds, NULL, NULL, &delta);
640 rlen = read2 (buf, nbytes);
644 buf = (char*)buf + rlen;
652 r = select (sd + 1, &rfds, NULL, NULL, &delta);
656 rlen = read2 (buf, nbytes);
660 buf = (char*)buf + rlen;