authentik.enterprise.endpoints.connectors.agent.views.apple_token

  1from typing import Any
  2
  3from django.http import HttpRequest, HttpResponse
  4from django.urls import reverse
  5from django.utils.decorators import method_decorator
  6from django.utils.timezone import now
  7from django.views import View
  8from django.views.decorators.csrf import csrf_exempt
  9from jwt import PyJWTError, decode, encode, get_unverified_header
 10from rest_framework.exceptions import ValidationError
 11from structlog.stdlib import get_logger
 12
 13from authentik.common.oauth.constants import TOKEN_TYPE
 14from authentik.core.models import AuthenticatedSession, Session, User
 15from authentik.core.sessions import SessionStore
 16from authentik.crypto.apps import MANAGED_KEY
 17from authentik.crypto.models import CertificateKeyPair
 18from authentik.endpoints.connectors.agent.models import (
 19    AgentConnector,
 20    AgentDeviceConnection,
 21    AgentDeviceUserBinding,
 22    AppleNonce,
 23    DeviceAuthenticationToken,
 24)
 25from authentik.enterprise.endpoints.connectors.agent.http import JWEResponse
 26from authentik.events.models import Event, EventAction
 27from authentik.events.signals import SESSION_LOGIN_EVENT
 28from authentik.flows.planner import PLAN_CONTEXT_DEVICE
 29from authentik.lib.utils.time import timedelta_from_string
 30from authentik.providers.oauth2.id_token import IDToken
 31from authentik.providers.oauth2.models import JWTAlgorithms
 32from authentik.root.middleware import SessionMiddleware
 33
 34LOGGER = get_logger()
 35
 36
 37@method_decorator(csrf_exempt, name="dispatch")
 38class TokenView(View):
 39
 40    device_connection: AgentDeviceConnection
 41    connector: AgentConnector
 42
 43    def post(self, request: HttpRequest) -> HttpResponse:
 44        assertion = request.POST.get("assertion", request.POST.get("request"))
 45        if not assertion:
 46            return HttpResponse(status=400)
 47        self.now = now()
 48        try:
 49            self.jwt_request = self.validate_request_token(assertion)
 50        except PyJWTError as exc:
 51            LOGGER.warning("failed to parse JWT", exc=exc)
 52            raise ValidationError("Invalid request") from None
 53        version = request.POST.get("platform_sso_version")
 54        grant_type = request.POST.get("grant_type")
 55        handler_func = (
 56            f"handle_v{version}_{grant_type}".replace("-", "_")
 57            .replace("+", "_")
 58            .replace(":", "_")
 59            .replace(".", "_")
 60        )
 61        handler = getattr(self, handler_func, None)
 62        if not handler:
 63            LOGGER.debug("Handler not found", handler=handler_func)
 64            return HttpResponse(status=400)
 65        LOGGER.debug("sending to handler", handler=handler_func)
 66        return handler()
 67
 68    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 69        # Decode without validation to get header
 70        header = get_unverified_header(assertion)
 71        LOGGER.debug("token header", header=header)
 72        expected_kid = header["kid"]
 73
 74        self.device_connection = (
 75            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 76            .select_related("device")
 77            .first()
 78        )
 79        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 80        LOGGER.debug("got device", device=self.device_connection.device)
 81
 82        expected_aud = self.request.build_absolute_uri(
 83            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 84        )
 85        if not self.device_connection.apple_signing_key:
 86            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 87            raise ValidationError("Invalid request")
 88        # Properly decode the JWT with the key from the device
 89        decoded = decode(
 90            assertion,
 91            self.device_connection.apple_signing_key,
 92            algorithms=["ES256"],
 93            audience=expected_aud,
 94            issuer=str(self.connector.pk),
 95        )
 96        self.remote_nonce = decoded.get("nonce")
 97
 98        # Check that the nonce hasn't been used before
 99        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
