utils: use an exclusive lock on model's table in safe_get_or_create (#60658)

This commit is contained in:
Benjamin Dauvergne 2022-01-14 10:08:47 +01:00
parent 4f96d751c0
commit 20a8b32ee6
2 changed files with 10 additions and 36 deletions

View File

@ -14,24 +14,8 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import math
import random
import time
from django.conf import settings
from django.db import connection
def poisson_random(frequency):
'''Generate random numbers following a poisson distribution'''
return -math.log(1.0 - random.random()) / frequency
SAFE_GET_OR_CREATE_RETRIES = 3
class ConcurrencyError(Exception):
pass
from django.db import connection, transaction
def safe_get_or_create(model, defaults=None, **kwargs):
@ -39,19 +23,11 @@ def safe_get_or_create(model, defaults=None, **kwargs):
getattr(settings, 'TESTING', False) or not connection.in_atomic_block
), 'safe_get_or_create cannot be used in inside a transaction'
defaults = defaults or {}
exception = None
for dummy in range(SAFE_GET_OR_CREATE_RETRIES):
try:
instance, created = model.objects.get_or_create(defaults=defaults, **kwargs)
except model.MultipleObjectsReturned as e:
exception = e
time.sleep(max(poisson_random(1), 0.5))
continue
if created and model.objects.filter(**kwargs).exclude(pk=instance.pk).exists():
instance.delete()
time.sleep(max(poisson_random(1), 0.5))
continue
return instance, created
raise exception
try:
return model.objects.get(**kwargs), False
except model.DoesNotExist:
pass
with transaction.atomic():
with connection.cursor() as cur:
cur.execute('LOCK TABLE "%s" IN EXCLUSIVE MODE' % model._meta.db_table)
return model.objects.get_or_create(defaults=defaults, **kwargs)

View File

@ -17,7 +17,6 @@
import threading
from django.core.exceptions import MultipleObjectsReturned
from django.db import connection
from authentic2.custom_user.models import User
@ -48,8 +47,7 @@ def test_safe_get_or_create(transactional_db, concurrency):
threads[-1].start()
for thread in threads:
thread.join()
assert not exceptions
assert len(users) == 1
assert User.objects.count() == 1
assert all(isinstance(exception, MultipleObjectsReturned) for exception in exceptions)
assert len(exceptions) < (0.5 * concurrency) # 50% of failure is 'ok-ish' with a lot of concurrency
users[0].delete()