authentik.enterprise.providers.scim.auth_oauth2
1from datetime import timedelta 2from typing import TYPE_CHECKING 3 4from django.utils.timezone import now 5from requests import Request, RequestException 6from structlog.stdlib import get_logger 7 8from authentik.providers.scim.clients.exceptions import SCIMRequestException 9from authentik.sources.oauth.clients.oauth2 import OAuth2Client 10from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection 11 12if TYPE_CHECKING: 13 from authentik.providers.scim.models import SCIMProvider 14 15 16class SCIMOAuthException(SCIMRequestException): 17 """Exceptions related to OAuth operations for SCIM requests""" 18 19 20class SCIMOAuthAuth: 21 22 def __init__(self, provider: SCIMProvider): 23 self.provider = provider 24 self.user = provider.auth_oauth_user 25 self.logger = get_logger().bind() 26 self.connection = self.get_connection() 27 28 def retrieve_token(self): 29 if not self.provider.auth_oauth: 30 return None 31 source: OAuthSource = self.provider.auth_oauth 32 client = OAuth2Client(source, None) 33 access_token_url = source.source_type.access_token_url or "" 34 if source.source_type.urls_customizable and source.access_token_url: 35 access_token_url = source.access_token_url 36 data = client.get_access_token_args(None, None) 37 data["grant_type"] = "password" 38 data.update(self.provider.auth_oauth_params) 39 try: 40 response = client.do_request( 41 "POST", 42 access_token_url, 43 auth=client.get_access_token_auth(), 44 data=data, 45 headers=client._default_headers, 46 ) 47 response.raise_for_status() 48 body = response.json() 49 if "error" in body: 50 self.logger.info("Failed to get new OAuth token", error=body["error"]) 51 raise SCIMOAuthException(response, body["error"]) 52 return body 53 except RequestException as exc: 54 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc 55 56 def get_connection(self): 57 token = UserOAuthSourceConnection.objects.filter( 58 source=self.provider.auth_oauth, user=self.user, expires__gt=now() 59 ).first() 60 if token and token.access_token: 61 return token 62 token = self.retrieve_token() 63 access_token = token["access_token"] 64 expires_in = int(token.get("expires_in", 0)) 65 token, _ = UserOAuthSourceConnection.objects.update_or_create( 66 source=self.provider.auth_oauth, 67 user=self.user, 68 defaults={ 69 "access_token": access_token, 70 "expires": now() + timedelta(seconds=expires_in), 71 }, 72 ) 73 return token 74 75 def __call__(self, request: Request) -> Request: 76 if not self.connection.is_valid: 77 self.logger.info("OAuth token expired, renewing token") 78 self.connection = self.get_connection() 79 request.headers["Authorization"] = f"Bearer {self.connection.access_token}" 80 return request
17class SCIMOAuthException(SCIMRequestException): 18 """Exceptions related to OAuth operations for SCIM requests"""
Exceptions related to OAuth operations for SCIM requests
class
SCIMOAuthAuth:
21class SCIMOAuthAuth: 22 23 def __init__(self, provider: SCIMProvider): 24 self.provider = provider 25 self.user = provider.auth_oauth_user 26 self.logger = get_logger().bind() 27 self.connection = self.get_connection() 28 29 def retrieve_token(self): 30 if not self.provider.auth_oauth: 31 return None 32 source: OAuthSource = self.provider.auth_oauth 33 client = OAuth2Client(source, None) 34 access_token_url = source.source_type.access_token_url or "" 35 if source.source_type.urls_customizable and source.access_token_url: 36 access_token_url = source.access_token_url 37 data = client.get_access_token_args(None, None) 38 data["grant_type"] = "password" 39 data.update(self.provider.auth_oauth_params) 40 try: 41 response = client.do_request( 42 "POST", 43 access_token_url, 44 auth=client.get_access_token_auth(), 45 data=data, 46 headers=client._default_headers, 47 ) 48 response.raise_for_status() 49 body = response.json() 50 if "error" in body: 51 self.logger.info("Failed to get new OAuth token", error=body["error"]) 52 raise SCIMOAuthException(response, body["error"]) 53 return body 54 except RequestException as exc: 55 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc 56 57 def get_connection(self): 58 token = UserOAuthSourceConnection.objects.filter( 59 source=self.provider.auth_oauth, user=self.user, expires__gt=now() 60 ).first() 61 if token and token.access_token: 62 return token 63 token = self.retrieve_token() 64 access_token = token["access_token"] 65 expires_in = int(token.get("expires_in", 0)) 66 token, _ = UserOAuthSourceConnection.objects.update_or_create( 67 source=self.provider.auth_oauth, 68 user=self.user, 69 defaults={ 70 "access_token": access_token, 71 "expires": now() + timedelta(seconds=expires_in), 72 }, 73 ) 74 return token 75 76 def __call__(self, request: Request) -> Request: 77 if not self.connection.is_valid: 78 self.logger.info("OAuth token expired, renewing token") 79 self.connection = self.get_connection() 80 request.headers["Authorization"] = f"Bearer {self.connection.access_token}" 81 return request
def
retrieve_token(self):
29 def retrieve_token(self): 30 if not self.provider.auth_oauth: 31 return None 32 source: OAuthSource = self.provider.auth_oauth 33 client = OAuth2Client(source, None) 34 access_token_url = source.source_type.access_token_url or "" 35 if source.source_type.urls_customizable and source.access_token_url: 36 access_token_url = source.access_token_url 37 data = client.get_access_token_args(None, None) 38 data["grant_type"] = "password" 39 data.update(self.provider.auth_oauth_params) 40 try: 41 response = client.do_request( 42 "POST", 43 access_token_url, 44 auth=client.get_access_token_auth(), 45 data=data, 46 headers=client._default_headers, 47 ) 48 response.raise_for_status() 49 body = response.json() 50 if "error" in body: 51 self.logger.info("Failed to get new OAuth token", error=body["error"]) 52 raise SCIMOAuthException(response, body["error"]) 53 return body 54 except RequestException as exc: 55 raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
def
get_connection(self):
57 def get_connection(self): 58 token = UserOAuthSourceConnection.objects.filter( 59 source=self.provider.auth_oauth, user=self.user, expires__gt=now() 60 ).first() 61 if token and token.access_token: 62 return token 63 token = self.retrieve_token() 64 access_token = token["access_token"] 65 expires_in = int(token.get("expires_in", 0)) 66 token, _ = UserOAuthSourceConnection.objects.update_or_create( 67 source=self.provider.auth_oauth, 68 user=self.user, 69 defaults={ 70 "access_token": access_token, 71 "expires": now() + timedelta(seconds=expires_in), 72 }, 73 ) 74 return token