authentik.providers.oauth2.views.jwks
authentik OAuth2 JWKS Views
1"""authentik OAuth2 JWKS Views""" 2 3from base64 import b64encode, urlsafe_b64encode 4from collections.abc import Generator 5from typing import Literal 6 7from cryptography.hazmat.primitives import hashes 8from cryptography.hazmat.primitives.asymmetric.ec import ( 9 SECP256R1, 10 SECP384R1, 11 SECP521R1, 12 EllipticCurvePrivateKey, 13 EllipticCurvePublicKey, 14) 15from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey 16from cryptography.hazmat.primitives.serialization import Encoding 17from django.http import Http404, HttpRequest, HttpResponse, JsonResponse 18from django.views import View 19from jwt.utils import base64url_encode 20 21from authentik.core.models import Application 22from authentik.crypto.models import CertificateKeyPair 23from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider 24 25# See https://notes.salrahman.com/generate-es256-es384-es512-private-keys/ 26# and _CURVE_TYPES in the same file as the below curve files 27ec_crv_map = { 28 SECP256R1: "P-256", 29 SECP384R1: "P-384", 30 SECP521R1: "P-521", 31} 32min_length_map = { 33 SECP256R1: 32, 34 SECP384R1: 48, 35 SECP521R1: 66, 36} 37 38 39# https://github.com/jpadilla/pyjwt/issues/709 40def bytes_from_int(val: int, min_length: int = 0) -> bytes: 41 """Custom bytes_from_int that accepts a minimum length""" 42 remaining = val 43 byte_length = 0 44 45 while remaining != 0: 46 remaining >>= 8 47 byte_length += 1 48 length = max([byte_length, min_length]) 49 return val.to_bytes(length, "big", signed=False) 50 51 52def to_base64url_uint(val: int, min_length: int = 0) -> bytes: 53 """Custom to_base64url_uint that accepts a minimum length""" 54 if val < 0: 55 raise ValueError("Must be a positive integer") 56 57 int_bytes = bytes_from_int(val, min_length) 58 59 if len(int_bytes) == 0: 60 int_bytes = b"\x00" 61 62 return base64url_encode(int_bytes) 63 64 65class JWKSView(View): 66 """Show RSA Key data for Provider""" 67 68 @staticmethod 69 def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None: 70 """Convert a certificate-key pair into JWK""" 71 private_key = key.private_key 72 key_data = None 73 if not private_key: 74 return key_data 75 76 key_data = {} 77 78 if use == "sig": 79 key_data["alg"] = JWTAlgorithms.from_private_key(private_key) 80 elif use == "enc": 81 key_data["alg"] = "RSA-OAEP-256" 82 key_data["enc"] = "A256CBC-HS512" 83 84 if isinstance(private_key, RSAPrivateKey): 85 public_key: RSAPublicKey = private_key.public_key() 86 public_numbers = public_key.public_numbers() 87 key_data["kid"] = key.kid 88 key_data["kty"] = "RSA" 89 key_data["use"] = use 90 key_data["n"] = to_base64url_uint(public_numbers.n).decode() 91 key_data["e"] = to_base64url_uint(public_numbers.e).decode() 92 elif isinstance(private_key, EllipticCurvePrivateKey): 93 public_key: EllipticCurvePublicKey = private_key.public_key() 94 public_numbers = public_key.public_numbers() 95 curve_type = type(public_key.curve) 96 key_data["kid"] = key.kid 97 key_data["kty"] = "EC" 98 key_data["use"] = use 99 key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode() 100 key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode() 101 key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name) 102 else: 103 return key_data 104 key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")] 105 key_data["x5t"] = ( 106 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1())) # nosec 107 .decode("utf-8") 108 .rstrip("=") 109 ) 110 key_data["x5t#S256"] = ( 111 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256())) 112 .decode("utf-8") 113 .rstrip("=") 114 ) 115 return key_data 116 117 def get_keys(self) -> Generator[dict | None]: 118 provider_ids = Application.objects.filter( 119 slug=self.kwargs["application_slug"], 120 ).values_list( 121 "provider_id", 122 flat=True, 123 ) 124 provider = ( 125 OAuth2Provider.objects.select_related("signing_key", "encryption_key") 126 .filter(pk__in=provider_ids) 127 .first() 128 ) 129 130 if provider is None: 131 raise Http404() 132 133 if signing_key := provider.signing_key: 134 yield JWKSView.get_jwk_for_key(signing_key, "sig") 135 if encryption_key := provider.encryption_key: 136 yield JWKSView.get_jwk_for_key(encryption_key, "enc") 137 138 def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: 139 """Show JWK Key data for Provider""" 140 response_data = {} 141 for jwk in self.get_keys(): 142 if jwk: 143 response_data.setdefault("keys", []) 144 response_data["keys"].append(jwk) 145 146 response = JsonResponse(response_data) 147 response["Access-Control-Allow-Origin"] = "*" 148 149 return response
ec_crv_map =
{<class 'cryptography.hazmat.primitives.asymmetric.ec.SECP256R1'>: 'P-256', <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP384R1'>: 'P-384', <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP521R1'>: 'P-521'}
min_length_map =
{<class 'cryptography.hazmat.primitives.asymmetric.ec.SECP256R1'>: 32, <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP384R1'>: 48, <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP521R1'>: 66}
def
bytes_from_int(val: int, min_length: int = 0) -> bytes:
41def bytes_from_int(val: int, min_length: int = 0) -> bytes: 42 """Custom bytes_from_int that accepts a minimum length""" 43 remaining = val 44 byte_length = 0 45 46 while remaining != 0: 47 remaining >>= 8 48 byte_length += 1 49 length = max([byte_length, min_length]) 50 return val.to_bytes(length, "big", signed=False)
Custom bytes_from_int that accepts a minimum length
def
to_base64url_uint(val: int, min_length: int = 0) -> bytes:
53def to_base64url_uint(val: int, min_length: int = 0) -> bytes: 54 """Custom to_base64url_uint that accepts a minimum length""" 55 if val < 0: 56 raise ValueError("Must be a positive integer") 57 58 int_bytes = bytes_from_int(val, min_length) 59 60 if len(int_bytes) == 0: 61 int_bytes = b"\x00" 62 63 return base64url_encode(int_bytes)
Custom to_base64url_uint that accepts a minimum length
class
JWKSView(django.views.generic.base.View):
66class JWKSView(View): 67 """Show RSA Key data for Provider""" 68 69 @staticmethod 70 def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None: 71 """Convert a certificate-key pair into JWK""" 72 private_key = key.private_key 73 key_data = None 74 if not private_key: 75 return key_data 76 77 key_data = {} 78 79 if use == "sig": 80 key_data["alg"] = JWTAlgorithms.from_private_key(private_key) 81 elif use == "enc": 82 key_data["alg"] = "RSA-OAEP-256" 83 key_data["enc"] = "A256CBC-HS512" 84 85 if isinstance(private_key, RSAPrivateKey): 86 public_key: RSAPublicKey = private_key.public_key() 87 public_numbers = public_key.public_numbers() 88 key_data["kid"] = key.kid 89 key_data["kty"] = "RSA" 90 key_data["use"] = use 91 key_data["n"] = to_base64url_uint(public_numbers.n).decode() 92 key_data["e"] = to_base64url_uint(public_numbers.e).decode() 93 elif isinstance(private_key, EllipticCurvePrivateKey): 94 public_key: EllipticCurvePublicKey = private_key.public_key() 95 public_numbers = public_key.public_numbers() 96 curve_type = type(public_key.curve) 97 key_data["kid"] = key.kid 98 key_data["kty"] = "EC" 99 key_data["use"] = use 100 key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode() 101 key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode() 102 key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name) 103 else: 104 return key_data 105 key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")] 106 key_data["x5t"] = ( 107 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1())) # nosec 108 .decode("utf-8") 109 .rstrip("=") 110 ) 111 key_data["x5t#S256"] = ( 112 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256())) 113 .decode("utf-8") 114 .rstrip("=") 115 ) 116 return key_data 117 118 def get_keys(self) -> Generator[dict | None]: 119 provider_ids = Application.objects.filter( 120 slug=self.kwargs["application_slug"], 121 ).values_list( 122 "provider_id", 123 flat=True, 124 ) 125 provider = ( 126 OAuth2Provider.objects.select_related("signing_key", "encryption_key") 127 .filter(pk__in=provider_ids) 128 .first() 129 ) 130 131 if provider is None: 132 raise Http404() 133 134 if signing_key := provider.signing_key: 135 yield JWKSView.get_jwk_for_key(signing_key, "sig") 136 if encryption_key := provider.encryption_key: 137 yield JWKSView.get_jwk_for_key(encryption_key, "enc") 138 139 def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: 140 """Show JWK Key data for Provider""" 141 response_data = {} 142 for jwk in self.get_keys(): 143 if jwk: 144 response_data.setdefault("keys", []) 145 response_data["keys"].append(jwk) 146 147 response = JsonResponse(response_data) 148 response["Access-Control-Allow-Origin"] = "*" 149 150 return response
Show RSA Key data for Provider
@staticmethod
def
get_jwk_for_key( key: authentik.crypto.models.CertificateKeyPair, use: Literal['sig', 'enc']) -> dict | None:
69 @staticmethod 70 def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None: 71 """Convert a certificate-key pair into JWK""" 72 private_key = key.private_key 73 key_data = None 74 if not private_key: 75 return key_data 76 77 key_data = {} 78 79 if use == "sig": 80 key_data["alg"] = JWTAlgorithms.from_private_key(private_key) 81 elif use == "enc": 82 key_data["alg"] = "RSA-OAEP-256" 83 key_data["enc"] = "A256CBC-HS512" 84 85 if isinstance(private_key, RSAPrivateKey): 86 public_key: RSAPublicKey = private_key.public_key() 87 public_numbers = public_key.public_numbers() 88 key_data["kid"] = key.kid 89 key_data["kty"] = "RSA" 90 key_data["use"] = use 91 key_data["n"] = to_base64url_uint(public_numbers.n).decode() 92 key_data["e"] = to_base64url_uint(public_numbers.e).decode() 93 elif isinstance(private_key, EllipticCurvePrivateKey): 94 public_key: EllipticCurvePublicKey = private_key.public_key() 95 public_numbers = public_key.public_numbers() 96 curve_type = type(public_key.curve) 97 key_data["kid"] = key.kid 98 key_data["kty"] = "EC" 99 key_data["use"] = use 100 key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode() 101 key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode() 102 key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name) 103 else: 104 return key_data 105 key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")] 106 key_data["x5t"] = ( 107 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1())) # nosec 108 .decode("utf-8") 109 .rstrip("=") 110 ) 111 key_data["x5t#S256"] = ( 112 urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256())) 113 .decode("utf-8") 114 .rstrip("=") 115 ) 116 return key_data
Convert a certificate-key pair into JWK
def
get_keys(self) -> Generator[dict | None]:
118 def get_keys(self) -> Generator[dict | None]: 119 provider_ids = Application.objects.filter( 120 slug=self.kwargs["application_slug"], 121 ).values_list( 122 "provider_id", 123 flat=True, 124 ) 125 provider = ( 126 OAuth2Provider.objects.select_related("signing_key", "encryption_key") 127 .filter(pk__in=provider_ids) 128 .first() 129 ) 130 131 if provider is None: 132 raise Http404() 133 134 if signing_key := provider.signing_key: 135 yield JWKSView.get_jwk_for_key(signing_key, "sig") 136 if encryption_key := provider.encryption_key: 137 yield JWKSView.get_jwk_for_key(encryption_key, "enc")
def
get( self, request: django.http.request.HttpRequest, *args, **kwargs) -> django.http.response.HttpResponse:
139 def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: 140 """Show JWK Key data for Provider""" 141 response_data = {} 142 for jwk in self.get_keys(): 143 if jwk: 144 response_data.setdefault("keys", []) 145 response_data["keys"].append(jwk) 146 147 response = JsonResponse(response_data) 148 response["Access-Control-Allow-Origin"] = "*" 149 150 return response
Show JWK Key data for Provider