authentik.endpoints.connectors.agent.auth

  1from typing import Any
  2
  3from django.db.models import Model, Q
  4from django.http import HttpRequest
  5from django.utils.timezone import now
  6from drf_spectacular.extensions import OpenApiAuthenticationExtension
  7from jwt import PyJWTError, decode, encode
  8from rest_framework.authentication import BaseAuthentication, get_authorization_header
  9from rest_framework.exceptions import PermissionDenied
 10from rest_framework.request import Request
 11from structlog.stdlib import get_logger
 12
 13from authentik.api.authentication import VirtualUser, validate_auth
 14from authentik.core.middleware import CTX_AUTH_VIA
 15from authentik.core.models import User
 16from authentik.crypto.apps import MANAGED_KEY
 17from authentik.crypto.models import CertificateKeyPair
 18from authentik.endpoints.connectors.agent.models import AgentConnector, DeviceToken, EnrollmentToken
 19from authentik.endpoints.models import Device
 20from authentik.lib.utils.time import timedelta_from_string
 21from authentik.policies.engine import PolicyEngine
 22from authentik.policies.models import PolicyBindingModel
 23from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
 24
 25LOGGER = get_logger()
 26PLATFORM_ISSUER = "goauthentik.io/platform"
 27
 28
 29class DeviceUser(VirtualUser):
 30
 31    username = "authentik:endpoints:device"
 32
 33    def has_perm(self, perm: str, obj: Model | None = None) -> bool:
 34        if perm in [
 35            "authentik_core.view_user",
 36            "authentik_core.view_group",
 37        ]:
 38            return True
 39        return False
 40
 41
 42class AgentEnrollmentAuth(BaseAuthentication):
 43
 44    def authenticate(self, request: Request) -> tuple[User, Any] | None:
 45        auth = get_authorization_header(request)
 46        key = validate_auth(auth)
 47        token = EnrollmentToken.objects.filter(key=key).first()
 48        if not token:
 49            raise PermissionDenied()
 50        if not token.connector.enabled:
 51            raise PermissionDenied()
 52        CTX_AUTH_VIA.set("endpoint_token_enrollment")
 53        return (DeviceUser(), token)
 54
 55
 56class AgentAuth(BaseAuthentication):
 57
 58    def authenticate(self, request: Request) -> tuple[User, Any] | None:
 59        auth = get_authorization_header(request)
 60        key = validate_auth(auth, format="bearer+agent")
 61        if not key:
 62            return None
 63        device_token = DeviceToken.objects.filter(key=key).first()
 64        if not device_token:
 65            raise PermissionDenied()
 66        if not device_token.device.connector.enabled:
 67            raise PermissionDenied()
 68        if device_token.device.device.is_expired:
 69            raise PermissionDenied()
 70        CTX_AUTH_VIA.set("endpoint_token")
 71        return (DeviceUser(), device_token)
 72
 73
 74def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
 75    kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
 76    if not kp:
 77        return None, None
 78    exp = now() + timedelta_from_string(connector.auth_session_duration)
 79    token = encode(
 80        {
 81            "iss": PLATFORM_ISSUER,
 82            "aud": str(device.pk),
 83            "iat": int(now().timestamp()),
 84            "exp": int(exp.timestamp()),
 85            "preferred_username": user.username,
 86            **kwargs,
 87        },
 88        kp.private_key,
 89        headers={
 90            "kid": kp.kid,
 91        },
 92        algorithm=JWTAlgorithms.from_private_key(kp.private_key),
 93    )
 94    return token, exp
 95
 96
 97class DeviceAuthFedAuthentication(BaseAuthentication):
 98
 99    def authenticate(self, request):
