authentik.sources.oauth.types.registry

Source type manager

  1"""Source type manager"""
  2
  3from collections.abc import Callable
  4from enum import Enum
  5from typing import Any
  6
  7from django.http.request import HttpRequest
  8from django.templatetags.static import static
  9from django.urls.base import reverse
 10from structlog.stdlib import get_logger
 11
 12from authentik.flows.challenge import Challenge, RedirectChallenge
 13from authentik.sources.oauth.models import AuthorizationCodeAuthMethod, OAuthSource, PKCEMethod
 14from authentik.sources.oauth.views.callback import OAuthCallback
 15from authentik.sources.oauth.views.redirect import OAuthRedirect
 16
 17LOGGER = get_logger()
 18
 19
 20class RequestKind(Enum):
 21    """Enum of OAuth Request types"""
 22
 23    CALLBACK = "callback"
 24    REDIRECT = "redirect"
 25
 26
 27class SourceType:
 28    """Source type, allows overriding of urls and views per type"""
 29
 30    callback_view = OAuthCallback
 31    redirect_view = OAuthRedirect
 32    name: str = "default"
 33    verbose_name: str = "Default source type"
 34
 35    urls_customizable = False
 36
 37    request_token_url: str | None = None
 38    authorization_url: str | None = None
 39    access_token_url: str | None = None
 40    profile_url: str | None = None
 41    oidc_well_known_url: str | None = None
 42    oidc_jwks_url: str | None = None
 43    pkce: PKCEMethod = PKCEMethod.NONE
 44
 45    authorization_code_auth_method: AuthorizationCodeAuthMethod = (
 46        AuthorizationCodeAuthMethod.BASIC_AUTH
 47    )
 48
 49    def icon_url(self) -> str:
 50        """Get Icon URL for login"""
 51        return static(f"authentik/sources/{self.name}.svg")
 52
 53    def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
 54        """Allow types to return custom challenges"""
 55        return RedirectChallenge(
 56            data={
 57                "to": reverse(
 58                    "authentik_sources_oauth:oauth-client-login",
 59                    kwargs={"source_slug": source.slug},
 60                ),
 61            }
 62        )
 63
 64    def get_base_user_properties(
 65        self, source: OAuthSource, info: dict[str, Any], **kwargs
 66    ) -> dict[str, Any | dict[str, Any]]:
 67        """Get base user properties for enrollment/update"""
 68        return info
 69
 70    def get_base_group_properties(
 71        self, source: OAuthSource, group_id: str, **kwargs
 72    ) -> dict[str, Any | dict[str, Any]]:
 73        """Get base group properties for creation/update"""
 74        return {
 75            "name": group_id,
 76        }
 77
 78
 79class SourceTypeRegistry:
 80    """Registry to hold all Source types."""
 81
 82    def __init__(self) -> None:
 83        self.__sources: list[type[SourceType]] = []
 84
 85    def register(self):
 86        """Class decorator to register classes inline."""
 87
 88        def inner_wrapper(cls):
 89            self.__sources.append(cls)
 90            return cls
 91
 92        return inner_wrapper
 93
 94    def get(self):
 95        """Get a list of all source types"""
 96        return self.__sources
 97
 98    def get_name_tuple(self):
 99        """Get list of tuples of all registered names"""
100        return [(x.name, x.verbose_name) for x in self.__sources]
101
102    def find_type(self, type_name: str) -> type[SourceType]:
103        """Find type based on source"""
104        found_type = None
105        for src_type in self.__sources:
106            if src_type.name == type_name:
107                return src_type
108        if not found_type:
109            found_type = SourceType
110            LOGGER.warning(
111                "no matching type found, using default",
112                wanted=type_name,
113                have=[x.name for x in self.__sources],
114            )
115        return found_type
116
117    def find(self, type_name: str, kind: RequestKind) -> Callable:
118        """Find fitting Source Type"""
119        found_type = self.find_type(type_name)
120        if kind == RequestKind.CALLBACK:
121            return found_type.callback_view
122        if kind == RequestKind.REDIRECT:
123            return found_type.redirect_view
124        raise ValueError
125
126
127registry = SourceTypeRegistry()
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
class RequestKind(enum.Enum):
21class RequestKind(Enum):
22    """Enum of OAuth Request types"""
23
24    CALLBACK = "callback"
25    REDIRECT = "redirect"

Enum of OAuth Request types

CALLBACK = <RequestKind.CALLBACK: 'callback'>
REDIRECT = <RequestKind.REDIRECT: 'redirect'>
class SourceType:
28class SourceType:
29    """Source type, allows overriding of urls and views per type"""
30
31    callback_view = OAuthCallback
32    redirect_view = OAuthRedirect
33    name: str = "default"
34    verbose_name: str = "Default source type"
35
36    urls_customizable = False
37
38    request_token_url: str | None = None
39    authorization_url: str | None = None
40    access_token_url: str | None = None
41    profile_url: str | None = None
42    oidc_well_known_url: str | None = None
43    oidc_jwks_url: str | None = None
44    pkce: PKCEMethod = PKCEMethod.NONE
45
46    authorization_code_auth_method: AuthorizationCodeAuthMethod = (
47        AuthorizationCodeAuthMethod.BASIC_AUTH
48    )
49
50    def icon_url(self) -> str:
51        """Get Icon URL for login"""
52        return static(f"authentik/sources/{self.name}.svg")
53
54    def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
55        """Allow types to return custom challenges"""
56        return RedirectChallenge(
57            data={
58                "to": reverse(
59                    "authentik_sources_oauth:oauth-client-login",
60                    kwargs={"source_slug": source.slug},
61                ),
62            }
63        )
64
65    def get_base_user_properties(
66        self, source: OAuthSource, info: dict[str, Any], **kwargs
67    ) -> dict[str, Any | dict[str, Any]]:
68        """Get base user properties for enrollment/update"""
69        return info
70
71    def get_base_group_properties(
72        self, source: OAuthSource, group_id: str, **kwargs
73    ) -> dict[str, Any | dict[str, Any]]:
74        """Get base group properties for creation/update"""
75        return {
76            "name": group_id,
77        }

