authentik.enterprise.providers.scim.auth_oauth2

 1from datetime import timedelta
 2from typing import TYPE_CHECKING
 3
 4from django.utils.timezone import now
 5from requests import Request, RequestException
 6from structlog.stdlib import get_logger
 7
 8from authentik.providers.scim.clients.exceptions import SCIMRequestException
 9from authentik.sources.oauth.clients.oauth2 import OAuth2Client
10from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
11
12if TYPE_CHECKING:
13    from authentik.providers.scim.models import SCIMProvider
14
15
16class SCIMOAuthException(SCIMRequestException):
17    """Exceptions related to OAuth operations for SCIM requests"""
18
19
20class SCIMOAuthAuth:
21
22    def __init__(self, provider: SCIMProvider):
23        self.provider = provider
24        self.user = provider.auth_oauth_user
25        self.logger = get_logger().bind()
26        self.connection = self.get_connection()
27
28    def retrieve_token(self):
29        if not self.provider.auth_oauth:
30            return None
31        source: OAuthSource = self.provider.auth_oauth
32        client = OAuth2Client(source, None)
33        access_token_url = source.source_type.access_token_url or ""
34        if source.source_type.urls_customizable and source.access_token_url:
35            access_token_url = source.access_token_url
36        data = client.get_access_token_args(None, None)
37        data["grant_type"] = "password"
38        data.update(self.provider.auth_oauth_params)
39        try:
40            response = client.do_request(
41                "POST",
42                access_token_url,
43                auth=client.get_access_token_auth(),
44                data=data,
45                headers=client._default_headers,
46            )
47            response.raise_for_status()
48            body = response.json()
49            if "error" in body:
50                self.logger.info("Failed to get new OAuth token", error=body["error"])
51                raise SCIMOAuthException(response, body["error"])
52            return body
53        except RequestException as exc:
54            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
55
56    def get_connection(self):
57        token = UserOAuthSourceConnection.objects.filter(
58            source=self.provider.auth_oauth, user=self.user, expires__gt=now()
59        ).first()
60        if token and token.access_token:
61            return token
62        token = self.retrieve_token()
63        access_token = token["access_token"]
64        expires_in = int(token.get("expires_in", 0))
65        token, _ = UserOAuthSourceConnection.objects.update_or_create(
66            source=self.provider.auth_oauth,
67            user=self.user,
68            defaults={
69                "access_token": access_token,
70                "expires": now() + timedelta(seconds=expires_in),
71            },
72        )
73        return token
74
75    def __call__(self, request: Request) -> Request:
76        if not self.connection.is_valid:
77            self.logger.info("OAuth token expired, renewing token")
78            self.connection = self.get_connection()
79        request.headers["Authorization"] = f"Bearer {self.connection.access_token}"
80        return request
17class SCIMOAuthException(SCIMRequestException):
18    """Exceptions related to OAuth operations for SCIM requests"""

Exceptions related to OAuth operations for SCIM requests

class SCIMOAuthAuth:
21class SCIMOAuthAuth:
22
23    def __init__(self, provider: SCIMProvider):
24        self.provider = provider
25        self.user = provider.auth_oauth_user
26        self.logger = get_logger().bind()
27        self.connection = self.get_connection()
28
29    def retrieve_token(self):
30        if not self.provider.auth_oauth:
31            return None
32        source: OAuthSource = self.provider.auth_oauth
33        client = OAuth2Client(source, None)
34        access_token_url = source.source_type.access_token_url or ""
35        if source.source_type.urls_customizable and source.access_token_url:
36            access_token_url = source.access_token_url
37        data = client.get_access_token_args(None, None)
38        data["grant_type"] = "password"
39        data.update(self.provider.auth_oauth_params)
40        try:
41            response = client.do_request(
42                "POST",
43                access_token_url,
44                auth=client.get_access_token_auth(),
45                data=data,
46                headers=client._default_headers,
47            )
48            response.raise_for_status()
49            body = response.json()
50            if "error" in body:
51                self.logger.info("Failed to get new OAuth token", error=body["error"])
52                raise SCIMOAuthException(response, body["error"])
53            return body
54        except RequestException as exc:
55            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
56
57    def get_connection(self):
58        token = UserOAuthSourceConnection.objects.filter(
59            source=self.provider.auth_oauth, user=self.user, expires__gt=now()
60        ).first()
61        if token and token.access_token:
62            return token
63        token = self.retrieve_token()
64        access_token = token["access_token"]
65        expires_in = int(token.get("expires_in", 0))
66        token, _ = UserOAuthSourceConnection.objects.update_or_create(
67            source=self.provider.auth_oauth,
68            user=self.user,
69            defaults={
70                "access_token": access_token,
71                "expires": now() + timedelta(seconds=expires_in),
72            },
73        )
74        return token
75
76    def __call__(self, request: Request) -> Request:
77        if not self.connection.is_valid:
78            self.logger.info("OAuth token expired, renewing token")
79            self.connection = self.get_connection()
80        request.headers["Authorization"] = f"Bearer {self.connection.access_token}"
81        return request
provider
user
logger
connection
def retrieve_token(self):
29    def retrieve_token(self):
30        if not self.provider.auth_oauth:
31            return None
32        source: OAuthSource = self.provider.auth_oauth
33        client = OAuth2Client(source, None)
34        access_token_url = source.source_type.access_token_url or ""
35        if source.source_type.urls_customizable and source.access_token_url:
36            access_token_url = source.access_token_url
37        data = client.get_access_token_args(None, None)
38        data["grant_type"] = "password"
39        data.update(self.provider.auth_oauth_params)
40        try:
41            response = client.do_request(
42                "POST",
43                access_token_url,
44                auth=client.get_access_token_auth(),
45                data=data,
46                headers=client._default_headers,
47            )
48            response.raise_for_status()
49            body = response.json()
50            if "error" in body:
51                self.logger.info("Failed to get new OAuth token", error=body["error"])
52                raise SCIMOAuthException(response, body["error"])
53            return body
54        except RequestException as exc:
55            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
def get_connection(self):
57    def get_connection(self):
58        token = UserOAuthSourceConnection.objects.filter(
59            source=self.provider.auth_oauth, user=self.user, expires__gt=now()
60        ).first()
61        if token and token.access_token:
62            return token
63        token = self.retrieve_token()
64        access_token = token["access_token"]
65        expires_in = int(token.get("expires_in", 0))
66        token, _ = UserOAuthSourceConnection.objects.update_or_create(
67            source=self.provider.auth_oauth,
68            user=self.user,
69            defaults={
70                "access_token": access_token,
71                "expires": now() + timedelta(seconds=expires_in),
72            },
73        )
74        return token