api: apply unflatten to input JSON (#66742)

It should help dumb clients to make API calls.
This commit is contained in:
Benjamin Dauvergne 2022-06-28 21:04:00 +02:00
parent 3a6355673e
commit 80c0e0fdd2
4 changed files with 204 additions and 0 deletions

View File

@ -309,6 +309,11 @@ MIGRATION_MODULES = {
# Django REST Framework
REST_FRAMEWORK = {
'NON_FIELD_ERRORS_KEY': '__all__',
'DEFAULT_PARSER_CLASSES': [
'authentic2.utils.rest_framework.UnflattenJSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser',
],
'DEFAULT_AUTHENTICATION_CLASSES': (
'authentic2.authentication.Authentic2Authentication',
'rest_framework.authentication.SessionAuthentication',

View File

@ -0,0 +1,97 @@
# authentic2 - versatile identity manager
# Copyright (C) 2022 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from rest_framework.parsers import JSONParser
_FLATTEN_SEPARATOR = '/'
def _is_number(string):
if hasattr(string, 'isdecimal'):
return string.isdecimal() and [ord(c) < 256 for c in string]
else: # str PY2
return string.isdigit()
def _unflatten(d, separator=_FLATTEN_SEPARATOR):
"""Transform:
{"a/b/0/x": "1234"}
into:
{"a": {"b": [{"x": "1234"}]}}
"""
if not isinstance(d, dict) or not d: # unflattening an empty dict has no sense
return d
# ok d is a dict
def map_digits(parts):
return [int(x) if _is_number(x) else x for x in parts]
keys = [(map_digits(key.split(separator)), key) for key in d]
keys.sort()
def set_path(path, orig_key, d, value, i=0):
assert path
key, tail = path[i], path[i + 1 :]
if not tail: # end of path, set thevalue
if isinstance(key, int):
assert isinstance(d, list)
if len(d) != key:
raise ValueError('incomplete array before %s' % orig_key)
d.append(value)
else:
assert isinstance(d, dict)
d[key] = value
else:
new = [] if isinstance(tail[0], int) else {}
if isinstance(key, int):
assert isinstance(d, list)
if len(d) < key:
raise ValueError(
'incomplete array before %s in %s'
% (separator.join(map(str, path[: i + 1])), orig_key)
)
if len(d) == key:
d.append(new)
else:
new = d[key]
else:
new = d.setdefault(key, new)
set_path(path, orig_key, new, value, i + 1)
# Is the first level an array or a dict ?
if isinstance(keys[0][0][0], int):
new = []
else:
new = {}
for path, key in keys:
value = d[key]
set_path(path, key, new, value)
return new
class UnflattenJSONParser(JSONParser):
def parse(self, *args, **kwargs):
result = super().parse(*args, **kwargs)
if isinstance(result, dict) and any('/' in key for key in result):
result = _unflatten(result)
return result

View File

@ -222,6 +222,15 @@ class TestViews:
assert resp.json == {'err': 0, 'data': []}
assert not set(roles.grandchild.children(include_self=False, direct=True))
def test_delete_unflatten(self, app, roles):
assert set(roles.parent.children(include_self=False, direct=True)) == {roles.child}
resp = app.delete_json(
'/api/roles/%s/relationships/parents/' % roles.grandchild.uuid,
params={'parent/uuid': roles.child.uuid},
)
assert resp.json == {'err': 0, 'data': []}
assert not set(roles.grandchild.children(include_self=False, direct=True))
class TestPermission:
@pytest.fixture
def user(self, simple_user):

View File

@ -0,0 +1,93 @@
# authentic2 - versatile identity manager
# Copyright (C) 2010-2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# authentic2
import io
import json
import pytest
from authentic2.utils.rest_framework import _FLATTEN_SEPARATOR as SEP
from authentic2.utils.rest_framework import UnflattenJSONParser
from authentic2.utils.rest_framework import _unflatten as unflatten
def test_unflatten_base():
assert unflatten('') == ''
assert unflatten('a') == 'a'
assert unflatten([]) == []
assert unflatten([1]) == [1]
assert unflatten({}) == {}
assert unflatten(0) == 0
assert unflatten(1) == 1
assert unflatten(False) is False
assert unflatten(True) is True
def test_unflatten_dict():
assert unflatten(
{
'a' + SEP + 'b' + SEP + '0': 1,
'a' + SEP + 'c' + SEP + '1': 'a',
'a' + SEP + 'b' + SEP + '1': True,
'a' + SEP + 'c' + SEP + '0': [1],
}
) == {
'a': {
'b': [1, True],
'c': [[1], 'a'],
}
}
def test_unflatten_array():
assert unflatten(
{
'0' + SEP + 'b' + SEP + '0': 1,
'1' + SEP + 'c' + SEP + '1': 'a',
'0' + SEP + 'b' + SEP + '1': True,
'1' + SEP + 'c' + SEP + '0': [1],
}
) == [{'b': [1, True]}, {'c': [[1], 'a']}]
def test_unflatten_missing_final_index():
with pytest.raises(ValueError) as exc_info:
unflatten({'1': 1})
assert 'incomplete' in exc_info.value.args[0]
def test_unflatten_missing_intermediate_index():
with pytest.raises(ValueError) as exc_info:
unflatten({'a' + SEP + '1' + SEP + 'b': 1})
assert 'incomplete' in exc_info.value.args[0]
class TestUnflattenJsonParser:
@pytest.fixture
def parser(self):
return UnflattenJSONParser()
def test_parse(self, parser):
in_json = {
'a/b/c': {'d/e': 1},
'b/0': 1,
'b/1': 2,
}
out_json = {'a': {'b': {'c': {'d/e': 1}}}, 'b': [1, 2]}
stream = io.BytesIO(json.dumps(in_json).encode())
assert parser.parse(stream) == out_json