authentik.endpoints.connectors.agent.auth

  1from typing import Any
  2
  3from django.http import HttpRequest
  4from django.utils.timezone import now
  5from drf_spectacular.extensions import OpenApiAuthenticationExtension
  6from jwt import PyJWTError, decode, encode
  7from rest_framework.authentication import BaseAuthentication, get_authorization_header
  8from rest_framework.exceptions import PermissionDenied
  9from rest_framework.request import Request
 10from structlog.stdlib import get_logger
 11
 12from authentik.api.authentication import IPCUser, validate_auth
 13from authentik.core.middleware import CTX_AUTH_VIA
 14from authentik.core.models import User
 15from authentik.crypto.apps import MANAGED_KEY
 16from authentik.crypto.models import CertificateKeyPair
 17from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceToken, EnrollmentToken
 18from authentik.endpoints.models import Device
 19from authentik.lib.utils.time import timedelta_from_string
 20from authentik.policies.engine import PolicyEngine
 21from authentik.policies.models import PolicyBindingModel
 22from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
 23
 24LOGGER = get_logger()
 25PLATFORM_ISSUER = "goauthentik.io/platform"
 26
 27
 28class DeviceUser(IPCUser):
 29    username = "authentik:endpoints:device"
 30
 31
 32class AgentEnrollmentAuth(BaseAuthentication):
 33
 34    def authenticate(self, request: Request) -> tuple[User, Any] | None:
 35        auth = get_authorization_header(request)
 36        key = validate_auth(auth)
 37        token = EnrollmentToken.objects.filter(key=key).first()
 38        if not token:
 39            raise PermissionDenied()
 40        if not token.connector.enabled:
 41            raise PermissionDenied()
 42        CTX_AUTH_VIA.set("endpoint_token_enrollment")
 43        return (DeviceUser(), token)
 44
 45
 46class AgentAuth(BaseAuthentication):
 47
 48    def authenticate(self, request: Request) -> tuple[User, Any] | None:
 49        auth = get_authorization_header(request)
 50        key = validate_auth(auth, format="bearer+agent")
 51        if not key:
 52            return None
 53        device_token = DeviceToken.objects.filter(key=key).first()
 54        if not device_token:
 55            raise PermissionDenied()
 56        if not device_token.device.connector.enabled:
 57            raise PermissionDenied()
 58        if device_token.device.device.is_expired:
 59            raise PermissionDenied()
 60        CTX_AUTH_VIA.set("endpoint_token")
 61        return (DeviceUser(), device_token)
 62
 63
 64def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
 65    kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
 66    if not kp:
 67        return None, None
 68    exp = now() + timedelta_from_string(connector.auth_session_duration)
 69    token = encode(
 70        {
 71            "iss": PLATFORM_ISSUER,
 72            "aud": str(device.pk),
 73            "iat": int(now().timestamp()),
 74            "exp": int(exp.timestamp()),
 75            "preferred_username": user.username,
 76            **kwargs,
 77        },
 78        kp.private_key,
 79        headers={
 80            "kid": kp.kid,
 81        },
 82        algorithm=JWTAlgorithms.from_private_key(kp.private_key),
 83    )
 84    return token, exp
 85
 86
 87class DeviceAuthFedAuthentication(BaseAuthentication):
 88
 89    def authenticate(self, request):
 90        raw_token = validate_auth(get_authorization_header(request))
 91        if not raw_token:
 92            LOGGER.warning("Missing token")
 93            return None
 94        device = Device.objects.filter(name=request.query_params.get("device")).first()
 95        if not device:
 96            LOGGER.warning("Couldn't find device")
 97            return None
 98        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
 99        connector = connectors_for_device.first()