100        if not nonce:
101            raise ValidationError("Invalid nonce")
102        self.nonce = nonce
103        nonce.delete()
104        return decoded
105
106    def validate_embedded_assertion(self, assertion: str) -> tuple[AgentDeviceUserBinding, dict]:
107        """Decode an embedded assertion and validate it by looking up the matching device user"""
108        decode_unvalidated = get_unverified_header(assertion)
109        expected_kid = decode_unvalidated["kid"]
110
111        device_user = AgentDeviceUserBinding.objects.filter(
112            target=self.device_connection.device, apple_enclave_key_id=expected_kid
113        ).first()
114        if not device_user:
115            LOGGER.warning("Could not find device user binding for user")
116            raise ValidationError("Invalid request")
117        decoded: dict[str, Any] = decode(
118            assertion,
119            device_user.apple_secure_enclave_key,
120            audience=str(self.device_connection.device.pk),
121            algorithms=["ES256"],
122        )
123        if decoded.get("nonce") != self.jwt_request.get("nonce"):
124            LOGGER.warning("Mis-matched nonce to outer assertion")
125            raise ValidationError("Invalid nonce")
126        return device_user, decoded
127
128    def create_auth_session(self, user: User):
129        event = Event.new(
130            EventAction.LOGIN,
131            app="authentik.endpoints.connectors.agent",
132            **{
133                PLAN_CONTEXT_DEVICE: self.device_connection.device,
134            },
135        ).from_http(self.request, user=user)
136        store = SessionStore()
137        store[SESSION_LOGIN_EVENT] = event
138        store.save()
139        session = Session.objects.filter(session_key=store.session_key).first()
140        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
141        AuthenticatedSession.objects.create(session=session, user=user)
142        session = SessionMiddleware.encode_session(store.session_key, user)
143        return session
144
145    def create_id_token(self, user: User, **kwargs):
146        issuer = self.request.build_absolute_uri(
147            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
148        )
149        id_token = IDToken(
150            iss=issuer,
151            sub=user.username,
152            aud=str(self.connector.pk),
153            exp=int(
154                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
155            ),
156            iat=int(now().timestamp()),
157            **kwargs,
158        )
159        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
160        return encode(
161            id_token.to_dict(),
162            kp.private_key,
163            headers={
164                "kid": kp.kid,
165            },
166            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
167        )
168
169    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
170        try:
171            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
172        except PyJWTError as exc:
173            LOGGER.warning("failed to validate inner assertion", exc=exc)
174            raise ValidationError("Invalid request") from None
175        id_token = self.create_id_token(user.user)
176        auth_token = DeviceAuthenticationToken.objects.create(
177            device=self.device_connection.device,
178            connector=self.connector,
179            user=user.user,
180            device_token=self.nonce.device_token,
181        )
182        return JWEResponse(
183            {
184                "refresh_token": auth_token.token,
185                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
186                "id_token": id_token,
187                "token_type": TOKEN_TYPE,
188                "session_key": self.create_auth_session(user.user),
189            },
190            device=self.device_connection,
191            apv=self.jwt_request["jwe_crypto"]["apv"],
192        )
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
@method_decorator(csrf_exempt, name='dispatch')
class TokenView(django.views.generic.base.View):
 38@method_decorator(csrf_exempt, name="dispatch")
 39class TokenView(View):
 40
 41    device_connection: AgentDeviceConnection
 42    connector: AgentConnector
 43
 44    def post(self, request: HttpRequest) -> HttpResponse:
 45        assertion = request.POST.get("assertion", request.POST.get("request"))
 46        if not assertion:
 47            return HttpResponse(status=400)
 48        self.now = now()
 49        try:
 50            self.jwt_request = self.validate_request_token(assertion)
 51        except PyJWTError as exc:
 52            LOGGER.warning("failed to parse JWT", exc=exc)
 53            raise ValidationError("Invalid request") from None
 54        version = request.POST.get("platform_sso_version")
 55        grant_type = request.POST.get("grant_type")
 56        handler_func = (
 57            f"handle_v{version}_{grant_type}".replace("-", "_")
 58            .replace("+", "_")
 59            .replace(":", "_")
 60            .replace(".", "_")
 61        )
 62        handler = getattr(self, handler_func, None)
 63        if not handler:
 64            LOGGER.debug("Handler not found", handler=handler_func)
 65            return HttpResponse(status=400)
 66        LOGGER.debug("sending to handler", handler=handler_func)
 67        return handler()
 68
 69    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 70        # Decode without validation to get header
 71        header = get_unverified_header(assertion)
 72        LOGGER.debug("token header", header=header)
 73        expected_kid = header["kid"]
 74
 75        self.device_connection = (
 76            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 77            .select_related("device")
 78            .first()
 79        )
 80        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 81        LOGGER.debug("got device", device=self.device_connection.device)
 82
 83        expected_aud = self.request.build_absolute_uri(
 84            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 85        )
 86        if not self.device_connection.apple_signing_key:
 87            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 88            raise ValidationError("Invalid request")
 89        # Properly decode the JWT with the key from the device
 90        decoded = decode(
 91            assertion,
 92            self.device_connection.apple_signing_key,
 93            algorithms=["ES256"],
 94            audience=expected_aud,
 95            issuer=str(self.connector.pk),
 96        )
 97        self.remote_nonce = decoded.get("nonce")
 98
 99        # Check that the nonce hasn't been used before
