OSDN Git Service

これでいいはずと思います
[pybbs/pybbs.git] / linebot / webhook.py
1 # -*- coding: utf-8 -*-
2
3 #  Licensed under the Apache License, Version 2.0 (the "License"); you may
4 #  not use this file except in compliance with the License. You may obtain
5 #  a copy of the License at
6 #
7 #       https://www.apache.org/licenses/LICENSE-2.0
8 #
9 #  Unless required by applicable law or agreed to in writing, software
10 #  distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 #  WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 #  License for the specific language governing permissions and limitations
13 #  under the License.
14
15 """linebot.http_client webhook."""
16
17 from __future__ import unicode_literals
18
19 import base64
20 import hashlib
21 import hmac
22 import inspect
23 import json
24
25 from .exceptions import InvalidSignatureError
26 from .models.events import (
27     MessageEvent,
28     FollowEvent,
29     UnfollowEvent,
30     JoinEvent,
31     LeaveEvent,
32     PostbackEvent,
33     BeaconEvent,
34     AccountLinkEvent,
35 )
36 from .utils import LOGGER, PY3, safe_compare_digest
37
38
39 if hasattr(hmac, "compare_digest"):
40     def compare_digest(val1, val2):
41         """compare_digest function.
42
43         If hmac module has compare_digest function, use it.
44         Or not, use linebot.utils.safe_compare_digest.
45
46         :param val1: string or bytes for compare
47         :type val1: str | bytes
48         :param val2: string or bytes for compare
49         :type val2: str | bytes
50         :rtype: bool
51         :return: result
52         """
53         return hmac.compare_digest(val1, val2)
54 else:
55     def compare_digest(val1, val2):
56         """compare_digest function.
57
58         If hmac module has compare_digest function, use it.
59         Or not, use linebot.utils.safe_compare_digest.
60
61         :param val1: string or bytes for compare
62         :type val1: str | bytes
63         :param val2: string or bytes for compare
64         :type val2: str | bytes
65         :rtype: bool
66         :return: result
67         """
68         return safe_compare_digest(val1, val2)
69
70
71 class SignatureValidator(object):
72     """Signature validator.
73
74     https://devdocs.line.me/en/#webhook-authentication
75     """
76
77     def __init__(self, channel_secret):
78         """__init__ method.
79
80         :param str channel_secret: Channel secret (as text)
81         """
82         self.channel_secret = channel_secret.encode('utf-8')
83
84     def validate(self, body, signature):
85         """Check signature.
86
87         https://devdocs.line.me/en/#webhook-authentication
88
89         :param str body: Request body (as text)
90         :param str signature: X-Line-Signature value (as text)
91         :rtype: bool
92         :return: result
93         """
94         gen_signature = hmac.new(
95             self.channel_secret,
96             body.encode('utf-8'),
97             hashlib.sha256
98         ).digest()
99
100         return compare_digest(
101                 signature.encode('utf-8'), base64.b64encode(gen_signature)
102         )
103
104
105 class WebhookParser(object):
106     """Webhook Parser."""
107
108     def __init__(self, channel_secret):
109         """__init__ method.
110
111         :param str channel_secret: Channel secret (as text)
112         """
113         self.signature_validator = SignatureValidator(channel_secret)
114
115     def parse(self, body, signature):
116         """Parse webhook request body as text.
117
118         :param str body: Webhook request body (as text)
119         :param str signature: X-Line-Signature value (as text)
120         :rtype: list[T <= :py:class:`linebot.models.events.Event`]
121         :return:
122         """
123         if not self.signature_validator.validate(body, signature):
124             raise InvalidSignatureError(
125                 'Invalid signature. signature=' + signature)
126
127         body_json = json.loads(body)
128         events = []
129         for event in body_json['events']:
130             event_type = event['type']
131             if event_type == 'message':
132                 events.append(MessageEvent.new_from_json_dict(event))
133             elif event_type == 'follow':
134                 events.append(FollowEvent.new_from_json_dict(event))
135             elif event_type == 'unfollow':
136                 events.append(UnfollowEvent.new_from_json_dict(event))
137             elif event_type == 'join':
138                 events.append(JoinEvent.new_from_json_dict(event))
139             elif event_type == 'leave':
140                 events.append(LeaveEvent.new_from_json_dict(event))
141             elif event_type == 'postback':
142                 events.append(PostbackEvent.new_from_json_dict(event))
143             elif event_type == 'beacon':
144                 events.append(BeaconEvent.new_from_json_dict(event))
145             elif event_type == 'accountLink':
146                 events.append(AccountLinkEvent.new_from_json_dict(event))
147             else:
148                 LOGGER.warn('Unknown event type. type=' + event_type)
149
150         return events
151
152
153 class WebhookHandler(object):
154     """Webhook Handler."""
155
156     def __init__(self, channel_secret):
157         """__init__ method.
158
159         :param str channel_secret: Channel secret (as text)
160         """
161         self.parser = WebhookParser(channel_secret)
162         self._handlers = {}
163         self._default = None
164
165     def add(self, event, message=None):
166         """[Decorator] Add handler method.
167
168         :param event: Specify a kind of Event which you want to handle
169         :type event: T <= :py:class:`linebot.models.events.Event` class
170         :param message: (optional) If event is MessageEvent,
171             specify kind of Messages which you want to handle
172         :type: message: T <= :py:class:`linebot.models.messages.Message` class
173         :rtype: func
174         :return: decorator
175         """
176         def decorator(func):
177             if isinstance(message, (list, tuple)):
178                 for it in message:
179                     self.__add_handler(func, event, message=it)
180             else:
181                 self.__add_handler(func, event, message=message)
182
183             return func
184
185         return decorator
186
187     def default(self):
188         """[Decorator] Set default handler method.
189
190         :rtype: func
191         :return:
192         """
193         def decorator(func):
194             self._default = func
195             return func
196
197         return decorator
198
199     def handle(self, body, signature):
200         """Handle webhook.
201
202         :param str body: Webhook request body (as text)
203         :param str signature: X-Line-Signature value (as text)
204         """
205         events = self.parser.parse(body, signature)
206
207         for event in events:
208             func = None
209             key = None
210
211             if isinstance(event, MessageEvent):
212                 key = self.__get_handler_key(
213                     event.__class__, event.message.__class__)
214                 func = self._handlers.get(key, None)
215
216             if func is None:
217                 key = self.__get_handler_key(event.__class__)
218                 func = self._handlers.get(key, None)
219
220             if func is None:
221                 func = self._default
222
223             if func is None:
224                 LOGGER.info('No handler of ' + key + ' and no default handler')
225             else:
226                 args_count = self.__get_args_count(func)
227                 if args_count == 0:
228                     func()
229                 else:
230                     func(event)
231
232     def __add_handler(self, func, event, message=None):
233         key = self.__get_handler_key(event, message=message)
234         self._handlers[key] = func
235
236     @staticmethod
237     def __get_args_count(func):
238         if PY3:
239             arg_spec = inspect.getfullargspec(func)
240             return len(arg_spec.args)
241         else:
242             arg_spec = inspect.getargspec(func)
243             return len(arg_spec.args)
244
245     @staticmethod
246     def __get_handler_key(event, message=None):
247         if message is None:
248             return event.__name__
249         else:
250             return event.__name__ + '_' + message.__name__