100        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
101        federated_token = AccessToken.objects.filter(
102            token=raw_token, provider__in=providers
103        ).first()
104        if not federated_token:
105            LOGGER.warning("Couldn't lookup provider")
106            return None
107        _key, _alg = federated_token.provider.jwt_key
108        try:
109            decode(
110                raw_token,
111                _key.public_key(),
112                algorithms=[_alg],
113                options={
114                    "verify_aud": False,
115                },
116            )
117            LOGGER.info(
118                "successfully verified JWT with provider", provider=federated_token.provider.name
119            )
120            return (federated_token.user, (federated_token, device, connector))
121        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
122            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
123            return None
124
125
126class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
127    """Auth schema"""
128
129    target_class = DeviceAuthFedAuthentication
130    name = "device_federation"
131
132    def get_security_definition(self, auto_schema):
133        """Auth schema"""
134        return {"type": "http", "scheme": "bearer"}
135
136
137def check_device_policies(device: Device, user: User, request: HttpRequest):
138    """Check policies bound to device group and device"""
139    if device.access_group:
140        result = check_pbm_policies(device.access_group, user, request)
141        if result.passing:
142            return result
143    return check_pbm_policies(device, user, request)
144
145
146def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
147    policy_engine = PolicyEngine(pbm, user, request)
148    policy_engine.use_cache = False
149    policy_engine.empty_result = False
150    policy_engine.mode = pbm.policy_engine_mode
151    policy_engine.build()
152    result = policy_engine.result
153    LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
154    return result
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
PLATFORM_ISSUER = 'goauthentik.io/platform'
class DeviceUser(authentik.api.authentication.IPCUser):
29class DeviceUser(IPCUser):
30    username = "authentik:endpoints:device"

'Virtual' user for IPC communication between authentik core and the authentik router

username = 'authentik:endpoints:device'
class AgentEnrollmentAuth(rest_framework.authentication.BaseAuthentication):
33class AgentEnrollmentAuth(BaseAuthentication):
34
35    def authenticate(self, request: Request) -> tuple[User, Any] | None:
36        auth = get_authorization_header(request)
37        key = validate_auth(auth)
38        token = EnrollmentToken.objects.filter(key=key).first()
39        if not token:
40            raise PermissionDenied()
41        if not token.connector.enabled:
42            raise PermissionDenied()
43        CTX_AUTH_VIA.set("endpoint_token_enrollment")
44        return (DeviceUser(), token)

All authentication classes should extend BaseAuthentication.

def authenticate( self, request: rest_framework.request.Request) -> tuple[authentik.core.models.User, Any] | None:
35    def authenticate(self, request: Request) -> tuple[User, Any] | None:
36        auth = get_authorization_header(request)
37        key = validate_auth(auth)
38        token = EnrollmentToken.objects.filter(key=key).first()
39        if not token:
40            raise PermissionDenied()
41        if not token.connector.enabled:
42            raise PermissionDenied()
43        CTX_AUTH_VIA.set("endpoint_token_enrollment")
44        return (DeviceUser(), token)

Authenticate the request and return a two-tuple of (user, token).

class AgentAuth(rest_framework.authentication.BaseAuthentication):
47class AgentAuth(BaseAuthentication):
48
49    def authenticate(self, request: Request) -> tuple[User, Any] | None:
50        auth = get_authorization_header(request)
51        key = validate_auth(auth, format="bearer+agent")
52        if not key:
53            return None
54        device_token = DeviceToken.objects.filter(key=key).first()
55        if not device_token:
56            raise PermissionDenied()
57        if not device_token.device.connector.enabled:
58            raise PermissionDenied()
59        if device_token.device.device.is_expired:
60            raise PermissionDenied()
61        CTX_AUTH_VIA.set("endpoint_token")
62        return (DeviceUser(), device_token)

All authentication classes should extend BaseAuthentication.

def authenticate( self, request: rest_framework.request.Request) -> tuple[authentik.core.models.User, Any] | None:
49    def authenticate(self, request: Request) -> tuple[User, Any] | None:
50        auth = get_authorization_header(request)
51        key = validate_auth(auth, format="bearer+agent")
52        if not key:
53            return None
54        device_token = DeviceToken.objects.filter(key=key).first()
55        if not device_token:
56            raise PermissionDenied()
57        if not device_token.device.connector.enabled:
58            raise PermissionDenied()
59        if device_token.device.device.is_expired:
60            raise PermissionDenied()
61        CTX_AUTH_VIA.set("endpoint_token")
62        return (DeviceUser(), device_token)

Authenticate the request and return a two-tuple of (user, token).

