authentik.enterprise.endpoints.connectors.agent.views.apple_token

  1from typing import Any
  2
  3from django.http import HttpRequest, HttpResponse
  4from django.urls import reverse
  5from django.utils.decorators import method_decorator
  6from django.utils.timezone import now
  7from django.views import View
  8from django.views.decorators.csrf import csrf_exempt
  9from jwt import PyJWTError, decode, encode, get_unverified_header
 10from rest_framework.exceptions import ValidationError
 11from structlog.stdlib import get_logger
 12
 13from authentik.common.oauth.constants import TOKEN_TYPE
 14from authentik.core.models import AuthenticatedSession, Session, User
 15from authentik.core.sessions import SessionStore
 16from authentik.crypto.apps import MANAGED_KEY
 17from authentik.crypto.models import CertificateKeyPair
 18from authentik.endpoints.connectors.agent.models import (
 19    AgentConnector,
 20    AgentDeviceConnection,
 21    AgentDeviceUserBinding,
 22    AppleIndependentSecureEnclave,
 23    AppleNonce,
 24    DeviceAuthenticationToken,
 25)
 26from authentik.enterprise.endpoints.connectors.agent.http import JWEResponse
 27from authentik.events.models import Event, EventAction
 28from authentik.events.signals import SESSION_LOGIN_EVENT
 29from authentik.flows.planner import PLAN_CONTEXT_DEVICE
 30from authentik.lib.utils.time import timedelta_from_string
 31from authentik.providers.oauth2.id_token import IDToken
 32from authentik.providers.oauth2.models import JWTAlgorithms
 33from authentik.root.middleware import SessionMiddleware
 34
 35LOGGER = get_logger()
 36
 37
 38@method_decorator(csrf_exempt, name="dispatch")
 39class TokenView(View):
 40
 41    device_connection: AgentDeviceConnection
 42    connector: AgentConnector
 43
 44    def post(self, request: HttpRequest) -> HttpResponse:
 45        assertion = request.POST.get("assertion", request.POST.get("request"))
 46        if not assertion:
 47            return HttpResponse(status=400)
 48        self.now = now()
 49        try:
 50            self.jwt_request = self.validate_request_token(assertion)
 51        except PyJWTError as exc:
 52            LOGGER.warning("failed to parse JWT", exc=exc)
 53            raise ValidationError("Invalid request") from None
 54        version = request.POST.get("platform_sso_version")
 55        grant_type = request.POST.get("grant_type")
 56        handler_func = (
 57            f"handle_v{version}_{grant_type}".replace("-", "_")
 58            .replace("+", "_")
 59            .replace(":", "_")
 60            .replace(".", "_")
 61        )
 62        handler = getattr(self, handler_func, None)
 63        if not handler:
 64            LOGGER.debug("Handler not found", handler=handler_func)
 65            return HttpResponse(status=400)
 66        LOGGER.debug("sending to handler", handler=handler_func)
 67        return handler()
 68
 69    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 70        # Decode without validation to get header
 71        header = get_unverified_header(assertion)
 72        LOGGER.debug("token header", header=header)
 73        expected_kid = header["kid"]
 74
 75        self.device_connection = (
 76            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 77            .select_related("device")
 78            .first()
 79        )
 80        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 81        LOGGER.debug("got device", device=self.device_connection.device)
 82
 83        expected_aud = self.request.build_absolute_uri(
 84            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 85        )
 86        if not self.device_connection.apple_signing_key:
 87            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 88            raise ValidationError("Invalid request")
 89        # Properly decode the JWT with the key from the device
 90        decoded = decode(
 91            assertion,
 92            self.device_connection.apple_signing_key,
 93            algorithms=["ES256"],
 94            audience=expected_aud,
 95            issuer=str(self.connector.pk),
 96        )
 97        self.remote_nonce = decoded.get("nonce")
 98
 99        # Check that the nonce hasn't been used before
