models: lock user model when changing multiple attribute values (#37390)

This commit is contained in:
Benjamin Dauvergne 2019-11-12 11:05:14 +01:00
parent 173f63f647
commit 3d3df4e858
3 changed files with 71 additions and 62 deletions

View File

@ -14,7 +14,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from django.db import models
from django.db import models, transaction
from django.utils import timezone
from django.core.mail import send_mail
from django.utils import six
@ -55,29 +55,33 @@ class Attributes(object):
for atv in self.owner.attribute_values.all():
attribute = get_attributes_map()[atv.attribute_id]
atv.attribute = attribute
values[attribute.name] = atv
if attribute.multiple:
values.setdefault(attribute.name, []).append(atv)
else:
values[attribute.name] = atv
self.__dict__['values'] = owner._a2_attributes_cache
def __setattr__(self, name, value):
atv = self.values.get(name)
if atv:
if isinstance(atv, (list, tuple)):
attribute = atv[0].attribute
else:
attribute = atv.attribute
attribute.set_value(self.owner, value, verified=bool(self.verified), attribute_value=atv)
else:
attribute = get_attributes_map().get(name)
if not attribute:
raise AttributeError(name)
self.values[name] = attribute.set_value(self.owner, value, verified=bool(self.verified))
attribute = get_attributes_map().get(name)
if not attribute:
raise AttributeError(name)
update_fields = ['modified']
if name in ['first_name', 'last_name']:
if getattr(self.owner, name) != value:
setattr(self.owner, name, value)
update_fields.append(name)
self.owner.save(update_fields=update_fields)
with transaction.atomic():
if attribute.multiple:
attribute.set_value(self.owner, value, verified=bool(self.verified))
else:
atv = self.values.get(name)
self.values[name] = attribute.set_value(
self.owner, value,
verified=bool(self.verified),
attribute_value=atv)
update_fields = ['modified']
if name in ['first_name', 'last_name']:
if getattr(self.owner, name) != value:
setattr(self.owner, name, value)
update_fields.append(name)
self.owner.save(update_fields=update_fields)
def __getattr__(self, name):
if name not in get_attributes_map():

View File

@ -245,13 +245,15 @@ class Attribute(models.Model):
AttributeValue.objects.with_owner(owner).filter(attribute=self).delete()
return
if self.multiple:
assert isinstance(value, (list, set, tuple))
values = value
avs = []
content_list = []
with transaction.atomic():
if self.multiple:
assert isinstance(value, (list, set, tuple))
values = value
avs = []
content_list = []
list(owner.__class__.objects.filter(pk=owner.pk).select_for_update())
with transaction.atomic():
for value in values:
content = serialize(value)
av, created = AttributeValue.objects.get_or_create(
@ -273,23 +275,23 @@ class Attribute(models.Model):
object_id=owner.pk,
multiple=True
).exclude(content__in=content_list).delete()
return avs
else:
content = serialize(value)
if attribute_value:
av, created = attribute_value, False
return avs
else:
av, created = AttributeValue.objects.get_or_create(
content_type=ContentType.objects.get_for_model(owner),
object_id=owner.pk,
attribute=self,
multiple=False,
defaults={'content': content, 'verified': verified})
if not created and (av.content != content or av.verified != verified):
av.content = content
av.verified = verified
av.save()
return av
content = serialize(value)
if attribute_value:
av, created = attribute_value, False
else:
av, created = AttributeValue.objects.get_or_create(
content_type=ContentType.objects.get_for_model(owner),
object_id=owner.pk,
attribute=self,
multiple=False,
defaults={'content': content, 'verified': verified})
if not created and (av.content != content or av.verified != verified):
av.content = content
av.verified = verified
av.save()
return av
def natural_key(self):
return (self.name,)

View File

@ -25,11 +25,11 @@ from utils import skipif_sqlite
@skipif_sqlite
def test_attribute_value_uniqueness(migrations, transactional_db, simple_user, concurrency):
from django.db.transaction import set_autocommit
#from django.db.transaction import set_autocommit
# disabled default attributes
Attribute.objects.update(disabled=True)
set_autocommit(True)
#set_autocommit(True)
acount = Attribute.objects.count()
single_at = Attribute.objects.create(
@ -44,23 +44,26 @@ def test_attribute_value_uniqueness(migrations, transactional_db, simple_user, c
multiple=True)
assert Attribute.objects.count() == acount + 2
def map_threads(f, l):
threads = []
for i in l:
threads.append(threading.Thread(target=f, args=(i,)))
threads[-1].start()
for thread in threads:
thread.join()
AttributeValue.objects.all().delete()
def f(i):
simple_user.attributes.multiple = [str(i)]
connection.close()
map_threads(f, range(concurrency))
map_threads(f, range(concurrency))
assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1
for i in range(10):
def map_threads(f, l):
threads = []
for i in l:
threads.append(threading.Thread(target=f, args=(i,)))
threads[-1].start()
for thread in threads:
thread.join()
def f(i):
simple_user.attributes.single = str(i)
connection.close()
map_threads(f, range(concurrency))
assert AttributeValue.objects.filter(attribute=single_at).count() == 1
def f(i):
simple_user.attributes.multiple = [str(i)]
connection.close()
map_threads(f, range(concurrency))
map_threads(f, range(concurrency))
assert AttributeValue.objects.filter(attribute=multiple_at).count() == 1
def f(i):
simple_user.attributes.single = str(i)
connection.close()
map_threads(f, range(concurrency))
assert AttributeValue.objects.filter(attribute=single_at).count() == 1