100        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
101        if not nonce:
102            raise ValidationError("Invalid nonce")
103        self.nonce = nonce
104        nonce.delete()
105        return decoded
106
107    def validate_embedded_assertion(self, assertion: str) -> tuple[AgentDeviceUserBinding, dict]:
108        """Decode an embedded assertion and validate it by looking up the matching device user"""
109        decode_unvalidated = get_unverified_header(assertion)
110        expected_kid = decode_unvalidated["kid"]
111
112        device_user = AgentDeviceUserBinding.objects.filter(
113            target=self.device_connection.device, apple_enclave_key_id=expected_kid
114        ).first()
115        if not device_user:
116            LOGGER.warning("Could not find device user binding for user")
117            raise ValidationError("Invalid request")
118        decoded: dict[str, Any] = decode(
119            assertion,
120            device_user.apple_secure_enclave_key,
121            audience=str(self.device_connection.device.pk),
122            algorithms=["ES256"],
123        )
124        if decoded.get("nonce") != self.jwt_request.get("nonce"):
125            LOGGER.warning("Mis-matched nonce to outer assertion")
126            raise ValidationError("Invalid nonce")
127        return device_user, decoded
128
129    def create_auth_session(self, user: User):
130        event = Event.new(
131            EventAction.LOGIN,
132            app="authentik.endpoints.connectors.agent",
133            **{
134                PLAN_CONTEXT_DEVICE: self.device_connection.device,
135            },
136        ).from_http(self.request, user=user)
137        store = SessionStore()
138        store[SESSION_LOGIN_EVENT] = event
139        store.save()
140        session = Session.objects.filter(session_key=store.session_key).first()
141        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
142        AuthenticatedSession.objects.create(session=session, user=user)
143        session = SessionMiddleware.encode_session(store.session_key, user)
144        return session
145
146    def create_id_token(self, user: User, **kwargs):
147        issuer = self.request.build_absolute_uri(
148            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
149        )
150        id_token = IDToken(
151            iss=issuer,
152            sub=user.username,
153            aud=str(self.connector.pk),
154            exp=int(
155                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
156            ),
157            iat=int(now().timestamp()),
158            **kwargs,
159        )
160        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
161        return encode(
162            id_token.to_dict(),
163            kp.private_key,
164            headers={
165                "kid": kp.kid,
166            },
167            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
168        )
169
170    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
171        try:
172            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
173        except PyJWTError as exc:
174            LOGGER.warning("failed to validate inner assertion", exc=exc)
175            raise ValidationError("Invalid request") from None
176        id_token = self.create_id_token(user.user)
177        auth_token = DeviceAuthenticationToken.objects.create(
178            device=self.device_connection.device,
179            connector=self.connector,
180            user=user.user,
181            device_token=self.nonce.device_token,
182        )
183        return JWEResponse(
184            {
185                "refresh_token": auth_token.token,
186                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
187                "id_token": id_token,
188                "token_type": TOKEN_TYPE,
189                "session_key": self.create_auth_session(user.user),
190            },
191            device=self.device_connection,
192            apv=self.jwt_request["jwe_crypto"]["apv"],
193        )

Intentionally simple parent class for all views. Only implements dispatch-by-method and simple sanity checking.

def post( self, request: django.http.request.HttpRequest) -> django.http.response.HttpResponse:
44    def post(self, request: HttpRequest) -> HttpResponse:
45        assertion = request.POST.get("assertion", request.POST.get("request"))
46        if not assertion:
47            return HttpResponse(status=400)
48        self.now = now()
49        try:
50            self.jwt_request = self.validate_request_token(assertion)
51        except PyJWTError as exc:
52            LOGGER.warning("failed to parse JWT", exc=exc)
53            raise ValidationError("Invalid request") from None
54        version = request.POST.get("platform_sso_version")
55        grant_type = request.POST.get("grant_type")
56        handler_func = (
57            f"handle_v{version}_{grant_type}".replace("-", "_")
58            .replace("+", "_")
59            .replace(":", "_")
60            .replace(".", "_")
61        )
62        handler = getattr(self, handler_func, None)
63        if not handler:
64            LOGGER.debug("Handler not found", handler=handler_func)
65            return HttpResponse(status=400)
66        LOGGER.debug("sending to handler", handler=handler_func)
67        return handler()
def validate_request_token(self, assertion: str) -> dict[str, typing.Any]:
 69    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 70        # Decode without validation to get header
 71        header = get_unverified_header(assertion)
 72        LOGGER.debug("token header", header=header)
 73        expected_kid = header["kid"]
 74
 75        self.device_connection = (
 76            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 77            .select_related("device")
 78            .first()
 79        )
 80        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 81        LOGGER.debug("got device", device=self.device_connection.device)
 82
 83        expected_aud = self.request.build_absolute_uri(
 84            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 85        )
 86        if not self.device_connection.apple_signing_key:
 87            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 88            raise ValidationError("Invalid request")
 89        # Properly decode the JWT with the key from the device
 90        decoded = decode(
 91            assertion,
 92            self.device_connection.apple_signing_key,
 93            algorithms=["ES256"],
 94            audience=expected_aud,
 95            issuer=str(self.connector.pk),
 96        )
 97        self.remote_nonce = decoded.get("nonce")
 98
 99        # Check that the nonce hasn't been used before
