authentik.sources.oauth.clients.oauth2

OAuth 2 Clients

  1"""OAuth 2 Clients"""
  2
  3from json import loads
  4from typing import Any
  5from urllib.parse import parse_qsl
  6
  7from django.utils.crypto import constant_time_compare, get_random_string
  8from django.utils.translation import gettext as _
  9from requests.auth import AuthBase, HTTPBasicAuth
 10from requests.exceptions import RequestException
 11from requests.models import Response
 12from structlog.stdlib import get_logger
 13
 14from authentik.lib.generators import generate_id
 15from authentik.providers.oauth2.utils import pkce_s256_challenge
 16from authentik.sources.oauth.clients.base import BaseOAuthClient
 17from authentik.sources.oauth.models import (
 18    AuthorizationCodeAuthMethod,
 19    PKCEMethod,
 20)
 21
 22LOGGER = get_logger()
 23SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce"
 24
 25
 26class OAuth2Client(BaseOAuthClient):
 27    """OAuth2 Client"""
 28
 29    _default_headers = {
 30        "Accept": "application/json",
 31    }
 32
 33    def get_request_arg(self, key: str, default: Any | None = None) -> Any:
 34        """Depending on request type, get data from post or get"""
 35        if self.request.method == "POST":
 36            return self.request.POST.get(key, default)
 37        return self.request.GET.get(key, default)
 38
 39    def check_application_state(self) -> bool:
 40        """Check optional state parameter."""
 41        stored = self.request.session.get(self.session_key, None)
 42        returned = self.get_request_arg("state", None)
 43        check = False
 44        if stored is not None:
 45            if returned is not None:
 46                check = constant_time_compare(stored, returned)
 47            else:
 48                LOGGER.warning("No state parameter returned by the source.")
 49        else:
 50            LOGGER.warning("No state stored in the session.")
 51        return check
 52
 53    def get_application_state(self) -> str:
 54        """Generate state optional parameter."""
 55        return get_random_string(32)
 56
 57    def get_client_id(self) -> str:
 58        """Get client id"""
 59        return self.source.consumer_key
 60
 61    def get_client_secret(self) -> str:
 62        """Get client secret"""
 63        return self.source.consumer_secret
 64
 65    def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]:
 66        args = {
 67            "grant_type": "authorization_code",
 68        }
 69        if callback:
 70            args["redirect_uri"] = callback
 71        if code:
 72            args["code"] = code
 73        if self.request:
 74            pkce_mode = self.source.source_type.pkce
 75            if self.source.source_type.urls_customizable and self.source.pkce:
 76                pkce_mode = self.source.pkce
 77            if pkce_mode != PKCEMethod.NONE:
 78                args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
 79        if (
 80            self.source.source_type.authorization_code_auth_method
 81            == AuthorizationCodeAuthMethod.POST_BODY
 82        ):
 83            args["client_id"] = self.get_client_id()
 84            args["client_secret"] = self.get_client_secret()
 85        return args
 86
 87    def get_access_token_auth(self) -> AuthBase | None:
 88        if (
 89            self.source.source_type.authorization_code_auth_method
 90            == AuthorizationCodeAuthMethod.BASIC_AUTH
 91        ):
 92            return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
 93        return None
 94
 95    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 96        """Fetch access token from callback request."""
 97        callback = self.request.build_absolute_uri(self.callback or self.request.path)
 98        if not self.check_application_state():
 99            LOGGER.warning("Application state check failed.")
