OSDN Git Service

a8f12acb438a5092ef222a590eb8970afa6f9f4c
[kancollesniffer/KancolleSniffer.git] / KancolleSniffer / HttpProxy.cs
1 // Copyright (c) 2015 Kazuhiro Fujieda <fujieda@users.osdn.me>\r
2 //\r
3 // Licensed under the Apache License, Version 2.0 (the "License");\r
4 // you may not use this file except in compliance with the License.\r
5 // You may obtain a copy of the License at\r
6 //\r
7 //    http://www.apache.org/licenses/LICENSE-2.0\r
8 //\r
9 // Unless required by applicable law or agreed to in writing, software\r
10 // distributed under the License is distributed on an "AS IS" BASIS,\r
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
12 // See the License for the specific language governing permissions and\r
13 // limitations under the License.\r
14 \r
15 using System;\r
16 using System.Globalization;\r
17 using System.IO;\r
18 using System.IO.Compression;\r
19 using System.Net;\r
20 using System.Net.Sockets;\r
21 using System.Text;\r
22 using System.Text.RegularExpressions;\r
23 using System.Threading.Tasks;\r
24 \r
25 namespace KancolleSniffer\r
26 {\r
27     public class HttpProxy\r
28     {\r
29         private static HttpProxy _httpProxy;\r
30         public static int LocalPort { get; set; }\r
31         public static string UpstreamProxyHost { get; set; }\r
32         public static int UpstreamProxyPort { get; set; }\r
33         public static bool IsEnableUpstreamProxy { get; set; }\r
34         public static bool IsInListening { get; private set; }\r
35         public static event Action<Session> AfterSessionComplete;\r
36 \r
37         private TcpListener _listener;\r
38 \r
39         public static void Startup(int port, bool dummy0, bool dummy1)\r
40         {\r
41             LocalPort = port;\r
42             _httpProxy = new HttpProxy();\r
43             _httpProxy.Start();\r
44         }\r
45 \r
46         public void Start()\r
47         {\r
48             _listener = new TcpListener(IPAddress.Loopback, LocalPort);\r
49             _listener.Start();\r
50             LocalPort = ((IPEndPoint)_listener.LocalEndpoint).Port;\r
51             IsInListening = true;\r
52             Task.Run(() => AcceptClient());\r
53         }\r
54 \r
55         public static void Shutdown()\r
56         {\r
57             _httpProxy?.Stop();\r
58         }\r
59 \r
60         public void Stop()\r
61         {\r
62             IsInListening = false;\r
63             _listener.Server.Close();\r
64             _listener.Stop();\r
65         }\r
66 \r
67         public void AcceptClient()\r
68         {\r
69             try\r
70             {\r
71                 while (true)\r
72                 {\r
73                     var client = _listener.AcceptSocket();\r
74                     Task.Run(() => new HttpClient(client).ProcessRequest());\r
75                 }\r
76             }\r
77             catch (SocketException)\r
78             {\r
79             }\r
80             finally\r
81             {\r
82                 Stop();\r
83             }\r
84         }\r
85 \r
86         private class HttpClient\r
87         {\r
88             private readonly Socket _client;\r
89             private Socket _server;\r
90             private readonly Session _session;\r
91             private readonly HttpStream _clientStream;\r
92             private HttpStream _serverStream;\r
93 \r
94             public HttpClient(Socket client)\r
95             {\r
96                 _client = client;\r
97                 _clientStream = new HttpStream(client);\r
98                 _session = new Session();\r
99             }\r
100 \r
101             public void ProcessRequest()\r
102             {\r
103                 try\r
104                 {\r
105                     ReceiveRequest();\r
106                     if (_session.Request.Method == null)\r
107                         return;\r
108                     if (_session.Request.Method == "CONNECT")\r
109                     {\r
110                         HandleConnect();\r
111                         return;\r
112                     }\r
113                     if (_session.Request.Host.StartsWith("localhost") || _session.Request.Host.StartsWith("127.0.0.1"))\r
114                     {\r
115                         LogServer.Process(_client, _session.Request.RequestLine);\r
116                         return;\r
117                     }\r
118                     SendRequest();\r
119                     ReceiveRequestBody();\r
120                     SendRequestBody();\r
121                     ReceiveResponse();\r
122                     if (_session.Response.StatusCode == null)\r
123                         return;\r
124                     SendResponse();\r
125                     Close();\r
126                     AfterSessionComplete?.Invoke(_session);\r
127                 }\r
128 #if DEBUG\r
129                 catch (Exception e)\r
130                 {\r
131                     File.AppendAllText("debug.log", $"[{DateTime.Now:g}] " + e + "\r\n");\r
132                 }\r
133 #else // ReSharper disable once EmptyGeneralCatchClause\r
134                 catch\r
135                 {\r
136                 }\r
137 #endif\r
138                 finally\r
139                 {\r
140                     Close();\r
141                 }\r
142             }\r
143 \r
144             private void ReceiveRequest()\r
145             {\r
146                 var requestLine = _clientStream.ReadLine();\r
147                 if (requestLine == "")\r
148                     return;\r
149                 _session.Request.RequestLine = requestLine;\r
150                 _session.Request.Headers = _clientStream.ReadHeaders();\r
151             }\r
152 \r
153             private void ReceiveRequestBody()\r
154             {\r
155                 if (_session.Request.ContentLength != -1 || _session.Request.TransferEncoding != null)\r
156                     _session.Request.ReadBody(_clientStream);\r
157             }\r
158 \r
159             private void SendRequest()\r
160             {\r
161                 _server = ConnectServer();\r
162                 _serverStream = new HttpStream(_server).\r
163                     WriteLines(_session.Request.RequestLine + _session.Request.ModifiedHeaders);\r
164             }\r
165 \r
166             private void SendRequestBody()\r
167             {\r
168                 _serverStream.Write(_session.Request.Body);\r
169             }\r
170 \r
171             private void ReceiveResponse()\r
172             {\r
173                 var statusLine = _serverStream.ReadLine();\r
174                 if (statusLine == "")\r
175                     return;\r
176                 _session.Response.StatusLine = statusLine;\r
177                 _session.Response.Headers = _serverStream.ReadHeaders();\r
178                 if (HasBody)\r
179                     _session.Response.ReadBody(_serverStream);\r
180             }\r
181 \r
182             private bool HasBody\r
183             {\r
184                 get\r
185                 {\r
186                     var code = _session.Response.StatusCode;\r
187                     return (!(_session.Request.Method == "HEAD" ||\r
188                               code.StartsWith("1") || code == "204" || code == "304"));\r
189                 }\r
190             }\r
191 \r
192             private void SendResponse()\r
193             {\r
194                 _clientStream.WriteLines(_session.Response.StatusLine + _session.Response.ModifiedHeaders)\r
195                     .Write(_session.Response.Body);\r
196             }\r
197 \r
198             private void HandleConnect()\r
199             {\r
200                 var host = "";\r
201                 var port = 443;\r
202                 if (!ParseAuthority(_session.Request.PathAndQuery, ref host, ref port))\r
203                     return;\r
204                 _server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);\r
205                 _server.Connect(host, port);\r
206                 _clientStream.WriteLines("HTTP/1.0 200 Connection established\r\n\r\n");\r
207                 Task[] tasks =\r
208                 {\r
209                     Task.Run(() => { TunnnelSockets(_client, _server); }),\r
210                     Task.Run(() => { TunnnelSockets(_server, _client); })\r
211                 };\r
212                 Task.WaitAll(tasks);\r
213             }\r
214 \r
215             private void TunnnelSockets(Socket from, Socket to)\r
216             {\r
217                 try\r
218                 {\r
219                     var buf = new byte[8192];\r
220                     while (true)\r
221                     {\r
222                         var n = from.Receive(buf);\r
223                         if (n == 0)\r
224                             break;\r
225                         var sent = to.Send(buf, n, SocketFlags.None);\r
226                         if (sent < n)\r
227                             break;\r
228                     }\r
229                     to.Shutdown(SocketShutdown.Send);\r
230                 }\r
231                 catch (SocketException)\r
232                 {\r
233                 }\r
234             }\r
235 \r
236             private static readonly Regex HostAndPortRegex =\r
237                 new Regex("http://([^:/]+)(?::(\\d+))?/", RegexOptions.Compiled);\r
238 \r
239             private Socket ConnectServer()\r
240             {\r
241                 string host = null;\r
242                 var port = 80;\r
243                 if (IsEnableUpstreamProxy)\r
244                 {\r
245                     host = UpstreamProxyHost;\r
246                     port = UpstreamProxyPort;\r
247                     goto connect;\r
248                 }\r
249                 var m = HostAndPortRegex.Match(_session.Request.RequestLine);\r
250                 if (m.Success)\r
251                 {\r
252                     host = m.Groups[1].Value;\r
253                     if (m.Groups[2].Success)\r
254                         port = int.Parse(m.Groups[2].Value);\r
255                     _session.Request.RequestLine = _session.Request.RequestLine.Remove(m.Index, m.Length - 1);\r
256                 }\r
257                 if (host == null && !ParseAuthority(_session.Request.Host, ref host, ref port))\r
258                     throw new HttpProxyAbort("Can't find destination host");\r
259                 connect:\r
260                 var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);\r
261                 socket.Connect(host, port);\r
262                 return socket;\r
263             }\r
264 \r
265             private static readonly Regex AuthorityRegex = new Regex("([^:]+)(?::(\\d+))?");\r
266 \r
267             private bool ParseAuthority(string authority, ref string host, ref int port)\r
268             {\r
269                 if (string.IsNullOrEmpty(authority))\r
270                     return false;\r
271                 var m = AuthorityRegex.Match(authority);\r
272                 if (!m.Success)\r
273                     return false;\r
274                 host = m.Groups[1].Value;\r
275                 if (m.Groups[2].Success)\r
276                     port = int.Parse(m.Groups[2].Value);\r
277                 return true;\r
278             }\r
279 \r
280             private void Close()\r
281             {\r
282                 _serverStream?.Close();\r
283                 _clientStream?.Close();\r
284                 _server?.Close();\r
285                 _client.Close();\r
286             }\r
287         }\r
288 \r
289         public class Session\r
290         {\r
291             public Request Request { get; set; } = new Request();\r
292             public Response Response { get; set; } = new Response();\r
293         }\r
294 \r
295         public class Message\r
296         {\r
297             private string _headers;\r
298             public byte[] Body { get; set; }\r
299 \r
300             private static readonly Regex CharsetRegx = new Regex("charset=([\\w-]+)",\r
301                 RegexOptions.Compiled | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);\r
302 \r
303             public int ContentLength { get; set; } = -1;\r
304             public string TransferEncoding { get; set; }\r
305             public string ContentType { get; set; }\r
306             public string ContentEncoding { get; set; }\r
307             public string Host { get; set; }\r
308 \r
309             public string Headers\r
310             {\r
311                 get => _headers;\r
312                 set\r
313                 {\r
314                     _headers = value;\r
315                     SetHeaders(_headers);\r
316                 }\r
317             }\r
318 \r
319             public virtual string ModifiedHeaders => SetConnectionClose(Headers);\r
320 \r
321             private string SetConnectionClose(string headers)\r
322             {\r
323                 return InsertHeader(RemoveHeaders(headers,\r
324                     new[] {"connection", "keep-alive", "proxy-connection"}), "Connection: close\r\n");\r
325             }\r
326 \r
327             protected string RemoveHeaders(string headers, string[] fields)\r
328             {\r
329                 foreach (var f in fields)\r
330                 {\r
331                     var m = MatchField(f, headers);\r
332                     if (!m.Success)\r
333                         continue;\r
334                     headers = headers.Remove(m.Index, m.Length);\r
335                 }\r
336                 return headers;\r
337             }\r
338 \r
339             protected string InsertHeader(string headers, string header)\r
340             {\r
341                 return headers.Insert(headers.Length - 2, header);\r
342             }\r
343 \r
344             protected virtual void SetHeaders(string headers)\r
345             {\r
346                 var s = GetField("content-length");\r
347                 if (s != null)\r
348                 {\r
349                     ContentLength = int.TryParse(s, out var len) ? len : -1;\r
350                 }\r
351                 TransferEncoding = GetField("transfer-encoding")?.ToLower(CultureInfo.InvariantCulture);\r
352                 ContentType = GetField("content-type");\r
353                 ContentEncoding = GetField("content-encoding");\r
354                 Host = GetField("host");\r
355             }\r
356 \r
357             protected Match MatchField(string name, string headers)\r
358             {\r
359                 var regex = new Regex("^" + name + ":\\s*([^\r]+)\r\n",\r
360                     RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.Multiline);\r
361                 return regex.Match(headers);\r
362             }\r
363 \r
364             protected string GetField(string name)\r
365             {\r
366                 var m = MatchField(name, Headers);\r
367                 return m.Success ? m.Groups[1].Value : null;\r
368             }\r
369 \r
370             public string BodyAsString\r
371             {\r
372                 get\r
373                 {\r
374                     if (Body == null)\r
375                         return "";\r
376                     var m = CharsetRegx.Match(ContentType ?? "");\r
377                     var encoding = Encoding.ASCII;\r
378                     if (m.Success)\r
379                     {\r
380                         var name = m.Groups[1].Value;\r
381                         if (name == "utf8")\r
382                             name = "UTF-8";\r
383                         encoding = Encoding.GetEncoding(name);\r
384                     }\r
385                     return encoding.GetString(Body);\r
386                 }\r
387             }\r
388 \r
389             public void ReadBody(HttpStream stream)\r
390             {\r
391                 if (TransferEncoding != null && TransferEncoding.Contains("chunked"))\r
392                 {\r
393                     Body = stream.ReadChunked();\r
394                 }\r
395                 else if (ContentLength == 0)\r
396                 {\r
397                 }\r
398                 else if (ContentLength > 0)\r
399                 {\r
400                     var buf = new byte[ContentLength];\r
401                     stream.Read(buf, 0, ContentLength);\r
402                     Body = buf;\r
403                 }\r
404                 else\r
405                 {\r
406                     Body = stream.ReadToEnd();\r
407                 }\r
408                 if (ContentEncoding == null)\r
409                     return;\r
410                 var dc = new MemoryStream();\r
411                 try\r
412                 {\r
413                     if (ContentEncoding == "gzip")\r
414                         new GZipStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
415                     else if (ContentEncoding == "deflate")\r
416                         new DeflateStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
417                 }\r
418                 catch (Exception ex)\r
419                 {\r
420                     throw new HttpProxyAbort($"Fail to decode {ContentEncoding}: " + ex.Message);\r
421                 }\r
422                 Body = dc.ToArray();\r
423             }\r
424         }\r
425 \r
426         public class Request : Message\r
427         {\r
428             private string _requestLine;\r
429 \r
430             public string RequestLine\r
431             {\r
432                 get => _requestLine;\r
433                 set\r
434                 {\r
435                     _requestLine = value;\r
436                     var f = _requestLine.Split(' ');\r
437                     if (f.Length < 3)\r
438                         throw new HttpProxyAbort("Invalid request line");\r
439                     Method = f[0];\r
440                     PathAndQuery = f.Length < 2 ? "" : f[1];\r
441                 }\r
442             }\r
443 \r
444             public string Method { get; private set; }\r
445             public string PathAndQuery { get; private set; }\r
446         }\r
447 \r
448         public class Response : Message\r
449         {\r
450             private string _statusLine;\r
451 \r
452             public override string ModifiedHeaders =>\r
453                 InsertContentLength(RemoveHeaders(base.ModifiedHeaders,\r
454                     new[] {"transfer-encoding", "content-encoding", "content-length"}));\r
455 \r
456             private string InsertContentLength(string headers)\r
457             {\r
458                 return Body == null ? headers : InsertHeader(headers, $"Content-Length: {Body.Length}\r\n");\r
459             }\r
460 \r
461             public string StatusLine\r
462             {\r
463                 get => _statusLine;\r
464                 set\r
465                 {\r
466                     _statusLine = value;\r
467                     var f = _statusLine.Split(' ');\r
468                     if (f.Length < 3)\r
469                         throw new HttpProxyAbort("Invalid status line");\r
470                     StatusCode = _statusLine.Split(' ')[1];\r
471                 }\r
472             }\r
473 \r
474             public string StatusCode { get; private set; }\r
475         }\r
476 \r
477         private class HttpProxyAbort : Exception\r
478         {\r
479             public HttpProxyAbort(string message) : base(message)\r
480             {\r
481             }\r
482         }\r
483 \r
484         public class HttpStream\r
485         {\r
486             private readonly Socket _socket;\r
487             private readonly byte[] _buffer = new byte[4096];\r
488             private int _available;\r
489             private int _position;\r
490 \r
491             public HttpStream(Socket socket)\r
492             {\r
493                 _socket = socket;\r
494                 socket.NoDelay = true;\r
495             }\r
496 \r
497             public string ReadLine()\r
498             {\r
499                 var sb = new StringBuilder();\r
500                 int ch;\r
501                 while ((ch = ReadByte()) != -1)\r
502                 {\r
503                     sb.Append((char)ch);\r
504                     if (ch == '\n')\r
505                         break;\r
506                 }\r
507                 return sb.ToString();\r
508             }\r
509 \r
510             private int ReadByte()\r
511             {\r
512                 if (_position < _available)\r
513                     return _buffer[_position++];\r
514                 _available = _socket.Receive(_buffer, 0, _buffer.Length, SocketFlags.None);\r
515                 _position = 0;\r
516                 return _available == 0 ? -1 : _buffer[_position++];\r
517             }\r
518 \r
519             public HttpStream WriteLines(string s)\r
520             {\r
521                 var buf = Encoding.ASCII.GetBytes(s);\r
522                 Write(buf, 0, buf.Length);\r
523                 return this;\r
524             }\r
525 \r
526             public string ReadHeaders()\r
527             {\r
528                 var sb = new StringBuilder();\r
529                 string line;\r
530                 do\r
531                 {\r
532                     line = ReadLine();\r
533                     sb.Append(line);\r
534                 } while (line != "\r\n");\r
535                 return sb.ToString();\r
536             }\r
537 \r
538             public byte[] ReadChunked()\r
539             {\r
540                 var buf = new MemoryStream();\r
541                 while (true)\r
542                 {\r
543                     var size = ReadLine();\r
544                     if (size.Length < 3)\r
545                         break;\r
546                     var ext = size.IndexOf(';');\r
547                     size = ext == -1 ? size.Substring(0, size.Length - 2) : size.Substring(0, ext);\r
548                     if (!int.TryParse(size, NumberStyles.HexNumber, CultureInfo.InvariantCulture, out var val))\r
549                         throw new HttpProxyAbort("Can't parse chunk size: " + size);\r
550                     if (val == 0)\r
551                         break;\r
552                     var chunk = new byte[val];\r
553                     Read(chunk, 0, chunk.Length);\r
554                     buf.Write(chunk, 0, chunk.Length);\r
555                     ReadLine();\r
556                 }\r
557                 string line;\r
558                 do\r
559                 {\r
560                     line = ReadLine();\r
561                 } while (line != "" && line != "\r\n");\r
562                 return buf.ToArray();\r
563             }\r
564 \r
565             public byte[] ReadToEnd()\r
566             {\r
567                 var result = new MemoryStream();\r
568                 var buf = new byte[4096];\r
569                 int len;\r
570                 while ((len = Read(buf, 0, buf.Length)) > 0)\r
571                     result.Write(buf, 0, len);\r
572                 return result.ToArray();\r
573             }\r
574 \r
575             public HttpStream Write(byte[] body)\r
576             {\r
577                 if (body != null)\r
578                     Write(body, 0, body.Length);\r
579                 return this;\r
580             }\r
581 \r
582             public int Read(byte[] buf, int offset, int count)\r
583             {\r
584                 var total = 0;\r
585                 do\r
586                 {\r
587                     int n;\r
588                     if (_position < _available)\r
589                     {\r
590                         n = Math.Min(count, _available - _position);\r
591                         Buffer.BlockCopy(_buffer, _position, buf, 0, n);\r
592                         _position += n;\r
593                     }\r
594                     else\r
595                     {\r
596                         n = _socket.Receive(buf, offset, count, SocketFlags.None);\r
597                         if (n == 0)\r
598                             return total == 0 ? n : total;\r
599                     }\r
600                     count -= n;\r
601                     offset += n;\r
602                     total += n;\r
603                 } while (count > 0);\r
604                 return total;\r
605             }\r
606 \r
607             public void Write(byte[] buf, int offset, int count)\r
608             {\r
609                 do\r
610                 {\r
611                     var n = _socket.Send(buf, offset, count, SocketFlags.None);\r
612                     if (n == 0)\r
613                         return;\r
614                     count -= n;\r
615                     offset += n;\r
616                 } while (count > 0);\r
617             }\r
618 \r
619             public HttpStream Close()\r
620             {\r
621                 _socket.Close();\r
622                 return this;\r
623             }\r
624         }\r
625     }\r
626 }