100        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
101        if not nonce:
102            raise ValidationError("Invalid nonce")
103        self.nonce = nonce
104        nonce.delete()
105        return decoded
106
107    def validate_embedded_assertion(
108        self, assertion: str
109    ) -> tuple[AgentDeviceUserBinding | AppleIndependentSecureEnclave, dict]:
110        """Decode an embedded assertion and validate it by looking up the matching device user"""
111        decode_unvalidated = get_unverified_header(assertion)
112        expected_kid = decode_unvalidated["kid"]
113
114        device_user = AgentDeviceUserBinding.objects.filter(
115            target=self.device_connection.device, apple_enclave_key_id=expected_kid
116        ).first()
117        if not device_user:
118            independent_user = AppleIndependentSecureEnclave.objects.filter(
119                apple_enclave_key_id=expected_kid
120            ).first()
121            if not independent_user:
122                LOGGER.warning("Could not find device user binding or independent enclave for user")
123                raise ValidationError("Invalid request")
124            device_user = independent_user
125        decoded: dict[str, Any] = decode(
126            assertion,
127            device_user.apple_secure_enclave_key,
128            audience=str(self.device_connection.device.pk),
129            algorithms=["ES256"],
130        )
131        if decoded.get("nonce") != self.jwt_request.get("nonce"):
132            LOGGER.warning("Mis-matched nonce to outer assertion")
133            raise ValidationError("Invalid nonce")
134        return device_user, decoded
135
136    def create_auth_session(self, user: User):
137        event = Event.new(
138            EventAction.LOGIN,
139            app="authentik.endpoints.connectors.agent",
140            **{
141                PLAN_CONTEXT_DEVICE: self.device_connection.device,
142            },
143        ).from_http(self.request, user=user)
144        store = SessionStore()
145        store[SESSION_LOGIN_EVENT] = event
146        store.save()
147        session = Session.objects.filter(session_key=store.session_key).first()
148        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
149        AuthenticatedSession.objects.create(session=session, user=user)
150        session = SessionMiddleware.encode_session(store.session_key, user)
151        return session
152
153    def create_id_token(self, user: User, **kwargs):
154        issuer = self.request.build_absolute_uri(
155            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
156        )
157        id_token = IDToken(
158            iss=issuer,
159            sub=user.username,
160            aud=str(self.connector.pk),
161            exp=int(
162                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
163            ),
164            iat=int(now().timestamp()),
165            **kwargs,
166        )
167        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
168        return encode(
169            id_token.to_dict(),
170            kp.private_key,
171            headers={
172                "kid": kp.kid,
173            },
174            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
175        )
176
177    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
178        try:
179            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
180        except PyJWTError as exc:
181            LOGGER.warning("failed to validate inner assertion", exc=exc)
182            raise ValidationError("Invalid request") from None
183        id_token = self.create_id_token(user.user)
184        auth_token = DeviceAuthenticationToken.objects.create(
185            device=self.device_connection.device,
186            connector=self.connector,
187            user=user.user,
188            device_token=self.nonce.device_token,
189        )
190        return JWEResponse(
191            {
192                "refresh_token": auth_token.token,
193                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
194                "id_token": id_token,
195                "token_type": TOKEN_TYPE,
196                "session_key": self.create_auth_session(user.user),
197            },
198            device=self.device_connection,
199            apv=self.jwt_request["jwe_crypto"]["apv"],
200        )
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
@method_decorator(csrf_exempt, name='dispatch')
class TokenView(django.views.generic.base.View):
 39@method_decorator(csrf_exempt, name="dispatch")
 40class TokenView(View):
 41
 42    device_connection: AgentDeviceConnection
 43    connector: AgentConnector
 44
 45    def post(self, request: HttpRequest) -> HttpResponse:
 46        assertion = request.POST.get("assertion", request.POST.get("request"))
 47        if not assertion:
 48            return HttpResponse(status=400)
 49        self.now = now()
 50        try:
 51            self.jwt_request = self.validate_request_token(assertion)
 52        except PyJWTError as exc:
 53            LOGGER.warning("failed to parse JWT", exc=exc)
 54            raise ValidationError("Invalid request") from None
 55        version = request.POST.get("platform_sso_version")
 56        grant_type = request.POST.get("grant_type")
 57        handler_func = (
 58            f"handle_v{version}_{grant_type}".replace("-", "_")
 59            .replace("+", "_")
 60            .replace(":", "_")
 61            .replace(".", "_")
 62        )
 63        handler = getattr(self, handler_func, None)
 64        if not handler:
 65            LOGGER.debug("Handler not found", handler=handler_func)
 66            return HttpResponse(status=400)
 67        LOGGER.debug("sending to handler", handler=handler_func)
 68        return handler()
 69
 70    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 71        # Decode without validation to get header
 72        header = get_unverified_header(assertion)
 73        LOGGER.debug("token header", header=header)
 74        expected_kid = header["kid"]
 75
 76        self.device_connection = (
 77            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 78            .select_related("device")
 79            .first()
 80        )
 81        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 82        LOGGER.debug("got device", device=self.device_connection.device)
 83
 84        expected_aud = self.request.build_absolute_uri(
 85            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 86        )
 87        if not self.device_connection.apple_signing_key:
 88            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 89            raise ValidationError("Invalid request")
 90        # Properly decode the JWT with the key from the device
 91        decoded = decode(
 92            assertion,
 93            self.device_connection.apple_signing_key,
 94            algorithms=["ES256"],
 95            audience=expected_aud,
 96            issuer=str(self.connector.pk),
 97        )
 98        self.remote_nonce = decoded.get("nonce")
 99