100            return {"error": "State check failed."}
101        code = self.get_request_arg("code", None)
102        if not code:
103            LOGGER.warning("No code returned by the source")
104            error = self.get_request_arg("error", None)
105            error_desc = self.get_request_arg("error_description", None)
106            return {"error": error_desc or error or _("No token received.")}
107        try:
108            access_token_url = self.source.source_type.access_token_url or ""
109            if self.source.source_type.urls_customizable and self.source.access_token_url:
110                access_token_url = self.source.access_token_url
111            response = self.do_request(
112                "post",
113                access_token_url,
114                auth=self.get_access_token_auth(),
115                data=self.get_access_token_args(callback, code),
116                headers=self._default_headers,
117                **request_kwargs,
118            )
119            response.raise_for_status()
120        except RequestException as exc:
121            LOGGER.warning(
122                "Unable to fetch access token",
123                exc=exc,
124                response=exc.response.text if exc.response else str(exc),
125            )
126            return None
127        return response.json()
128
129    def get_redirect_args(self) -> dict[str, str]:
130        """Get request parameters for redirect url."""
131        callback = self.request.build_absolute_uri(self.callback)
132        client_id: str = self.get_client_id()
133        args: dict[str, str] = {
134            "client_id": client_id,
135            "redirect_uri": callback,
136            "response_type": "code",
137        }
138        state = self.get_application_state()
139        if state is not None:
140            args["state"] = state
141            self.request.session[self.session_key] = state
142        pkce_mode = self.source.source_type.pkce
143        if self.source.source_type.urls_customizable and self.source.pkce:
144            pkce_mode = self.source.pkce
145        if pkce_mode != PKCEMethod.NONE:
146            verifier = generate_id(length=128)
147            self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier
148            # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
149            if pkce_mode == PKCEMethod.PLAIN:
150                args["code_challenge"] = verifier
151            elif pkce_mode == PKCEMethod.S256:
152                args["code_challenge"] = pkce_s256_challenge(verifier)
153            args["code_challenge_method"] = str(pkce_mode)
154        return args
155
156    def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
157        """Parse token and secret from raw token response."""
158        # Load as json first then parse as query string
159        try:
160            token_data = loads(raw_token)
161        except ValueError:
162            return dict(parse_qsl(raw_token))
163        return token_data
164
165    def do_request(self, method: str, url: str, **kwargs) -> Response:
166        """Build remote url request. Constructs necessary auth."""
167        if "token" in kwargs:
168            token = kwargs.pop("token")
169
170            params = kwargs.get("params", {})
171            params["access_token"] = token["access_token"]
172            kwargs["params"] = params
173
174            headers = kwargs.get("headers", {})
175            headers["Authorization"] = f"{token['token_type']} {token['access_token']}"
176            kwargs["headers"] = headers
177        return super().do_request(method, url, **kwargs)
178
179    @property
180    def session_key(self):
181        return f"oauth-client-{self.source.name}-request-state"
182
183
184class UserprofileHeaderAuthClient(OAuth2Client):
185    """OAuth client which only sends authentication via header, not querystring"""
186
187    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
188        "Fetch user profile information."
189        profile_url = self.source.source_type.profile_url or ""
190        if self.source.source_type.urls_customizable and self.source.profile_url:
191            profile_url = self.source.profile_url
192        if profile_url == "":
193            return None
194        response = self.session.request(
195            "get",
196            profile_url,
197            headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
198        )
199        try:
200            response.raise_for_status()
201        except RequestException as exc:
202            LOGGER.warning(
203                "Unable to fetch user profile from profile_url",
204                exc=exc,
205                response=exc.response.text if exc.response else str(exc),
206            )
207            return None
208        return response.json()
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
SESSION_KEY_OAUTH_PKCE = 'authentik/sources/oauth/pkce'
class OAuth2Client(authentik.sources.oauth.clients.base.BaseOAuthClient):
 27class OAuth2Client(BaseOAuthClient):
 28    """OAuth2 Client"""
 29
 30    _default_headers = {
 31        "Accept": "application/json",
 32    }
 33
 34    def get_request_arg(self, key: str, default: Any | None = None) -> Any:
 35        """Depending on request type, get data from post or get"""
 36        if self.request.method == "POST":
 37            return self.request.POST.get(key, default)
 38        return self.request.GET.get(key, default)
 39
 40    def check_application_state(self) -> bool:
 41        """Check optional state parameter."""
 42        stored = self.request.session.get(self.session_key, None)
 43        returned = self.get_request_arg("state", None)
 44        check = False
 45        if stored is not None:
 46            if returned is not None:
 47                check = constant_time_compare(stored, returned)
 48            else:
 49                LOGGER.warning("No state parameter returned by the source.")
 50        else:
 51            LOGGER.warning("No state stored in the session.")
 52        return check
 53
 54    def get_application_state(self) -> str:
 55        """Generate state optional parameter."""
 56        return get_random_string(32)
 57
 58    def get_client_id(self) -> str:
 59        """Get client id"""
 60        return self.source.consumer_key
 61
 62    def get_client_secret(self) -> str:
 63        """Get client secret"""
 64        return self.source.consumer_secret
 65
 66    def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]:
 67        args = {
 68            "grant_type": "authorization_code",
 69        }
 70        if callback:
 71            args["redirect_uri"] = callback
 72        if code:
 73            args["code"] = code
 74        if self.request:
 75            pkce_mode = self.source.source_type.pkce
 76            if self.source.source_type.urls_customizable and self.source.pkce:
 77                pkce_mode = self.source.pkce
 78            if pkce_mode != PKCEMethod.NONE:
 79                args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
 80        if (
 81            self.source.source_type.authorization_code_auth_method
 82            == AuthorizationCodeAuthMethod.POST_BODY
 83        ):
 84            args["client_id"] = self.get_client_id()
 85            args["client_secret"] = self.get_client_secret()
 86        return args
 87
 88    def get_access_token_auth(self) -> AuthBase | None:
 89        if (
 90            self.source.source_type.authorization_code_auth_method
 91            == AuthorizationCodeAuthMethod.BASIC_AUTH
 92        ):
 93            return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
 94        return None
 95
 96    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 97        """Fetch access token from callback request."""
 98        callback = self.request.build_absolute_uri(self.callback or self.request.path)
 99        if not self.check_application_state():
