OSDN Git Service

HTTPプロキシがポート番号指定のリクエストを正しく扱えないのを直す
[kancollesniffer/KancolleSniffer.git] / KancolleSniffer / HttpProxy.cs
1 // Copyright (c) 2015 Kazuhiro Fujieda <fujieda@users.osdn.me>\r
2 //\r
3 // Licensed under the Apache License, Version 2.0 (the "License");\r
4 // you may not use this file except in compliance with the License.\r
5 // You may obtain a copy of the License at\r
6 //\r
7 //    http://www.apache.org/licenses/LICENSE-2.0\r
8 //\r
9 // Unless required by applicable law or agreed to in writing, software\r
10 // distributed under the License is distributed on an "AS IS" BASIS,\r
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\r
12 // See the License for the specific language governing permissions and\r
13 // limitations under the License.\r
14 \r
15 using System;\r
16 using System.Globalization;\r
17 using System.IO;\r
18 using System.IO.Compression;\r
19 using System.Net;\r
20 using System.Net.Sockets;\r
21 using System.Text;\r
22 using System.Text.RegularExpressions;\r
23 using System.Threading.Tasks;\r
24 \r
25 namespace KancolleSniffer\r
26 {\r
27     public class HttpProxy\r
28     {\r
29         private static HttpProxy _httpProxy;\r
30         public static int LocalPort { get; set; }\r
31         public static string UpstreamProxyHost { get; set; }\r
32         public static int UpstreamProxyPort { get; set; }\r
33         public static bool IsEnableUpstreamProxy { get; set; }\r
34         public static bool IsInListening { get; private set; }\r
35         public static event Action<Session> AfterSessionComplete;\r
36 \r
37         private TcpListener _listener;\r
38 \r
39         public static void Startup(int port, bool dummy0, bool dummy1)\r
40         {\r
41             LocalPort = port;\r
42             _httpProxy = new HttpProxy();\r
43             _httpProxy.Start();\r
44         }\r
45 \r
46         public void Start()\r
47         {\r
48             _listener = new TcpListener(IPAddress.Loopback, LocalPort);\r
49             _listener.Start();\r
50             LocalPort = ((IPEndPoint)(_listener.LocalEndpoint)).Port;\r
51             IsInListening = true;\r
52             Task.Run(() => AcceptClient());\r
53         }\r
54 \r
55         public static void Shutdown()\r
56         {\r
57             _httpProxy?.Stop();\r
58         }\r
59 \r
60         public void Stop()\r
61         {\r
62             IsInListening = false;\r
63             _listener.Server.Close();\r
64             _listener.Stop();\r
65         }\r
66 \r
67         public void AcceptClient()\r
68         {\r
69             try\r
70             {\r
71                 while (true)\r
72                 {\r
73                     var client = _listener.AcceptSocket();\r
74                     Task.Run(() => new HttpClient(client).ProcessRequest());\r
75                 }\r
76             }\r
77             catch (SocketException)\r
78             {\r
79             }\r
80             finally\r
81             {\r
82                 Stop();\r
83             }\r
84         }\r
85 \r
86         private class HttpClient\r
87         {\r
88             private readonly Socket _client;\r
89             private Socket _server;\r
90             private readonly Session _session;\r
91             private readonly HttpStream _clientStream;\r
92             private HttpStream _serverStream;\r
93 \r
94             public HttpClient(Socket client)\r
95             {\r
96                 _client = client;\r
97                 _clientStream = new HttpStream(client);\r
98                 _session = new Session();\r
99             }\r
100 \r
101             public void ProcessRequest()\r
102             {\r
103                 try\r
104                 {\r
105                     ReceiveRequest();\r
106                     SendRequest();\r
107                     ReceiveResponse();\r
108                     SendResponse();\r
109                     Close();\r
110                     AfterSessionComplete?.Invoke(_session);\r
111                 }\r
112                 catch (SocketException)\r
113                 {\r
114                 }\r
115                 catch (IOException)\r
116                 {\r
117                 }\r
118                 catch (HttpProxyAbort)\r
119                 {\r
120                 }\r
121                 finally\r
122                 {\r
123                     Close();\r
124                 }\r
125             }\r
126 \r
127             private void ReceiveRequest()\r
128             {\r
129                 var requestLine = _clientStream.ReadLine();\r
130                 _session.Request.RequestLine = requestLine;\r
131                 _session.Request.Headers = _clientStream.ReadHeaders();\r
132                 if (_session.Request.ContentLength != -1 || _session.Request.TransferEncoding != null)\r
133                     _session.Request.ReadBody(_clientStream);\r
134             }\r
135 \r
136             private void SendRequest()\r
137             {\r
138                 _server = ConnectServer();\r
139                 _serverStream = new HttpStream(_server).\r
140                     WriteLines(_session.Request.RequestLine + _session.Request.ModifiedHeaders).\r
141                     Write(_session.Request.Body);\r
142             }\r
143 \r
144             private void ReceiveResponse()\r
145             {\r
146                 _session.Response.StatusLine = _serverStream.ReadLine();\r
147                 _session.Response.Headers = _serverStream.ReadHeaders();\r
148                 if (HasBody)\r
149                     _session.Response.ReadBody(_serverStream);\r
150             }\r
151 \r
152             private bool HasBody\r
153             {\r
154                 get\r
155                 {\r
156                     var code = _session.Response.StatusCode;\r
157                     return (!(_session.Request.Method == "HEAD" ||\r
158                               code.StartsWith("1") || code == "204" || code == "304"));\r
159                 }\r
160             }\r
161 \r
162             private void SendResponse()\r
163             {\r
164                 _clientStream.WriteLines(_session.Response.StatusLine + _session.Response.ModifiedHeaders)\r
165                     .Write(_session.Response.Body);\r
166             }\r
167 \r
168             private static readonly Regex HostAndPortRegex =\r
169                 new Regex("http://([^:/]+)(?::(\\d+))?/", RegexOptions.Compiled);\r
170 \r
171             private Socket ConnectServer()\r
172             {\r
173                 string host = null;\r
174                 var port = 80;\r
175                 if (IsEnableUpstreamProxy)\r
176                 {\r
177                     host = UpstreamProxyHost;\r
178                     port = UpstreamProxyPort;\r
179                     goto connect;\r
180                 }\r
181                 var m = HostAndPortRegex.Match(_session.Request.RequestLine);\r
182                 if (m.Success)\r
183                 {\r
184                     host = m.Groups[1].Value;\r
185                     if (m.Groups[2].Success)\r
186                         port = int.Parse(m.Groups[2].Value);\r
187                     _session.Request.RequestLine = _session.Request.RequestLine.Remove(m.Index, m.Length - 1);\r
188                 }\r
189                 if (host == null && !ParseAuthority(_session.Request.Host, ref host, ref port))\r
190                     throw new HttpProxyAbort("Can't find destination host");\r
191                 connect:\r
192                 var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);\r
193                 socket.Connect(host, port);\r
194                 return socket;\r
195             }\r
196 \r
197             private static readonly Regex AuthorityRegex = new Regex("([^:]+)(?::(\\d+))?");\r
198 \r
199             private bool ParseAuthority(string authority, ref string host, ref int port)\r
200             {\r
201                 if (string.IsNullOrEmpty(authority))\r
202                     return false;\r
203                 var m = AuthorityRegex.Match(authority);\r
204                 if (!m.Success)\r
205                     return false;\r
206                 host = m.Groups[1].Value;\r
207                 if (m.Groups[2].Success)\r
208                     port = int.Parse(m.Groups[2].Value);\r
209                 return true;\r
210             }\r
211 \r
212             private void Close()\r
213             {\r
214                 _serverStream?.Close();\r
215                 _clientStream?.Close();\r
216                 _server?.Close();\r
217                 _client.Close();\r
218             }\r
219         }\r
220 \r
221         public class Session\r
222         {\r
223             public Request Request { get; set; } = new Request();\r
224             public Response Response { get; set; } = new Response();\r
225         }\r
226 \r
227         public class Message\r
228         {\r
229             private string _headers;\r
230             public byte[] Body { get; set; }\r
231 \r
232             private static readonly Regex CharsetRegx = new Regex("charset=([\\w-]+)",\r
233                 RegexOptions.Compiled | RegexOptions.IgnoreCase | RegexOptions.CultureInvariant);\r
234 \r
235             public int ContentLength { get; set; } = -1;\r
236             public string TransferEncoding { get; set; }\r
237             public string ContentType { get; set; }\r
238             public string ContentEncoding { get; set; }\r
239             public string Host { get; set; }\r
240 \r
241             public string Headers\r
242             {\r
243                 get { return _headers; }\r
244                 set\r
245                 {\r
246                     _headers = value;\r
247                     SetHeaders(_headers);\r
248                 }\r
249             }\r
250 \r
251             public virtual string ModifiedHeaders => SetConnectionClose(Headers);\r
252 \r
253             private string SetConnectionClose(string headers)\r
254             {\r
255                 return InsertHeader(RemoveHeaders(headers,\r
256                     new[] {"connection", "keep-alive", "proxy-connection"}), "Connection: close\r\n");\r
257             }\r
258 \r
259             protected string RemoveHeaders(string headers, string[] fields)\r
260             {\r
261                 foreach (var f in fields)\r
262                 {\r
263                     var m = MatchField(f, headers);\r
264                     if (!m.Success)\r
265                         continue;\r
266                     headers = headers.Remove(m.Index, m.Length);\r
267                 }\r
268                 return headers;\r
269             }\r
270 \r
271             protected string InsertHeader(string headers, string header)\r
272             {\r
273                 return headers.Insert(headers.Length - 2, header);\r
274             }\r
275 \r
276             protected virtual void SetHeaders(string headers)\r
277             {\r
278                 var s = GetField("content-length");\r
279                 if (s != null)\r
280                 {\r
281                     int len;\r
282                     ContentLength = int.TryParse(s, out len) ? len : -1;\r
283                 }\r
284                 TransferEncoding = GetField("transfer-encoding")?.ToLower(CultureInfo.InvariantCulture);\r
285                 ContentType = GetField("content-type");\r
286                 ContentEncoding = GetField("content-encoding");\r
287                 Host = GetField("host");\r
288             }\r
289 \r
290             protected Match MatchField(string name, string headers)\r
291             {\r
292                 var regex = new Regex("^" + name + ":\\s*([^\r]+)\r\n",\r
293                     RegexOptions.CultureInvariant | RegexOptions.IgnoreCase | RegexOptions.Multiline);\r
294                 return regex.Match(headers);\r
295             }\r
296 \r
297             protected string GetField(string name)\r
298             {\r
299                 var m = MatchField(name, Headers);\r
300                 return m.Success ? m.Groups[1].Value : null;\r
301             }\r
302 \r
303             public string BodyAsString\r
304             {\r
305                 get\r
306                 {\r
307                     if (Body == null)\r
308                         return "";\r
309                     var m = CharsetRegx.Match(ContentType ?? "");\r
310                     var encoding = m.Success ? Encoding.GetEncoding(m.Groups[1].Value) : Encoding.ASCII;\r
311                     return encoding.GetString(Body);\r
312                 }\r
313             }\r
314 \r
315             public void ReadBody(HttpStream stream)\r
316             {\r
317                 if (TransferEncoding != null && TransferEncoding.Contains("chunked"))\r
318                 {\r
319                     Body = stream.ReadChunked();\r
320                 }\r
321                 else if (ContentLength == 0)\r
322                 {\r
323                 }\r
324                 else if (ContentLength > 0)\r
325                 {\r
326                     var buf = new byte[ContentLength];\r
327                     stream.Read(buf, 0, ContentLength);\r
328                     Body = buf;\r
329                 }\r
330                 else\r
331                 {\r
332                     Body = stream.ReadToEnd();\r
333                 }\r
334                 if (ContentEncoding == null)\r
335                     return;\r
336                 var dc = new MemoryStream();\r
337                 if (ContentEncoding == "gzip")\r
338                     new GZipStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
339                 else if (ContentEncoding == "deflate")\r
340                     new DeflateStream(new MemoryStream(Body), CompressionMode.Decompress).CopyTo(dc);\r
341                 Body = dc.ToArray();\r
342             }\r
343         }\r
344 \r
345         public class Request : Message\r
346         {\r
347             private string _requestLine;\r
348 \r
349             public string RequestLine\r
350             {\r
351                 get { return _requestLine; }\r
352                 set\r
353                 {\r
354                     _requestLine = value;\r
355                     var f = _requestLine.Split(' ');\r
356                     if (f.Length < 3)\r
357                         throw new HttpProxyAbort("Invalid request line");\r
358                     Method = f[0];\r
359                     PathAndQuery = f.Length < 2 ? "" : f[1];\r
360                 }\r
361             }\r
362 \r
363             public string Method { get; private set; }\r
364             public string PathAndQuery { get; private set; }\r
365         }\r
366 \r
367         public class Response : Message\r
368         {\r
369             private string _statusLine;\r
370 \r
371             public override string ModifiedHeaders =>\r
372                 InsertContentLength(RemoveHeaders(base.ModifiedHeaders, new [] {"transfer-encoding", "content-encoding", "content-length"}));\r
373 \r
374             private string InsertContentLength(string headers)\r
375             {\r
376                 return Body == null ? headers : InsertHeader(headers, $"Content-Length: {Body.Length}\r\n");\r
377             }\r
378 \r
379             public string StatusLine\r
380             {\r
381                 get { return _statusLine; }\r
382                 set\r
383                 {\r
384                     _statusLine = value;\r
385                     var f = _statusLine.Split(' ');\r
386                     if (f.Length < 3)\r
387                         throw new HttpProxyAbort("Invalid status line");\r
388                     StatusCode = _statusLine.Split(' ')[1];\r
389                 }\r
390             }\r
391 \r
392             public string StatusCode { get; private set; }\r
393         }\r
394 \r
395         private class HttpProxyAbort : Exception\r
396         {\r
397             public HttpProxyAbort(string message) : base(message)\r
398             {\r
399             }\r
400         }\r
401 \r
402         public class HttpStream\r
403         {\r
404             private readonly Socket _socket;\r
405             private readonly byte[] _buffer = new byte[4096];\r
406             private int _available;\r
407             private int _position;\r
408 \r
409             public HttpStream(Socket socket)\r
410             {\r
411                 _socket = socket;\r
412                 socket.NoDelay = true;\r
413             }\r
414 \r
415             public string ReadLine()\r
416             {\r
417                 var sb = new StringBuilder();\r
418                 int ch;\r
419                 while ((ch = ReadByte()) != -1)\r
420                 {\r
421                     sb.Append((char)ch);\r
422                     if (ch == '\n')\r
423                         break;\r
424                 }\r
425                 return sb.ToString();\r
426             }\r
427 \r
428             private int ReadByte()\r
429             {\r
430                 if (_position < _available)\r
431                     return _buffer[_position++];\r
432                 _available = _socket.Receive(_buffer, 0, _buffer.Length, SocketFlags.None);\r
433                 _position = 0;\r
434                 return _available == 0 ? -1 : _buffer[_position++];\r
435             }\r
436 \r
437             public HttpStream WriteLines(string s)\r
438             {\r
439                 var buf = Encoding.ASCII.GetBytes(s);\r
440                 Write(buf, 0, buf.Length);\r
441                 return this;\r
442             }\r
443 \r
444             public string ReadHeaders()\r
445             {\r
446                 var sb = new StringBuilder();\r
447                 string line;\r
448                 do\r
449                 {\r
450                     line = ReadLine();\r
451                     sb.Append(line);\r
452                 } while (line != "\r\n");\r
453                 return sb.ToString();\r
454             }\r
455 \r
456             public byte[] ReadChunked()\r
457             {\r
458                 var buf = new MemoryStream();\r
459                 while (true)\r
460                 {\r
461                     var size = ReadLine();\r
462                     if (size.Length < 3)\r
463                         break;\r
464                     int val;\r
465                     if (!int.TryParse(size.Substring(0, size.Length - 2),\r
466                         NumberStyles.AllowHexSpecifier, CultureInfo.InvariantCulture, out val))\r
467                         break;\r
468                     if (val == 0)\r
469                     {\r
470                         ReadLine();\r
471                         break;\r
472                     }\r
473                     var chunk = new byte[val];\r
474                     Read(chunk, 0, chunk.Length);\r
475                     buf.Write(chunk, 0, chunk.Length);\r
476                     ReadLine();\r
477                 }\r
478                 return buf.ToArray();\r
479             }\r
480 \r
481             public byte[] ReadToEnd()\r
482             {\r
483                 var result = new MemoryStream();\r
484                 var buf = new byte[4096];\r
485                 int len;\r
486                 while ((len = Read(buf, 0, buf.Length)) > 0)\r
487                     result.Write(buf, 0, len);\r
488                 return result.ToArray();\r
489             }\r
490 \r
491             public HttpStream Write(byte[] body)\r
492             {\r
493                 if (body != null)\r
494                     Write(body, 0, body.Length);\r
495                 return this;\r
496             }\r
497 \r
498             public int Read(byte[] buf, int offset, int count)\r
499             {\r
500                 try\r
501                 {\r
502                     var total = 0;\r
503                     do\r
504                     {\r
505                         int n;\r
506                         if (_position < _available)\r
507                         {\r
508                             n = Math.Min(count, _available - _position);\r
509                             Buffer.BlockCopy(_buffer, _position, buf, 0, n);\r
510                             _position += n;\r
511                         }\r
512                         else\r
513                         {\r
514                             n = _socket.Receive(buf, offset, count, SocketFlags.None);\r
515                             if (n == 0)\r
516                                 return total == 0 ? n : total;\r
517                         }\r
518                         count -= n;\r
519                         offset += n;\r
520                         total += n;\r
521                     } while (count > 0);\r
522                     return total;\r
523                 }\r
524                 catch (IOException)\r
525                 {\r
526                     return -1;\r
527                 }\r
528             }\r
529 \r
530             public void Write(byte[] buf, int offset, int count)\r
531             {\r
532                 try\r
533                 {\r
534                     do\r
535                     {\r
536                         var n = _socket.Send(buf, offset, count, SocketFlags.None);\r
537                         if (n == 0)\r
538                             return;\r
539                         count -= n;\r
540                         offset += n;\r
541                     } while (count > 0);\r
542                 }\r
543                 catch (IOException)\r
544                 {\r
545                 }\r
546             }\r
547 \r
548             public HttpStream Close()\r
549             {\r
550                 _socket.Close();\r
551                 return this;\r
552             }\r
553         }\r
554     }\r
555 }