100        raw_token = validate_auth(get_authorization_header(request))
101        if not raw_token:
102            LOGGER.warning("Missing token")
103            return None
104        device = (
105            Device.objects.filter(
106                Q(
107                    name=request.query_params.get("device"),
108                )
109                | Q(
110                    **{
111                        "deviceconnection__devicefactsnapshot__"
112                        "data__vendor__goauthentik.io/platform__"
113                        "ssh_host_keys__contains": request.query_params.get("device"),
114                    }
115                )
116            )
117            .distinct()
118            .first()
119        )
120        if not device:
121            LOGGER.warning("Couldn't find device")
122            return None
123        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
124        connector = connectors_for_device.first()
125        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
126        federated_token = AccessToken.objects.filter(
127            token=raw_token, provider__in=providers
128        ).first()
129        if not federated_token:
130            LOGGER.warning("Couldn't lookup provider")
131            return None
132        _key, _alg = federated_token.provider.jwt_key
133        try:
134            decode(
135                raw_token,
136                _key.public_key(),
137                algorithms=[_alg],
138                options={
139                    "verify_aud": False,
140                },
141            )
142            LOGGER.info(
143                "successfully verified JWT with provider", provider=federated_token.provider.name
144            )
145            return (federated_token.user, (federated_token, device, connector))
146        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
147            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
148            return None
149
150
151class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
152    """Auth schema"""
153
154    target_class = DeviceAuthFedAuthentication
155    name = "device_federation"
156
157    def get_security_definition(self, auto_schema):
158        """Auth schema"""
159        return {"type": "http", "scheme": "bearer"}
160
161
162def check_device_policies(device: Device, user: User, request: HttpRequest):
163    """Check policies bound to device group and device"""
164    if device.access_group:
165        result = check_pbm_policies(device.access_group, user, request)
166        if result.passing:
167            return result
168    return check_pbm_policies(device, user, request)
169
170
171def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
172    policy_engine = PolicyEngine(pbm, user, request)
173    policy_engine.use_cache = False
174    policy_engine.empty_result = False
175    policy_engine.mode = pbm.policy_engine_mode
176    policy_engine.build()
177    result = policy_engine.result
178    LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
179    return result
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
PLATFORM_ISSUER = 'goauthentik.io/platform'
class DeviceUser(authentik.api.authentication.VirtualUser):
30class DeviceUser(VirtualUser):
31
32    username = "authentik:endpoints:device"
33
34    def has_perm(self, perm: str, obj: Model | None = None) -> bool:
35        if perm in [
36            "authentik_core.view_user",
37            "authentik_core.view_group",
38        ]:
39            return True
40        return False
username = 'authentik:endpoints:device'
def has_perm(self, perm: str, obj: django.db.models.base.Model | None = None) -> bool:
34    def has_perm(self, perm: str, obj: Model | None = None) -> bool:
35        if perm in [
36            "authentik_core.view_user",
37            "authentik_core.view_group",
38        ]:
39            return True
40        return False
class AgentEnrollmentAuth(rest_framework.authentication.BaseAuthentication):
43class AgentEnrollmentAuth(BaseAuthentication):
44
45    def authenticate(self, request: Request) -> tuple[User, Any] | None:
46        auth = get_authorization_header(request)
47        key = validate_auth(auth)
48        token = EnrollmentToken.objects.filter(key=key).first()
49        if not token:
50            raise PermissionDenied()
51        if not token.connector.enabled:
52            raise PermissionDenied()
53        CTX_AUTH_VIA.set("endpoint_token_enrollment")
54        return (DeviceUser(), token)

All authentication classes should extend BaseAuthentication.

def authenticate( self, request: rest_framework.request.Request) -> tuple[authentik.core.models.User, Any] | None:
45    def authenticate(self, request: Request) -> tuple[User, Any] | None:
46        auth = get_authorization_header(request)
47        key = validate_auth(auth)
48        token = EnrollmentToken.objects.filter(key=key).first()
49        if not token:
50            raise PermissionDenied()
51        if not token.connector.enabled:
52            raise PermissionDenied()
53        CTX_AUTH_VIA.set("endpoint_token_enrollment")
54        return (DeviceUser(), token)

