authentik.providers.oauth2.tests.utils
OAuth test helpers
1"""OAuth test helpers""" 2 3from typing import Any 4 5from django.test import TestCase 6from jwcrypto.jwe import JWE 7from jwcrypto.jwk import JWK 8from jwt import decode 9 10from authentik.core.tests.utils import create_test_cert 11from authentik.crypto.models import CertificateKeyPair 12from authentik.providers.oauth2.models import AccessToken, JWTAlgorithms, OAuth2Provider 13 14 15class OAuthTestCase(TestCase): 16 """OAuth test helpers""" 17 18 keypair: CertificateKeyPair 19 required_jwt_keys = [ 20 "exp", 21 "iat", 22 "acr", 23 "sub", 24 "iss", 25 ] 26 27 @classmethod 28 def setUpClass(cls) -> None: 29 cls.keypair = create_test_cert() 30 super().setUpClass() 31 32 def assert_non_none_or_unset(self, container: dict, key: str): 33 """Check that a key, if set, is not none""" 34 if key in container: 35 self.assertIsNotNone(container[key]) 36 37 def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 38 """Validate JWEs""" 39 private_key = JWK.from_pem(provider.encryption_key.key_data.encode()) 40 41 jwetoken = JWE() 42 jwetoken.deserialize(token.token, key=private_key) 43 token.token = jwetoken.payload.decode() 44 return self.validate_jwt(token, provider) 45 46 def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 47 """Validate that all required fields are set""" 48 key, alg = provider.jwt_key 49 if alg != JWTAlgorithms.HS256: 50 key = provider.signing_key.public_key 51 jwt = decode( 52 token.token, 53 key, 54 algorithms=[alg], 55 audience=provider.client_id, 56 ) 57 id_token = token.id_token.to_dict() 58 self.assert_non_none_or_unset(id_token, "at_hash") 59 self.assert_non_none_or_unset(id_token, "nonce") 60 self.assert_non_none_or_unset(id_token, "c_hash") 61 self.assert_non_none_or_unset(id_token, "amr") 62 self.assert_non_none_or_unset(id_token, "auth_time") 63 for key in self.required_jwt_keys: 64 self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") 65 self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") 66 return jwt
class
OAuthTestCase(django.test.testcases.TestCase):
16class OAuthTestCase(TestCase): 17 """OAuth test helpers""" 18 19 keypair: CertificateKeyPair 20 required_jwt_keys = [ 21 "exp", 22 "iat", 23 "acr", 24 "sub", 25 "iss", 26 ] 27 28 @classmethod 29 def setUpClass(cls) -> None: 30 cls.keypair = create_test_cert() 31 super().setUpClass() 32 33 def assert_non_none_or_unset(self, container: dict, key: str): 34 """Check that a key, if set, is not none""" 35 if key in container: 36 self.assertIsNotNone(container[key]) 37 38 def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 39 """Validate JWEs""" 40 private_key = JWK.from_pem(provider.encryption_key.key_data.encode()) 41 42 jwetoken = JWE() 43 jwetoken.deserialize(token.token, key=private_key) 44 token.token = jwetoken.payload.decode() 45 return self.validate_jwt(token, provider) 46 47 def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 48 """Validate that all required fields are set""" 49 key, alg = provider.jwt_key 50 if alg != JWTAlgorithms.HS256: 51 key = provider.signing_key.public_key 52 jwt = decode( 53 token.token, 54 key, 55 algorithms=[alg], 56 audience=provider.client_id, 57 ) 58 id_token = token.id_token.to_dict() 59 self.assert_non_none_or_unset(id_token, "at_hash") 60 self.assert_non_none_or_unset(id_token, "nonce") 61 self.assert_non_none_or_unset(id_token, "c_hash") 62 self.assert_non_none_or_unset(id_token, "amr") 63 self.assert_non_none_or_unset(id_token, "auth_time") 64 for key in self.required_jwt_keys: 65 self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") 66 self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") 67 return jwt
OAuth test helpers
@classmethod
def
setUpClass(cls) -> None:
28 @classmethod 29 def setUpClass(cls) -> None: 30 cls.keypair = create_test_cert() 31 super().setUpClass()
Hook method for setting up class fixture before running tests in the class.
def
assert_non_none_or_unset(self, container: dict, key: str):
33 def assert_non_none_or_unset(self, container: dict, key: str): 34 """Check that a key, if set, is not none""" 35 if key in container: 36 self.assertIsNotNone(container[key])
Check that a key, if set, is not none
def
validate_jwe( self, token: authentik.providers.oauth2.models.AccessToken, provider: authentik.providers.oauth2.models.OAuth2Provider) -> dict[str, typing.Any]:
38 def validate_jwe(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 39 """Validate JWEs""" 40 private_key = JWK.from_pem(provider.encryption_key.key_data.encode()) 41 42 jwetoken = JWE() 43 jwetoken.deserialize(token.token, key=private_key) 44 token.token = jwetoken.payload.decode() 45 return self.validate_jwt(token, provider)
Validate JWEs
def
validate_jwt( self, token: authentik.providers.oauth2.models.AccessToken, provider: authentik.providers.oauth2.models.OAuth2Provider) -> dict[str, typing.Any]:
47 def validate_jwt(self, token: AccessToken, provider: OAuth2Provider) -> dict[str, Any]: 48 """Validate that all required fields are set""" 49 key, alg = provider.jwt_key 50 if alg != JWTAlgorithms.HS256: 51 key = provider.signing_key.public_key 52 jwt = decode( 53 token.token, 54 key, 55 algorithms=[alg], 56 audience=provider.client_id, 57 ) 58 id_token = token.id_token.to_dict() 59 self.assert_non_none_or_unset(id_token, "at_hash") 60 self.assert_non_none_or_unset(id_token, "nonce") 61 self.assert_non_none_or_unset(id_token, "c_hash") 62 self.assert_non_none_or_unset(id_token, "amr") 63 self.assert_non_none_or_unset(id_token, "auth_time") 64 for key in self.required_jwt_keys: 65 self.assertIsNotNone(jwt[key], f"Key {key} is missing in access_token") 66 self.assertIsNotNone(id_token[key], f"Key {key} is missing in id_token") 67 return jwt
Validate that all required fields are set