From c514c6792777c95102b0a871c816987ceb24cc1b Mon Sep 17 00:00:00 2001 From: Benjamin Dauvergne Date: Sat, 17 Apr 2021 17:53:44 +0200 Subject: [PATCH] utils: add a safe_get_or_create() primitive (#52929) --- src/authentic2/utils/models.py | 57 ++++++++++++++++++++++++++++++++++ tests/settings.py | 2 ++ tests/test_utils_models.py | 53 +++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+) create mode 100644 src/authentic2/utils/models.py create mode 100644 tests/test_utils_models.py diff --git a/src/authentic2/utils/models.py b/src/authentic2/utils/models.py new file mode 100644 index 000000000..131c3a233 --- /dev/null +++ b/src/authentic2/utils/models.py @@ -0,0 +1,57 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2021 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + +import math +import random +import time + +from django.conf import settings +from django.db import connection + + +def poisson_random(frequency): + '''Generate random numbers following a poisson distribution''' + return -math.log(1.0 - random.random()) / frequency + + +SAFE_GET_OR_CREATE_RETRIES = 3 + + +class ConcurrencyError(Exception): + pass + + +def safe_get_or_create(model, defaults=None, **kwargs): + assert ( + getattr(settings, 'TESTING', False) or not connection.in_atomic_block + ), 'safe_get_or_create cannot be used in inside a transaction' + + defaults = defaults or {} + exception = None + for dummy in range(SAFE_GET_OR_CREATE_RETRIES): + try: + instance, created = model.objects.get_or_create(defaults=defaults, **kwargs) + except model.MultipleObjectsReturned as e: + exception = e + time.sleep(max(poisson_random(1), 0.5)) + continue + + if created and model.objects.filter(**kwargs).exclude(pk=instance.pk).exists(): + instance.delete() + time.sleep(max(poisson_random(1), 0.5)) + continue + return instance, created + raise exception diff --git a/tests/settings.py b/tests/settings.py index b3cd2fcea..cd0a33769 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -57,3 +57,5 @@ A2_MAX_EMAILS_FOR_ADDRESS = None A2_TOKEN_EXISTS_WARNING = False A2_REDIRECT_WHITELIST = ['http://sp.org/'] + +TESTING = True diff --git a/tests/test_utils_models.py b/tests/test_utils_models.py new file mode 100644 index 000000000..6d61b7abd --- /dev/null +++ b/tests/test_utils_models.py @@ -0,0 +1,53 @@ +# authentic2 - versatile identity manager +# Copyright (C) 2010-2021 Entr'ouvert +# +# This program is free software: you can redistribute it and/or modify it +# under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +# authentic2 + +import threading + +from django.db import connection + +from authentic2.custom_user.models import User +from authentic2.utils.models import safe_get_or_create + + +def test_safe_get_or_create(transactional_db, concurrency): + EMAIL = 'john.doe@example.net' + barrier = threading.Barrier(concurrency) + users = [] + exceptions = [] + threads = [] + + def thread_run(): + try: + barrier.wait() + user, created = safe_get_or_create(User, email=EMAIL, defaults={'email': EMAIL}) + except Exception as e: + exceptions.append(e) + else: + if created: + users.append(user) + finally: + connection.close() + + for _ in range(concurrency): + threads.append(threading.Thread(target=thread_run)) + threads[-1].start() + for thread in threads: + thread.join() + assert len(users) == 1 + assert User.objects.count() == 1 + assert len(exceptions) == 0 + users[0].delete()