Authenticate the request and return a two-tuple of (user, token).

class AgentAuth(rest_framework.authentication.BaseAuthentication):
57class AgentAuth(BaseAuthentication):
58
59    def authenticate(self, request: Request) -> tuple[User, Any] | None:
60        auth = get_authorization_header(request)
61        key = validate_auth(auth, format="bearer+agent")
62        if not key:
63            return None
64        device_token = DeviceToken.objects.filter(key=key).first()
65        if not device_token:
66            raise PermissionDenied()
67        if not device_token.device.connector.enabled:
68            raise PermissionDenied()
69        if device_token.device.device.is_expired:
70            raise PermissionDenied()
71        CTX_AUTH_VIA.set("endpoint_token")
72        return (DeviceUser(), device_token)

All authentication classes should extend BaseAuthentication.

def authenticate( self, request: rest_framework.request.Request) -> tuple[authentik.core.models.User, Any] | None:
59    def authenticate(self, request: Request) -> tuple[User, Any] | None:
60        auth = get_authorization_header(request)
61        key = validate_auth(auth, format="bearer+agent")
62        if not key:
63            return None
64        device_token = DeviceToken.objects.filter(key=key).first()
65        if not device_token:
66            raise PermissionDenied()
67        if not device_token.device.connector.enabled:
68            raise PermissionDenied()
69        if device_token.device.device.is_expired:
70            raise PermissionDenied()
71        CTX_AUTH_VIA.set("endpoint_token")
72        return (DeviceUser(), device_token)

Authenticate the request and return a two-tuple of (user, token).

def agent_auth_issue_token( device: authentik.endpoints.models.Device, connector: authentik.endpoints.connectors.agent.models.AgentConnector, user: authentik.core.models.User, **kwargs):
75def agent_auth_issue_token(device: Device, connector: AgentConnector, user: User, **kwargs):
76    kp = CertificateKeyPair.objects.filter(managed=MANAGED_KEY).first()
77    if not kp:
78        return None, None
79    exp = now() + timedelta_from_string(connector.auth_session_duration)
80    token = encode(
81        {
82            "iss": PLATFORM_ISSUER,
83            "aud": str(device.pk),
84            "iat": int(now().timestamp()),
85            "exp": int(exp.timestamp()),
86            "preferred_username": user.username,
87            **kwargs,
88        },
89        kp.private_key,
90        headers={
91            "kid": kp.kid,
92        },
93        algorithm=JWTAlgorithms.from_private_key(kp.private_key),
94    )
95    return token, exp
class DeviceAuthFedAuthentication(rest_framework.authentication.BaseAuthentication):
 98class DeviceAuthFedAuthentication(BaseAuthentication):
 99
100    def authenticate(self, request):
101        raw_token = validate_auth(get_authorization_header(request))
102        if not raw_token:
103            LOGGER.warning("Missing token")
104            return None
105        device = (
106            Device.objects.filter(
107                Q(
108                    name=request.query_params.get("device"),
109                )
110                | Q(
111                    **{
112                        "deviceconnection__devicefactsnapshot__"
113                        "data__vendor__goauthentik.io/platform__"
114                        "ssh_host_keys__contains": request.query_params.get("device"),
115                    }
116                )
117            )
118            .distinct()
119            .first()
120        )
121        if not device:
122            LOGGER.warning("Couldn't find device")
123            return None
124        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
125        connector = connectors_for_device.first()
126        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
127        federated_token = AccessToken.objects.filter(
128            token=raw_token, provider__in=providers
129        ).first()
130        if not federated_token:
131            LOGGER.warning("Couldn't lookup provider")
132            return None
133        _key, _alg = federated_token.provider.jwt_key
134        try:
135            decode(
136                raw_token,
137                _key.public_key(),
138                algorithms=[_alg],
139                options={
140                    "verify_aud": False,
141                },
142            )
143            LOGGER.info(
144                "successfully verified JWT with provider", provider=federated_token.provider.name
145            )
146            return (federated_token.user, (federated_token, device, connector))
147        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
148            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
149            return None

