diff --git a/AUTHORS b/AUTHORS index a6c6ef1d2..a1591b6da 100644 --- a/AUTHORS +++ b/AUTHORS @@ -66,4 +66,5 @@ pySilver Shaheed Haque Vinay Karanam Eduardo Oliveira +Andrea Greco Dominik George diff --git a/CHANGELOG.md b/CHANGELOG.md index b5a70cd5a..f86c13edc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * #651 Batch expired token deletions in `cleartokens` management command * Added pt-BR translations. * #1070 Add a Celery task for clearing expired tokens, e.g. to be scheduled as a [periodic task](https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html) +* #1069 OIDC: Re-introduce [additional claims](https://django-oauth-toolkit.readthedocs.io/en/latest/oidc.html#adding-claims-to-the-id-token) beyond `sub` to the id_token. ### Fixed * #1012 Return status for introspecting a nonexistent token from 401 to the correct value of 200 per [RFC 7662](https://datatracker.ietf.org/doc/html/rfc7662#section-2.2). diff --git a/docs/oidc.rst b/docs/oidc.rst index ba69e984f..143bec5e5 100644 --- a/docs/oidc.rst +++ b/docs/oidc.rst @@ -245,16 +245,45 @@ required claims, eg ``iss``, ``aud``, ``exp``, ``iat``, ``auth_time`` etc), and the ``sub`` claim will use the primary key of the user as the value. You'll probably want to customize this and add additional claims or change what is sent for the ``sub`` claim. To do so, you will need to add a method to -our custom validator:: +our custom validator. It takes one of two forms: +The first form gets passed a request object, and should return a dictionary +mapping a claim name to claim data:: class CustomOAuth2Validator(OAuth2Validator): - def get_additional_claims(self, request): - return { - "sub": request.user.email, - "first_name": request.user.first_name, - "last_name": request.user.last_name, - } + claims = {} + claims["email"] = request.user.get_user_email() + claims["username"] = request.user.get_full_name() + + return claims + +The second form gets no request object, and should return a dictionary +mapping a claim name to a callable, accepting a request and producing +the claim data:: + class CustomOAuth2Validator(OAuth2Validator): + def get_additional_claims(self): + def get_user_email(request): + return request.user.get_user_email() + + claims = {} + claims["email"] = get_user_email + claims["username"] = lambda r: r.user.get_full_name() + + return claims + +Standard claim ``sub`` is included by default, to remove it override ``get_claim_dict``. + +In some cases, it might be desirable to not list all claims in discovery info. To customize +which claims are advertised, you can override the ``get_discovery_claims`` method to return +a list of claim names to advertise. If your ``get_additional_claims`` uses the first form +and you still want to advertise claims, you can also override ``get_discovery_claims``. + +In order to help lcients discover claims early, they can be advertised in the discovery +info, under the ``claims_supported`` key. In order for the discovery info view to automatically +add all claims your validator returns, you need to use the second form (producing callables), +because the discovery info views are requested with an unauthenticated request, so directly +producing claim data would fail. If you use the first form, producing claim data directly, +your claims will not be added to discovery info. .. note:: This ``request`` object is not a ``django.http.Request`` object, but an diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index f3a24e258..4d9480be1 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -1,6 +1,7 @@ import base64 import binascii import http.client +import inspect import json import logging import uuid @@ -725,18 +726,40 @@ def _save_id_token(self, jti, request, expires, *args, **kwargs): ) return id_token + @classmethod + def _get_additional_claims_is_request_agnostic(cls): + return len(inspect.signature(cls.get_additional_claims).parameters) == 1 + def get_jwt_bearer_token(self, token, token_handler, request): return self.get_id_token(token, token_handler, request) - def get_oidc_claims(self, token, token_handler, request): - # Required OIDC claims - claims = { - "sub": str(request.user.id), - } + def get_claim_dict(self, request): + if self._get_additional_claims_is_request_agnostic(): + claims = {"sub": lambda r: str(r.user.id)} + else: + claims = {"sub": str(request.user.id)} # https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims - claims.update(**self.get_additional_claims(request)) + if self._get_additional_claims_is_request_agnostic(): + add = self.get_additional_claims() + else: + add = self.get_additional_claims(request) + claims.update(add) + + return claims + + def get_discovery_claims(self, request): + claims = ["sub"] + if self._get_additional_claims_is_request_agnostic(): + claims += list(self.get_claim_dict(request).keys()) + return claims + + def get_oidc_claims(self, token, token_handler, request): + data = self.get_claim_dict(request) + claims = {} + for k, v in data.items(): + claims[k] = v(request) if callable(v) else v return claims def get_id_token_dictionary(self, token, token_handler, request): diff --git a/oauth2_provider/views/oidc.py b/oauth2_provider/views/oidc.py index b4bb8869b..e66b30a86 100644 --- a/oauth2_provider/views/oidc.py +++ b/oauth2_provider/views/oidc.py @@ -45,6 +45,11 @@ def get(self, request, *args, **kwargs): signing_algorithms = [Application.HS256_ALGORITHM] if oauth2_settings.OIDC_RSA_PRIVATE_KEY: signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM] + + validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS + validator = validator_class() + oidc_claims = list(set(validator.get_discovery_claims(request))) + data = { "issuer": issuer_url, "authorization_endpoint": authorization_endpoint, @@ -57,6 +62,7 @@ def get(self, request, *args, **kwargs): "token_endpoint_auth_methods_supported": ( oauth2_settings.OIDC_TOKEN_ENDPOINT_AUTH_METHODS_SUPPORTED ), + "claims_supported": oidc_claims, } response = JsonResponse(data) response["Access-Control-Allow-Origin"] = "*" diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 46040f86d..fa514ac92 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -29,6 +29,7 @@ def test_get_connect_discovery_info(self): "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256", "HS256"], "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "claims_supported": ["sub"], } response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) @@ -55,6 +56,7 @@ def test_get_connect_discovery_info_without_issuer_url(self): "subject_types_supported": ["public"], "id_token_signing_alg_values_supported": ["RS256", "HS256"], "token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"], + "claims_supported": ["sub"], } response = self.client.get(reverse("oauth2_provider:oidc-connect-discovery-info")) self.assertEqual(response.status_code, 200) @@ -146,11 +148,47 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 +EXAMPLE_EMAIL = "example.email@example.com" + + +def claim_user_email(request): + return EXAMPLE_EMAIL + + +@pytest.mark.django_db +def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings): + class CustomValidator(OAuth2Validator): + def get_additional_claims(self): + return { + "username": claim_user_email, + "email": claim_user_email, + } + + oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator + auth_header = "Bearer %s" % oidc_tokens.access_token + rsp = client.get( + reverse("oauth2_provider:user-info"), + HTTP_AUTHORIZATION=auth_header, + ) + data = rsp.json() + assert "sub" in data + assert data["sub"] == str(oidc_tokens.user.pk) + + assert "username" in data + assert data["username"] == EXAMPLE_EMAIL + + assert "email" in data + assert data["email"] == EXAMPLE_EMAIL + + @pytest.mark.django_db -def test_userinfo_endpoint_custom_claims(oidc_tokens, client, oauth2_settings): +def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): def get_additional_claims(self, request): - return {"state": "very nice"} + return { + "username": EXAMPLE_EMAIL, + "email": EXAMPLE_EMAIL, + } oidc_tokens.oauth2_settings.OAUTH2_VALIDATOR_CLASS = CustomValidator auth_header = "Bearer %s" % oidc_tokens.access_token @@ -161,5 +199,9 @@ def get_additional_claims(self, request): data = rsp.json() assert "sub" in data assert data["sub"] == str(oidc_tokens.user.pk) - assert "state" in data - assert data["state"] == "very nice" + + assert "username" in data + assert data["username"] == EXAMPLE_EMAIL + + assert "email" in data + assert data["email"] == EXAMPLE_EMAIL