100            LOGGER.warning("Application state check failed.")
101            return {"error": "State check failed."}
102        code = self.get_request_arg("code", None)
103        if not code:
104            LOGGER.warning("No code returned by the source")
105            error = self.get_request_arg("error", None)
106            error_desc = self.get_request_arg("error_description", None)
107            return {"error": error_desc or error or _("No token received.")}
108        try:
109            access_token_url = self.source.source_type.access_token_url or ""
110            if self.source.source_type.urls_customizable and self.source.access_token_url:
111                access_token_url = self.source.access_token_url
112            response = self.do_request(
113                "post",
114                access_token_url,
115                auth=self.get_access_token_auth(),
116                data=self.get_access_token_args(callback, code),
117                headers=self._default_headers,
118                **request_kwargs,
119            )
120            response.raise_for_status()
121        except RequestException as exc:
122            LOGGER.warning(
123                "Unable to fetch access token",
124                exc=exc,
125                response=exc.response.text if exc.response else str(exc),
126            )
127            return None
128        return response.json()
129
130    def get_redirect_args(self) -> dict[str, str]:
131        """Get request parameters for redirect url."""
132        callback = self.request.build_absolute_uri(self.callback)
133        client_id: str = self.get_client_id()
134        args: dict[str, str] = {
135            "client_id": client_id,
136            "redirect_uri": callback,
137            "response_type": "code",
138        }
139        state = self.get_application_state()
140        if state is not None:
141            args["state"] = state
142            self.request.session[self.session_key] = state
143        pkce_mode = self.source.source_type.pkce
144        if self.source.source_type.urls_customizable and self.source.pkce:
145            pkce_mode = self.source.pkce
146        if pkce_mode != PKCEMethod.NONE:
147            verifier = generate_id(length=128)
148            self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier
149            # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
150            if pkce_mode == PKCEMethod.PLAIN:
151                args["code_challenge"] = verifier
152            elif pkce_mode == PKCEMethod.S256:
153                args["code_challenge"] = pkce_s256_challenge(verifier)
154            args["code_challenge_method"] = str(pkce_mode)
155        return args
156
157    def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
158        """Parse token and secret from raw token response."""
159        # Load as json first then parse as query string
160        try:
161            token_data = loads(raw_token)
162        except ValueError:
163            return dict(parse_qsl(raw_token))
164        return token_data
165
166    def do_request(self, method: str, url: str, **kwargs) -> Response:
167        """Build remote url request. Constructs necessary auth."""
168        if "token" in kwargs:
169            token = kwargs.pop("token")
170
171            params = kwargs.get("params", {})
172            params["access_token"] = token["access_token"]
173            kwargs["params"] = params
174
175            headers = kwargs.get("headers", {})
176            headers["Authorization"] = f"{token['token_type']} {token['access_token']}"
177            kwargs["headers"] = headers
178        return super().do_request(method, url, **kwargs)
179
180    @property
181    def session_key(self):
182        return f"oauth-client-{self.source.name}-request-state"

