OSDN Git Service

任務名をダブルクリックでクリップボードにコピーする
[kancollesniffer/KancolleSniffer.git] / KancolleSniffer / HttpProxy.cs
1 // Copyright (c) 2015 Kazuhiro Fujieda <fujieda@users.osdn.me>\r
2 //\r
3 // Licensed under the Apache License, Version 2.0 (the "License");\r
4 // you may not use this file except in compliance with the License.\r
5 // You may obtain a copy of the License at\r
6 //\r
7 //    http://www.apache.org/licenses/LICENSE-2.0\r
8 //\r
9 // Unless required by applicable law or agreed to in writing, software\r
10 // distributed under the License is distributed on an "AS IS" BASIS,\r
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
12 // See the License for the specific language governing permissions and\r
13 // limitations under the License.\r
14 \r
15 using System;\r
16 using System.Globalization;\r
17 using System.IO;\r
18 using System.IO.Compression;\r
19 using System.Net;\r
20 using System.Net.Sockets;\r
21 using System.Text;\r
22 using System.Text.RegularExpressions;\r
23 using System.Threading.Tasks;\r
24 \r
25 namespace KancolleSniffer\r
26 {\r
27     public class HttpProxy\r
28     {\r
29         private static HttpProxy _httpProxy;\r
30         public static int LocalPort { get; set; }\r
31         public static string UpstreamProxyHost { get; set; }\r
32         public static int UpstreamProxyPort { get; set; }\r
33         public static bool IsEnableUpstreamProxy { get; set; }\r
34         public static bool IsInListening { get; private set; }\r
35         public static event Action<Session> AfterSessionComplete;\r
36 \r
37         private TcpListener _listener;\r
38 \r
39         public static void Startup(int port, bool dummy0, bool dummy1)\r
40         {\r
41             LocalPort = port;\r
42             _httpProxy = new HttpProxy();\r
43             _httpProxy.Start();\r
44         }\r
45 \r
46         public void Start()\r
47         {\r
48             _listener = new TcpListener(IPAddress.Loopback, LocalPort);\r
49             _listener.Start();\r
50             LocalPort = ((IPEndPoint)_listener.LocalEndpoint).Port;\r
51             IsInListening = true;\r
52             Task.Run(() => AcceptClient());\r
53         }\r
54 \r
55         public static void Shutdown()\r
56         {\r
57             _httpProxy?.Stop();\r
58         }\r
59 \r
60         public void Stop()\r
61         {\r
62             IsInListening = false;\r
63             _listener.Server.Close();\r
64             _listener.Stop();\r
65         }\r
66 \r
67         public void AcceptClient()\r
68         {\r
69             try\r
70             {\r
71                 while (true)\r
72                 {\r
73                     var client = _listener.AcceptSocket();\r
74                     Task.Run(() => new HttpClient(client).ProcessRequest());\r
75                 }\r
76             }\r
77             catch (SocketException)\r
78             {\r
79             }\r
80             finally\r
81             {\r
82                 Stop();\r
83             }\r
84         }\r
85 \r
86         private class HttpClient\r
87         {\r
88             private readonly Socket _client;\r
89             private Socket _server;\r
90             private readonly Session _session;\r
91             private readonly HttpStream _clientStream;\r
92             private HttpStream _serverStream;\r
93 \r
94             public HttpClient(Socket client)\r
95             {\r
96                 _client = client;\r
97                 _clientStream = new HttpStream(client);\r
98                 _session = new Session();\r
99             }\r
100 \r
101             public void ProcessRequest()\r
102             {\r
103                 try\r
104                 {\r
105                     ReceiveRequest();\r
106                     if (_session.Request.Method == null)\r
107                         return;\r
108                     if (_session.Request.Method == "CONNECT")\r
109                     {\r
110                         HandleConnect();\r
111                         return;\r
112                     }\r
113                     if (_session.Request.Host.StartsWith("localhost") || _session.Request.Host.StartsWith("127.0.0.1"))\r
114                     {\r
115                         LogServer.Process(_client, _session.Request.RequestLine);\r
116                         return;\r
117                     }\r
118                     SendRequest();\r
119                     ReceiveRequestBody();\r
120                     SendRequestBody();\r
121                     ReceiveResponse();\r
122                     if (_session.Response.StatusCode == null)\r
123                         return;\r
124                     AfterSessionComplete?.Invoke(_session);\r
125                     SendResponse();\r
126                 }\r
127 #if DEBUG\r
128                 catch (Exception e)\r
129                 {\r
130                     File.AppendAllText("debug.log", $"[{DateTime.Now:g}] " + e + "\r\n");\r
131                 }\r
132 #else // ReSharper disable once EmptyGeneralCatchClause\r
133                 catch\r
134                 {\r
135                 }\r
136 #endif\r
137                 finally\r
138                 {\r
139                     Close();\r
140                 }\r
141             }\r
142 \r
143             private void ReceiveRequest()\r
144             {\r
145                 var requestLine = _clientStream.ReadLine();\r
146                 if (requestLine == "")\r
147                     return;\r
148                 _session.Request.RequestLine = requestLine;\r
149                 _session.Request.Headers = _clientStream.ReadHeaders();\r
150             }\r
151 \r
152             private void ReceiveRequestBody()\r
153             {\r
154                 if (_session.Request.ContentLength != -1 || _session.Request.TransferEncoding != null)\r
155                     _session.Request.ReadBody(_clientStream);\r
156             }\r
157 \r
158             private void SendRequest()\r
159             {\r
160                 _server = ConnectServer();\r
161                 _serverStream =\r
162                     new HttpStream(_server).WriteLines(_session.Request.RequestLine + _session.Request.ModifiedHeaders);\r
163             }\r
164 \r
165             private void SendRequestBody()\r
166             {\r
167                 _serverStream.Write(_session.Request.Body);\r
168             }\r
169 \r
170             private void ReceiveResponse()\r
171             {\r
172                 var statusLine = _serverStream.ReadLine();\r
173                 if (statusLine == "")\r
174                     return;\r
175                 _session.Response.StatusLine = statusLine;\r
176                 _session.Response.Headers = _serverStream.ReadHeaders();\r
177                 if (HasBody)\r
178                     _session.Response.ReadBody(_serverStream);\r
179             }\r
180 \r
181             private bool HasBody\r
182             {\r
183                 get\r
184                 {\r
185                     var code = _session.Response.StatusCode;\r
186                     return (!(_session.Request.Method == "HEAD" ||\r
187                               code.StartsWith("1") || code == "204" || code == "304"));\r
188                 }\r
189             }\r
190 \r
191             private void SendResponse()\r
192             {\r
193                 _clientStream.WriteLines(_session.Response.StatusLine + _session.Response.ModifiedHeaders)\r
194                     .Write(_session.Response.Body);\r
195             }\r
196 \r
197             private void HandleConnect()\r
198             {\r
199                 var host = "";\r
200                 var port = 443;\r
201                 if (!ParseAuthority(_session.Request.PathAndQuery, ref host, ref port))\r
202                     return;\r
203                 _server = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);\r
204                 _server.Connect(host, port);\r
205                 _clientStream.WriteLines("HTTP/1.0 200 Connection established\r\n\r\n");\r
206                 Task[] tasks =\r
207                 {\r
208                     Task.Run(() => { TunnnelSockets(_client, _server); }),\r
209                     Task.Run(() => { TunnnelSockets(_server, _client); })\r
210                 };\r
211                 Task.WaitAll(tasks);\r
212             }\r
213 \r
214             private void TunnnelSockets(Socket from, Socket to)\r
215             {\r
216                 try\r
217                 {\r
218                     var buf = new byte[8192];\r
219                     while (true)\r
220                     {\r
221                         var n = from.Receive(buf);\r
222                         if (n == 0)\r
223                             break;\r
224                         var sent = to.Send(buf, n, SocketFlags.None);\r
225                         if (sent < n)\r
226                             break;\r
227                     }\r
228                     to.Shutdown(SocketShutdown.Send);\r
229                 }\r
230                 catch (SocketException)\r
231                 {\r
232                 }\r
233             }\r
234 \r
235             private static readonly Regex HostAndPortRegex =\r
236                 new Regex("http://([^:/]+)(?::(\\d+))?/", RegexOptions.Compiled);\r
237 \r
238             private Socket ConnectServer()\r
239             {\r
240                 string host = null;\r
241                 var port = 80;\r
242                 if (IsEnableUpstreamProxy)\r
243                 {\r
244                     host = UpstreamProxyHost;\r
245                     port = UpstreamProxyPort;\r
246                     goto connect;\r
247                 }\r
248                 var m = HostAndPortRegex.Match(_session.Request.RequestLine);\r
249                 if (m.Success)\r
250                 {\r
251                     host = m.Groups[1].Value;\r
252                     if (m.Groups[2].Success)\r
253                         port = int.Parse(m.Groups[2].Value);\r
254                     _session.Request.RequestLine = _session.Request.RequestLine.Remove(m.Index, m.Length - 1);\r
255                 }\r
256                 if (host == null && !ParseAuthority(_session.Request.Host, ref host, ref port))\r
257                     throw new HttpProxyAbort("Can't find destination host");\r
258                 connect:\r
259                 var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);\r
260                 socket.Connect(host, port);\r
261                 return socket;\r
262             }\r
263 \r
264             private static readonly Regex AuthorityRegex = new Regex("([^:]+)(?::(\\d+))?");\r
265 \r
266             private bool ParseAuthority(string authority, ref string host, ref int port)\r
267             {\r
268                 if (string.IsNullOrEmpty(authority))\r
269                     return false;\r
270                 var m = AuthorityRegex.Match(authority);\r
271                 if (!m.Success)\r
272                     return false;\r
273                 host = m.Groups[1].Value;\r
274                 if (m.Groups[2].Success)\r
275                     port = int.Parse(m.Groups[2].Value);\r
276                 return true;\r
277             }\r
278 \r
279             private void Close()\r
280             {\r
281                 SocketClose(_server);\r
282                 SocketClose(_client);\r
283             }\r
284 \r
285             private void SocketClose(Socket socket)\r
286             {\r
287                 if (socket == null)\r
288                     return;\r
289                 try\r
290                 {\r
291                     socket.Shutdown(SocketShutdown.Both);\r
292                 }\r
293                 // ReSharper disable EmptyGeneralCatchClause\r
294                 catch\r
295 \r
296                 {\r
297                 }\r
298                 try\r
299                 {\r
300                     socket.Close();\r
301                 }\r
302                 catch\r
303                     // ReSharper restore EmptyGeneralCatchClause\r
304                 {\r
305                 }\r
306             }\r
307         }\r
308 \r
309         public class Session\r
310         {\r
311             public Request Request { get; set; } = new Request();\r
312             public Response Response { get; set; } = new Response();\r
313         }\r
314 \r
315         public class Message\r
316         {\r
317             private string _headers;\r
318             public byte[] Body { get; set; }\r
319 \r
320             private static readonly Regex CharsetRegx = new Regex("charset=([\\w-]+)",\r
321                 RegexOptions.Compiled | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);\r
322 \r
323             public int ContentLength { get; set; } = -1;\r
324             public string TransferEncoding { get; set; }\r
325             public string ContentType { get; set; }\r
326             public string ContentEncoding { get; set; }\r
327             public string Host { get; set; }\r
328 \r
329             public string Headers\r
330             {\r
331                 get => _headers;\r
332                 set\r
333                 {\r
334                     _headers = value;\r
335                     SetHeaders(_headers);\r
336                 }\r
337             }\r
338 \r
339             public virtual string ModifiedHeaders => SetConnectionClose(Headers);\r
340 \r
341             private string SetConnectionClose(string headers)\r
342             {\r
343                 return InsertHeader(RemoveHeaders(headers,\r
344                     new[] {"connection", "keep-alive", "proxy-connection"}), "Connection: close\r\n");\r
345             }\r
346 \r
347             protected string RemoveHeaders(string headers, string[] fields)\r
348             {\r
349                 foreach (var f in fields)\r
350                 {\r
351                     var m = MatchField(f, headers);\r
352                     if (!m.Success)\r
353                         continue;\r
354                     headers = headers.Remove(m.Index, m.Length);\r
355                 }\r
356                 return headers;\r
357             }\r
358 \r
359             protected string InsertHeader(string headers, string header)\r
360             {\r
361                 return headers.Insert(headers.Length - 2, header);\r
362             }\r
363 \r
364             protected virtual void SetHeaders(string headers)\r
365             {\r
366                 var s = GetField("content-length");\r
367                 if (s != null)\r
368                 {\r
369                     ContentLength = int.TryParse(s, out var len) ? len : -1;\r
370                 }\r
371                 TransferEncoding = GetField("transfer-encoding")?.ToLower(CultureInfo.InvariantCulture);\r
372                 ContentType = GetField("content-type");\r
373                 ContentEncoding = GetField("content-encoding");\r
374                 Host = GetField("host");\r
375             }\r
376 \r
377             protected Match MatchField(string name, string headers)\r
378             {\r
379                 var regex = new Regex("^" + name + ":\\s*([^\r]+)\r\n",\r
380                     RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.Multiline);\r
381                 return regex.Match(headers);\r
382             }\r
383 \r
384             protected string GetField(string name)\r
385             {\r
386                 var m = MatchField(name, Headers);\r
387                 return m.Success ? m.Groups[1].Value : null;\r
388             }\r
389 \r
390             public string BodyAsString\r
391             {\r
392                 get\r
393                 {\r
394                     if (Body == null)\r
395                         return "";\r
396                     var m = CharsetRegx.Match(ContentType ?? "");\r
397                     var encoding = Encoding.ASCII;\r
398                     if (m.Success)\r
399                     {\r
400                         var name = m.Groups[1].Value;\r
401                         if (name == "utf8")\r
402                             name = "UTF-8";\r
403                         encoding = Encoding.GetEncoding(name);\r
404                     }\r
405                     return encoding.GetString(Body);\r
406                 }\r
407             }\r
408 \r
409             public void ReadBody(HttpStream stream)\r
410             {\r
411                 if (TransferEncoding != null && TransferEncoding.Contains("chunked"))\r
412                 {\r
413                     Body = stream.ReadChunked();\r
414                 }\r
415                 else if (ContentLength == 0)\r
416                 {\r
417                 }\r
418                 else if (ContentLength > 0)\r
419                 {\r
420                     var buf = new byte[ContentLength];\r
421                     stream.Read(buf, 0, ContentLength);\r
422                     Body = buf;\r
423                 }\r
424                 else\r
425                 {\r
426                     Body = stream.ReadToEnd();\r
427                 }\r
428                 if (ContentEncoding == null)\r
429                     return;\r
430                 var dc = new MemoryStream();\r
431                 try\r
432                 {\r
433                     if (ContentEncoding == "gzip")\r
434                         new GZipStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
435                     else if (ContentEncoding == "deflate")\r
436                         new DeflateStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
437                 }\r
438                 catch (Exception ex)\r
439                 {\r
440                     throw new HttpProxyAbort($"Fail to decode {ContentEncoding}: " + ex.Message);\r
441                 }\r
442                 Body = dc.ToArray();\r
443             }\r
444         }\r
445 \r
446         public class Request : Message\r
447         {\r
448             private string _requestLine;\r
449 \r
450             public string RequestLine\r
451             {\r
452                 get => _requestLine;\r
453                 set\r
454                 {\r
455                     _requestLine = value;\r
456                     var f = _requestLine.Split(' ');\r
457                     if (f.Length < 3)\r
458                         throw new HttpProxyAbort("Invalid request line");\r
459                     Method = f[0];\r
460                     PathAndQuery = f.Length < 2 ? "" : f[1];\r
461                 }\r
462             }\r
463 \r
464             public string Method { get; private set; }\r
465             public string PathAndQuery { get; private set; }\r
466         }\r
467 \r
468         public class Response : Message\r
469         {\r
470             private string _statusLine;\r
471 \r
472             public override string ModifiedHeaders =>\r
473                 InsertContentLength(RemoveHeaders(base.ModifiedHeaders,\r
474                     new[] {"transfer-encoding", "content-encoding", "content-length"}));\r
475 \r
476             private string InsertContentLength(string headers)\r
477             {\r
478                 return Body == null ? headers : InsertHeader(headers, $"Content-Length: {Body.Length}\r\n");\r
479             }\r
480 \r
481             public string StatusLine\r
482             {\r
483                 get => _statusLine;\r
484                 set\r
485                 {\r
486                     _statusLine = value;\r
487                     var f = _statusLine.Split(' ');\r
488                     if (f.Length < 3)\r
489                         throw new HttpProxyAbort("Invalid status line");\r
490                     StatusCode = _statusLine.Split(' ')[1];\r
491                 }\r
492             }\r
493 \r
494             public string StatusCode { get; private set; }\r
495         }\r
496 \r
497         private class HttpProxyAbort : Exception\r
498         {\r
499             public HttpProxyAbort(string message) : base(message)\r
500             {\r
501             }\r
502         }\r
503 \r
504         public class HttpStream\r
505         {\r
506             private readonly Socket _socket;\r
507             private readonly byte[] _buffer = new byte[4096];\r
508             private int _available;\r
509             private int _position;\r
510 \r
511             public HttpStream(Socket socket)\r
512             {\r
513                 _socket = socket;\r
514                 socket.NoDelay = true;\r
515             }\r
516 \r
517             public string ReadLine()\r
518             {\r
519                 var sb = new StringBuilder();\r
520                 int ch;\r
521                 while ((ch = ReadByte()) != -1)\r
522                 {\r
523                     sb.Append((char)ch);\r
524                     if (ch == '\n')\r
525                         break;\r
526                 }\r
527                 return sb.ToString();\r
528             }\r
529 \r
530             private int ReadByte()\r
531             {\r
532                 if (_position < _available)\r
533                     return _buffer[_position++];\r
534                 _available = _socket.Receive(_buffer, 0, _buffer.Length, SocketFlags.None);\r
535                 _position = 0;\r
536                 return _available == 0 ? -1 : _buffer[_position++];\r
537             }\r
538 \r
539             public HttpStream WriteLines(string s)\r
540             {\r
541                 var buf = Encoding.ASCII.GetBytes(s);\r
542                 Write(buf, 0, buf.Length);\r
543                 return this;\r
544             }\r
545 \r
546             public string ReadHeaders()\r
547             {\r
548                 var sb = new StringBuilder();\r
549                 string line;\r
550                 do\r
551                 {\r
552                     line = ReadLine();\r
553                     sb.Append(line);\r
554                 } while (line != "\r\n");\r
555                 return sb.ToString();\r
556             }\r
557 \r
558             public byte[] ReadChunked()\r
559             {\r
560                 var buf = new MemoryStream();\r
561                 while (true)\r
562                 {\r
563                     var size = ReadLine();\r
564                     if (size.Length < 3)\r
565                         break;\r
566                     var ext = size.IndexOf(';');\r
567                     size = ext == -1 ? size.Substring(0, size.Length - 2) : size.Substring(0, ext);\r
568                     if (!int.TryParse(size, NumberStyles.HexNumber, CultureInfo.InvariantCulture, out var val))\r
569                         throw new HttpProxyAbort("Can't parse chunk size: " + size);\r
570                     if (val == 0)\r
571                         break;\r
572                     var chunk = new byte[val];\r
573                     Read(chunk, 0, chunk.Length);\r
574                     buf.Write(chunk, 0, chunk.Length);\r
575                     ReadLine();\r
576                 }\r
577                 string line;\r
578                 do\r
579                 {\r
580                     line = ReadLine();\r
581                 } while (line != "" && line != "\r\n");\r
582                 return buf.ToArray();\r
583             }\r
584 \r
585             public byte[] ReadToEnd()\r
586             {\r
587                 var result = new MemoryStream();\r
588                 var buf = new byte[4096];\r
589                 int len;\r
590                 while ((len = Read(buf, 0, buf.Length)) > 0)\r
591                     result.Write(buf, 0, len);\r
592                 return result.ToArray();\r
593             }\r
594 \r
595             public HttpStream Write(byte[] body)\r
596             {\r
597                 if (body != null)\r
598                     Write(body, 0, body.Length);\r
599                 return this;\r
600             }\r
601 \r
602             public int Read(byte[] buf, int offset, int count)\r
603             {\r
604                 var total = 0;\r
605                 do\r
606                 {\r
607                     int n;\r
608                     if (_position < _available)\r
609                     {\r
610                         n = Math.Min(count, _available - _position);\r
611                         Buffer.BlockCopy(_buffer, _position, buf, 0, n);\r
612                         _position += n;\r
613                     }\r
614                     else\r
615                     {\r
616                         n = _socket.Receive(buf, offset, count, SocketFlags.None);\r
617                         if (n == 0)\r
618                             return total == 0 ? n : total;\r
619                     }\r
620                     count -= n;\r
621                     offset += n;\r
622                     total += n;\r
623                 } while (count > 0);\r
624                 return total;\r
625             }\r
626 \r
627             public void Write(byte[] buf, int offset, int count)\r
628             {\r
629                 do\r
630                 {\r
631                     var n = _socket.Send(buf, offset, count, SocketFlags.None);\r
632                     if (n == 0)\r
633                         return;\r
634                     count -= n;\r
635                     offset += n;\r
636                 } while (count > 0);\r
637             }\r
638         }\r
639     }\r
640 }