a2_rbac: move managers from django_rbac (#70894)
This commit is contained in:
parent
14e25ac186
commit
58dd0ae0be
|
@ -14,13 +14,301 @@
|
|||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
import contextlib
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
from authentic2.a2_rbac import models
|
||||
from django_rbac.managers import AbstractBaseManager
|
||||
from django_rbac.managers import RoleManager as BaseRoleManager
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import connection, models
|
||||
from django.db.models import query
|
||||
from django.db.models.query import Prefetch, Q
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from django_rbac import utils
|
||||
from django_rbac.utils import get_operation
|
||||
|
||||
from . import models as a2_models
|
||||
from . import signals
|
||||
|
||||
|
||||
class AbstractBaseManager(models.Manager):
|
||||
def get_by_natural_key(self, uuid):
|
||||
return self.get(uuid=uuid)
|
||||
|
||||
|
||||
class OperationManager(models.Manager):
|
||||
def get_by_natural_key(self, slug):
|
||||
return self.get(slug=slug)
|
||||
|
||||
def has_perm(self, user, operation_slug, object_or_model, ou=None):
|
||||
"""Test if an user can do the operation given by operation_slug
|
||||
on the given object_or_model eventually scoped by an organizational
|
||||
unit given by ou.
|
||||
|
||||
Returns True or False.
|
||||
"""
|
||||
ou_query = query.Q(ou__isnull=True)
|
||||
if ou:
|
||||
ou_query |= query.Q(ou=ou.as_scope())
|
||||
ct = ContentType.objects.get_for_model(object_or_model)
|
||||
target_query = query.Q(target_ct=ContentType.objects.get_for_model(ContentType), target_id=ct.pk)
|
||||
if isinstance(object_or_model, models.Model):
|
||||
target_query |= query.Q(target_ct=ct, target_id=object.pk)
|
||||
Permission = utils.get_permission_model()
|
||||
qs = Permission.objects.for_user(user)
|
||||
qs = qs.filter(operation__slug=operation_slug)
|
||||
qs = qs.filter(ou_query & target_query)
|
||||
return qs.exists()
|
||||
|
||||
|
||||
class PermissionManagerBase(models.Manager):
|
||||
def get_by_natural_key(self, operation_slug, ou_nk, target_ct, target_nk):
|
||||
qs = self.filter(operation__slug=operation_slug)
|
||||
if ou_nk:
|
||||
OrganizationalUnit = utils.get_ou_model()
|
||||
try:
|
||||
ou = OrganizationalUnit.objects.get_by_natural_key(*ou_nk)
|
||||
except OrganizationalUnit.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
qs = qs.filter(ou=ou)
|
||||
else:
|
||||
qs = qs.filter(ou__isnull=True)
|
||||
try:
|
||||
target_ct = ContentType.objects.get_by_natural_key(*target_ct)
|
||||
except ContentType.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
target_model = target_ct.model_class()
|
||||
try:
|
||||
target = target_model.objects.get_by_natural_key(*target_nk)
|
||||
except target_model.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
return qs.get(target_ct=ContentType.objects.get_for_model(target), target_id=target.pk)
|
||||
|
||||
|
||||
class PermissionQueryset(query.QuerySet):
|
||||
def by_target_ct(self, target):
|
||||
"""Filter permission whose target content-type matches the content
|
||||
type of the target argument
|
||||
"""
|
||||
target_ct = ContentType.objects.get_for_model(target)
|
||||
return self.filter(target_ct=target_ct)
|
||||
|
||||
def by_target(self, target):
|
||||
'''Filter permission whose target matches target'''
|
||||
return self.by_target_ct(target).filter(target_id=target.pk)
|
||||
|
||||
def for_user(self, user):
|
||||
"""Retrieve all permissions hold by an user through its role and
|
||||
inherited roles.
|
||||
"""
|
||||
Role = utils.get_role_model()
|
||||
roles = Role.objects.for_user(user=user)
|
||||
return self.filter(roles__in=roles)
|
||||
|
||||
def cleanup(self):
|
||||
count = 0
|
||||
for p in self:
|
||||
if not p.target and (p.target_ct_id or p.target_id):
|
||||
p.delete()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
PermissionManager = PermissionManagerBase.from_queryset(PermissionQueryset)
|
||||
|
||||
|
||||
class IntCast(models.Func):
|
||||
function = 'int'
|
||||
template = 'CAST((%(expressions)s) AS %(function)s)'
|
||||
|
||||
|
||||
class RoleQuerySet(query.QuerySet):
|
||||
def for_user(self, user):
|
||||
if hasattr(user, 'apiclient_roles'):
|
||||
queryset = self.filter(apiclients=user)
|
||||
else:
|
||||
queryset = self.filter(members=user)
|
||||
return queryset.parents().distinct()
|
||||
|
||||
def parents(self, include_self=True, annotate=False, direct=None):
|
||||
assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
|
||||
if direct is None:
|
||||
qs = self.model.objects.filter(
|
||||
child_relation__deleted__isnull=True,
|
||||
child_relation__child__in=self,
|
||||
)
|
||||
else:
|
||||
qs = self.model.objects.filter(
|
||||
child_relation__deleted__isnull=True,
|
||||
child_relation__child__in=self,
|
||||
child_relation__direct=direct,
|
||||
)
|
||||
if include_self:
|
||||
qs = self | qs
|
||||
qs = qs.distinct()
|
||||
if annotate:
|
||||
qs = qs.annotate(direct=models.Max(IntCast('child_relation__direct')))
|
||||
return qs
|
||||
|
||||
def children(self, include_self=True, annotate=False, direct=None):
|
||||
assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
|
||||
if direct is None:
|
||||
qs = self.model.objects.filter(
|
||||
parent_relation__deleted__isnull=True,
|
||||
parent_relation__parent__in=self,
|
||||
)
|
||||
else:
|
||||
qs = self.model.objects.filter(
|
||||
parent_relation__deleted__isnull=True,
|
||||
parent_relation__parent__in=self,
|
||||
parent_relation__direct=direct,
|
||||
)
|
||||
if include_self:
|
||||
qs = self | qs
|
||||
qs = qs.distinct()
|
||||
if annotate:
|
||||
qs = qs.annotate(direct=models.Max(IntCast('parent_relation__direct')))
|
||||
return qs
|
||||
|
||||
def all_members(self):
|
||||
User = get_user_model()
|
||||
prefetch = Prefetch('roles', queryset=self, to_attr='direct')
|
||||
return (
|
||||
User.objects.filter(
|
||||
Q(roles__in=self)
|
||||
| Q(roles__parent_relation__parent__in=self, roles__parent_relation__deleted__isnull=True)
|
||||
)
|
||||
.distinct()
|
||||
.prefetch_related(prefetch)
|
||||
)
|
||||
|
||||
def by_admin_scope_ct(self, admin_scope):
|
||||
admin_scope_ct = ContentType.objects.get_for_model(admin_scope)
|
||||
return self.filter(admin_scope_ct=admin_scope_ct)
|
||||
|
||||
def cleanup(self):
|
||||
count = 0
|
||||
for r in self.filter(Q(admin_scope_ct_id__isnull=False) | Q(admin_scope_id__isnull=False)):
|
||||
if not r.admin_scope:
|
||||
r.delete()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
BaseRoleManager = AbstractBaseManager.from_queryset(RoleQuerySet)
|
||||
|
||||
|
||||
class RoleParentingManager(models.Manager):
|
||||
class Local(threading.local):
|
||||
DO_UPDATE_CLOSURE = True
|
||||
CLOSURE_UPDATED = False
|
||||
|
||||
tls = Local()
|
||||
|
||||
def get_by_natural_key(self, parent_nk, child_nk, direct):
|
||||
Role = utils.get_role_model()
|
||||
try:
|
||||
parent = Role.objects.get_by_natural_key(*parent_nk)
|
||||
except Role.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
try:
|
||||
child = Role.objects.get_by_natural_key(*child_nk)
|
||||
except Role.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
return self.get(parent=parent, child=child, direct=direct)
|
||||
|
||||
def soft_create(self, parent, child):
|
||||
with atomic(savepoint=False):
|
||||
rp, created = self.get_or_create(parent=parent, child=child, direct=True)
|
||||
new = created or rp.deleted
|
||||
if not created and rp.deleted:
|
||||
rp.created = datetime.datetime.now()
|
||||
rp.deleted = None
|
||||
rp.save(update_fields=['created', 'deleted'])
|
||||
if new:
|
||||
signals.post_soft_create.send(sender=self.model, instance=rp)
|
||||
|
||||
def soft_delete(self, parent, child):
|
||||
qs = self.filter(parent=parent, child=child, deleted__isnull=True, direct=True)
|
||||
with atomic(savepoint=False):
|
||||
rp = qs.first()
|
||||
if rp:
|
||||
count = qs.update(deleted=datetime.datetime.now())
|
||||
# read-commited, view of tables can change during transaction
|
||||
if count:
|
||||
signals.post_soft_delete.send(sender=self.model, instance=rp)
|
||||
|
||||
def update_transitive_closure(self):
|
||||
"""Recompute the transitive closure of the inheritance relation
|
||||
from scratch. Add missing indirect relations and delete
|
||||
obsolete indirect relations.
|
||||
"""
|
||||
if not self.tls.DO_UPDATE_CLOSURE:
|
||||
self.tls.CLOSURE_UPDATED = True
|
||||
return
|
||||
|
||||
with atomic(savepoint=False):
|
||||
# existing direct paths
|
||||
direct = set(self.filter(direct=True, deleted__isnull=True).values_list('parent_id', 'child_id'))
|
||||
old_indirects = set(
|
||||
self.filter(direct=False, deleted__isnull=True).values_list('parent_id', 'child_id')
|
||||
)
|
||||
indirects = set(direct)
|
||||
|
||||
while True:
|
||||
changed = False
|
||||
for (i, j) in list(indirects):
|
||||
for (k, l) in direct:
|
||||
if j == k and i != l and (i, l) not in indirects:
|
||||
indirects.add((i, l))
|
||||
changed = True
|
||||
if not changed:
|
||||
break
|
||||
|
||||
with connection.cursor() as cur:
|
||||
# Delete old ones
|
||||
obsolete = old_indirects - indirects - direct
|
||||
if obsolete:
|
||||
sql = '''UPDATE "%s" AS relation \
|
||||
SET deleted = now()\
|
||||
FROM (VALUES %s) AS dead(parent_id, child_id) \
|
||||
WHERE relation.direct = 'false' AND relation.parent_id = dead.parent_id \
|
||||
AND relation.child_id = dead.child_id AND deleted IS NULL''' % (
|
||||
self.model._meta.db_table,
|
||||
', '.join('(%s, %s)' % (a, b) for a, b in obsolete),
|
||||
)
|
||||
cur.execute(sql)
|
||||
# Create new indirect relations
|
||||
new = indirects - old_indirects - direct
|
||||
if new:
|
||||
new_values = ', '.join(
|
||||
(
|
||||
"(%s, %s, 'false', now(), NULL)" % (parent_id, child_id)
|
||||
for parent_id, child_id in new
|
||||
)
|
||||
)
|
||||
sql = '''INSERT INTO "%s" (parent_id, child_id, direct, created, deleted) VALUES %s \
|
||||
ON CONFLICT (parent_id, child_id, direct) DO UPDATE SET created = EXCLUDED.created, deleted = NULL''' % (
|
||||
self.model._meta.db_table,
|
||||
new_values,
|
||||
)
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def defer_update_transitive_closure():
|
||||
from . import utils
|
||||
|
||||
RoleParentingManager.tls.DO_UPDATE_CLOSURE = False
|
||||
try:
|
||||
yield
|
||||
if RoleParentingManager.tls.CLOSURE_UPDATED:
|
||||
utils.get_role_parenting_model().objects.update_transitive_closure()
|
||||
finally:
|
||||
RoleParentingManager.tls.DO_UPDATE_CLOSURE = True
|
||||
RoleParentingManager.tls.CLOSURE_UPDATED = False
|
||||
|
||||
|
||||
class OrganizationalUnitManager(AbstractBaseManager):
|
||||
def get_by_natural_key(self, uuid):
|
||||
|
@ -70,7 +358,7 @@ class RoleManager(BaseRoleManager):
|
|||
# find an operation matching the template
|
||||
op = get_operation(operation)
|
||||
if create:
|
||||
perm, _ = models.Permission.objects.update_or_create(
|
||||
perm, _ = a2_models.Permission.objects.update_or_create(
|
||||
operation=op,
|
||||
target_ct=ContentType.objects.get_for_model(instance),
|
||||
target_id=instance.pk,
|
||||
|
@ -79,13 +367,13 @@ class RoleManager(BaseRoleManager):
|
|||
)
|
||||
else:
|
||||
try:
|
||||
perm = models.Permission.objects.get(
|
||||
perm = a2_models.Permission.objects.get(
|
||||
operation=op,
|
||||
target_ct=ContentType.objects.get_for_model(instance),
|
||||
target_id=instance.pk,
|
||||
**kwargs,
|
||||
)
|
||||
except models.Permission.DoesNotExist:
|
||||
except a2_models.Permission.DoesNotExist:
|
||||
return None
|
||||
|
||||
# in which ou do we put the role ?
|
||||
|
@ -157,8 +445,8 @@ class RoleManager(BaseRoleManager):
|
|||
kwargs['ou__isnull'] = True
|
||||
else:
|
||||
try:
|
||||
ou = models.OrganizationalUnit.objects.get_by_natural_key(*ou_natural_key)
|
||||
except models.OrganizationalUnit.DoesNotExist:
|
||||
ou = a2_models.OrganizationalUnit.objects.get_by_natural_key(*ou_natural_key)
|
||||
except a2_models.OrganizationalUnit.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
kwargs['ou'] = ou
|
||||
if service_natural_key is None:
|
||||
|
|
|
@ -37,7 +37,6 @@ from model_utils.managers import QueryManager
|
|||
from authentic2.decorators import errorcollector
|
||||
from authentic2.utils.cache import GlobalCache
|
||||
from authentic2.validators import HexaColourValidator
|
||||
from django_rbac import managers as rbac_managers
|
||||
from django_rbac import utils as rbac_utils
|
||||
|
||||
from . import app_settings, fields, managers
|
||||
|
@ -55,7 +54,7 @@ class AbstractBase(models.Model):
|
|||
slug = models.SlugField(max_length=256, verbose_name=_('slug'))
|
||||
description = models.TextField(verbose_name=_('description'), blank=True)
|
||||
|
||||
objects = rbac_managers.AbstractBaseManager()
|
||||
objects = managers.AbstractBaseManager()
|
||||
|
||||
def __str__(self):
|
||||
return str(self.name)
|
||||
|
@ -280,7 +279,7 @@ class Permission(models.Model):
|
|||
target_id = models.PositiveIntegerField()
|
||||
target = GenericForeignKey('target_ct', 'target_id')
|
||||
|
||||
objects = rbac_managers.PermissionManager()
|
||||
objects = managers.PermissionManager()
|
||||
|
||||
class Meta:
|
||||
verbose_name = _('permission')
|
||||
|
@ -412,7 +411,7 @@ class Role(AbstractBase):
|
|||
default=True, verbose_name=_('Allow adding or deleting role members')
|
||||
)
|
||||
|
||||
objects = rbac_managers.RoleQuerySet.as_manager()
|
||||
objects = managers.RoleQuerySet.as_manager()
|
||||
|
||||
def add_child(self, child):
|
||||
RoleParenting = rbac_utils.get_role_parenting_model()
|
||||
|
@ -720,7 +719,7 @@ class RoleParenting(models.Model):
|
|||
created = models.DateTimeField(verbose_name=_('Creation date'), auto_now_add=True)
|
||||
deleted = models.DateTimeField(verbose_name=_('Deletion date'), null=True)
|
||||
|
||||
objects = rbac_managers.RoleParentingManager()
|
||||
objects = managers.RoleParentingManager()
|
||||
alive = QueryManager(deleted__isnull=True)
|
||||
|
||||
def natural_key(self):
|
||||
|
@ -778,7 +777,7 @@ class Operation(models.Model):
|
|||
|
||||
_registry = {}
|
||||
|
||||
objects = rbac_managers.OperationManager()
|
||||
objects = managers.OperationManager()
|
||||
|
||||
|
||||
Operation._meta.natural_key = ['slug']
|
||||
|
|
|
@ -22,9 +22,10 @@ from django.utils.translation import override
|
|||
|
||||
from authentic2.a2_rbac.models import OrganizationalUnit, Role
|
||||
from authentic2.utils.misc import get_fk_model
|
||||
from django_rbac.managers import defer_update_transitive_closure
|
||||
from django_rbac.utils import get_operation, get_role_parenting_model
|
||||
|
||||
from .managers import defer_update_transitive_closure
|
||||
|
||||
|
||||
def create_default_ou(app_config, verbosity=2, interactive=True, using=DEFAULT_DB_ALIAS, **kwargs):
|
||||
if not router.allow_migrate(using, OrganizationalUnit):
|
||||
|
|
|
@ -1,292 +0,0 @@
|
|||
import contextlib
|
||||
import datetime
|
||||
import threading
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import connection, models
|
||||
from django.db.models import query
|
||||
from django.db.models.query import Prefetch, Q
|
||||
from django.db.transaction import atomic
|
||||
|
||||
from authentic2.a2_rbac import signals
|
||||
|
||||
from . import utils
|
||||
|
||||
|
||||
class AbstractBaseManager(models.Manager):
|
||||
def get_by_natural_key(self, uuid):
|
||||
return self.get(uuid=uuid)
|
||||
|
||||
|
||||
class OperationManager(models.Manager):
|
||||
def get_by_natural_key(self, slug):
|
||||
return self.get(slug=slug)
|
||||
|
||||
def has_perm(self, user, operation_slug, object_or_model, ou=None):
|
||||
"""Test if an user can do the operation given by operation_slug
|
||||
on the given object_or_model eventually scoped by an organizational
|
||||
unit given by ou.
|
||||
|
||||
Returns True or False.
|
||||
"""
|
||||
ou_query = query.Q(ou__isnull=True)
|
||||
if ou:
|
||||
ou_query |= query.Q(ou=ou.as_scope())
|
||||
ct = ContentType.objects.get_for_model(object_or_model)
|
||||
target_query = query.Q(target_ct=ContentType.objects.get_for_model(ContentType), target_id=ct.pk)
|
||||
if isinstance(object_or_model, models.Model):
|
||||
target_query |= query.Q(target_ct=ct, target_id=object.pk)
|
||||
Permission = utils.get_permission_model()
|
||||
qs = Permission.objects.for_user(user)
|
||||
qs = qs.filter(operation__slug=operation_slug)
|
||||
qs = qs.filter(ou_query & target_query)
|
||||
return qs.exists()
|
||||
|
||||
|
||||
class PermissionManagerBase(models.Manager):
|
||||
def get_by_natural_key(self, operation_slug, ou_nk, target_ct, target_nk):
|
||||
qs = self.filter(operation__slug=operation_slug)
|
||||
if ou_nk:
|
||||
OrganizationalUnit = utils.get_ou_model()
|
||||
try:
|
||||
ou = OrganizationalUnit.objects.get_by_natural_key(*ou_nk)
|
||||
except OrganizationalUnit.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
qs = qs.filter(ou=ou)
|
||||
else:
|
||||
qs = qs.filter(ou__isnull=True)
|
||||
try:
|
||||
target_ct = ContentType.objects.get_by_natural_key(*target_ct)
|
||||
except ContentType.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
target_model = target_ct.model_class()
|
||||
try:
|
||||
target = target_model.objects.get_by_natural_key(*target_nk)
|
||||
except target_model.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
return qs.get(target_ct=ContentType.objects.get_for_model(target), target_id=target.pk)
|
||||
|
||||
|
||||
class PermissionQueryset(query.QuerySet):
|
||||
def by_target_ct(self, target):
|
||||
"""Filter permission whose target content-type matches the content
|
||||
type of the target argument
|
||||
"""
|
||||
target_ct = ContentType.objects.get_for_model(target)
|
||||
return self.filter(target_ct=target_ct)
|
||||
|
||||
def by_target(self, target):
|
||||
'''Filter permission whose target matches target'''
|
||||
return self.by_target_ct(target).filter(target_id=target.pk)
|
||||
|
||||
def for_user(self, user):
|
||||
"""Retrieve all permissions hold by an user through its role and
|
||||
inherited roles.
|
||||
"""
|
||||
Role = utils.get_role_model()
|
||||
roles = Role.objects.for_user(user=user)
|
||||
return self.filter(roles__in=roles)
|
||||
|
||||
def cleanup(self):
|
||||
count = 0
|
||||
for p in self:
|
||||
if not p.target and (p.target_ct_id or p.target_id):
|
||||
p.delete()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
PermissionManager = PermissionManagerBase.from_queryset(PermissionQueryset)
|
||||
|
||||
|
||||
class IntCast(models.Func):
|
||||
function = 'int'
|
||||
template = 'CAST((%(expressions)s) AS %(function)s)'
|
||||
|
||||
|
||||
class RoleQuerySet(query.QuerySet):
|
||||
def for_user(self, user):
|
||||
if hasattr(user, 'apiclient_roles'):
|
||||
queryset = self.filter(apiclients=user)
|
||||
else:
|
||||
queryset = self.filter(members=user)
|
||||
return queryset.parents().distinct()
|
||||
|
||||
def parents(self, include_self=True, annotate=False, direct=None):
|
||||
assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
|
||||
if direct is None:
|
||||
qs = self.model.objects.filter(
|
||||
child_relation__deleted__isnull=True,
|
||||
child_relation__child__in=self,
|
||||
)
|
||||
else:
|
||||
qs = self.model.objects.filter(
|
||||
child_relation__deleted__isnull=True,
|
||||
child_relation__child__in=self,
|
||||
child_relation__direct=direct,
|
||||
)
|
||||
if include_self:
|
||||
qs = self | qs
|
||||
qs = qs.distinct()
|
||||
if annotate:
|
||||
qs = qs.annotate(direct=models.Max(IntCast('child_relation__direct')))
|
||||
return qs
|
||||
|
||||
def children(self, include_self=True, annotate=False, direct=None):
|
||||
assert annotate is False or direct is not True, 'annotate=True cannot be used with direct=True'
|
||||
if direct is None:
|
||||
qs = self.model.objects.filter(
|
||||
parent_relation__deleted__isnull=True,
|
||||
parent_relation__parent__in=self,
|
||||
)
|
||||
else:
|
||||
qs = self.model.objects.filter(
|
||||
parent_relation__deleted__isnull=True,
|
||||
parent_relation__parent__in=self,
|
||||
parent_relation__direct=direct,
|
||||
)
|
||||
if include_self:
|
||||
qs = self | qs
|
||||
qs = qs.distinct()
|
||||
if annotate:
|
||||
qs = qs.annotate(direct=models.Max(IntCast('parent_relation__direct')))
|
||||
return qs
|
||||
|
||||
def all_members(self):
|
||||
User = get_user_model()
|
||||
prefetch = Prefetch('roles', queryset=self, to_attr='direct')
|
||||
return (
|
||||
User.objects.filter(
|
||||
Q(roles__in=self)
|
||||
| Q(roles__parent_relation__parent__in=self, roles__parent_relation__deleted__isnull=True)
|
||||
)
|
||||
.distinct()
|
||||
.prefetch_related(prefetch)
|
||||
)
|
||||
|
||||
def by_admin_scope_ct(self, admin_scope):
|
||||
admin_scope_ct = ContentType.objects.get_for_model(admin_scope)
|
||||
return self.filter(admin_scope_ct=admin_scope_ct)
|
||||
|
||||
def cleanup(self):
|
||||
count = 0
|
||||
for r in self.filter(Q(admin_scope_ct_id__isnull=False) | Q(admin_scope_id__isnull=False)):
|
||||
if not r.admin_scope:
|
||||
r.delete()
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
RoleManager = AbstractBaseManager.from_queryset(RoleQuerySet)
|
||||
|
||||
|
||||
class RoleParentingManager(models.Manager):
|
||||
class Local(threading.local):
|
||||
DO_UPDATE_CLOSURE = True
|
||||
CLOSURE_UPDATED = False
|
||||
|
||||
tls = Local()
|
||||
|
||||
def get_by_natural_key(self, parent_nk, child_nk, direct):
|
||||
Role = utils.get_role_model()
|
||||
try:
|
||||
parent = Role.objects.get_by_natural_key(*parent_nk)
|
||||
except Role.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
try:
|
||||
child = Role.objects.get_by_natural_key(*child_nk)
|
||||
except Role.DoesNotExist:
|
||||
raise self.model.DoesNotExist
|
||||
return self.get(parent=parent, child=child, direct=direct)
|
||||
|
||||
def soft_create(self, parent, child):
|
||||
with atomic(savepoint=False):
|
||||
rp, created = self.get_or_create(parent=parent, child=child, direct=True)
|
||||
new = created or rp.deleted
|
||||
if not created and rp.deleted:
|
||||
rp.created = datetime.datetime.now()
|
||||
rp.deleted = None
|
||||
rp.save(update_fields=['created', 'deleted'])
|
||||
if new:
|
||||
signals.post_soft_create.send(sender=self.model, instance=rp)
|
||||
|
||||
def soft_delete(self, parent, child):
|
||||
qs = self.filter(parent=parent, child=child, deleted__isnull=True, direct=True)
|
||||
with atomic(savepoint=False):
|
||||
rp = qs.first()
|
||||
if rp:
|
||||
count = qs.update(deleted=datetime.datetime.now())
|
||||
# read-commited, view of tables can change during transaction
|
||||
if count:
|
||||
signals.post_soft_delete.send(sender=self.model, instance=rp)
|
||||
|
||||
def update_transitive_closure(self):
|
||||
"""Recompute the transitive closure of the inheritance relation
|
||||
from scratch. Add missing indirect relations and delete
|
||||
obsolete indirect relations.
|
||||
"""
|
||||
if not self.tls.DO_UPDATE_CLOSURE:
|
||||
self.tls.CLOSURE_UPDATED = True
|
||||
return
|
||||
|
||||
with atomic(savepoint=False):
|
||||
# existing direct paths
|
||||
direct = set(self.filter(direct=True, deleted__isnull=True).values_list('parent_id', 'child_id'))
|
||||
old_indirects = set(
|
||||
self.filter(direct=False, deleted__isnull=True).values_list('parent_id', 'child_id')
|
||||
)
|
||||
indirects = set(direct)
|
||||
|
||||
while True:
|
||||
changed = False
|
||||
for (i, j) in list(indirects):
|
||||
for (k, l) in direct:
|
||||
if j == k and i != l and (i, l) not in indirects:
|
||||
indirects.add((i, l))
|
||||
changed = True
|
||||
if not changed:
|
||||
break
|
||||
|
||||
with connection.cursor() as cur:
|
||||
# Delete old ones
|
||||
obsolete = old_indirects - indirects - direct
|
||||
if obsolete:
|
||||
sql = '''UPDATE "%s" AS relation \
|
||||
SET deleted = now()\
|
||||
FROM (VALUES %s) AS dead(parent_id, child_id) \
|
||||
WHERE relation.direct = 'false' AND relation.parent_id = dead.parent_id \
|
||||
AND relation.child_id = dead.child_id AND deleted IS NULL''' % (
|
||||
self.model._meta.db_table,
|
||||
', '.join('(%s, %s)' % (a, b) for a, b in obsolete),
|
||||
)
|
||||
cur.execute(sql)
|
||||
# Create new indirect relations
|
||||
new = indirects - old_indirects - direct
|
||||
if new:
|
||||
new_values = ', '.join(
|
||||
(
|
||||
"(%s, %s, 'false', now(), NULL)" % (parent_id, child_id)
|
||||
for parent_id, child_id in new
|
||||
)
|
||||
)
|
||||
sql = '''INSERT INTO "%s" (parent_id, child_id, direct, created, deleted) VALUES %s \
|
||||
ON CONFLICT (parent_id, child_id, direct) DO UPDATE SET created = EXCLUDED.created, deleted = NULL''' % (
|
||||
self.model._meta.db_table,
|
||||
new_values,
|
||||
)
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def defer_update_transitive_closure():
|
||||
from . import utils
|
||||
|
||||
RoleParentingManager.tls.DO_UPDATE_CLOSURE = False
|
||||
try:
|
||||
yield
|
||||
if RoleParentingManager.tls.CLOSURE_UPDATED:
|
||||
utils.get_role_parenting_model().objects.update_transitive_closure()
|
||||
finally:
|
||||
RoleParentingManager.tls.DO_UPDATE_CLOSURE = True
|
||||
RoleParentingManager.tls.CLOSURE_UPDATED = False
|
Loading…
Reference in New Issue