Source type, allows overriding of urls and views per type

name: str = 'default'
verbose_name: str = 'Default source type'
urls_customizable = False
request_token_url: str | None = None
authorization_url: str | None = None
access_token_url: str | None = None
profile_url: str | None = None
oidc_well_known_url: str | None = None
oidc_jwks_url: str | None = None
authorization_code_auth_method: authentik.sources.oauth.models.AuthorizationCodeAuthMethod = AuthorizationCodeAuthMethod.BASIC_AUTH
def icon_url(self) -> str:
50    def icon_url(self) -> str:
51        """Get Icon URL for login"""
52        return static(f"authentik/sources/{self.name}.svg")

Get Icon URL for login

def login_challenge( self, source: authentik.sources.oauth.models.OAuthSource, request: django.http.request.HttpRequest) -> authentik.flows.challenge.Challenge:
54    def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
55        """Allow types to return custom challenges"""
56        return RedirectChallenge(
57            data={
58                "to": reverse(
59                    "authentik_sources_oauth:oauth-client-login",
60                    kwargs={"source_slug": source.slug},
61                ),
62            }
63        )

Allow types to return custom challenges

def get_base_user_properties( self, source: authentik.sources.oauth.models.OAuthSource, info: dict[str, typing.Any], **kwargs) -> dict[str, typing.Any | dict[str, typing.Any]]:
65    def get_base_user_properties(
66        self, source: OAuthSource, info: dict[str, Any], **kwargs
67    ) -> dict[str, Any | dict[str, Any]]:
68        """Get base user properties for enrollment/update"""
69        return info

Get base user properties for enrollment/update

def get_base_group_properties( self, source: authentik.sources.oauth.models.OAuthSource, group_id: str, **kwargs) -> dict[str, typing.Any | dict[str, typing.Any]]:
71    def get_base_group_properties(
72        self, source: OAuthSource, group_id: str, **kwargs
73    ) -> dict[str, Any | dict[str, Any]]:
74        """Get base group properties for creation/update"""
75        return {
76            "name": group_id,
77        }

Get base group properties for creation/update

class SourceTypeRegistry:
 80class SourceTypeRegistry:
 81    """Registry to hold all Source types."""
 82
 83    def __init__(self) -> None:
 84        self.__sources: list[type[SourceType]] = []
 85
 86    def register(self):
 87        """Class decorator to register classes inline."""
 88
 89        def inner_wrapper(cls):
 90            self.__sources.append(cls)
 91            return cls
 92
 93        return inner_wrapper
 94
 95    def get(self):
 96        """Get a list of all source types"""
 97        return self.__sources
 98
 99    def get_name_tuple(self):
100        """Get list of tuples of all registered names"""
101        return [(x.name, x.verbose_name) for x in self.__sources]
102
103    def find_type(self, type_name: str) -> type[SourceType]:
104        """Find type based on source"""
105        found_type = None
106        for src_type in self.__sources:
107            if src_type.name == type_name:
108                return src_type
109        if not found_type:
110            found_type = SourceType
111            LOGGER.warning(
112                "no matching type found, using default",
113                wanted=type_name,
114                have=[x.name for x in self.__sources],
115            )
116        return found_type
117
118    def find(self, type_name: str, kind: RequestKind) -> Callable:
119        """Find fitting Source Type"""
120        found_type = self.find_type(type_name)
121        if kind == RequestKind.CALLBACK:
122            return found_type.callback_view
123        if kind == RequestKind.REDIRECT:
124            return found_type.redirect_view
125        raise ValueError

Registry to hold all Source types.

def register(self):
86    def register(self):
87        """Class decorator to register classes inline."""
88
89        def inner_wrapper(cls):
90            self.__sources.append(cls)
91            return cls
92
93        return inner_wrapper

Class decorator to register classes inline.

def get(self):
95    def get(self):
96        """Get a list of all source types"""
97        return self.__sources

Get a list of all source types

def get_name_tuple(self):
 99    def get_name_tuple(self):
100        """Get list of tuples of all registered names"""
101        return [(x.name, x.verbose_name) for x in self.__sources]

Get list of tuples of all registered names

def find_type( self, type_name: str) -> type[SourceType]:
103    def find_type(self, type_name: str) -> type[SourceType]:
104        """Find type based on source"""
105        found_type = None
106        for src_type in self.__sources:
107            if src_type.name == type_name:
108                return src_type
109        if not found_type:
110            found_type = SourceType
111            LOGGER.warning(
112                "no matching type found, using default",
113                wanted=type_name,
114                have=[x.name for x in self.__sources],
115            )
116        return found_type

Find type based on source

def find( self, type_name: str, kind: RequestKind) -> Callable:
118    def find(self, type_name: str, kind: RequestKind) -> Callable:
119        """Find fitting Source Type"""
120        found_type = self.find_type(type_name)
121        if kind == RequestKind.CALLBACK:
122            return found_type.callback_view
123        if kind == RequestKind.REDIRECT:
124            return found_type.redirect_view
125        raise ValueError

Find fitting Source Type

registry = <SourceTypeRegistry object>