misc: add lock model (#64485)
This commit is contained in:
parent
873ebb3c7a
commit
e555ca5a0a
|
@ -0,0 +1,24 @@
|
|||
# Generated by Django 2.2.27 on 2022-04-07 20:24
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('authentic2', '0040_add_external_guid'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name='Lock',
|
||||
fields=[
|
||||
('created', models.DateTimeField(auto_now_add=True, verbose_name='Creation date')),
|
||||
('name', models.TextField(primary_key=True, serialize=False, verbose_name='Name')),
|
||||
],
|
||||
options={
|
||||
'verbose_name': 'Lock',
|
||||
'verbose_name_plural': 'Lock',
|
||||
},
|
||||
),
|
||||
]
|
|
@ -28,7 +28,7 @@ from django.contrib.postgres.fields import jsonb
|
|||
from django.contrib.postgres.indexes import GinIndex
|
||||
from django.contrib.postgres.search import SearchVectorField
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models, transaction
|
||||
from django.db import DatabaseError, models, transaction
|
||||
from django.db.models.query import Q
|
||||
from django.urls import reverse
|
||||
from django.utils import timezone
|
||||
|
@ -559,3 +559,43 @@ class Token(models.Model):
|
|||
def cleanup(cls, now=None):
|
||||
now = now or timezone.now()
|
||||
cls.objects.filter(expires__lte=now).delete()
|
||||
|
||||
|
||||
class Lock(models.Model):
|
||||
created = models.DateTimeField(auto_now_add=True, verbose_name=_('Creation date'))
|
||||
name = models.TextField(verbose_name=_('Name'), primary_key=True)
|
||||
|
||||
class Error(Exception):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def cleanup(cls, now=None, age=None):
|
||||
age = age if age is not None else datetime.timedelta(hours=1)
|
||||
now = now or timezone.now()
|
||||
with transaction.atomic(savepoint=False):
|
||||
pks = (
|
||||
cls.objects.filter(created__lte=now - age)
|
||||
.select_for_update(skip_locked=True)
|
||||
.values_list('pk', flat=True)
|
||||
)
|
||||
cls.objects.filter(pk__in=pks).delete()
|
||||
|
||||
@classmethod
|
||||
def lock(cls, *args, nowait=False):
|
||||
# force ordering to prevent deadlock
|
||||
names = sorted(args)
|
||||
for name in names:
|
||||
dummy, dummy = cls.objects.get_or_create(name=name)
|
||||
try:
|
||||
cls.objects.select_for_update(nowait=nowait).get(name=name)
|
||||
except transaction.TransactionManagementError:
|
||||
raise
|
||||
except DatabaseError:
|
||||
# happen only if nowait=True, in this case the error must be
|
||||
# intercepted with "except Lock.Error:", this error is
|
||||
# recoverable (i.e. the transaction can continue after)
|
||||
raise cls.Error
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('Lock')
|
||||
verbose_name_plural = _('Lock')
|
||||
|
|
|
@ -14,12 +14,15 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from django.db import connection, transaction
|
||||
|
||||
from authentic2.a2_rbac.models import Role
|
||||
from authentic2.custom_user.models import User
|
||||
from authentic2.models import Attribute, Service
|
||||
from authentic2.models import Attribute, Lock, Service
|
||||
from authentic2.utils.misc import ServiceAccessDenied
|
||||
|
||||
|
||||
|
@ -52,3 +55,113 @@ def test_service_authorize(db):
|
|||
user.is_superuser = True
|
||||
user.save()
|
||||
assert service.authorize(user)
|
||||
|
||||
|
||||
class TestLock:
|
||||
def test_wait(self, transactional_db):
|
||||
with pytest.raises(transaction.TransactionManagementError):
|
||||
Lock.lock('a')
|
||||
|
||||
with transaction.atomic():
|
||||
Lock.lock('a')
|
||||
|
||||
count = 50
|
||||
l1 = ['a'] * count
|
||||
l2 = []
|
||||
|
||||
def f():
|
||||
import random
|
||||
|
||||
for _ in range(count):
|
||||
with transaction.atomic():
|
||||
locks = ['a', 'b']
|
||||
random.shuffle(locks)
|
||||
Lock.lock(*locks)
|
||||
if l1:
|
||||
l2.append(l1.pop())
|
||||
l2.append('x')
|
||||
connection.close()
|
||||
|
||||
thread1 = threading.Thread(target=f)
|
||||
thread2 = threading.Thread(target=f)
|
||||
thread1.start()
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
|
||||
assert len(l1) == 0
|
||||
assert len(l2) == count + 2
|
||||
|
||||
def test_nowait(self, transactional_db):
|
||||
barrier1 = threading.Barrier(2)
|
||||
barrier2 = threading.Barrier(2)
|
||||
|
||||
# prevent contention on the unique index lock
|
||||
Lock.objects.create(name='lock-name')
|
||||
|
||||
def locker():
|
||||
try:
|
||||
with transaction.atomic():
|
||||
Lock.lock('lock-name')
|
||||
barrier1.wait()
|
||||
barrier2.wait()
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
exception = None
|
||||
|
||||
def locker_nowait():
|
||||
nonlocal exception
|
||||
|
||||
try:
|
||||
with transaction.atomic():
|
||||
barrier1.wait()
|
||||
try:
|
||||
Lock.lock('lock-name', nowait=True)
|
||||
except Lock.Error as e:
|
||||
exception = e
|
||||
finally:
|
||||
barrier2.wait()
|
||||
connection.close()
|
||||
|
||||
locker_thread = threading.Thread(target=locker)
|
||||
locker_nowait_thread = threading.Thread(target=locker_nowait)
|
||||
locker_thread.start()
|
||||
locker_nowait_thread.start()
|
||||
locker_thread.join()
|
||||
locker_nowait_thread.join()
|
||||
assert exception is not None
|
||||
|
||||
def test_clean(self, transactional_db):
|
||||
import time
|
||||
import uuid
|
||||
|
||||
count = 0
|
||||
|
||||
def take_locks():
|
||||
nonlocal count
|
||||
for _ in range(100):
|
||||
with transaction.atomic():
|
||||
name = str(uuid.uuid4())
|
||||
Lock.lock(name)
|
||||
time.sleep(0.01)
|
||||
assert Lock.objects.get(name=name)
|
||||
count += 1
|
||||
connection.close()
|
||||
|
||||
thread1 = threading.Thread(target=take_locks)
|
||||
thread1.start()
|
||||
|
||||
def clean():
|
||||
while thread1.is_alive():
|
||||
time.sleep(0.001)
|
||||
Lock.cleanup(age=datetime.timedelta(seconds=0))
|
||||
connection.close()
|
||||
|
||||
thread2 = threading.Thread(target=clean)
|
||||
thread2.start()
|
||||
thread1.join()
|
||||
thread2.join()
|
||||
Lock.cleanup(age=datetime.timedelta(seconds=0))
|
||||
assert Lock.objects.count() == 0
|
||||
assert count == 100
|
||||
|
|
Loading…
Reference in New Issue