100        # Check that the nonce hasn't been used before
101        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
102        if not nonce:
103            raise ValidationError("Invalid nonce")
104        self.nonce = nonce
105        nonce.delete()
106        return decoded
107
108    def validate_embedded_assertion(
109        self, assertion: str
110    ) -> tuple[AgentDeviceUserBinding | AppleIndependentSecureEnclave, dict]:
111        """Decode an embedded assertion and validate it by looking up the matching device user"""
112        decode_unvalidated = get_unverified_header(assertion)
113        expected_kid = decode_unvalidated["kid"]
114
115        device_user = AgentDeviceUserBinding.objects.filter(
116            target=self.device_connection.device, apple_enclave_key_id=expected_kid
117        ).first()
118        if not device_user:
119            independent_user = AppleIndependentSecureEnclave.objects.filter(
120                apple_enclave_key_id=expected_kid
121            ).first()
122            if not independent_user:
123                LOGGER.warning("Could not find device user binding or independent enclave for user")
124                raise ValidationError("Invalid request")
125            device_user = independent_user
126        decoded: dict[str, Any] = decode(
127            assertion,
128            device_user.apple_secure_enclave_key,
129            audience=str(self.device_connection.device.pk),
130            algorithms=["ES256"],
131        )
132        if decoded.get("nonce") != self.jwt_request.get("nonce"):
133            LOGGER.warning("Mis-matched nonce to outer assertion")
134            raise ValidationError("Invalid nonce")
135        return device_user, decoded
136
137    def create_auth_session(self, user: User):
138        event = Event.new(
139            EventAction.LOGIN,
140            app="authentik.endpoints.connectors.agent",
141            **{
142                PLAN_CONTEXT_DEVICE: self.device_connection.device,
143            },
144        ).from_http(self.request, user=user)
145        store = SessionStore()
146        store[SESSION_LOGIN_EVENT] = event
147        store.save()
148        session = Session.objects.filter(session_key=store.session_key).first()
149        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
150        AuthenticatedSession.objects.create(session=session, user=user)
151        session = SessionMiddleware.encode_session(store.session_key, user)
152        return session
153
154    def create_id_token(self, user: User, **kwargs):
155        issuer = self.request.build_absolute_uri(
156            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
157        )
158        id_token = IDToken(
159            iss=issuer,
160            sub=user.username,
161            aud=str(self.connector.pk),
162            exp=int(
163                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
164            ),
165            iat=int(now().timestamp()),
166            **kwargs,
167        )
168        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
169        return encode(
170            id_token.to_dict(),
171            kp.private_key,
172            headers={
173                "kid": kp.kid,
174            },
175            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
176        )
177
178    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
179        try:
180            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
181        except PyJWTError as exc:
182            LOGGER.warning("failed to validate inner assertion", exc=exc)
183            raise ValidationError("Invalid request") from None
184        id_token = self.create_id_token(user.user)
185        auth_token = DeviceAuthenticationToken.objects.create(
186            device=self.device_connection.device,
187            connector=self.connector,
188            user=user.user,
189            device_token=self.nonce.device_token,
190        )
191        return JWEResponse(
192            {
193                "refresh_token": auth_token.token,
194                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
195                "id_token": id_token,
196                "token_type": TOKEN_TYPE,
197                "session_key": self.create_auth_session(user.user),
198            },
199            device=self.device_connection,
200            apv=self.jwt_request["jwe_crypto"]["apv"],
201        )

