authentik.providers.oauth2.tests.utils

OAuth test helpers

 1"""OAuth test helpers"""
 2
 3from typing import Any
 4
 5from django.test import TestCase
 6from jwcrypto.jwe import JWE
 7from jwcrypto.jwk import JWK
 8from jwt import decode
 9
10from authentik.core.tests.utils import create_test_cert
11from authentik.crypto.models import CertificateKeyPair
12from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider
13
14
15class OAuthTestCase(TestCase):
16    """OAuth test helpers"""
17
18    keypair: CertificateKeyPair
19    required_jwt_keys = [
20        "exp",
21        "iat",
22        "acr",
23        "sub",
24        "iss",
25    ]
26
27    @classmethod
28    def setUpClass(cls) -> None:
29        cls.keypair = create_test_cert()
30        super().setUpClass()
31
32    def assert_non_none_or_unset(self, container: dict, key: str):
33        """Check that a key, if set, is not none"""
34        if key in container:
35            self.assertIsNotNone(container[key])
36
37    def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
38        """Validate JWEs"""
39        private_key = JWK.from_pem(provider.encryption_key.key_data.encode())
40
41        jwetoken = JWE()
42        jwetoken.deserialize(token.token, key=private_key)
43        token.token = jwetoken.payload.decode()
44        return self.validate_jwt(token, provider)
45
46    def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
47        """Validate that all required fields are set"""
48        key, alg = provider.jwt_key
49        if alg != JWTAlgorithms.HS256:
50            key = provider.signing_key.public_key
51        jwt = decode(
52            token.token,
53            key,
54            algorithms=[alg],
55            audience=provider.client_id,
56        )
57        id_token = token.id_token.to_dict()
58        self.assert_non_none_or_unset(id_token, "at_hash")
59        self.assert_non_none_or_unset(id_token, "nonce")
60        self.assert_non_none_or_unset(id_token, "c_hash")
61        self.assert_non_none_or_unset(id_token, "amr")
62        self.assert_non_none_or_unset(id_token, "auth_time")
63        for key in self.required_jwt_keys:
64            self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
65            self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
66        return jwt
class OAuthTestCase(django.test.testcases.TestCase):
16class OAuthTestCase(TestCase):
17    """OAuth test helpers"""
18
19    keypair: CertificateKeyPair
20    required_jwt_keys = [
21        "exp",
22        "iat",
23        "acr",
24        "sub",
25        "iss",
26    ]
27
28    @classmethod
29    def setUpClass(cls) -> None:
30        cls.keypair = create_test_cert()
31        super().setUpClass()
32
33    def assert_non_none_or_unset(self, container: dict, key: str):
34        """Check that a key, if set, is not none"""
35        if key in container:
36            self.assertIsNotNone(container[key])
37
38    def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
39        """Validate JWEs"""
40        private_key = JWK.from_pem(provider.encryption_key.key_data.encode())
41
42        jwetoken = JWE()
43        jwetoken.deserialize(token.token, key=private_key)
44        token.token = jwetoken.payload.decode()
45        return self.validate_jwt(token, provider)
46
47    def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
48        """Validate that all required fields are set"""
49        key, alg = provider.jwt_key
50        if alg != JWTAlgorithms.HS256:
51            key = provider.signing_key.public_key
52        jwt = decode(
53            token.token,
54            key,
55            algorithms=[alg],
56            audience=provider.client_id,
57        )
58        id_token = token.id_token.to_dict()
59        self.assert_non_none_or_unset(id_token, "at_hash")
60        self.assert_non_none_or_unset(id_token, "nonce")
61        self.assert_non_none_or_unset(id_token, "c_hash")
62        self.assert_non_none_or_unset(id_token, "amr")
63        self.assert_non_none_or_unset(id_token, "auth_time")
64        for key in self.required_jwt_keys:
65            self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
66            self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
67        return jwt

OAuth test helpers

required_jwt_keys = ['exp', 'iat', 'acr', 'sub', 'iss']
@classmethod
def setUpClass(cls) -> None:
28    @classmethod
29    def setUpClass(cls) -> None:
30        cls.keypair = create_test_cert()
31        super().setUpClass()

Hook method for setting up class fixture before running tests in the class.

def assert_non_none_or_unset(self, container: dict, key: str):
33    def assert_non_none_or_unset(self, container: dict, key: str):
34        """Check that a key, if set, is not none"""
35        if key in container:
36            self.assertIsNotNone(container[key])

Check that a key, if set, is not none

def validate_jwe( self, token: authentik.providers.oauth2.models.AccessToken, provider: authentik.providers.oauth2.models.OAuth2Provider) -> dict[str, typing.Any]:
38    def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
39        """Validate JWEs"""
40        private_key = JWK.from_pem(provider.encryption_key.key_data.encode())
41
42        jwetoken = JWE()
43        jwetoken.deserialize(token.token, key=private_key)
44        token.token = jwetoken.payload.decode()
45        return self.validate_jwt(token, provider)

Validate JWEs

def validate_jwt( self, token: authentik.providers.oauth2.models.AccessToken, provider: authentik.providers.oauth2.models.OAuth2Provider) -> dict[str, typing.Any]:
47    def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]:
48        """Validate that all required fields are set"""
49        key, alg = provider.jwt_key
50        if alg != JWTAlgorithms.HS256:
51            key = provider.signing_key.public_key
52        jwt = decode(
53            token.token,
54            key,
55            algorithms=[alg],
56            audience=provider.client_id,
57        )
58        id_token = token.id_token.to_dict()
59        self.assert_non_none_or_unset(id_token, "at_hash")
60        self.assert_non_none_or_unset(id_token, "nonce")
61        self.assert_non_none_or_unset(id_token, "c_hash")
62        self.assert_non_none_or_unset(id_token, "amr")
63        self.assert_non_none_or_unset(id_token, "auth_time")
64        for key in self.required_jwt_keys:
65            self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token")
66            self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token")
67        return jwt

Validate that all required fields are set