authentik.enterprise.providers.microsoft_entra.clients.base

  1from asyncio import run
  2from collections.abc import Coroutine
  3from dataclasses import asdict
  4from typing import Any
  5
  6import httpx
  7from azure.core.exceptions import (
  8    ClientAuthenticationError,
  9    ServiceRequestError,
 10    ServiceResponseError,
 11)
 12from azure.identity.aio import ClientSecretCredential
 13from django.db.models import Model
 14from django.http import HttpResponseBadRequest, HttpResponseNotFound
 15from kiota_abstractions.api_error import APIError
 16from kiota_abstractions.request_information import RequestInformation
 17from kiota_authentication_azure.azure_identity_authentication_provider import (
 18    AzureIdentityAuthenticationProvider,
 19)
 20from kiota_http.kiota_client_factory import KiotaClientFactory
 21from msgraph.generated.models.entity import Entity
 22from msgraph.generated.models.o_data_errors.o_data_error import ODataError
 23from msgraph.graph_request_adapter import GraphRequestAdapter, options
 24from msgraph.graph_service_client import GraphServiceClient
 25from msgraph_core import GraphClientFactory
 26from opentelemetry import trace
 27
 28from authentik.enterprise.providers.microsoft_entra.models import MicrosoftEntraProvider
 29from authentik.events.utils import sanitize_item
 30from authentik.lib.sync.outgoing import HTTP_CONFLICT
 31from authentik.lib.sync.outgoing.base import SAFE_METHODS, BaseOutgoingSyncClient
 32from authentik.lib.sync.outgoing.exceptions import (
 33    BadRequestSyncException,
 34    DryRunRejected,
 35    NotFoundSyncException,
 36    ObjectExistsSyncException,
 37    StopSync,
 38    TransientSyncException,
 39)
 40
 41
 42class AuthentikRequestAdapter(GraphRequestAdapter):
 43    def __init__(self, auth_provider, provider: MicrosoftEntraProvider, client=None):
 44        super().__init__(auth_provider, client)
 45        self._provider = provider
 46
 47    async def get_http_response_message(
 48        self,
 49        request_info: RequestInformation,
 50        parent_span: trace.Span,
 51        claims: str = "",
 52    ) -> httpx.Response:
 53        if self._provider.dry_run and request_info.http_method.value.upper() not in SAFE_METHODS:
 54            raise DryRunRejected(
 55                url=request_info.url,
 56                method=request_info.http_method.value,
 57                body=request_info.content.decode("utf-8"),
 58            )
 59        return await super().get_http_response_message(request_info, parent_span, claims=claims)
 60
 61
 62class MicrosoftEntraSyncClient[TModel: Model, TConnection: Model, TSchema: dict](
 63    BaseOutgoingSyncClient[TModel, TConnection, TSchema, MicrosoftEntraProvider]
 64):
 65    """Base client for syncing to microsoft entra"""
 66
 67    domains: list
 68
 69    def __init__(self, provider: MicrosoftEntraProvider) -> None:
 70        super().__init__(provider)
 71        self.credentials = provider.microsoft_credentials()
 72        self.__prefetch_domains()
 73
 74    def get_request_adapter(
 75        self, credentials: ClientSecretCredential, scopes: list[str] | None = None
 76    ) -> AuthentikRequestAdapter:
 77        if scopes:
 78            auth_provider = AzureIdentityAuthenticationProvider(
 79                credentials=credentials, scopes=scopes
 80            )
 81        else:
 82            auth_provider = AzureIdentityAuthenticationProvider(credentials=credentials)
 83
 84        return AuthentikRequestAdapter(
 85            auth_provider=auth_provider,
 86            provider=self.provider,
 87            client=GraphClientFactory.create_with_default_middleware(
 88                options=options, client=KiotaClientFactory.get_default_client()
 89            ),
 90        )
 91
 92    @property
 93    def client(self):
 94        return GraphServiceClient(request_adapter=self.get_request_adapter(**self.credentials))
 95
 96    def _request[T](self, request: Coroutine[Any, Any, T]) -> T:
 97        try:
 98            return run(request)
 99        except ClientAuthenticationError as exc:
100            raise StopSync(exc, None, None) from exc
101        except ODataError as exc:
102            raise StopSync(exc, None, None) from exc
103        except (ServiceRequestError, ServiceResponseError) as exc:
104            raise TransientSyncException("Failed to sent request") from exc
105        except APIError as exc:
106            if exc.response_status_code == HttpResponseNotFound.status_code:
107                raise NotFoundSyncException("Object not found") from exc
108            if exc.response_status_code == HttpResponseBadRequest.status_code:
109                raise BadRequestSyncException("Bad request", exc.response_headers) from exc
110            if exc.response_status_code == HTTP_CONFLICT:
111                raise ObjectExistsSyncException("Object exists", exc.response_headers) from exc
112            raise exc
113
114    def __prefetch_domains(self):
115        self.domains = []
116        organizations = self._request(self.client.organization.get())
117        next_link = True
118        while next_link:
119            for org in organizations.value:
120                self.domains.extend([x.name for x in org.verified_domains])
121            next_link = organizations.odata_next_link
122            if not next_link:
123                break
124            organizations = self._request(self.client.organization.with_url(next_link).get())
125
126    def check_email_valid(self, *emails: str):
127        for email in emails:
128            if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
129                raise BadRequestSyncException(f"Invalid email domain: {email}")
130
131    def entity_as_dict(self, entity: Entity) -> dict:
132        """Create a dictionary of a model instance, making sure to remove (known) things
133        we can't JSON serialize"""
134        raw_data = asdict(entity)
135        raw_data.pop("backing_store", None)
136        return sanitize_item(raw_data)
class AuthentikRequestAdapter(abc.ABC, typing.Generic[~RequestType]):
43class AuthentikRequestAdapter(GraphRequestAdapter):
44    def __init__(self, auth_provider, provider: MicrosoftEntraProvider, client=None):
45        super().__init__(auth_provider, client)
46        self._provider = provider
47
48    async def get_http_response_message(
49        self,
50        request_info: RequestInformation,
51        parent_span: trace.Span,
52        claims: str = "",
53    ) -> httpx.Response:
54        if self._provider.dry_run and request_info.http_method.value.upper() not in SAFE_METHODS:
55            raise DryRunRejected(
56                url=request_info.url,
57                method=request_info.http_method.value,
58                body=request_info.content.decode("utf-8"),
59            )
60        return await super().get_http_response_message(request_info, parent_span, claims=claims)

Service responsible for translating abstract Request Info into concrete native HTTP requests.

AuthentikRequestAdapter( auth_provider, provider: authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider, client=None)
44    def __init__(self, auth_provider, provider: MicrosoftEntraProvider, client=None):
45        super().__init__(auth_provider, client)
46        self._provider = provider
async def get_http_response_message( self, request_info: kiota_abstractions.request_information.RequestInformation, parent_span: opentelemetry.trace.span.Span, claims: str = '') -> httpx.Response:
48    async def get_http_response_message(
49        self,
50        request_info: RequestInformation,
51        parent_span: trace.Span,
52        claims: str = "",
53    ) -> httpx.Response:
54        if self._provider.dry_run and request_info.http_method.value.upper() not in SAFE_METHODS:
55            raise DryRunRejected(
56                url=request_info.url,
57                method=request_info.http_method.value,
58                body=request_info.content.decode("utf-8"),
59            )
60        return await super().get_http_response_message(request_info, parent_span, claims=claims)
 63class MicrosoftEntraSyncClient[TModel: Model, TConnection: Model, TSchema: dict](
 64    BaseOutgoingSyncClient[TModel, TConnection, TSchema, MicrosoftEntraProvider]
 65):
 66    """Base client for syncing to microsoft entra"""
 67
 68    domains: list
 69
 70    def __init__(self, provider: MicrosoftEntraProvider) -> None:
 71        super().__init__(provider)
 72        self.credentials = provider.microsoft_credentials()
 73        self.__prefetch_domains()
 74
 75    def get_request_adapter(
 76        self, credentials: ClientSecretCredential, scopes: list[str] | None = None
 77    ) -> AuthentikRequestAdapter:
 78        if scopes:
 79            auth_provider = AzureIdentityAuthenticationProvider(
 80                credentials=credentials, scopes=scopes
 81            )
 82        else:
 83            auth_provider = AzureIdentityAuthenticationProvider(credentials=credentials)
 84
 85        return AuthentikRequestAdapter(
 86            auth_provider=auth_provider,
 87            provider=self.provider,
 88            client=GraphClientFactory.create_with_default_middleware(
 89                options=options, client=KiotaClientFactory.get_default_client()
 90            ),
 91        )
 92
 93    @property
 94    def client(self):
 95        return GraphServiceClient(request_adapter=self.get_request_adapter(**self.credentials))
 96
 97    def _request[T](self, request: Coroutine[Any, Any, T]) -> T:
 98        try:
 99            return run(request)
