From 9049d6350f88f2a67ceff582b84bfb1c1d2a0ee2 Mon Sep 17 00:00:00 2001 From: Ross McFarland Date: Sun, 3 Nov 2013 13:07:10 -0800 Subject: [PATCH] implement OAUTH_CLEAN_EXPIRED, clean as you go this avoids keeping around unneeded/no-longer usable objects. it's optional defaulting to False, the current behavior. adds a setting OAUTH_CLEAN_EXPIRED that optionally cleans out objects as they're expired which should keep the size of the grant, access, and refresh token tables in check. once grants are used they are deleted. and once a RefreshToken is used both it and its corresponding AccessToken are deleted. --- provider/oauth2/tests.py | 76 +++++++++++++++++++++++++++++++++++++++- provider/oauth2/views.py | 22 ++++++++---- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 6a051bc..988af23 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -12,7 +12,7 @@ from ..compat import skipIfCustomUser from ..templatetags.scope import scopes from ..utils import now as date_now from .forms import ClientForm -from .models import Client, Grant, AccessToken +from .models import Client, Grant, AccessToken, RefreshToken from .backends import BasicClientBackend, RequestParamsClientBackend from .backends import AccessTokenBackend @@ -531,3 +531,77 @@ class ScopeTest(TestCase): names.sort() self.assertEqual('read read+write write', ' '.join(names)) + + +class CleanExpiredTest(BaseOAuth2TestCase): + fixtures = ['test_oauth2'] + + def setUp(self): + self._old_oauth_clean_expired = getattr(settings, + 'OAUTH_CLEAN_EXPIRED', None) + settings.OAUTH_CLEAN_EXPIRED = True + + def tearDown(self): + if self._old_oauth_clean_expired is not None: + settings.OAUTH_CLEAN_EXPIRED = self._old_oauth_clean_expired + else: + delattr(settings, 'OAUTH_CLEAN_EXPIRED') + + def test_clear_expired(self): + self.login() + + self._login_and_authorize() + + response = self.client.get(self.redirect_url()) + + self.assertEqual(302, response.status_code) + location = response['Location'] + self.assertFalse('error' in location) + self.assertTrue('code' in location) + + # verify that Grant with code exists + code = urlparse.parse_qs(location)['code'][0] + self.assertTrue(Grant.objects.filter(code=code).exists()) + + from pprint import pprint + + # use the code/grant + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + 'code': code}) + self.assertEquals(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + access_token = token['access_token'] + self.assertTrue('refresh_token' in token) + refresh_token = token['refresh_token'] + + # make sure the grant is gone + self.assertFalse(Grant.objects.filter(code=code).exists()) + # and verify that the AccessToken and RefreshToken exist + self.assertTrue(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertTrue(RefreshToken.objects.filter(token=refresh_token) + .exists()) + + # refresh the token + response = self.client.post(self.access_token_url(), { + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + }) + self.assertEqual(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + self.assertNotEquals(access_token, token['access_token']) + self.assertTrue('refresh_token' in token) + self.assertNotEquals(refresh_token, token['refresh_token']) + + # make sure the orig AccessToken and RefreshToken are gone + self.assertFalse(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertFalse(RefreshToken.objects.filter(token=refresh_token) + .exists()) diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index 0c3af84..fa9fbb1 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -1,4 +1,5 @@ from datetime import timedelta +from django.conf import settings from django.core.urlresolvers import reverse from ..views import Capture, Authorize, Redirect from ..views import AccessToken as AccessTokenView, OAuthError @@ -116,13 +117,22 @@ class AccessTokenView(AccessTokenView): ) def invalidate_grant(self, grant): - grant.expires = now() - timedelta(days=1) - grant.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + grant.delete() + else: + grant.expires = now() - timedelta(days=1) + grant.save() def invalidate_refresh_token(self, rt): - rt.expired = True - rt.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + rt.delete() + else: + rt.expired = True + rt.save() def invalidate_access_token(self, at): - at.expires = now() - timedelta(days=1) - at.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + at.delete() + else: + at.expires = now() - timedelta(days=1) + at.save()