Intentionally simple parent class for all views. Only implements dispatch-by-method and simple sanity checking.

def post( self, request: django.http.request.HttpRequest) -> django.http.response.HttpResponse:
45    def post(self, request: HttpRequest) -> HttpResponse:
46        assertion = request.POST.get("assertion", request.POST.get("request"))
47        if not assertion:
48            return HttpResponse(status=400)
49        self.now = now()
50        try:
51            self.jwt_request = self.validate_request_token(assertion)
52        except PyJWTError as exc:
53            LOGGER.warning("failed to parse JWT", exc=exc)
54            raise ValidationError("Invalid request") from None
55        version = request.POST.get("platform_sso_version")
56        grant_type = request.POST.get("grant_type")
57        handler_func = (
58            f"handle_v{version}_{grant_type}".replace("-", "_")
59            .replace("+", "_")
60            .replace(":", "_")
61            .replace(".", "_")
62        )
63        handler = getattr(self, handler_func, None)
64        if not handler:
65            LOGGER.debug("Handler not found", handler=handler_func)
66            return HttpResponse(status=400)
67        LOGGER.debug("sending to handler", handler=handler_func)
68        return handler()
def validate_request_token(self, assertion: str) -> dict[str, typing.Any]:
 70    def validate_request_token(self, assertion: str) -> dict[str, Any]:
 71        # Decode without validation to get header
 72        header = get_unverified_header(assertion)
 73        LOGGER.debug("token header", header=header)
 74        expected_kid = header["kid"]
 75
 76        self.device_connection = (
 77            AgentDeviceConnection.objects.filter(apple_sign_key_id=expected_kid)
 78            .select_related("device")
 79            .first()
 80        )
 81        self.connector = AgentConnector.objects.get(pk=self.device_connection.connector.pk)
 82        LOGGER.debug("got device", device=self.device_connection.device)
 83
 84        expected_aud = self.request.build_absolute_uri(
 85            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
 86        )
 87        if not self.device_connection.apple_signing_key:
 88            LOGGER.warning("Failed to issue token for device, no apple_signing_key")
 89            raise ValidationError("Invalid request")
 90        # Properly decode the JWT with the key from the device
 91        decoded = decode(
 92            assertion,
 93            self.device_connection.apple_signing_key,
 94            algorithms=["ES256"],
 95            audience=expected_aud,
 96            issuer=str(self.connector.pk),
 97        )
 98        self.remote_nonce = decoded.get("nonce")
 99