OAuth2 Client

def get_request_arg(self, key: str, default: Any | None = None) -> Any:
34    def get_request_arg(self, key: str, default: Any | None = None) -> Any:
35        """Depending on request type, get data from post or get"""
36        if self.request.method == "POST":
37            return self.request.POST.get(key, default)
38        return self.request.GET.get(key, default)

Depending on request type, get data from post or get

def check_application_state(self) -> bool:
40    def check_application_state(self) -> bool:
41        """Check optional state parameter."""
42        stored = self.request.session.get(self.session_key, None)
43        returned = self.get_request_arg("state", None)
44        check = False
45        if stored is not None:
46            if returned is not None:
47                check = constant_time_compare(stored, returned)
48            else:
49                LOGGER.warning("No state parameter returned by the source.")
50        else:
51            LOGGER.warning("No state stored in the session.")
52        return check

Check optional state parameter.

def get_application_state(self) -> str:
54    def get_application_state(self) -> str:
55        """Generate state optional parameter."""
56        return get_random_string(32)

Generate state optional parameter.

def get_client_id(self) -> str:
58    def get_client_id(self) -> str:
59        """Get client id"""
60        return self.source.consumer_key

Get client id

def get_client_secret(self) -> str:
62    def get_client_secret(self) -> str:
63        """Get client secret"""
64        return self.source.consumer_secret

Get client secret

def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, typing.Any]:
66    def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]:
67        args = {
68            "grant_type": "authorization_code",
69        }
70        if callback:
71            args["redirect_uri"] = callback
72        if code:
73            args["code"] = code
74        if self.request:
75            pkce_mode = self.source.source_type.pkce
76            if self.source.source_type.urls_customizable and self.source.pkce:
77                pkce_mode = self.source.pkce
78            if pkce_mode != PKCEMethod.NONE:
79                args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
80        if (
81            self.source.source_type.authorization_code_auth_method
82            == AuthorizationCodeAuthMethod.POST_BODY
83        ):
84            args["client_id"] = self.get_client_id()
85            args["client_secret"] = self.get_client_secret()
86        return args
def get_access_token_auth(self) -> requests.auth.AuthBase | None:
88    def get_access_token_auth(self) -> AuthBase | None:
89        if (
90            self.source.source_type.authorization_code_auth_method
91            == AuthorizationCodeAuthMethod.BASIC_AUTH
92        ):
93            return HTTPBasicAuth(self.get_client_id(), self.get_client_secret())
94        return None
def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 96    def get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
 97        """Fetch access token from callback request."""
 98        callback = self.request.build_absolute_uri(self.callback or self.request.path)
 99        if not self.check_application_state():
100            LOGGER.warning("Application state check failed.")
101            return {"error": "State check failed."}
102        code = self.get_request_arg("code", None)
103        if not code:
104            LOGGER.warning("No code returned by the source")
105            error = self.get_request_arg("error", None)
106            error_desc = self.get_request_arg("error_description", None)
107            return {"error": error_desc or error or _("No token received.")}
108        try:
109            access_token_url = self.source.source_type.access_token_url or ""
110            if self.source.source_type.urls_customizable and self.source.access_token_url:
111                access_token_url = self.source.access_token_url
112            response = self.do_request(
113                "post",
114                access_token_url,
115                auth=self.get_access_token_auth(),
116                data=self.get_access_token_args(callback, code),
117                headers=self._default_headers,
118                **request_kwargs,
119            )
120            response.raise_for_status()
121        except RequestException as exc:
122            LOGGER.warning(
123                "Unable to fetch access token",
124                exc=exc,
125                response=exc.response.text if exc.response else str(exc),
126            )
127            return None
128        return response.json()

