1 // Copyright (c) 2015 Kazuhiro Fujieda <fujieda@users.osdn.me>
\r
3 // Licensed under the Apache License, Version 2.0 (the "License");
\r
4 // you may not use this file except in compliance with the License.
\r
5 // You may obtain a copy of the License at
\r
7 // http://www.apache.org/licenses/LICENSE-2.0
\r
9 // Unless required by applicable law or agreed to in writing, software
\r
10 // distributed under the License is distributed on an "AS IS" BASIS,
\r
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
\r
12 // See the License for the specific language governing permissions and
\r
13 // limitations under the License.
\r
16 using System.Collections;
\r
17 using System.Globalization;
\r
19 using System.IO.Compression;
\r
21 using System.Net.Sockets;
\r
23 using System.Text.RegularExpressions;
\r
24 using System.Threading.Tasks;
\r
26 namespace KancolleSniffer.Net
\r
28 public class HttpProxy
\r
30 private static HttpProxy _httpProxy;
\r
31 public static int LocalPort { get; set; }
\r
32 public static string UpstreamProxyHost { get; set; }
\r
33 public static int UpstreamProxyPort { get; set; }
\r
34 public static bool IsEnableUpstreamProxy { get; set; }
\r
35 public static bool IsInListening { get; private set; }
\r
36 public static event Action<Session> AfterSessionComplete;
\r
38 private TcpListener _listener;
\r
40 private static readonly object SyncObj = new object();
\r
42 public static void Startup(int port, bool dummy0, bool dummy1)
\r
45 _httpProxy = new HttpProxy();
\r
51 _listener = new TcpListener(IPAddress.Loopback, LocalPort);
\r
53 LocalPort = ((IPEndPoint)_listener.LocalEndpoint).Port;
\r
54 IsInListening = true;
\r
55 Task.Run(AcceptClient);
\r
58 public static void Shutdown()
\r
65 IsInListening = false;
\r
66 _listener.Server.Close();
\r
70 public void AcceptClient()
\r
76 var client = _listener.AcceptSocket();
\r
77 Task.Run(() => new HttpClient(client).ProcessRequest());
\r
80 catch (SocketException)
\r
89 private class HttpClient
\r
91 private readonly Socket _client;
\r
92 private Socket _server;
\r
93 private string _host;
\r
95 private Session _session;
\r
96 private HttpStream _clientStream;
\r
97 private HttpStream _serverStream;
\r
99 public HttpClient(Socket client)
\r
104 public void ProcessRequest()
\r
110 _clientStream = new HttpStream(_client);
\r
111 _session = new Session();
\r
112 if (CheckServerTimeOut())
\r
115 if (_session.Request.Method == null)
\r
117 if (_session.Request.Method == "CONNECT")
\r
122 if (_session.Request.Host.StartsWith("localhost") ||
\r
123 _session.Request.Host.StartsWith("127.0.0.1"))
\r
125 LogServer.Process(_client, _session.Request.RequestLine);
\r
129 ReceiveRequestBody();
\r
132 if (_session.Response.StatusCode == null)
\r
134 AfterSessionComplete?.Invoke(_session);
\r
136 } while (_client.Connected && _server.Connected &&
\r
137 _session.Request.IsKeepAlive && _session.Response.IsKeepAlive);
\r
140 catch (Exception e)
\r
143 File.AppendAllText("debug.log", $"[{DateTime.Now:g}] " + e + "\r\n");
\r
145 #else // ReSharper disable once EmptyGeneralCatchClause
\r
156 private bool CheckServerTimeOut()
\r
158 if (_server == null)
\r
160 var readList = new ArrayList {_client, _server};
\r
161 // ReSharper disable once AssignNullToNotNullAttribute
\r
162 Socket.Select(readList, null, null, -1);
\r
163 return readList.Count == 1 && readList[0] == _server && _server.Available == 0;
\r
166 private void ReceiveRequest()
\r
168 var requestLine = _clientStream.ReadLine();
\r
169 if (requestLine == "")
\r
171 _session.Request.RequestLine = requestLine;
\r
172 _session.Request.Headers = _clientStream.ReadHeaders();
\r
175 private void ReceiveRequestBody()
\r
177 if (_session.Request.ContentLength != -1 || _session.Request.TransferEncoding != null)
\r
178 _session.Request.ReadBody(_clientStream);
\r
181 private void SendRequest()
\r
183 GetHostAndPort(out var host, out var port);
\r
184 if (_server == null || host != _host || port != _port || IsSocketDead(_server))
\r
186 SocketClose(_server);
\r
187 _server = ConnectServer(host, port);
\r
192 new HttpStream(_server).WriteLines(_session.Request.RequestLine + _session.Request.ModifiedHeaders);
\r
195 private Socket ConnectServer(string host, int port)
\r
197 var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
\r
198 socket.Connect(host, port);
\r
202 private void GetHostAndPort(out string host, out int port)
\r
204 if (IsEnableUpstreamProxy)
\r
206 host = UpstreamProxyHost;
\r
207 port = UpstreamProxyPort;
\r
211 MakeRequestUrlRelative(out host, out port);
\r
212 if (host == null && !ParseAuthority(_session.Request.Host, ref host, ref port))
\r
213 throw new HttpProxyAbort("Can't find destination host");
\r
217 private static readonly Regex HostAndPortRegex =
\r
218 new Regex("http://([^:/]+)(?::(\\d+))?/", RegexOptions.Compiled);
\r
220 private void MakeRequestUrlRelative(out string host, out int port)
\r
224 var m = HostAndPortRegex.Match(_session.Request.RequestLine);
\r
227 host = m.Groups[1].Value;
\r
228 if (m.Groups[2].Success)
\r
229 port = int.Parse(m.Groups[2].Value);
\r
230 _session.Request.RequestLine = _session.Request.RequestLine.Remove(m.Index, m.Length - 1);
\r
233 bool IsSocketDead(Socket s) => (s.Poll(1000, SelectMode.SelectRead) && s.Available == 0) || !s.Connected;
\r
235 private void SendRequestBody()
\r
237 _serverStream.Write(_session.Request.Body);
\r
240 private void ReceiveResponse()
\r
242 var statusLine = _serverStream.ReadLine();
\r
243 if (statusLine == "")
\r
245 _session.Response.StatusLine = statusLine;
\r
246 _session.Response.Headers = _serverStream.ReadHeaders();
\r
248 _session.Response.ReadBody(_serverStream);
\r
251 private bool HasBody
\r
255 var code = _session.Response.StatusCode;
\r
256 return (!(_session.Request.Method == "HEAD" ||
\r
257 code.StartsWith("1") || code == "204" || code == "304"));
\r
261 private void SendResponse()
\r
263 _clientStream.WriteLines(_session.Response.StatusLine + _session.Response.ModifiedHeaders)
\r
264 .Write(_session.Response.Body);
\r
267 private void HandleConnect()
\r
271 if (!ParseAuthority(_session.Request.PathAndQuery, ref host, ref port))
\r
273 _server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
\r
274 _server.Connect(host, port);
\r
275 _clientStream.WriteLines("HTTP/1.0 200 Connection established\r\n\r\n");
\r
278 Task.Run(() => { TunnelSockets(_client, _server); }),
\r
279 Task.Run(() => { TunnelSockets(_server, _client); })
\r
281 Task.WaitAll(tasks);
\r
284 private void TunnelSockets(Socket from, Socket to)
\r
288 var buf = new byte[8192];
\r
291 var n = from.Receive(buf);
\r
294 var sent = to.Send(buf, n, SocketFlags.None);
\r
298 to.Shutdown(SocketShutdown.Send);
\r
300 catch (SocketException)
\r
305 private static readonly Regex AuthorityRegex = new Regex("([^:]+)(?::(\\d+))?");
\r
307 private bool ParseAuthority(string authority, ref string host, ref int port)
\r
309 if (string.IsNullOrEmpty(authority))
\r
311 var m = AuthorityRegex.Match(authority);
\r
314 host = m.Groups[1].Value;
\r
315 if (m.Groups[2].Success)
\r
316 port = int.Parse(m.Groups[2].Value);
\r
320 private void Close()
\r
322 SocketClose(_server);
\r
323 SocketClose(_client);
\r
326 private void SocketClose(Socket socket)
\r
328 if (socket == null)
\r
332 socket.Shutdown(SocketShutdown.Both);
\r
334 // ReSharper disable EmptyGeneralCatchClause
\r
344 // ReSharper restore EmptyGeneralCatchClause
\r
350 public class Session
\r
352 public Request Request { get; set; } = new Request();
\r
353 public Response Response { get; set; } = new Response();
\r
356 public class Message
\r
358 private string _headers;
\r
359 public byte[] Body { get; set; }
\r
361 private static readonly Regex CharsetRegex = new Regex("charset=([\\w-]+)",
\r
362 RegexOptions.Compiled | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);
\r
364 public int ContentLength { get; set; } = -1;
\r
365 public string TransferEncoding { get; set; }
\r
366 public string ContentType { get; set; }
\r
367 public string ContentEncoding { get; set; }
\r
368 public string Host { get; set; }
\r
369 public bool IsKeepAlive;
\r
371 public string Headers
\r
377 SetHeaders(_headers);
\r
381 public virtual string ModifiedHeaders => RemoveHeaders(Headers, new[] {"proxy-connection"});
\r
383 protected string RemoveHeaders(string headers, string[] fields)
\r
385 foreach (var f in fields)
\r
387 var m = MatchField(f, headers);
\r
390 headers = headers.Remove(m.Index, m.Length);
\r
395 protected string InsertHeader(string headers, string header)
\r
397 return headers.Insert(headers.Length - 2, header);
\r
400 protected virtual void SetHeaders(string headers)
\r
402 var s = GetField("content-length");
\r
405 ContentLength = int.TryParse(s, out var len) ? len : -1;
\r
407 TransferEncoding = GetField("transfer-encoding")?.ToLower(CultureInfo.InvariantCulture);
\r
408 ContentType = GetField("content-type");
\r
409 ContentEncoding = GetField("content-encoding");
\r
410 Host = GetField("host");
\r
411 IsKeepAlive = GetField("connection")?.ToLower(CultureInfo.InvariantCulture) != "close";
\r
414 protected Match MatchField(string name, string headers)
\r
416 var regex = new Regex("^" + name + ":\\s*([^\r]+)\r\n",
\r
417 RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.Multiline);
\r
418 return regex.Match(headers);
\r
421 protected string GetField(string name)
\r
423 var m = MatchField(name, Headers);
\r
424 return m.Success ? m.Groups[1].Value : null;
\r
427 public string BodyAsString
\r
433 var m = CharsetRegex.Match(ContentType ?? "");
\r
434 var encoding = Encoding.ASCII;
\r
437 var name = m.Groups[1].Value;
\r
438 if (name == "utf8")
\r
440 encoding = Encoding.GetEncoding(name);
\r
442 return encoding.GetString(Body);
\r
446 public void ReadBody(HttpStream stream)
\r
448 if (TransferEncoding != null && TransferEncoding.Contains("chunked"))
\r
450 Body = stream.ReadChunked();
\r
452 else if (ContentLength == 0)
\r
455 else if (ContentLength > 0)
\r
457 var buf = new byte[ContentLength];
\r
458 stream.Read(buf, 0, ContentLength);
\r
463 Body = stream.ReadToEnd();
\r
465 if (ContentEncoding == null)
\r
467 var dc = new MemoryStream();
\r
470 if (ContentEncoding == "gzip")
\r
471 new GZipStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);
\r
472 else if (ContentEncoding == "deflate")
\r
473 new DeflateStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);
\r
475 catch (Exception ex)
\r
477 throw new HttpProxyAbort($"Fail to decode {ContentEncoding}: " + ex.Message);
\r
479 Body = dc.ToArray();
\r
483 public class Request : Message
\r
485 private string _requestLine;
\r
487 public string RequestLine
\r
489 get => _requestLine;
\r
492 _requestLine = value;
\r
493 var f = _requestLine.Split(' ');
\r
495 throw new HttpProxyAbort("Invalid request line");
\r
497 PathAndQuery = f.Length < 2 ? "" : f[1];
\r
501 public string Method { get; private set; }
\r
502 public string PathAndQuery { get; private set; }
\r
505 public class Response : Message
\r
507 private string _statusLine;
\r
509 public override string ModifiedHeaders =>
\r
510 InsertContentLength(RemoveHeaders(base.ModifiedHeaders,
\r
511 new[] {"transfer-encoding", "content-encoding", "content-length"}));
\r
513 private string InsertContentLength(string headers)
\r
515 return Body == null ? headers : InsertHeader(headers, $"Content-Length: {Body.Length}\r\n");
\r
518 public string StatusLine
\r
520 get => _statusLine;
\r
523 _statusLine = value;
\r
524 var f = _statusLine.Split(' ');
\r
526 throw new HttpProxyAbort("Invalid status line");
\r
527 StatusCode = _statusLine.Split(' ')[1];
\r
531 public string StatusCode { get; private set; }
\r
534 private class HttpProxyAbort : Exception
\r
536 public HttpProxyAbort(string message) : base(message)
\r
541 public class HttpStream
\r
543 private readonly Socket _socket;
\r
544 private readonly byte[] _buffer = new byte[4096];
\r
545 private int _available;
\r
546 private int _position;
\r
548 public HttpStream(Socket socket)
\r
551 socket.NoDelay = true;
\r
554 public string ReadLine()
\r
556 var sb = new StringBuilder();
\r
558 while ((ch = ReadByte()) != -1)
\r
560 sb.Append((char)ch);
\r
564 return sb.ToString();
\r
567 private int ReadByte()
\r
569 if (_position < _available)
\r
570 return _buffer[_position++];
\r
571 _available = _socket.Receive(_buffer, 0, _buffer.Length, SocketFlags.None);
\r
573 return _available == 0 ? -1 : _buffer[_position++];
\r
576 public HttpStream WriteLines(string s)
\r
578 var buf = Encoding.ASCII.GetBytes(s);
\r
579 Write(buf, 0, buf.Length);
\r
583 public string ReadHeaders()
\r
585 var sb = new StringBuilder();
\r
591 } while (line != "\r\n");
\r
592 return sb.ToString();
\r
595 public byte[] ReadChunked()
\r
597 var buf = new MemoryStream();
\r
600 var size = ReadLine();
\r
601 if (size.Length < 3)
\r
603 var ext = size.IndexOf(';');
\r
604 size = ext == -1 ? size.Substring(0, size.Length - 2) : size.Substring(0, ext);
\r
605 if (!int.TryParse(size, NumberStyles.HexNumber, CultureInfo.InvariantCulture, out var val))
\r
606 throw new HttpProxyAbort("Can't parse chunk size: " + size);
\r
609 var chunk = new byte[val];
\r
610 Read(chunk, 0, chunk.Length);
\r
611 buf.Write(chunk, 0, chunk.Length);
\r
618 } while (line != "" && line != "\r\n");
\r
619 return buf.ToArray();
\r
622 public byte[] ReadToEnd()
\r
624 var result = new MemoryStream();
\r
625 var buf = new byte[4096];
\r
627 while ((len = Read(buf, 0, buf.Length)) > 0)
\r
628 result.Write(buf, 0, len);
\r
629 return result.ToArray();
\r
632 public HttpStream Write(byte[] body)
\r
635 Write(body, 0, body.Length);
\r
639 public int Read(byte[] buf, int offset, int count)
\r
645 if (_position < _available)
\r
647 n = Math.Min(count, _available - _position);
\r
648 Buffer.BlockCopy(_buffer, _position, buf, 0, n);
\r
653 n = _socket.Receive(buf, offset, count, SocketFlags.None);
\r
655 return total == 0 ? n : total;
\r
660 } while (count > 0);
\r
664 public void Write(byte[] buf, int offset, int count)
\r
668 var n = _socket.Send(buf, offset, count, SocketFlags.None);
\r
673 } while (count > 0);
\r