100        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
101        if not nonce:
102            raise ValidationError("Invalid nonce")
103        self.nonce = nonce
104        nonce.delete()
105        return decoded
def validate_embedded_assertion( self, assertion: str) -> tuple[authentik.endpoints.connectors.agent.models.AgentDeviceUserBinding, dict]:
107    def validate_embedded_assertion(self, assertion: str) -> tuple[AgentDeviceUserBinding, dict]:
108        """Decode an embedded assertion and validate it by looking up the matching device user"""
109        decode_unvalidated = get_unverified_header(assertion)
110        expected_kid = decode_unvalidated["kid"]
111
112        device_user = AgentDeviceUserBinding.objects.filter(
113            target=self.device_connection.device, apple_enclave_key_id=expected_kid
114        ).first()
115        if not device_user:
116            LOGGER.warning("Could not find device user binding for user")
117            raise ValidationError("Invalid request")
118        decoded: dict[str, Any] = decode(
119            assertion,
120            device_user.apple_secure_enclave_key,
121            audience=str(self.device_connection.device.pk),
122            algorithms=["ES256"],
123        )
124        if decoded.get("nonce") != self.jwt_request.get("nonce"):
125            LOGGER.warning("Mis-matched nonce to outer assertion")
126            raise ValidationError("Invalid nonce")
127        return device_user, decoded

Decode an embedded assertion and validate it by looking up the matching device user

def create_auth_session(self, user: authentik.core.models.User):
129    def create_auth_session(self, user: User):
130        event = Event.new(
131            EventAction.LOGIN,
132            app="authentik.endpoints.connectors.agent",
133            **{
134                PLAN_CONTEXT_DEVICE: self.device_connection.device,
135            },
136        ).from_http(self.request, user=user)
137        store = SessionStore()
138        store[SESSION_LOGIN_EVENT] = event
139        store.save()
140        session = Session.objects.filter(session_key=store.session_key).first()
141        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
142        AuthenticatedSession.objects.create(session=session, user=user)
143        session = SessionMiddleware.encode_session(store.session_key, user)
144        return session
def create_id_token(self, user: authentik.core.models.User, **kwargs):
146    def create_id_token(self, user: User, **kwargs):
147        issuer = self.request.build_absolute_uri(
148            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
149        )
150        id_token = IDToken(
151            iss=issuer,
152            sub=user.username,
153            aud=str(self.connector.pk),
154            exp=int(
155                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
156            ),
157            iat=int(now().timestamp()),
158            **kwargs,
159        )
160        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
161        return encode(
162            id_token.to_dict(),
163            kp.private_key,
164            headers={
165                "kid": kp.kid,
166            },
167            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
168        )
def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
170    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
171        try:
172            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
173        except PyJWTError as exc:
174            LOGGER.warning("failed to validate inner assertion", exc=exc)
175            raise ValidationError("Invalid request") from None
176        id_token = self.create_id_token(user.user)
177        auth_token = DeviceAuthenticationToken.objects.create(
178            device=self.device_connection.device,
179            connector=self.connector,
180            user=user.user,
181            device_token=self.nonce.device_token,
182        )
183        return JWEResponse(
184            {
185                "refresh_token": auth_token.token,
186                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
187                "id_token": id_token,
188                "token_type": TOKEN_TYPE,
189                "session_key": self.create_auth_session(user.user),
190            },
191            device=self.device_connection,
192            apv=self.jwt_request["jwe_crypto"]["apv"],
193        )
def dispatch(self, request, *args, **kwargs):
135    def dispatch(self, request, *args, **kwargs):
136        # Try to dispatch to the right method; if a method doesn't exist,
137        # defer to the error handler. Also defer to the error handler if the
138        # request method isn't on the approved list.
139        if request.method.lower() in self.http_method_names:
140            handler = getattr(
141                self, request.method.lower(), self.http_method_not_allowed
142            )
143        else:
144            handler = self.http_method_not_allowed
145        return handler(request, *args, **kwargs)