Fetch access token from callback request.

def get_redirect_args(self) -> dict[str, str]:
130    def get_redirect_args(self) -> dict[str, str]:
131        """Get request parameters for redirect url."""
132        callback = self.request.build_absolute_uri(self.callback)
133        client_id: str = self.get_client_id()
134        args: dict[str, str] = {
135            "client_id": client_id,
136            "redirect_uri": callback,
137            "response_type": "code",
138        }
139        state = self.get_application_state()
140        if state is not None:
141            args["state"] = state
142            self.request.session[self.session_key] = state
143        pkce_mode = self.source.source_type.pkce
144        if self.source.source_type.urls_customizable and self.source.pkce:
145            pkce_mode = self.source.pkce
146        if pkce_mode != PKCEMethod.NONE:
147            verifier = generate_id(length=128)
148            self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier
149            # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2
150            if pkce_mode == PKCEMethod.PLAIN:
151                args["code_challenge"] = verifier
152            elif pkce_mode == PKCEMethod.S256:
153                args["code_challenge"] = pkce_s256_challenge(verifier)
154            args["code_challenge_method"] = str(pkce_mode)
155        return args

Get request parameters for redirect url.

def parse_raw_token(self, raw_token: str) -> dict[str, typing.Any]:
157    def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
158        """Parse token and secret from raw token response."""
159        # Load as json first then parse as query string
160        try:
161            token_data = loads(raw_token)
162        except ValueError:
163            return dict(parse_qsl(raw_token))
164        return token_data

Parse token and secret from raw token response.

def do_request(self, method: str, url: str, **kwargs) -> requests.models.Response:
166    def do_request(self, method: str, url: str, **kwargs) -> Response:
167        """Build remote url request. Constructs necessary auth."""
168        if "token" in kwargs:
169            token = kwargs.pop("token")
170
171            params = kwargs.get("params", {})
172            params["access_token"] = token["access_token"]
173            kwargs["params"] = params
174
175            headers = kwargs.get("headers", {})
176            headers["Authorization"] = f"{token['token_type']} {token['access_token']}"
177            kwargs["headers"] = headers
178        return super().do_request(method, url, **kwargs)

Build remote url request. Constructs necessary auth.

session_key
180    @property
181    def session_key(self):
182        return f"oauth-client-{self.source.name}-request-state"

Return Session Key

class UserprofileHeaderAuthClient(OAuth2Client):
185class UserprofileHeaderAuthClient(OAuth2Client):
186    """OAuth client which only sends authentication via header, not querystring"""
187
188    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
189        "Fetch user profile information."
190        profile_url = self.source.source_type.profile_url or ""
191        if self.source.source_type.urls_customizable and self.source.profile_url:
192            profile_url = self.source.profile_url
193        if profile_url == "":
194            return None
195        response = self.session.request(
196            "get",
197            profile_url,
198            headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
199        )
200        try:
201            response.raise_for_status()
202        except RequestException as exc:
203            LOGGER.warning(
204                "Unable to fetch user profile from profile_url",
205                exc=exc,
206                response=exc.response.text if exc.response else str(exc),
207            )
208            return None
209        return response.json()

OAuth client which only sends authentication via header, not querystring

def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
188    def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
189        "Fetch user profile information."
190        profile_url = self.source.source_type.profile_url or ""
191        if self.source.source_type.urls_customizable and self.source.profile_url:
192            profile_url = self.source.profile_url
193        if profile_url == "":
194            return None
195        response = self.session.request(
196            "get",
197            profile_url,
198            headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
199        )
200        try:
201            response.raise_for_status()
202        except RequestException as exc:
203            LOGGER.warning(
204                "Unable to fetch user profile from profile_url",
205                exc=exc,
206                response=exc.response.text if exc.response else str(exc),
207            )
208            return None
209        return response.json()

Fetch user profile information.