authentik.enterprise.providers.scim.auth_oauth2
1from datetime import timedelta 2from typing import TYPE_CHECKING, Any 3 4from django.utils.timezone import now 5from requests import Request, RequestException 6from structlog.stdlib import get_logger 7 8from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN 9from authentik.providers.scim.clients.exceptions import SCIMRequestException 10from authentik.providers.scim.models import SCIMAuthenticationMode 11from authentik.sources.oauth.clients.base import BaseOAuthClient 12from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection 13 14if TYPE_CHECKING: 15 from authentik.providers.scim.models import SCIMProvider 16 17 18class SCIMOAuthException(SCIMRequestException): 19 """Exceptions related to OAuth operations for SCIM requests""" 20 21 22class SCIMOAuthAuth: 23 def __init__(self, provider: SCIMProvider): 24 self.provider = provider 25 self.user = provider.auth_oauth_user 26 self.logger = get_logger().bind() 27 self.connection = self.get_connection() 28 29 def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]: 30 source: OAuthSource = self.provider.auth_oauth 31 client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source) 32 access_token_url = source.source_type.access_token_url or "" 33 if source.source_type.urls_customizable and source.access_token_url: 34 access_token_url = source.access_token_url 35 data = client.get_access_token_args(None, None) 36 if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT: 37 data["grant_type"] = GRANT_TYPE_PASSWORD 38 elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE: 39 data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN 40 if not conn: 41 raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token") 42 data["refresh_token"] = conn.refresh_token 43 data.update(self.provider.auth_oauth_params) 44 try: 45 response = client.do_request( 46 "POST", 47 access_token_url, 48 auth=client.get_access_token_auth(), 49 data=data, 50 headers=client._default_headers, 51 ) 52 response.raise_for_status() 53 body = response.json() 54 if "error" in body: 55 self.logger.info("Failed to get new OAuth token", error=body["error"]) 56 raise SCIMOAuthException(response, body["error"]) 57 return body 58 except RequestException as exc: 59 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc 60 61 def get_connection(self): 62 if not self.provider.auth_oauth: 63 return None 64 conn = UserOAuthSourceConnection.objects.filter( 65 source=self.provider.auth_oauth, user=self.user 66 ).first() 67 if conn and conn.access_token and conn.expires > now(): 68 return conn 69 token = self.retrieve_token(conn) 70 access_token = token["access_token"] 71 expires_in = int(token.get("expires_in", 0)) 72 token, _ = UserOAuthSourceConnection.objects.update_or_create( 73 source=self.provider.auth_oauth, 74 user=self.user, 75 defaults={ 76 "access_token": access_token, 77 "refresh_token": token.get("refresh_token"), 78 "expires": now() + timedelta(seconds=expires_in), 79 # When using `update_or_create`, `last_updated` is not updated 80 "last_updated": now(), 81 }, 82 ) 83 return token 84 85 def __call__(self, request: Request) -> Request: 86 if not self.connection.is_valid: 87 self.logger.info("OAuth token expired, renewing token") 88 self.connection = self.get_connection() 89 request.headers["Authorization"] = f"Bearer {self.connection.access_token}" 90 return request
19class SCIMOAuthException(SCIMRequestException): 20 """Exceptions related to OAuth operations for SCIM requests"""
Exceptions related to OAuth operations for SCIM requests
class
SCIMOAuthAuth:
23class SCIMOAuthAuth: 24 def __init__(self, provider: SCIMProvider): 25 self.provider = provider 26 self.user = provider.auth_oauth_user 27 self.logger = get_logger().bind() 28 self.connection = self.get_connection() 29 30 def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]: 31 source: OAuthSource = self.provider.auth_oauth 32 client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source) 33 access_token_url = source.source_type.access_token_url or "" 34 if source.source_type.urls_customizable and source.access_token_url: 35 access_token_url = source.access_token_url 36 data = client.get_access_token_args(None, None) 37 if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT: 38 data["grant_type"] = GRANT_TYPE_PASSWORD 39 elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE: 40 data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN 41 if not conn: 42 raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token") 43 data["refresh_token"] = conn.refresh_token 44 data.update(self.provider.auth_oauth_params) 45 try: 46 response = client.do_request( 47 "POST", 48 access_token_url, 49 auth=client.get_access_token_auth(), 50 data=data, 51 headers=client._default_headers, 52 ) 53 response.raise_for_status() 54 body = response.json() 55 if "error" in body: 56 self.logger.info("Failed to get new OAuth token", error=body["error"]) 57 raise SCIMOAuthException(response, body["error"]) 58 return body 59 except RequestException as exc: 60 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc 61 62 def get_connection(self): 63 if not self.provider.auth_oauth: 64 return None 65 conn = UserOAuthSourceConnection.objects.filter( 66 source=self.provider.auth_oauth, user=self.user 67 ).first() 68 if conn and conn.access_token and conn.expires > now(): 69 return conn 70 token = self.retrieve_token(conn) 71 access_token = token["access_token"] 72 expires_in = int(token.get("expires_in", 0)) 73 token, _ = UserOAuthSourceConnection.objects.update_or_create( 74 source=self.provider.auth_oauth, 75 user=self.user, 76 defaults={ 77 "access_token": access_token, 78 "refresh_token": token.get("refresh_token"), 79 "expires": now() + timedelta(seconds=expires_in), 80 # When using `update_or_create`, `last_updated` is not updated 81 "last_updated": now(), 82 }, 83 ) 84 return token 85 86 def __call__(self, request: Request) -> Request: 87 if not self.connection.is_valid: 88 self.logger.info("OAuth token expired, renewing token") 89 self.connection = self.get_connection() 90 request.headers["Authorization"] = f"Bearer {self.connection.access_token}" 91 return request
def
retrieve_token( self, conn: authentik.sources.oauth.models.UserOAuthSourceConnection | None) -> dict[str, typing.Any]:
30 def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]: 31 source: OAuthSource = self.provider.auth_oauth 32 client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source) 33 access_token_url = source.source_type.access_token_url or "" 34 if source.source_type.urls_customizable and source.access_token_url: 35 access_token_url = source.access_token_url 36 data = client.get_access_token_args(None, None) 37 if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT: 38 data["grant_type"] = GRANT_TYPE_PASSWORD 39 elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE: 40 data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN 41 if not conn: 42 raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token") 43 data["refresh_token"] = conn.refresh_token 44 data.update(self.provider.auth_oauth_params) 45 try: 46 response = client.do_request( 47 "POST", 48 access_token_url, 49 auth=client.get_access_token_auth(), 50 data=data, 51 headers=client._default_headers, 52 ) 53 response.raise_for_status() 54 body = response.json() 55 if "error" in body: 56 self.logger.info("Failed to get new OAuth token", error=body["error"]) 57 raise SCIMOAuthException(response, body["error"]) 58 return body 59 except RequestException as exc: 60 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
def
get_connection(self):
62 def get_connection(self): 63 if not self.provider.auth_oauth: 64 return None 65 conn = UserOAuthSourceConnection.objects.filter( 66 source=self.provider.auth_oauth, user=self.user 67 ).first() 68 if conn and conn.access_token and conn.expires > now(): 69 return conn 70 token = self.retrieve_token(conn) 71 access_token = token["access_token"] 72 expires_in = int(token.get("expires_in", 0)) 73 token, _ = UserOAuthSourceConnection.objects.update_or_create( 74 source=self.provider.auth_oauth, 75 user=self.user, 76 defaults={ 77 "access_token": access_token, 78 "refresh_token": token.get("refresh_token"), 79 "expires": now() + timedelta(seconds=expires_in), 80 # When using `update_or_create`, `last_updated` is not updated 81 "last_updated": now(), 82 }, 83 ) 84 return token