diff --git a/pywebpush/__init__.py b/pywebpush/__init__.py index d834100..3312745 100644 --- a/pywebpush/__init__.py +++ b/pywebpush/__init__.py @@ -95,7 +95,7 @@ class WebPusher: "aes128gcm" # draft-httpbis-encryption-encoding-04 ] - def __init__(self, subscription_info): + def __init__(self, subscription_info, requests_session=None): """Initialize using the info provided by the client PushSubscription object (See https://developer.mozilla.org/en-US/docs/Web/API/PushManager/subscribe) @@ -104,7 +104,16 @@ class WebPusher: the client. :type subscription_info: dict + :param requests_session: a requests.Session object to optimize requests + to the same client. + :type requests_session: requests.Session + """ + if requests_session is None: + self.requests_method = requests + else: + self.requests_method = requests_session + if 'endpoint' not in subscription_info: raise WebPushException("subscription_info missing endpoint URL") self.subscription_info = subscription_info @@ -285,10 +294,10 @@ class WebPusher: # Authorization / Crypto-Key (VAPID headers) if curl: return self.as_curl(endpoint, encoded_data, headers) - return requests.post(endpoint, - data=encoded_data, - headers=headers, - timeout=timeout) + return self.requests_method.post(endpoint, + data=encoded_data, + headers=headers, + timeout=timeout) def webpush(subscription_info, diff --git a/pywebpush/tests/test_webpush.py b/pywebpush/tests/test_webpush.py index e70f681..136f749 100644 --- a/pywebpush/tests/test_webpush.py +++ b/pywebpush/tests/test_webpush.py @@ -318,3 +318,21 @@ class WebpushTestCase(unittest.TestCase): eq_(mock_post.call_args[1].get('timeout'), 5.2) webpush(subscription_info, timeout=10.001) eq_(mock_post.call_args[1].get('timeout'), 10.001) + + @patch("requests.Session") + def test_send_using_requests_session(self, mock_session): + subscription_info = self._gen_subscription_info() + headers = {"Crypto-Key": "pre-existing", + "Authentication": "bearer vapid"} + data = "Mary had a little lamb" + WebPusher(subscription_info, + requests_session=mock_session).send(data, headers) + eq_(subscription_info.get('endpoint'), + mock_session.post.call_args[0][0]) + pheaders = mock_session.post.call_args[1].get('headers') + eq_(pheaders.get('ttl'), '0') + ok_('encryption' in pheaders) + eq_(pheaders.get('AUTHENTICATION'), headers.get('Authentication')) + ckey = pheaders.get('crypto-key') + ok_('pre-existing' in ckey) + eq_(pheaders.get('content-encoding'), 'aesgcm')