diff --git a/README.txt b/README.txt index 292d70e..da7d555 100644 --- a/README.txt +++ b/README.txt @@ -18,3 +18,16 @@ A2_IDP_CAS_SERVICES A sequence of URL prefixes, any URL starting with A2_IDP_CAS_PROVIDER Class implementating CAS views, default to `authentic2_idp_cas.views.CasProvider` A2_IDP_CAS_TICKET_EXPIRATION Ticket lifetime + +Roadmap +======= + +- implement proxy tickets +- add test for samlValidate +- implement CAS 3.0 new constraints +- implement CAS 3.0 logout +- add service field to CasService model, use domain only to filter them + easily +- add model to store attribute configuration for a service +- add way to set attribute configuration for a service from settings + diff --git a/authentic2_idp_cas/constants.py b/authentic2_idp_cas/constants.py index b90c10b..417bdd6 100644 --- a/authentic2_idp_cas/constants.py +++ b/authentic2_idp_cas/constants.py @@ -42,6 +42,9 @@ PROXY_TICKET_ELT = 'proxyTicket' PROXY_FAILURE_ELT = 'proxyFailure' +# XML Elements for CAS 3.0 +ATTRIBUTES_ELT = 'attributes' + # Templates CAS10_VALIDATION_FAILURE = 'no\n\n' @@ -56,3 +59,42 @@ CAS20_VALIDATION_SUCCESS = ''' + + + + + + + + + + + + {audience} + + + + + {name_id} + + urn:oasis:names:tc:SAML:1.0:cm:artifact + + + {attributes} + + + + + {name_id} + + urn:oasis:names:tc:SAML:1.0:cm:artifact + + + + + + +''' + diff --git a/authentic2_idp_cas/tests.py b/authentic2_idp_cas/tests.py index c1750f3..59bd6fa 100644 --- a/authentic2_idp_cas/tests.py +++ b/authentic2_idp_cas/tests.py @@ -2,11 +2,12 @@ from xml.etree import ElementTree as ET from django.test import TestCase -from django.test.client import RequestFactory +from django.test.client import RequestFactory, Client +from django.test.utils import override_settings from authentic2.compat import get_user_model -from .models import CasTicket +from .models import CasTicket, CasService from . import views from . import constants @@ -14,12 +15,49 @@ from . import constants class CasTests(TestCase): LOGIN = 'test' PASSWORD = 'test' + DOMAIN = 'casclient.com' + SERVICE = 'https://%s/' % DOMAIN def setUp(self): User = get_user_model() self.user = User.objects.create_user(self.LOGIN, password=self.PASSWORD) self.factory = RequestFactory() + def test_cas_login_blacklist_failure(self): + client = Client() + response = client.get('/idp/cas/login/', {'service': self.SERVICE}) + self.assertEqual(response.status_code, 400) + self.assertIn('is not allowed', response.content) + + @override_settings(A2_IDP_CAS_SERVICES=(SERVICE,)) + def test_cas_login_settings_whitelist(self): + self.helper_test_cas_login() + + def test_cas_login_model_whitelist(self): + CasService.objects.create( + name=self.DOMAIN, + slug=self.DOMAIN, + domain=self.DOMAIN) + self.helper_test_cas_login() + + def helper_test_cas_login(self): + client = Client() + response = client.get('/idp/cas/login/', {'service': self.SERVICE}) + self.assertIn('Location', response) + self.assertTrue(response['Location'].startswith('http://testserver/login')) + response = client.post(response['Location'], { + 'username': self.LOGIN, + 'password': self.PASSWORD, + 'submit-password': ''}) + self.assertTrue(response['Location'].startswith('http://testserver/idp/cas/continue/')) + response = client.get(response['Location']) + self.assertTrue(response['Location'].startswith('https://casclient.com/?ticket=ST-')) + # verify ticket + ticket = response['Location'].split('ticket=')[1] + response = client.get('/idp/cas/serviceValidate/', {'service': self.SERVICE, 'ticket': ticket}) + self.assertEqual(response.content, ''' +test''') + def test_service_validate_with_default_attributes(self): CasTicket.objects.create( ticket_id='ST-xxx', @@ -32,28 +70,15 @@ class CasTests(TestCase): def get_attributes(self, request, st): assert st.service == 'yyy' assert st.ticket_id == 'ST-xxx' - return { 'username': 'bob', 'email': 'bob@example.com' }, 'default' + return 'bob', { 'username': 'bob', 'email': 'bob@example.com' } provider = TestCasProvider() response = provider.service_validate(request) - print response.content root = ET.fromstring(response.content) ns_ctx = { 'cas': constants.CAS_NAMESPACE } - user_elt = root.find('cas:authenticationSuccess/cas:utilisateur', namespaces=ns_ctx) - assert user_elt is not None - - def test_service_validate_with_custom_attributes(self): - CasTicket.objects.create( - ticket_id='ST-xxx', - service='yyy', - user=self.user, - validity=True) - request = self.factory.get('/idp/cas/serviceValidate', - {'service': 'yyy', 'ticket': 'ST-xxx'}) - class TestCasProvider(views.CasProvider): - def get_attributes(self, request, st): - assert st.service == 'yyy' - assert st.ticket_id == 'ST-xxx' - return { 'username': 'bob', 'email': 'bob@example.com' }, 'utilisateur' - provider = TestCasProvider() - response = provider.service_validate(request) - print response.content + user_elt = root.find('cas:authenticationSuccess/cas:user', namespaces=ns_ctx) + self.assertIsNotNone(user_elt) + self.assertEqual(user_elt.text, 'bob') + username_elt = root.find('cas:authenticationSuccess/cas:attributes/cas:username', namespaces=ns_ctx) + self.assertEqual(username_elt.text, 'bob') + email_elt = root.find('cas:authenticationSuccess/cas:attributes/cas:email', namespaces=ns_ctx) + self.assertEqual(email_elt.text, 'bob@example.com') diff --git a/authentic2_idp_cas/views.py b/authentic2_idp_cas/views.py index 2ab93b5..6f32a2b 100644 --- a/authentic2_idp_cas/views.py +++ b/authentic2_idp_cas/views.py @@ -7,7 +7,6 @@ from django.http import HttpResponseRedirect, HttpResponseBadRequest, \ HttpResponse, HttpResponseNotAllowed from django.core.urlresolvers import reverse from django.contrib.auth.views import redirect_to_login -from django.utils.http import urlquote, urlencode from django.conf.urls import patterns, url from django.conf import settings @@ -20,50 +19,13 @@ from constants import SERVICE_PARAM, RENEW_PARAM, GATEWAY_PARAM, ID_PARAM, \ CAS10_VALIDATION_FAILURE, CAS10_VALIDATION_SUCCESS, PGT_URL_PARAM, \ INVALID_REQUEST_ERROR, INVALID_TICKET_ERROR, INVALID_SERVICE_ERROR, \ INTERNAL_ERROR, CAS20_VALIDATION_FAILURE, \ - CAS_NAMESPACE, USER_ELT, SERVICE_RESPONSE_ELT, AUTHENTICATION_SUCCESS_ELT + CAS_NAMESPACE, USER_ELT, SERVICE_RESPONSE_ELT, AUTHENTICATION_SUCCESS_ELT, \ + SAML_RESPONSE_TEMPLATE, ATTRIBUTES_ELT from . import models, utils, app_settings logger = logging.getLogger(__name__) -SAML_RESPONSE_TEMPLATE = ''' - - - - - - - - - - - - {audience} - - - - - {name_id} - - urn:oasis:names:tc:SAML:1.0:cm:artifact - - - {attributes} - - - - - {name_id} - - urn:oasis:names:tc:SAML:1.0:cm:artifact - - - - - - -''' - class CasProvider(object): def get_url(self): return patterns('cas', @@ -121,10 +83,10 @@ class CasProvider(object): else: scheme, domain, x, x, x, x = urlparse.urlparse(service) try: - cas_service = models.CasService.get(domain=domain) + models.CasService.objects.get(domain=domain) except models.CasService.DoesNotExist: - self.failure(request, 'service %r is not allowed' % service) - return self.handle_login(request, cas_service, service, renew, gateway) + return self.failure(request, 'service %r is not allowed' % service) + return self.handle_login(request, service, renew, gateway) def must_authenticate(self, request, renew): '''Does the user needs to authenticate ? @@ -143,7 +105,7 @@ class CasProvider(object): ''' return request.user.username - def handle_login(self, request, cas_service, service, renew, gateway, + def handle_login(self, request, service, renew, gateway, duration=None): ''' Handle a login request @@ -279,11 +241,9 @@ renew:%s and gateway:%s' % (service, renew, gateway)) content_type='text/xml') def get_attributes(self, request, st): - # XXX: st.service contains the requesting service URL, use it to match CAS attribute policy - return {}, False + return request.user.username, {} - def saml_build_attributes(self, request, st): - attributes, section = self.get_attributes(request, st) + def saml_build_attributes(self, request, st, attributes): result = [] for key, value in attributes.iteritems(): key = key.encode('utf-8') @@ -311,6 +271,7 @@ renew:%s and gateway:%s' % (service, renew, gateway)) st = None if st is None or not st.valid(): return self.saml_error(request, INVALID_TICKET_ERROR) + username, attributes = self.get_attributes(request, st) new_id = self.generate_id() ctx = { @@ -321,31 +282,27 @@ renew:%s and gateway:%s' % (service, renew, gateway)) 'not_before': '', # XXX: issue time - lag 'not_on_or_after': '', # XXX issue time + lag, 'audience': st.service.encode('utf-8'), - 'name_id': request.user.username, - 'attributes': self.saml_build_attributes(request, st), + 'name_id': unicode(username).encode('utf-8'), + 'attributes': self.saml_build_attributes(request, st, attributes), } return HttpResponse(SAML_RESPONSE_TEMPLATE.format(**ctx), content_type='text/xml') def service_validate_success_response(self, request, st): - attributes, section = self.get_attributes(request, st) + username, attributes = self.get_attributes(request, st) try: ET.register_namespace('cas', 'http://www.yale.edu/tp/cas') except AttributeError: ET._namespace_map['http://www.yale.edu/tp/cas'] = 'cas' root = ET.Element('{%s}%s' % (CAS_NAMESPACE, SERVICE_RESPONSE_ELT)) success = ET.SubElement(root, '{%s}%s' % (CAS_NAMESPACE, AUTHENTICATION_SUCCESS_ELT)) + user = ET.SubElement(success, '{%s}%s' % (CAS_NAMESPACE, USER_ELT)) + user.text = unicode(username) if attributes: - if section == 'default': - user = success - else: - user = ET.SubElement(success, '{%s}%s' % (CAS_NAMESPACE, section)) + container = ET.SubElement(success, '{%s}%s' % (CAS_NAMESPACE, ATTRIBUTES_ELT)) for key, value in attributes.iteritems(): - elt = ET.SubElement(user, '{%s}%s' % (CAS_NAMESPACE, key)) + elt = ET.SubElement(container, '{%s}%s' % (CAS_NAMESPACE, key)) elt.text = unicode(value) - else: - user = ET.SubElement(success, '{%s}%s' % (CAS_NAMESPACE, USER_ELT)) - user.text = unicode(st.user) return HttpResponse(ET.tostring(root, encoding='utf8'), content_type='text/xml') @@ -383,7 +340,7 @@ renew:%s and gateway:%s' % (service, renew, gateway)) return self.service_validate_success_response(request, st) except Exception: logger.exception('error in cas:service_validate') - return self.cas20_error(INTERNAL_ERROR) + return self.cas20_error(request, INTERNAL_ERROR) def logout(self, request): url = request.REQUEST.get('url') @@ -395,19 +352,10 @@ renew:%s and gateway:%s' % (service, renew, gateway)) class Authentic2CasProvider(CasProvider): def authenticate(self, request, st, passive=False): - next = '%s?id=%s' % (reverse(self.continue_cas), - urlquote(st.ticket_id)) - if passive: - if getattr(settings, 'AUTH_SSL', False): - query = { 'next': next, - 'nonce': st.ticket_id } - return HttpResponseRedirect('%s?%s' % - (reverse('user_signin_ssl'), urlencode(query))) - else: - return self.cas_failure(request, st, - '''user needs to login and no passive authentication \ -is possible''') - return auth2_redirect_to_login(request, next=next, nonce=st.ticket_id) + next_url = utils.url_add_parameters(reverse(self.continue_cas), + id=st.sticket_id) + return auth2_redirect_to_login(request, next=next_url, + nonce=st.ticket_id) def check_authentication(self, request, st): try: