diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 6a051bc..988af23 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -12,7 +12,7 @@ from ..compat import skipIfCustomUser from ..templatetags.scope import scopes from ..utils import now as date_now from .forms import ClientForm -from .models import Client, Grant, AccessToken +from .models import Client, Grant, AccessToken, RefreshToken from .backends import BasicClientBackend, RequestParamsClientBackend from .backends import AccessTokenBackend @@ -531,3 +531,77 @@ class ScopeTest(TestCase): names.sort() self.assertEqual('read read+write write', ' '.join(names)) + + +class CleanExpiredTest(BaseOAuth2TestCase): + fixtures = ['test_oauth2'] + + def setUp(self): + self._old_oauth_clean_expired = getattr(settings, + 'OAUTH_CLEAN_EXPIRED', None) + settings.OAUTH_CLEAN_EXPIRED = True + + def tearDown(self): + if self._old_oauth_clean_expired is not None: + settings.OAUTH_CLEAN_EXPIRED = self._old_oauth_clean_expired + else: + delattr(settings, 'OAUTH_CLEAN_EXPIRED') + + def test_clear_expired(self): + self.login() + + self._login_and_authorize() + + response = self.client.get(self.redirect_url()) + + self.assertEqual(302, response.status_code) + location = response['Location'] + self.assertFalse('error' in location) + self.assertTrue('code' in location) + + # verify that Grant with code exists + code = urlparse.parse_qs(location)['code'][0] + self.assertTrue(Grant.objects.filter(code=code).exists()) + + from pprint import pprint + + # use the code/grant + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + 'code': code}) + self.assertEquals(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + access_token = token['access_token'] + self.assertTrue('refresh_token' in token) + refresh_token = token['refresh_token'] + + # make sure the grant is gone + self.assertFalse(Grant.objects.filter(code=code).exists()) + # and verify that the AccessToken and RefreshToken exist + self.assertTrue(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertTrue(RefreshToken.objects.filter(token=refresh_token) + .exists()) + + # refresh the token + response = self.client.post(self.access_token_url(), { + 'grant_type': 'refresh_token', + 'refresh_token': token['refresh_token'], + 'client_id': self.get_client().client_id, + 'client_secret': self.get_client().client_secret, + }) + self.assertEqual(200, response.status_code) + token = json.loads(response.content) + self.assertTrue('access_token' in token) + self.assertNotEquals(access_token, token['access_token']) + self.assertTrue('refresh_token' in token) + self.assertNotEquals(refresh_token, token['refresh_token']) + + # make sure the orig AccessToken and RefreshToken are gone + self.assertFalse(AccessToken.objects.filter(token=access_token) + .exists()) + self.assertFalse(RefreshToken.objects.filter(token=refresh_token) + .exists()) diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index 0c3af84..fa9fbb1 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -1,4 +1,5 @@ from datetime import timedelta +from django.conf import settings from django.core.urlresolvers import reverse from ..views import Capture, Authorize, Redirect from ..views import AccessToken as AccessTokenView, OAuthError @@ -116,13 +117,22 @@ class AccessTokenView(AccessTokenView): ) def invalidate_grant(self, grant): - grant.expires = now() - timedelta(days=1) - grant.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + grant.delete() + else: + grant.expires = now() - timedelta(days=1) + grant.save() def invalidate_refresh_token(self, rt): - rt.expired = True - rt.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + rt.delete() + else: + rt.expired = True + rt.save() def invalidate_access_token(self, at): - at.expires = now() - timedelta(days=1) - at.save() + if getattr(settings, 'OAUTH_CLEAN_EXPIRED', False): + at.delete() + else: + at.expires = now() - timedelta(days=1) + at.save()