utils: add a safe_get_or_create() primitive (#52929)
This commit is contained in:
parent
03a874be58
commit
c514c67927
|
@ -0,0 +1,57 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) 2010-2021 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
|
||||
|
||||
def poisson_random(frequency):
|
||||
'''Generate random numbers following a poisson distribution'''
|
||||
return -math.log(1.0 - random.random()) / frequency
|
||||
|
||||
|
||||
SAFE_GET_OR_CREATE_RETRIES = 3
|
||||
|
||||
|
||||
class ConcurrencyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def safe_get_or_create(model, defaults=None, **kwargs):
|
||||
assert (
|
||||
getattr(settings, 'TESTING', False) or not connection.in_atomic_block
|
||||
), 'safe_get_or_create cannot be used in inside a transaction'
|
||||
|
||||
defaults = defaults or {}
|
||||
exception = None
|
||||
for dummy in range(SAFE_GET_OR_CREATE_RETRIES):
|
||||
try:
|
||||
instance, created = model.objects.get_or_create(defaults=defaults, **kwargs)
|
||||
except model.MultipleObjectsReturned as e:
|
||||
exception = e
|
||||
time.sleep(max(poisson_random(1), 0.5))
|
||||
continue
|
||||
|
||||
if created and model.objects.filter(**kwargs).exclude(pk=instance.pk).exists():
|
||||
instance.delete()
|
||||
time.sleep(max(poisson_random(1), 0.5))
|
||||
continue
|
||||
return instance, created
|
||||
raise exception
|
|
@ -57,3 +57,5 @@ A2_MAX_EMAILS_FOR_ADDRESS = None
|
|||
|
||||
A2_TOKEN_EXISTS_WARNING = False
|
||||
A2_REDIRECT_WHITELIST = ['http://sp.org/']
|
||||
|
||||
TESTING = True
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# authentic2 - versatile identity manager
|
||||
# Copyright (C) 2010-2021 Entr'ouvert
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Affero General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
# authentic2
|
||||
|
||||
import threading
|
||||
|
||||
from django.db import connection
|
||||
|
||||
from authentic2.custom_user.models import User
|
||||
from authentic2.utils.models import safe_get_or_create
|
||||
|
||||
|
||||
def test_safe_get_or_create(transactional_db, concurrency):
|
||||
EMAIL = 'john.doe@example.net'
|
||||
barrier = threading.Barrier(concurrency)
|
||||
users = []
|
||||
exceptions = []
|
||||
threads = []
|
||||
|
||||
def thread_run():
|
||||
try:
|
||||
barrier.wait()
|
||||
user, created = safe_get_or_create(User, email=EMAIL, defaults={'email': EMAIL})
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
else:
|
||||
if created:
|
||||
users.append(user)
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
for _ in range(concurrency):
|
||||
threads.append(threading.Thread(target=thread_run))
|
||||
threads[-1].start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
assert len(users) == 1
|
||||
assert User.objects.count() == 1
|
||||
assert len(exceptions) == 0
|
||||
users[0].delete()
|
Loading…
Reference in New Issue