authentik.enterprise.providers.scim.auth_oauth2

 1from datetime import timedelta
 2from typing import TYPE_CHECKING, Any
 3
 4from django.utils.timezone import now
 5from requests import Request, RequestException
 6from structlog.stdlib import get_logger
 7
 8from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN
 9from authentik.providers.scim.clients.exceptions import SCIMRequestException
10from authentik.providers.scim.models import SCIMAuthenticationMode
11from authentik.sources.oauth.clients.base import BaseOAuthClient
12from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
13
14if TYPE_CHECKING:
15    from authentik.providers.scim.models import SCIMProvider
16
17
18class SCIMOAuthException(SCIMRequestException):
19    """Exceptions related to OAuth operations for SCIM requests"""
20
21
22class SCIMOAuthAuth:
23    def __init__(self, provider: SCIMProvider):
24        self.provider = provider
25        self.user = provider.auth_oauth_user
26        self.logger = get_logger().bind()
27        self.connection = self.get_connection()
28
29    def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
30        source: OAuthSource = self.provider.auth_oauth
31        client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
32        access_token_url = source.source_type.access_token_url or ""
33        if source.source_type.urls_customizable and source.access_token_url:
34            access_token_url = source.access_token_url
35        data = client.get_access_token_args(None, None)
36        if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
37            data["grant_type"] = GRANT_TYPE_PASSWORD
38        elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
39            data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
40            if not conn:
41                raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
42            data["refresh_token"] = conn.refresh_token
43        data.update(self.provider.auth_oauth_params)
44        try:
45            response = client.do_request(
46                "POST",
47                access_token_url,
48                auth=client.get_access_token_auth(),
49                data=data,
50                headers=client._default_headers,
51            )
52            response.raise_for_status()
53            body = response.json()
54            if "error" in body:
55                self.logger.info("Failed to get new OAuth token", error=body["error"])
56                raise SCIMOAuthException(response, body["error"])
57            return body
58        except RequestException as exc:
59            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
60
61    def get_connection(self):
62        if not self.provider.auth_oauth:
63            return None
64        conn = UserOAuthSourceConnection.objects.filter(
65            source=self.provider.auth_oauth, user=self.user
66        ).first()
67        if conn and conn.access_token and conn.expires > now():
68            return conn
69        token = self.retrieve_token(conn)
70        access_token = token["access_token"]
71        expires_in = int(token.get("expires_in", 0))
72        token, _ = UserOAuthSourceConnection.objects.update_or_create(
73            source=self.provider.auth_oauth,
74            user=self.user,
75            defaults={
76                "access_token": access_token,
77                "refresh_token": token.get("refresh_token"),
78                "expires": now() + timedelta(seconds=expires_in),
79                # When using `update_or_create`, `last_updated` is not updated
80                "last_updated": now(),
81            },
82        )
83        return token
84
85    def __call__(self, request: Request) -> Request:
86        if not self.connection.is_valid:
87            self.logger.info("OAuth token expired, renewing token")
88            self.connection = self.get_connection()
89        request.headers["Authorization"] = f"Bearer {self.connection.access_token}"
90        return request
19class SCIMOAuthException(SCIMRequestException):
20    """Exceptions related to OAuth operations for SCIM requests"""

Exceptions related to OAuth operations for SCIM requests

class SCIMOAuthAuth:
23class SCIMOAuthAuth:
24    def __init__(self, provider: SCIMProvider):
25        self.provider = provider
26        self.user = provider.auth_oauth_user
27        self.logger = get_logger().bind()
28        self.connection = self.get_connection()
29
30    def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
31        source: OAuthSource = self.provider.auth_oauth
32        client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
33        access_token_url = source.source_type.access_token_url or ""
34        if source.source_type.urls_customizable and source.access_token_url:
35            access_token_url = source.access_token_url
36        data = client.get_access_token_args(None, None)
37        if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
38            data["grant_type"] = GRANT_TYPE_PASSWORD
39        elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
40            data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
41            if not conn:
42                raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
43            data["refresh_token"] = conn.refresh_token
44        data.update(self.provider.auth_oauth_params)
45        try:
46            response = client.do_request(
47                "POST",
48                access_token_url,
49                auth=client.get_access_token_auth(),
50                data=data,
51                headers=client._default_headers,
52            )
53            response.raise_for_status()
54            body = response.json()
55            if "error" in body:
56                self.logger.info("Failed to get new OAuth token", error=body["error"])
57                raise SCIMOAuthException(response, body["error"])
58            return body
59        except RequestException as exc:
60            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
61
62    def get_connection(self):
63        if not self.provider.auth_oauth:
64            return None
65        conn = UserOAuthSourceConnection.objects.filter(
66            source=self.provider.auth_oauth, user=self.user
67        ).first()
68        if conn and conn.access_token and conn.expires > now():
69            return conn
70        token = self.retrieve_token(conn)
71        access_token = token["access_token"]
72        expires_in = int(token.get("expires_in", 0))
73        token, _ = UserOAuthSourceConnection.objects.update_or_create(
74            source=self.provider.auth_oauth,
75            user=self.user,
76            defaults={
77                "access_token": access_token,
78                "refresh_token": token.get("refresh_token"),
79                "expires": now() + timedelta(seconds=expires_in),
80                # When using `update_or_create`, `last_updated` is not updated
81                "last_updated": now(),
82            },
83        )
84        return token
85
86    def __call__(self, request: Request) -> Request:
87        if not self.connection.is_valid:
88            self.logger.info("OAuth token expired, renewing token")
89            self.connection = self.get_connection()
90        request.headers["Authorization"] = f"Bearer {self.connection.access_token}"
91        return request
provider
user
logger
connection
def retrieve_token( self, conn: authentik.sources.oauth.models.UserOAuthSourceConnection | None) -> dict[str, typing.Any]:
30    def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
31        source: OAuthSource = self.provider.auth_oauth
32        client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
33        access_token_url = source.source_type.access_token_url or ""
34        if source.source_type.urls_customizable and source.access_token_url:
35            access_token_url = source.access_token_url
36        data = client.get_access_token_args(None, None)
37        if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
38            data["grant_type"] = GRANT_TYPE_PASSWORD
39        elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
40            data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
41            if not conn:
42                raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
43            data["refresh_token"] = conn.refresh_token
44        data.update(self.provider.auth_oauth_params)
45        try:
46            response = client.do_request(
47                "POST",
48                access_token_url,
49                auth=client.get_access_token_auth(),
50                data=data,
51                headers=client._default_headers,
52            )
53            response.raise_for_status()
54            body = response.json()
55            if "error" in body:
56                self.logger.info("Failed to get new OAuth token", error=body["error"])
57                raise SCIMOAuthException(response, body["error"])
58            return body
59        except RequestException as exc:
60            raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
def get_connection(self):
62    def get_connection(self):
63        if not self.provider.auth_oauth:
64            return None
65        conn = UserOAuthSourceConnection.objects.filter(
66            source=self.provider.auth_oauth, user=self.user
67        ).first()
68        if conn and conn.access_token and conn.expires > now():
69            return conn
70        token = self.retrieve_token(conn)
71        access_token = token["access_token"]
72        expires_in = int(token.get("expires_in", 0))
73        token, _ = UserOAuthSourceConnection.objects.update_or_create(
74            source=self.provider.auth_oauth,
75            user=self.user,
76            defaults={
77                "access_token": access_token,
78                "refresh_token": token.get("refresh_token"),
79                "expires": now() + timedelta(seconds=expires_in),
80                # When using `update_or_create`, `last_updated` is not updated
81                "last_updated": now(),
82            },
83        )
84        return token