All authentication classes should extend BaseAuthentication.

def authenticate(self, request):
100    def authenticate(self, request):
101        raw_token = validate_auth(get_authorization_header(request))
102        if not raw_token:
103            LOGGER.warning("Missing token")
104            return None
105        device = (
106            Device.objects.filter(
107                Q(
108                    name=request.query_params.get("device"),
109                )
110                | Q(
111                    **{
112                        "deviceconnection__devicefactsnapshot__"
113                        "data__vendor__goauthentik.io/platform__"
114                        "ssh_host_keys__contains": request.query_params.get("device"),
115                    }
116                )
117            )
118            .distinct()
119            .first()
120        )
121        if not device:
122            LOGGER.warning("Couldn't find device")
123            return None
124        connectors_for_device = AgentConnector.objects.filter(device__in=[device])
125        connector = connectors_for_device.first()
126        providers = OAuth2Provider.objects.filter(agentconnector__in=connectors_for_device)
127        federated_token = AccessToken.objects.filter(
128            token=raw_token, provider__in=providers
129        ).first()
130        if not federated_token:
131            LOGGER.warning("Couldn't lookup provider")
132            return None
133        _key, _alg = federated_token.provider.jwt_key
134        try:
135            decode(
136                raw_token,
137                _key.public_key(),
138                algorithms=[_alg],
139                options={
140                    "verify_aud": False,
141                },
142            )
143            LOGGER.info(
144                "successfully verified JWT with provider", provider=federated_token.provider.name
145            )
146            return (federated_token.user, (federated_token, device, connector))
147        except (PyJWTError, ValueError, TypeError, AttributeError) as exc:
148            LOGGER.warning("failed to verify JWT", exc=exc, provider=federated_token.provider.name)
149            return None

Authenticate the request and return a two-tuple of (user, token).

class DeviceFederationAuthSchema(drf_spectacular.plumbing.OpenApiGeneratorExtension[ForwardRef('OpenApiAuthenticationExtension')]):
152class DeviceFederationAuthSchema(OpenApiAuthenticationExtension):
153    """Auth schema"""
154
155    target_class = DeviceAuthFedAuthentication
156    name = "device_federation"
157
158    def get_security_definition(self, auto_schema):
159        """Auth schema"""
160        return {"type": "http", "scheme": "bearer"}

Auth schema

target_class = <class 'DeviceAuthFedAuthentication'>
name = 'device_federation'
def get_security_definition(self, auto_schema):
158    def get_security_definition(self, auto_schema):
159        """Auth schema"""
160        return {"type": "http", "scheme": "bearer"}

Auth schema

def check_device_policies( device: authentik.endpoints.models.Device, user: authentik.core.models.User, request: django.http.request.HttpRequest):
163def check_device_policies(device: Device, user: User, request: HttpRequest):
164    """Check policies bound to device group and device"""
165    if device.access_group:
166        result = check_pbm_policies(device.access_group, user, request)
167        if result.passing:
168            return result
169    return check_pbm_policies(device, user, request)

Check policies bound to device group and device

def check_pbm_policies( pbm: authentik.policies.models.PolicyBindingModel, user: authentik.core.models.User, request: django.http.request.HttpRequest):
172def check_pbm_policies(pbm: PolicyBindingModel, user: User, request: HttpRequest):
173    policy_engine = PolicyEngine(pbm, user, request)
174    policy_engine.use_cache = False
175    policy_engine.empty_result = False
176    policy_engine.mode = pbm.policy_engine_mode
177    policy_engine.build()
178    result = policy_engine.result
179    LOGGER.debug("PolicyAccessView user_has_access", user=user.username, result=result, pbm=pbm.pk)
180    return result