def agent_auth_issue_token( device: authentik.endpoints.models.Device, connector: authentik.endpoints.connectors.agent.models.AgentConnector, user: authentik.core.models.User, **kwargs):
65def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
66    kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
67    if not kp:
68        return None, None
69    exp = now() + timedelta_from_string(connector.auth_session_duration)
70    token = encode(
71        {
72            "iss": PLATFORM_ISSUER,
73            "aud": str(device.pk),
74            "iat": int(now().timestamp()),
75            "exp": int(exp.timestamp()),
76            "preferred_username": user.username,
77            **kwargs,
78        },
79        kp.private_key,
80        headers={
81            "kid": kp.kid,
82        },
83        algorithm=JWTAlgorithms.from_private_key(kp.private_key),
84    )
85    return token, exp
class DeviceAuthFedAuthentication(rest_framework.authentication.BaseAuthentication):
 88class DeviceAuthFedAuthentication(BaseAuthentication):
 89
 90    def authenticate(self, request):
 91        raw_token = validate_auth(get_authorization_header(request))
 92        if not raw_token:
 93            LOGGER.warning("Missing token")
 94            return None
 95        device = Device.objects.filter(name=request.query_params.get("device")).first()
 96        if not device:
 97            LOGGER.warning("Couldn't find device")
 98            return None
 99        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
100        connector = connectors_for_device.first()
101        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
102        federated_token = AccessToken.objects.filter(
103            token=raw_token, provider__in=providers
104        ).first()
105        if not federated_token:
106            LOGGER.warning("Couldn't lookup provider")
107            return None
108        _key, _alg = federated_token.provider.jwt_key
109        try:
110            decode(
111                raw_token,
112                _key.public_key(),
113                algorithms=[_alg],
114                options={
115                    "verify_aud": False,
116                },
117            )
118            LOGGER.info(
119                "successfully verified JWT with provider", provider=federated_token.provider.name
120            )
121            return (federated_token.user, (federated_token, device, connector))
122        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
123            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
124            return None

All authentication classes should extend BaseAuthentication.

def authenticate(self, request):
 90    def authenticate(self, request):
 91        raw_token = validate_auth(get_authorization_header(request))
 92        if not raw_token:
 93            LOGGER.warning("Missing token")
 94            return None
 95        device = Device.objects.filter(name=request.query_params.get("device")).first()
 96        if not device:
 97            LOGGER.warning("Couldn't find device")
 98            return None
 99        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
100        connector = connectors_for_device.first()
101        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
102        federated_token = AccessToken.objects.filter(
103            token=raw_token, provider__in=providers
104        ).first()
105        if not federated_token:
106            LOGGER.warning("Couldn't lookup provider")
107            return None
108        _key, _alg = federated_token.provider.jwt_key
109        try:
110            decode(
111                raw_token,
112                _key.public_key(),
113                algorithms=[_alg],
114                options={
115                    "verify_aud": False,
116                },
117            )
118            LOGGER.info(
119                "successfully verified JWT with provider", provider=federated_token.provider.name
120            )
121            return (federated_token.user, (federated_token, device, connector))
122        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
123            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
124            return None

Authenticate the request and return a two-tuple of (user, token).

class DeviceFederationAuthSchema(drf_spectacular.plumbing.OpenApiGeneratorExtension[ForwardRef('OpenApiAuthenticationExtension')]):
127class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
128    """Auth schema"""
129
130    target_class = DeviceAuthFedAuthentication
131    name = "device_federation"
132
133    def get_security_definition(self, auto_schema):
134        """Auth schema"""
135        return {"type": "http", "scheme": "bearer"}

Auth schema

target_class = <class 'DeviceAuthFedAuthentication'>
name = 'device_federation'
def get_security_definition(self, auto_schema):
133    def get_security_definition(self, auto_schema):
134        """Auth schema"""
135        return {"type": "http", "scheme": "bearer"}

Auth schema

def check_device_policies( device: authentik.endpoints.models.Device, user: authentik.core.models.User, request: django.http.request.HttpRequest):
138def check_device_policies(device: Device, user: User, request: HttpRequest):
139    """Check policies bound to device group and device"""
140    if device.access_group:
141        result = check_pbm_policies(device.access_group, user, request)
142        if result.passing:
143            return result
144    return check_pbm_policies(device, user, request)

Check policies bound to device group and device

def check_pbm_policies( pbm: authentik.policies.models.PolicyBindingModel, user: authentik.core.models.User, request: django.http.request.HttpRequest):
147def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
148    policy_engine = PolicyEngine(pbm, user, request)
149    policy_engine.use_cache = False
150    policy_engine.empty_result = False
151    policy_engine.mode = pbm.policy_engine_mode
152    policy_engine.build()
153    result = policy_engine.result
154    LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
155    return result