utils: add dumps/loads for confidentiality protected tokens (#61130)

This commit is contained in:
Benjamin Dauvergne 2022-01-26 16:48:20 +01:00
parent 2d93d95fc5
commit 0795cbeb89
2 changed files with 70 additions and 7 deletions

View File

@ -25,6 +25,8 @@ from Cryptodome.Cipher import AES
from Cryptodome.Hash import HMAC, SHA256
from Cryptodome.Protocol.KDF import PBKDF2
from django.conf import settings
from django.core import signing
from django.core.signing import BadSignature, SignatureExpired # pylint: disable=unused-import
from django.utils.crypto import constant_time_compare
from django.utils.encoding import force_bytes
@ -50,7 +52,7 @@ def get_hashclass(name):
return None
def aes_base64_encrypt(key, data):
def aes_base64_encrypt(key, data, urlsafe=False, sep=b'$'):
"""Generate an AES key from any key material using PBKDF2, and encrypt data using CFB mode. A
new IV is generated each time, the IV is also used as salt for PBKDF2.
"""
@ -58,10 +60,13 @@ def aes_base64_encrypt(key, data):
aes_key = PBKDF2(key, iv)
aes = AES.new(aes_key, AES.MODE_CFB, iv=iv)
crypted = aes.encrypt(data)
return b'%s$%s' % (base64.b64encode(iv), base64.b64encode(crypted))
if urlsafe:
return b'%s%s%s' % (base64url_encode(iv), sep, base64url_encode(crypted))
else:
return b'%s%s%s' % (base64.b64encode(iv), sep, base64.b64encode(crypted))
def aes_base64_decrypt(key, payload, raise_on_error=True):
def aes_base64_decrypt(key, payload, raise_on_error=True, urlsafe=False, sep=b'$'):
'''Decrypt data encrypted with aes_base64_encrypt'''
if not isinstance(payload, bytes):
try:
@ -69,14 +74,20 @@ def aes_base64_decrypt(key, payload, raise_on_error=True):
except Exception:
raise DecryptionError('payload is not an ASCII string')
try:
iv, crypted = payload.split(b'$')
iv, crypted = payload.split(sep)
except (ValueError, TypeError):
if raise_on_error:
raise DecryptionError('bad payload')
return None
if urlsafe:
decode = base64url_decode
else:
decode = base64.b64decode
try:
iv = base64.b64decode(iv)
crypted = base64.b64decode(crypted)
iv = decode(iv)
crypted = decode(crypted)
except Base64Error:
if raise_on_error:
raise DecryptionError('incorrect base64 encoding')
@ -221,3 +232,25 @@ def hash_chain(n, seed=None, encoded_seed=None):
for dummy in range(n - 1):
chain.append(hashlib.sha256(chain[-1] + settings.SECRET_KEY.encode()).digest())
return [base64url_encode(x).decode('ascii') for x in chain]
def dumps(obj, key=None, **kwargs):
if not key:
key = settings.SECRET_KEY
return aes_base64_encrypt(
key.encode(), signing.dumps(obj, key=key, **kwargs).encode(), urlsafe=True, sep=b':'
).decode()
def loads(s, key=None, **kwargs):
if not key:
key = settings.SECRET_KEY
try:
decrypted = aes_base64_decrypt(key.encode(), s.encode(), urlsafe=True, sep=b':')
except DecryptionError:
return signing.loads(s, key=key, **kwargs)
try:
decrypted = decrypted.decode()
except UnicodeDecodeError:
raise BadSignature
return signing.loads(decrypted, key=key, **kwargs)

View File

@ -14,13 +14,14 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import random
import uuid
import pytest
from django.utils.encoding import force_bytes
from authentic2 import crypto
from authentic2.utils import crypto
key = b'1234'
@ -72,3 +73,32 @@ def test_hmac_url():
key = 'é'
url = 'https://example.invalid/\u0000'
assert crypto.check_hmac_url(key, url, crypto.hmac_url(key, url))
def test_dumps_loads(settings, freezer):
data = {'a': 1, 'b': 'foo', 'bar': 'zib@!$#$#$#$#'}
token = crypto.dumps(data)
assert token.encode('ascii')
assert crypto.loads(token) == data
settings.SECRET_KEY = 'bb'
with pytest.raises(crypto.BadSignature):
assert crypto.loads(token)
token = crypto.dumps(data, key='aa')
with pytest.raises(crypto.BadSignature):
assert crypto.loads(token)
assert crypto.loads(token, key='aa') == data
freezer.move_to(datetime.timedelta(seconds=100))
with pytest.raises(crypto.SignatureExpired):
crypto.loads(token, key='aa', max_age=10)
assert crypto.loads(token, key='aa') == data
def test_dumps_loads_retrocompatibility():
from django.core import signing
data = {'a': 1, 'b': 'foo', 'bar': 'zib@!$#$#$#$#'}
token = signing.dumps(data)
assert crypto.loads(token) == data