100        except ClientAuthenticationError as exc:
101            raise StopSync(exc, None, None) from exc
102        except ODataError as exc:
103            raise StopSync(exc, None, None) from exc
104        except (ServiceRequestError, ServiceResponseError) as exc:
105            raise TransientSyncException("Failed to sent request") from exc
106        except APIError as exc:
107            if exc.response_status_code == HttpResponseNotFound.status_code:
108                raise NotFoundSyncException("Object not found") from exc
109            if exc.response_status_code == HttpResponseBadRequest.status_code:
110                raise BadRequestSyncException("Bad request", exc.response_headers) from exc
111            if exc.response_status_code == HTTP_CONFLICT:
112                raise ObjectExistsSyncException("Object exists", exc.response_headers) from exc
113            raise exc
114
115    def __prefetch_domains(self):
116        self.domains = []
117        organizations = self._request(self.client.organization.get())
118        next_link = True
119        while next_link:
120            for org in organizations.value:
121                self.domains.extend([x.name for x in org.verified_domains])
122            next_link = organizations.odata_next_link
123            if not next_link:
124                break
125            organizations = self._request(self.client.organization.with_url(next_link).get())
126
127    def check_email_valid(self, *emails: str):
128        for email in emails:
129            if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
130                raise BadRequestSyncException(f"Invalid email domain: {email}")
131
132    def entity_as_dict(self, entity: Entity) -> dict:
133        """Create a dictionary of a model instance, making sure to remove (known) things
134        we can't JSON serialize"""
135        raw_data = asdict(entity)
136        raw_data.pop("backing_store", None)
137        return sanitize_item(raw_data)

Base client for syncing to microsoft entra

MicrosoftEntraSyncClient( provider: authentik.enterprise.providers.microsoft_entra.models.MicrosoftEntraProvider)
70    def __init__(self, provider: MicrosoftEntraProvider) -> None:
71        super().__init__(provider)
72        self.credentials = provider.microsoft_credentials()
73        self.__prefetch_domains()
domains: list
credentials
def get_request_adapter( self, credentials: azure.identity.aio._credentials.client_secret.ClientSecretCredential, scopes: list[str] | None = None) -> AuthentikRequestAdapter:
75    def get_request_adapter(
76        self, credentials: ClientSecretCredential, scopes: list[str] | None = None
77    ) -> AuthentikRequestAdapter:
78        if scopes:
79            auth_provider = AzureIdentityAuthenticationProvider(
80                credentials=credentials, scopes=scopes
81            )
82        else:
83            auth_provider = AzureIdentityAuthenticationProvider(credentials=credentials)
84
85        return AuthentikRequestAdapter(
86            auth_provider=auth_provider,
87            provider=self.provider,
88            client=GraphClientFactory.create_with_default_middleware(
89                options=options, client=KiotaClientFactory.get_default_client()
90            ),
91        )
client
93    @property
94    def client(self):
95        return GraphServiceClient(request_adapter=self.get_request_adapter(**self.credentials))
def check_email_valid(self, *emails: str):
127    def check_email_valid(self, *emails: str):
128        for email in emails:
129            if not any(email.endswith(f"@{domain_name}") for domain_name in self.domains):
130                raise BadRequestSyncException(f"Invalid email domain: {email}")
def entity_as_dict(self, entity: msgraph.generated.models.entity.Entity) -> dict:
132    def entity_as_dict(self, entity: Entity) -> dict:
133        """Create a dictionary of a model instance, making sure to remove (known) things
134        we can't JSON serialize"""
135        raw_data = asdict(entity)
136        raw_data.pop("backing_store", None)
137        return sanitize_item(raw_data)

Create a dictionary of a model instance, making sure to remove (known) things we can't JSON serialize