diff --git a/mellon/sessions_backends/cached_db.py b/mellon/sessions_backends/cached_db.py index f4cd363..e46c1f6 100644 --- a/mellon/sessions_backends/cached_db.py +++ b/mellon/sessions_backends/cached_db.py @@ -15,8 +15,8 @@ from django.contrib.sessions.backends.cached_db import SessionStore as BaseSessionStore -from . import db +from .db import MellonMixin -class SessionStore(db.SessionStore, BaseSessionStore): +class SessionStore(MellonMixin, BaseSessionStore): pass diff --git a/mellon/sessions_backends/db.py b/mellon/sessions_backends/db.py index be3b722..bd751d9 100644 --- a/mellon/sessions_backends/db.py +++ b/mellon/sessions_backends/db.py @@ -18,11 +18,12 @@ from django.contrib.sessions.backends.db import SessionStore as BaseSessionStore from mellon import utils -class SessionStore(BaseSessionStore): +class MellonMixin: def get_session_not_on_or_after(self): - session_not_on_or_after = self.get('mellon_session', {}).get('session_not_on_or_after') - if session_not_on_or_after: - return utils.iso8601_to_datetime(session_not_on_or_after) + if hasattr(self, '_session_cache'): + session_not_on_or_after = self.get('mellon_session', {}).get('session_not_on_or_after') + if session_not_on_or_after: + return utils.iso8601_to_datetime(session_not_on_or_after) return None def get_expiry_age(self, **kwargs): @@ -36,3 +37,7 @@ class SessionStore(BaseSessionStore): if session_not_on_or_after and 'expiry' not in kwargs: kwargs['expiry'] = session_not_on_or_after return super().get_expiry_date(**kwargs) + + +class SessionStore(MellonMixin, BaseSessionStore): + pass diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..b129a2a --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,47 @@ +import datetime + +import pytest +from django.core.cache import cache +from django.utils.timezone import now + +from mellon.sessions_backends import cached_db, db + +cls_param = pytest.mark.parametrize('cls', [cached_db.SessionStore, db.SessionStore], ids=['cached_db', 'db']) + + +@cls_param +def test_basic(db, cls): + cls = cached_db.SessionStore + + session1 = cls() + session1['foo'] = 'bar' + session1.save() + + session = cls(session_key=session1.session_key) + assert session['foo'] == 'bar' + + # check with loading from cache + cache.clear() + session = cls(session_key=session1.session_key) + assert session['foo'] == 'bar' + + +@cls_param +def test_expiry(db, cls, freezer): + cls = cached_db.SessionStore + + session1 = cls() + session1['foo'] = 'bar' + session1['mellon_session'] = { + 'session_not_on_or_after': (now() + datetime.timedelta(hours=1)).isoformat() + } + session1.save() + + freezer.tick(3599) + + session = cls(session_key=session1.session_key) + assert session['foo'] == 'bar' + + freezer.tick(2) + session = cls(session_key=session1.session_key) + assert 'foo' not in session