authentik.sources.oauth.clients.oauth2
OAuth 2 Clients
1"""OAuth 2 Clients""" 2 3from json import loads 4from typing import Any 5from urllib.parse import parse_qsl 6 7from django.utils.crypto import constant_time_compare, get_random_string 8from django.utils.translation import gettext as _ 9from requests.auth import AuthBase, HTTPBasicAuth 10from requests.exceptions import RequestException 11from requests.models import Response 12from structlog.stdlib import get_logger 13 14from authentik.lib.generators import generate_id 15from authentik.providers.oauth2.utils import pkce_s256_challenge 16from authentik.sources.oauth.clients.base import BaseOAuthClient 17from authentik.sources.oauth.models import ( 18 AuthorizationCodeAuthMethod, 19 PKCEMethod, 20) 21 22LOGGER = get_logger() 23SESSION_KEY_OAUTH_PKCE = "authentik/sources/oauth/pkce" 24 25 26class OAuth2Client(BaseOAuthClient): 27 """OAuth2 Client""" 28 29 _default_headers = { 30 "Accept": "application/json", 31 } 32 33 def get_request_arg(self, key: str, default: Any | None = None) -> Any: 34 """Depending on request type, get data from post or get""" 35 if self.request.method == "POST": 36 return self.request.POST.get(key, default) 37 return self.request.GET.get(key, default) 38 39 def check_application_state(self) -> bool: 40 """Check optional state parameter.""" 41 stored = self.request.session.get(self.session_key, None) 42 returned = self.get_request_arg("state", None) 43 check = False 44 if stored is not None: 45 if returned is not None: 46 check = constant_time_compare(stored, returned) 47 else: 48 LOGGER.warning("No state parameter returned by the source.") 49 else: 50 LOGGER.warning("No state stored in the session.") 51 return check 52 53 def get_application_state(self) -> str: 54 """Generate state optional parameter.""" 55 return get_random_string(32) 56 57 def get_client_id(self) -> str: 58 """Get client id""" 59 return self.source.consumer_key 60 61 def get_client_secret(self) -> str: 62 """Get client secret""" 63 return self.source.consumer_secret 64 65 def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]: 66 args = { 67 "grant_type": "authorization_code", 68 } 69 if callback: 70 args["redirect_uri"] = callback 71 if code: 72 args["code"] = code 73 if self.request: 74 pkce_mode = self.source.source_type.pkce 75 if self.source.source_type.urls_customizable and self.source.pkce: 76 pkce_mode = self.source.pkce 77 if pkce_mode != PKCEMethod.NONE: 78 args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] 79 if ( 80 self.source.source_type.authorization_code_auth_method 81 == AuthorizationCodeAuthMethod.POST_BODY 82 ): 83 args["client_id"] = self.get_client_id() 84 args["client_secret"] = self.get_client_secret() 85 return args 86 87 def get_access_token_auth(self) -> AuthBase | None: 88 if ( 89 self.source.source_type.authorization_code_auth_method 90 == AuthorizationCodeAuthMethod.BASIC_AUTH 91 ): 92 return HTTPBasicAuth(self.get_client_id(), self.get_client_secret()) 93 return None 94 95 def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: 96 """Fetch access token from callback request.""" 97 callback = self.request.build_absolute_uri(self.callback or self.request.path) 98 if not self.check_application_state(): 99 LOGGER.warning("Application state check failed.") 100 return {"error": "State check failed."} 101 code = self.get_request_arg("code", None) 102 if not code: 103 LOGGER.warning("No code returned by the source") 104 error = self.get_request_arg("error", None) 105 error_desc = self.get_request_arg("error_description", None) 106 return {"error": error_desc or error or _("No token received.")} 107 try: 108 access_token_url = self.source.source_type.access_token_url or "" 109 if self.source.source_type.urls_customizable and self.source.access_token_url: 110 access_token_url = self.source.access_token_url 111 response = self.do_request( 112 "post", 113 access_token_url, 114 auth=self.get_access_token_auth(), 115 data=self.get_access_token_args(callback, code), 116 headers=self._default_headers, 117 **request_kwargs, 118 ) 119 response.raise_for_status() 120 except RequestException as exc: 121 LOGGER.warning( 122 "Unable to fetch access token", 123 exc=exc, 124 response=exc.response.text if exc.response else str(exc), 125 ) 126 return None 127 return response.json() 128 129 def get_redirect_args(self) -> dict[str, str]: 130 """Get request parameters for redirect url.""" 131 callback = self.request.build_absolute_uri(self.callback) 132 client_id: str = self.get_client_id() 133 args: dict[str, str] = { 134 "client_id": client_id, 135 "redirect_uri": callback, 136 "response_type": "code", 137 } 138 state = self.get_application_state() 139 if state is not None: 140 args["state"] = state 141 self.request.session[self.session_key] = state 142 pkce_mode = self.source.source_type.pkce 143 if self.source.source_type.urls_customizable and self.source.pkce: 144 pkce_mode = self.source.pkce 145 if pkce_mode != PKCEMethod.NONE: 146 verifier = generate_id(length=128) 147 self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier 148 # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 149 if pkce_mode == PKCEMethod.PLAIN: 150 args["code_challenge"] = verifier 151 elif pkce_mode == PKCEMethod.S256: 152 args["code_challenge"] = pkce_s256_challenge(verifier) 153 args["code_challenge_method"] = str(pkce_mode) 154 return args 155 156 def parse_raw_token(self, raw_token: str) -> dict[str, Any]: 157 """Parse token and secret from raw token response.""" 158 # Load as json first then parse as query string 159 try: 160 token_data = loads(raw_token) 161 except ValueError: 162 return dict(parse_qsl(raw_token)) 163 return token_data 164 165 def do_request(self, method: str, url: str, **kwargs) -> Response: 166 """Build remote url request. Constructs necessary auth.""" 167 if "token" in kwargs: 168 token = kwargs.pop("token") 169 170 params = kwargs.get("params", {}) 171 params["access_token"] = token["access_token"] 172 kwargs["params"] = params 173 174 headers = kwargs.get("headers", {}) 175 headers["Authorization"] = f"{token['token_type']} {token['access_token']}" 176 kwargs["headers"] = headers 177 return super().do_request(method, url, **kwargs) 178 179 @property 180 def session_key(self): 181 return f"oauth-client-{self.source.name}-request-state" 182 183 184class UserprofileHeaderAuthClient(OAuth2Client): 185 """OAuth client which only sends authentication via header, not querystring""" 186 187 def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: 188 "Fetch user profile information." 189 profile_url = self.source.source_type.profile_url or "" 190 if self.source.source_type.urls_customizable and self.source.profile_url: 191 profile_url = self.source.profile_url 192 if profile_url == "": 193 return None 194 response = self.session.request( 195 "get", 196 profile_url, 197 headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, 198 ) 199 try: 200 response.raise_for_status() 201 except RequestException as exc: 202 LOGGER.warning( 203 "Unable to fetch user profile from profile_url", 204 exc=exc, 205 response=exc.response.text if exc.response else str(exc), 206 ) 207 return None 208 return response.json()
LOGGER =
<BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
SESSION_KEY_OAUTH_PKCE =
'authentik/sources/oauth/pkce'
27class OAuth2Client(BaseOAuthClient): 28 """OAuth2 Client""" 29 30 _default_headers = { 31 "Accept": "application/json", 32 } 33 34 def get_request_arg(self, key: str, default: Any | None = None) -> Any: 35 """Depending on request type, get data from post or get""" 36 if self.request.method == "POST": 37 return self.request.POST.get(key, default) 38 return self.request.GET.get(key, default) 39 40 def check_application_state(self) -> bool: 41 """Check optional state parameter.""" 42 stored = self.request.session.get(self.session_key, None) 43 returned = self.get_request_arg("state", None) 44 check = False 45 if stored is not None: 46 if returned is not None: 47 check = constant_time_compare(stored, returned) 48 else: 49 LOGGER.warning("No state parameter returned by the source.") 50 else: 51 LOGGER.warning("No state stored in the session.") 52 return check 53 54 def get_application_state(self) -> str: 55 """Generate state optional parameter.""" 56 return get_random_string(32) 57 58 def get_client_id(self) -> str: 59 """Get client id""" 60 return self.source.consumer_key 61 62 def get_client_secret(self) -> str: 63 """Get client secret""" 64 return self.source.consumer_secret 65 66 def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]: 67 args = { 68 "grant_type": "authorization_code", 69 } 70 if callback: 71 args["redirect_uri"] = callback 72 if code: 73 args["code"] = code 74 if self.request: 75 pkce_mode = self.source.source_type.pkce 76 if self.source.source_type.urls_customizable and self.source.pkce: 77 pkce_mode = self.source.pkce 78 if pkce_mode != PKCEMethod.NONE: 79 args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] 80 if ( 81 self.source.source_type.authorization_code_auth_method 82 == AuthorizationCodeAuthMethod.POST_BODY 83 ): 84 args["client_id"] = self.get_client_id() 85 args["client_secret"] = self.get_client_secret() 86 return args 87 88 def get_access_token_auth(self) -> AuthBase | None: 89 if ( 90 self.source.source_type.authorization_code_auth_method 91 == AuthorizationCodeAuthMethod.BASIC_AUTH 92 ): 93 return HTTPBasicAuth(self.get_client_id(), self.get_client_secret()) 94 return None 95 96 def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: 97 """Fetch access token from callback request.""" 98 callback = self.request.build_absolute_uri(self.callback or self.request.path) 99 if not self.check_application_state(): 100 LOGGER.warning("Application state check failed.") 101 return {"error": "State check failed."} 102 code = self.get_request_arg("code", None) 103 if not code: 104 LOGGER.warning("No code returned by the source") 105 error = self.get_request_arg("error", None) 106 error_desc = self.get_request_arg("error_description", None) 107 return {"error": error_desc or error or _("No token received.")} 108 try: 109 access_token_url = self.source.source_type.access_token_url or "" 110 if self.source.source_type.urls_customizable and self.source.access_token_url: 111 access_token_url = self.source.access_token_url 112 response = self.do_request( 113 "post", 114 access_token_url, 115 auth=self.get_access_token_auth(), 116 data=self.get_access_token_args(callback, code), 117 headers=self._default_headers, 118 **request_kwargs, 119 ) 120 response.raise_for_status() 121 except RequestException as exc: 122 LOGGER.warning( 123 "Unable to fetch access token", 124 exc=exc, 125 response=exc.response.text if exc.response else str(exc), 126 ) 127 return None 128 return response.json() 129 130 def get_redirect_args(self) -> dict[str, str]: 131 """Get request parameters for redirect url.""" 132 callback = self.request.build_absolute_uri(self.callback) 133 client_id: str = self.get_client_id() 134 args: dict[str, str] = { 135 "client_id": client_id, 136 "redirect_uri": callback, 137 "response_type": "code", 138 } 139 state = self.get_application_state() 140 if state is not None: 141 args["state"] = state 142 self.request.session[self.session_key] = state 143 pkce_mode = self.source.source_type.pkce 144 if self.source.source_type.urls_customizable and self.source.pkce: 145 pkce_mode = self.source.pkce 146 if pkce_mode != PKCEMethod.NONE: 147 verifier = generate_id(length=128) 148 self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier 149 # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 150 if pkce_mode == PKCEMethod.PLAIN: 151 args["code_challenge"] = verifier 152 elif pkce_mode == PKCEMethod.S256: 153 args["code_challenge"] = pkce_s256_challenge(verifier) 154 args["code_challenge_method"] = str(pkce_mode) 155 return args 156 157 def parse_raw_token(self, raw_token: str) -> dict[str, Any]: 158 """Parse token and secret from raw token response.""" 159 # Load as json first then parse as query string 160 try: 161 token_data = loads(raw_token) 162 except ValueError: 163 return dict(parse_qsl(raw_token)) 164 return token_data 165 166 def do_request(self, method: str, url: str, **kwargs) -> Response: 167 """Build remote url request. Constructs necessary auth.""" 168 if "token" in kwargs: 169 token = kwargs.pop("token") 170 171 params = kwargs.get("params", {}) 172 params["access_token"] = token["access_token"] 173 kwargs["params"] = params 174 175 headers = kwargs.get("headers", {}) 176 headers["Authorization"] = f"{token['token_type']} {token['access_token']}" 177 kwargs["headers"] = headers 178 return super().do_request(method, url, **kwargs) 179 180 @property 181 def session_key(self): 182 return f"oauth-client-{self.source.name}-request-state"
OAuth2 Client
def
get_request_arg(self, key: str, default: Any | None = None) -> Any:
34 def get_request_arg(self, key: str, default: Any | None = None) -> Any: 35 """Depending on request type, get data from post or get""" 36 if self.request.method == "POST": 37 return self.request.POST.get(key, default) 38 return self.request.GET.get(key, default)
Depending on request type, get data from post or get
def
check_application_state(self) -> bool:
40 def check_application_state(self) -> bool: 41 """Check optional state parameter.""" 42 stored = self.request.session.get(self.session_key, None) 43 returned = self.get_request_arg("state", None) 44 check = False 45 if stored is not None: 46 if returned is not None: 47 check = constant_time_compare(stored, returned) 48 else: 49 LOGGER.warning("No state parameter returned by the source.") 50 else: 51 LOGGER.warning("No state stored in the session.") 52 return check
Check optional state parameter.
def
get_application_state(self) -> str:
54 def get_application_state(self) -> str: 55 """Generate state optional parameter.""" 56 return get_random_string(32)
Generate state optional parameter.
def
get_client_secret(self) -> str:
62 def get_client_secret(self) -> str: 63 """Get client secret""" 64 return self.source.consumer_secret
Get client secret
def
get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, typing.Any]:
66 def get_access_token_args(self, callback: str | None, code: str | None) -> dict[str, Any]: 67 args = { 68 "grant_type": "authorization_code", 69 } 70 if callback: 71 args["redirect_uri"] = callback 72 if code: 73 args["code"] = code 74 if self.request: 75 pkce_mode = self.source.source_type.pkce 76 if self.source.source_type.urls_customizable and self.source.pkce: 77 pkce_mode = self.source.pkce 78 if pkce_mode != PKCEMethod.NONE: 79 args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] 80 if ( 81 self.source.source_type.authorization_code_auth_method 82 == AuthorizationCodeAuthMethod.POST_BODY 83 ): 84 args["client_id"] = self.get_client_id() 85 args["client_secret"] = self.get_client_secret() 86 return args
def
get_access_token(self, **request_kwargs) -> dict[str, Any] | None:
96 def get_access_token(self, **request_kwargs) -> dict[str, Any] | None: 97 """Fetch access token from callback request.""" 98 callback = self.request.build_absolute_uri(self.callback or self.request.path) 99 if not self.check_application_state(): 100 LOGGER.warning("Application state check failed.") 101 return {"error": "State check failed."} 102 code = self.get_request_arg("code", None) 103 if not code: 104 LOGGER.warning("No code returned by the source") 105 error = self.get_request_arg("error", None) 106 error_desc = self.get_request_arg("error_description", None) 107 return {"error": error_desc or error or _("No token received.")} 108 try: 109 access_token_url = self.source.source_type.access_token_url or "" 110 if self.source.source_type.urls_customizable and self.source.access_token_url: 111 access_token_url = self.source.access_token_url 112 response = self.do_request( 113 "post", 114 access_token_url, 115 auth=self.get_access_token_auth(), 116 data=self.get_access_token_args(callback, code), 117 headers=self._default_headers, 118 **request_kwargs, 119 ) 120 response.raise_for_status() 121 except RequestException as exc: 122 LOGGER.warning( 123 "Unable to fetch access token", 124 exc=exc, 125 response=exc.response.text if exc.response else str(exc), 126 ) 127 return None 128 return response.json()
Fetch access token from callback request.
def
get_redirect_args(self) -> dict[str, str]:
130 def get_redirect_args(self) -> dict[str, str]: 131 """Get request parameters for redirect url.""" 132 callback = self.request.build_absolute_uri(self.callback) 133 client_id: str = self.get_client_id() 134 args: dict[str, str] = { 135 "client_id": client_id, 136 "redirect_uri": callback, 137 "response_type": "code", 138 } 139 state = self.get_application_state() 140 if state is not None: 141 args["state"] = state 142 self.request.session[self.session_key] = state 143 pkce_mode = self.source.source_type.pkce 144 if self.source.source_type.urls_customizable and self.source.pkce: 145 pkce_mode = self.source.pkce 146 if pkce_mode != PKCEMethod.NONE: 147 verifier = generate_id(length=128) 148 self.request.session[SESSION_KEY_OAUTH_PKCE] = verifier 149 # https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 150 if pkce_mode == PKCEMethod.PLAIN: 151 args["code_challenge"] = verifier 152 elif pkce_mode == PKCEMethod.S256: 153 args["code_challenge"] = pkce_s256_challenge(verifier) 154 args["code_challenge_method"] = str(pkce_mode) 155 return args
Get request parameters for redirect url.
def
parse_raw_token(self, raw_token: str) -> dict[str, typing.Any]:
157 def parse_raw_token(self, raw_token: str) -> dict[str, Any]: 158 """Parse token and secret from raw token response.""" 159 # Load as json first then parse as query string 160 try: 161 token_data = loads(raw_token) 162 except ValueError: 163 return dict(parse_qsl(raw_token)) 164 return token_data
Parse token and secret from raw token response.
def
do_request(self, method: str, url: str, **kwargs) -> requests.models.Response:
166 def do_request(self, method: str, url: str, **kwargs) -> Response: 167 """Build remote url request. Constructs necessary auth.""" 168 if "token" in kwargs: 169 token = kwargs.pop("token") 170 171 params = kwargs.get("params", {}) 172 params["access_token"] = token["access_token"] 173 kwargs["params"] = params 174 175 headers = kwargs.get("headers", {}) 176 headers["Authorization"] = f"{token['token_type']} {token['access_token']}" 177 kwargs["headers"] = headers 178 return super().do_request(method, url, **kwargs)
Build remote url request. Constructs necessary auth.
185class UserprofileHeaderAuthClient(OAuth2Client): 186 """OAuth client which only sends authentication via header, not querystring""" 187 188 def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: 189 "Fetch user profile information." 190 profile_url = self.source.source_type.profile_url or "" 191 if self.source.source_type.urls_customizable and self.source.profile_url: 192 profile_url = self.source.profile_url 193 if profile_url == "": 194 return None 195 response = self.session.request( 196 "get", 197 profile_url, 198 headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, 199 ) 200 try: 201 response.raise_for_status() 202 except RequestException as exc: 203 LOGGER.warning( 204 "Unable to fetch user profile from profile_url", 205 exc=exc, 206 response=exc.response.text if exc.response else str(exc), 207 ) 208 return None 209 return response.json()
OAuth client which only sends authentication via header, not querystring
def
get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None:
188 def get_profile_info(self, token: dict[str, str]) -> dict[str, Any] | None: 189 "Fetch user profile information." 190 profile_url = self.source.source_type.profile_url or "" 191 if self.source.source_type.urls_customizable and self.source.profile_url: 192 profile_url = self.source.profile_url 193 if profile_url == "": 194 return None 195 response = self.session.request( 196 "get", 197 profile_url, 198 headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, 199 ) 200 try: 201 response.raise_for_status() 202 except RequestException as exc: 203 LOGGER.warning( 204 "Unable to fetch user profile from profile_url", 205 exc=exc, 206 response=exc.response.text if exc.response else str(exc), 207 ) 208 return None 209 return response.json()
Fetch user profile information.