misc: add lock model (#64485)

This commit is contained in:
Benjamin Dauvergne 2022-04-07 23:03:50 +02:00
parent 873ebb3c7a
commit e555ca5a0a
3 changed files with 179 additions and 2 deletions

View File

@ -0,0 +1,24 @@
# Generated by Django 2.2.27 on 2022-04-07 20:24
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authentic2', '0040_add_external_guid'),
]
operations = [
migrations.CreateModel(
name='Lock',
fields=[
('created', models.DateTimeField(auto_now_add=True, verbose_name='Creation date')),
('name', models.TextField(primary_key=True, serialize=False, verbose_name='Name')),
],
options={
'verbose_name': 'Lock',
'verbose_name_plural': 'Lock',
},
),
]

View File

@ -28,7 +28,7 @@ from django.contrib.postgres.fields import jsonb
from django.contrib.postgres.indexes import GinIndex
from django.contrib.postgres.search import SearchVectorField
from django.core.exceptions import ValidationError
from django.db import models, transaction
from django.db import DatabaseError, models, transaction
from django.db.models.query import Q
from django.urls import reverse
from django.utils import timezone
@ -559,3 +559,43 @@ class Token(models.Model):
def cleanup(cls, now=None):
now = now or timezone.now()
cls.objects.filter(expires__lte=now).delete()
class Lock(models.Model):
created = models.DateTimeField(auto_now_add=True, verbose_name=_('Creation date'))
name = models.TextField(verbose_name=_('Name'), primary_key=True)
class Error(Exception):
pass
@classmethod
def cleanup(cls, now=None, age=None):
age = age if age is not None else datetime.timedelta(hours=1)
now = now or timezone.now()
with transaction.atomic(savepoint=False):
pks = (
cls.objects.filter(created__lte=now - age)
.select_for_update(skip_locked=True)
.values_list('pk', flat=True)
)
cls.objects.filter(pk__in=pks).delete()
@classmethod
def lock(cls, *args, nowait=False):
# force ordering to prevent deadlock
names = sorted(args)
for name in names:
dummy, dummy = cls.objects.get_or_create(name=name)
try:
cls.objects.select_for_update(nowait=nowait).get(name=name)
except transaction.TransactionManagementError:
raise
except DatabaseError:
# happen only if nowait=True, in this case the error must be
# intercepted with "except Lock.Error:", this error is
# recoverable (i.e. the transaction can continue after)
raise cls.Error
class Meta:
verbose_name = _('Lock')
verbose_name_plural = _('Lock')

View File

@ -14,12 +14,15 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import datetime
import threading
import pytest
from django.db import connection, transaction
from authentic2.a2_rbac.models import Role
from authentic2.custom_user.models import User
from authentic2.models import Attribute, Service
from authentic2.models import Attribute, Lock, Service
from authentic2.utils.misc import ServiceAccessDenied
@ -52,3 +55,113 @@ def test_service_authorize(db):
user.is_superuser = True
user.save()
assert service.authorize(user)
class TestLock:
def test_wait(self, transactional_db):
with pytest.raises(transaction.TransactionManagementError):
Lock.lock('a')
with transaction.atomic():
Lock.lock('a')
count = 50
l1 = ['a'] * count
l2 = []
def f():
import random
for _ in range(count):
with transaction.atomic():
locks = ['a', 'b']
random.shuffle(locks)
Lock.lock(*locks)
if l1:
l2.append(l1.pop())
l2.append('x')
connection.close()
thread1 = threading.Thread(target=f)
thread2 = threading.Thread(target=f)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert len(l1) == 0
assert len(l2) == count + 2
def test_nowait(self, transactional_db):
barrier1 = threading.Barrier(2)
barrier2 = threading.Barrier(2)
# prevent contention on the unique index lock
Lock.objects.create(name='lock-name')
def locker():
try:
with transaction.atomic():
Lock.lock('lock-name')
barrier1.wait()
barrier2.wait()
finally:
connection.close()
exception = None
def locker_nowait():
nonlocal exception
try:
with transaction.atomic():
barrier1.wait()
try:
Lock.lock('lock-name', nowait=True)
except Lock.Error as e:
exception = e
finally:
barrier2.wait()
connection.close()
locker_thread = threading.Thread(target=locker)
locker_nowait_thread = threading.Thread(target=locker_nowait)
locker_thread.start()
locker_nowait_thread.start()
locker_thread.join()
locker_nowait_thread.join()
assert exception is not None
def test_clean(self, transactional_db):
import time
import uuid
count = 0
def take_locks():
nonlocal count
for _ in range(100):
with transaction.atomic():
name = str(uuid.uuid4())
Lock.lock(name)
time.sleep(0.01)
assert Lock.objects.get(name=name)
count += 1
connection.close()
thread1 = threading.Thread(target=take_locks)
thread1.start()
def clean():
while thread1.is_alive():
time.sleep(0.001)
Lock.cleanup(age=datetime.timedelta(seconds=0))
connection.close()
thread2 = threading.Thread(target=clean)
thread2.start()
thread1.join()
thread2.join()
Lock.cleanup(age=datetime.timedelta(seconds=0))
assert Lock.objects.count() == 0
assert count == 100