100        # Check that the nonce hasn't been used before
101        nonce = AppleNonce.objects.filter(nonce=decoded["request_nonce"]).first()
102        if not nonce:
103            raise ValidationError("Invalid nonce")
104        self.nonce = nonce
105        nonce.delete()
106        return decoded
108    def validate_embedded_assertion(
109        self, assertion: str
110    ) -> tuple[AgentDeviceUserBinding | AppleIndependentSecureEnclave, dict]:
111        """Decode an embedded assertion and validate it by looking up the matching device user"""
112        decode_unvalidated = get_unverified_header(assertion)
113        expected_kid = decode_unvalidated["kid"]
114
115        device_user = AgentDeviceUserBinding.objects.filter(
116            target=self.device_connection.device, apple_enclave_key_id=expected_kid
117        ).first()
118        if not device_user:
119            independent_user = AppleIndependentSecureEnclave.objects.filter(
120                apple_enclave_key_id=expected_kid
121            ).first()
122            if not independent_user:
123                LOGGER.warning("Could not find device user binding or independent enclave for user")
124                raise ValidationError("Invalid request")
125            device_user = independent_user
126        decoded: dict[str, Any] = decode(
127            assertion,
128            device_user.apple_secure_enclave_key,
129            audience=str(self.device_connection.device.pk),
130            algorithms=["ES256"],
131        )
132        if decoded.get("nonce") != self.jwt_request.get("nonce"):
133            LOGGER.warning("Mis-matched nonce to outer assertion")
134            raise ValidationError("Invalid nonce")
135        return device_user, decoded

Decode an embedded assertion and validate it by looking up the matching device user

def create_auth_session(self, user: authentik.core.models.User):
137    def create_auth_session(self, user: User):
138        event = Event.new(
139            EventAction.LOGIN,
140            app="authentik.endpoints.connectors.agent",
141            **{
142                PLAN_CONTEXT_DEVICE: self.device_connection.device,
143            },
144        ).from_http(self.request, user=user)
145        store = SessionStore()
146        store[SESSION_LOGIN_EVENT] = event
147        store.save()
148        session = Session.objects.filter(session_key=store.session_key).first()
149        session.expires = self.now + timedelta_from_string(self.connector.auth_session_duration)
150        AuthenticatedSession.objects.create(session=session, user=user)
151        session = SessionMiddleware.encode_session(store.session_key, user)
152        return session
def create_id_token(self, user: authentik.core.models.User, **kwargs):
154    def create_id_token(self, user: User, **kwargs):
155        issuer = self.request.build_absolute_uri(
156            reverse("authentik_enterprise_endpoints_connectors_agent:psso-token")
157        )
158        id_token = IDToken(
159            iss=issuer,
160            sub=user.username,
161            aud=str(self.connector.pk),
162            exp=int(
163                (self.now + timedelta_from_string(self.connector.auth_session_duration)).timestamp()
164            ),
165            iat=int(now().timestamp()),
166            **kwargs,
167        )
168        kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
169        return encode(
170            id_token.to_dict(),
171            kp.private_key,
172            headers={
173                "kid": kp.kid,
174            },
175            algorithm=JWTAlgorithms.from_private_key(kp.private_key),
176        )
def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
178    def handle_v1_0_urn_ietf_params_oauth_grant_type_jwt_bearer(self):
179        try:
180            user, inner = self.validate_embedded_assertion(self.jwt_request["assertion"])
181        except PyJWTError as exc:
182            LOGGER.warning("failed to validate inner assertion", exc=exc)
183            raise ValidationError("Invalid request") from None
184        id_token = self.create_id_token(user.user)
185        auth_token = DeviceAuthenticationToken.objects.create(
186            device=self.device_connection.device,
187            connector=self.connector,
188            user=user.user,
189            device_token=self.nonce.device_token,
190        )
191        return JWEResponse(
192            {
193                "refresh_token": auth_token.token,
194                "refresh_token_expires_in": int((auth_token.expires - now()).total_seconds()),
195                "id_token": id_token,
196                "token_type": TOKEN_TYPE,
197                "session_key": self.create_auth_session(user.user),
198            },
199            device=self.device_connection,
200            apv=self.jwt_request["jwe_crypto"]["apv"],
201        )
def dispatch(self, request, *args, **kwargs):
135    def dispatch(self, request, *args, **kwargs):
136        # Try to dispatch to the right method; if a method doesn't exist,
137        # defer to the error handler. Also defer to the error handler if the
138        # request method isn't on the approved list.
139        if request.method.lower() in self.http_method_names:
140            handler = getattr(
141                self, request.method.lower(), self.http_method_not_allowed
142            )
143        else:
144            handler = self.http_method_not_allowed
145        return handler(request, *args, **kwargs)