100 lines
3.2 KiB
Python
100 lines
3.2 KiB
Python
# passerelle - uniform access to multiple data sources and services
|
|
# Copyright (C) 2018 Entr'ouvert
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Affero General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
import base64
|
|
import hashlib
|
|
import hmac
|
|
import time
|
|
from urllib import parse as urlparse
|
|
from uuid import uuid4
|
|
|
|
from django.utils.encoding import force_bytes, force_str
|
|
from requests.auth import AuthBase
|
|
|
|
|
|
class HawkAuth(AuthBase):
|
|
def __init__(self, id, key, algorithm='sha256', ext=''):
|
|
self.id = id.encode('utf-8')
|
|
self.key = key.encode('utf-8')
|
|
self.algorithm = algorithm
|
|
self.timestamp = str(int(time.time()))
|
|
self.nonce = uuid4().hex
|
|
self.ext = ext
|
|
|
|
def get_payload_hash(self, req):
|
|
p_hash = hashlib.new(self.algorithm)
|
|
p_hash.update(force_bytes('hawk.1.payload\n'))
|
|
p_hash.update(force_bytes(req.headers.get('Content-Type', '') + '\n'))
|
|
p_hash.update(force_bytes(req.body or ''))
|
|
p_hash.update(force_bytes('\n'))
|
|
return force_str(base64.b64encode(p_hash.digest()))
|
|
|
|
def get_authorization_header(self, req):
|
|
url_parts = urlparse.urlparse(req.url)
|
|
uri = url_parts.path
|
|
if url_parts.query:
|
|
uri += '?' + url_parts.query
|
|
if url_parts.port is None:
|
|
if url_parts.scheme == 'http':
|
|
port = '80'
|
|
elif url_parts.scheme == 'https':
|
|
port = '443'
|
|
hash = self.get_payload_hash(req)
|
|
data = [
|
|
'hawk.1.header',
|
|
self.timestamp,
|
|
self.nonce,
|
|
req.method.upper(),
|
|
uri,
|
|
url_parts.hostname,
|
|
port,
|
|
hash,
|
|
self.ext,
|
|
'',
|
|
]
|
|
digestmod = getattr(hashlib, self.algorithm)
|
|
result = hmac.new(force_bytes(self.key), force_bytes('\n'.join(data)), digestmod)
|
|
mac = force_str(base64.b64encode(result.digest()))
|
|
authorization = 'Hawk id="%s", ts="%s", nonce="%s", hash="%s", mac="%s"' % (
|
|
force_str(self.id),
|
|
self.timestamp,
|
|
self.nonce,
|
|
hash,
|
|
mac,
|
|
)
|
|
if self.ext:
|
|
authorization += ', ext="%s"' % self.ext
|
|
return authorization
|
|
|
|
def __call__(self, r):
|
|
r.headers['Authorization'] = self.get_authorization_header(r)
|
|
return r
|
|
|
|
|
|
class HttpBearerAuth(AuthBase):
|
|
def __init__(self, token):
|
|
self.token = token
|
|
|
|
def __eq__(self, other):
|
|
return self.token == getattr(other, 'token', None)
|
|
|
|
def __ne__(self, other):
|
|
return not self == other
|
|
|
|
def __call__(self, r):
|
|
r.headers['Authorization'] = 'Bearer ' + self.token
|
|
return r
|