authentik.providers.oauth2.views.jwks

authentik OAuth2 JWKS Views

  1"""authentik OAuth2 JWKS Views"""
  2
  3from base64 import b64encode, urlsafe_b64encode
  4from collections.abc import Generator
  5from typing import Literal
  6
  7from cryptography.hazmat.primitives import hashes
  8from cryptography.hazmat.primitives.asymmetric.ec import (
  9    SECP256R1,
 10    SECP384R1,
 11    SECP521R1,
 12    EllipticCurvePrivateKey,
 13    EllipticCurvePublicKey,
 14)
 15from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
 16from cryptography.hazmat.primitives.serialization import Encoding
 17from django.http import Http404, HttpRequest, HttpResponse, JsonResponse
 18from django.views import View
 19from jwt.utils import base64url_encode
 20
 21from authentik.core.models import Application
 22from authentik.crypto.models import CertificateKeyPair
 23from authentik.providers.oauth2.models import JWTAlgorithms, OAuth2Provider
 24
 25# See https://notes.salrahman.com/generate-es256-es384-es512-private-keys/
 26# and _CURVE_TYPES in the same file as the below curve files
 27ec_crv_map = {
 28    SECP256R1: "P-256",
 29    SECP384R1: "P-384",
 30    SECP521R1: "P-521",
 31}
 32min_length_map = {
 33    SECP256R1: 32,
 34    SECP384R1: 48,
 35    SECP521R1: 66,
 36}
 37
 38
 39# https://github.com/jpadilla/pyjwt/issues/709
 40def bytes_from_int(val: int, min_length: int = 0) -> bytes:
 41    """Custom bytes_from_int that accepts a minimum length"""
 42    remaining = val
 43    byte_length = 0
 44
 45    while remaining != 0:
 46        remaining >>= 8
 47        byte_length += 1
 48    length = max([byte_length, min_length])
 49    return val.to_bytes(length, "big", signed=False)
 50
 51
 52def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
 53    """Custom to_base64url_uint that accepts a minimum length"""
 54    if val < 0:
 55        raise ValueError("Must be a positive integer")
 56
 57    int_bytes = bytes_from_int(val, min_length)
 58
 59    if len(int_bytes) == 0:
 60        int_bytes = b"\x00"
 61
 62    return base64url_encode(int_bytes)
 63
 64
 65class JWKSView(View):
 66    """Show RSA Key data for Provider"""
 67
 68    @staticmethod
 69    def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None:
 70        """Convert a certificate-key pair into JWK"""
 71        private_key = key.private_key
 72        key_data = None
 73        if not private_key:
 74            return key_data
 75
 76        key_data = {}
 77
 78        if use == "sig":
 79            key_data["alg"] = JWTAlgorithms.from_private_key(private_key)
 80        elif use == "enc":
 81            key_data["alg"] = "RSA-OAEP-256"
 82            key_data["enc"] = "A256CBC-HS512"
 83
 84        if isinstance(private_key, RSAPrivateKey):
 85            public_key: RSAPublicKey = private_key.public_key()
 86            public_numbers = public_key.public_numbers()
 87            key_data["kid"] = key.kid
 88            key_data["kty"] = "RSA"
 89            key_data["use"] = use
 90            key_data["n"] = to_base64url_uint(public_numbers.n).decode()
 91            key_data["e"] = to_base64url_uint(public_numbers.e).decode()
 92        elif isinstance(private_key, EllipticCurvePrivateKey):
 93            public_key: EllipticCurvePublicKey = private_key.public_key()
 94            public_numbers = public_key.public_numbers()
 95            curve_type = type(public_key.curve)
 96            key_data["kid"] = key.kid
 97            key_data["kty"] = "EC"
 98            key_data["use"] = use
 99            key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode()
100            key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode()
101            key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name)
102        else:
103            return key_data
104        key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
105        key_data["x5t"] = (
106            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1()))  # nosec
107            .decode("utf-8")
108            .rstrip("=")
109        )
110        key_data["x5t#S256"] = (
111            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256()))
112            .decode("utf-8")
113            .rstrip("=")
114        )
115        return key_data
116
117    def get_keys(self) -> Generator[dict | None]:
118        provider_ids = Application.objects.filter(
119            slug=self.kwargs["application_slug"],
120        ).values_list(
121            "provider_id",
122            flat=True,
123        )
124        provider = (
125            OAuth2Provider.objects.select_related("signing_key", "encryption_key")
126            .filter(pk__in=provider_ids)
127            .first()
128        )
129
130        if provider is None:
131            raise Http404()
132
133        if signing_key := provider.signing_key:
134            yield JWKSView.get_jwk_for_key(signing_key, "sig")
135        if encryption_key := provider.encryption_key:
136            yield JWKSView.get_jwk_for_key(encryption_key, "enc")
137
138    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
139        """Show JWK Key data for Provider"""
140        response_data = {}
141        for jwk in self.get_keys():
142            if jwk:
143                response_data.setdefault("keys", [])
144                response_data["keys"].append(jwk)
145
146        response = JsonResponse(response_data)
147        response["Access-Control-Allow-Origin"] = "*"
148
149        return response
ec_crv_map = {<class 'cryptography.hazmat.primitives.asymmetric.ec.SECP256R1'>: 'P-256', <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP384R1'>: 'P-384', <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP521R1'>: 'P-521'}
min_length_map = {<class 'cryptography.hazmat.primitives.asymmetric.ec.SECP256R1'>: 32, <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP384R1'>: 48, <class 'cryptography.hazmat.primitives.asymmetric.ec.SECP521R1'>: 66}
def bytes_from_int(val: int, min_length: int = 0) -> bytes:
41def bytes_from_int(val: int, min_length: int = 0) -> bytes:
42    """Custom bytes_from_int that accepts a minimum length"""
43    remaining = val
44    byte_length = 0
45
46    while remaining != 0:
47        remaining >>= 8
48        byte_length += 1
49    length = max([byte_length, min_length])
50    return val.to_bytes(length, "big", signed=False)

Custom bytes_from_int that accepts a minimum length

def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
53def to_base64url_uint(val: int, min_length: int = 0) -> bytes:
54    """Custom to_base64url_uint that accepts a minimum length"""
55    if val < 0:
56        raise ValueError("Must be a positive integer")
57
58    int_bytes = bytes_from_int(val, min_length)
59
60    if len(int_bytes) == 0:
61        int_bytes = b"\x00"
62
63    return base64url_encode(int_bytes)

Custom to_base64url_uint that accepts a minimum length

class JWKSView(django.views.generic.base.View):
 66class JWKSView(View):
 67    """Show RSA Key data for Provider"""
 68
 69    @staticmethod
 70    def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None:
 71        """Convert a certificate-key pair into JWK"""
 72        private_key = key.private_key
 73        key_data = None
 74        if not private_key:
 75            return key_data
 76
 77        key_data = {}
 78
 79        if use == "sig":
 80            key_data["alg"] = JWTAlgorithms.from_private_key(private_key)
 81        elif use == "enc":
 82            key_data["alg"] = "RSA-OAEP-256"
 83            key_data["enc"] = "A256CBC-HS512"
 84
 85        if isinstance(private_key, RSAPrivateKey):
 86            public_key: RSAPublicKey = private_key.public_key()
 87            public_numbers = public_key.public_numbers()
 88            key_data["kid"] = key.kid
 89            key_data["kty"] = "RSA"
 90            key_data["use"] = use
 91            key_data["n"] = to_base64url_uint(public_numbers.n).decode()
 92            key_data["e"] = to_base64url_uint(public_numbers.e).decode()
 93        elif isinstance(private_key, EllipticCurvePrivateKey):
 94            public_key: EllipticCurvePublicKey = private_key.public_key()
 95            public_numbers = public_key.public_numbers()
 96            curve_type = type(public_key.curve)
 97            key_data["kid"] = key.kid
 98            key_data["kty"] = "EC"
 99            key_data["use"] = use
100            key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode()
101            key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode()
102            key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name)
103        else:
104            return key_data
105        key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
106        key_data["x5t"] = (
107            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1()))  # nosec
108            .decode("utf-8")
109            .rstrip("=")
110        )
111        key_data["x5t#S256"] = (
112            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256()))
113            .decode("utf-8")
114            .rstrip("=")
115        )
116        return key_data
117
118    def get_keys(self) -> Generator[dict | None]:
119        provider_ids = Application.objects.filter(
120            slug=self.kwargs["application_slug"],
121        ).values_list(
122            "provider_id",
123            flat=True,
124        )
125        provider = (
126            OAuth2Provider.objects.select_related("signing_key", "encryption_key")
127            .filter(pk__in=provider_ids)
128            .first()
129        )
130
131        if provider is None:
132            raise Http404()
133
134        if signing_key := provider.signing_key:
135            yield JWKSView.get_jwk_for_key(signing_key, "sig")
136        if encryption_key := provider.encryption_key:
137            yield JWKSView.get_jwk_for_key(encryption_key, "enc")
138
139    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
140        """Show JWK Key data for Provider"""
141        response_data = {}
142        for jwk in self.get_keys():
143            if jwk:
144                response_data.setdefault("keys", [])
145                response_data["keys"].append(jwk)
146
147        response = JsonResponse(response_data)
148        response["Access-Control-Allow-Origin"] = "*"
149
150        return response

Show RSA Key data for Provider

@staticmethod
def get_jwk_for_key( key: authentik.crypto.models.CertificateKeyPair, use: Literal['sig', 'enc']) -> dict | None:
 69    @staticmethod
 70    def get_jwk_for_key(key: CertificateKeyPair, use: Literal["sig", "enc"]) -> dict | None:
 71        """Convert a certificate-key pair into JWK"""
 72        private_key = key.private_key
 73        key_data = None
 74        if not private_key:
 75            return key_data
 76
 77        key_data = {}
 78
 79        if use == "sig":
 80            key_data["alg"] = JWTAlgorithms.from_private_key(private_key)
 81        elif use == "enc":
 82            key_data["alg"] = "RSA-OAEP-256"
 83            key_data["enc"] = "A256CBC-HS512"
 84
 85        if isinstance(private_key, RSAPrivateKey):
 86            public_key: RSAPublicKey = private_key.public_key()
 87            public_numbers = public_key.public_numbers()
 88            key_data["kid"] = key.kid
 89            key_data["kty"] = "RSA"
 90            key_data["use"] = use
 91            key_data["n"] = to_base64url_uint(public_numbers.n).decode()
 92            key_data["e"] = to_base64url_uint(public_numbers.e).decode()
 93        elif isinstance(private_key, EllipticCurvePrivateKey):
 94            public_key: EllipticCurvePublicKey = private_key.public_key()
 95            public_numbers = public_key.public_numbers()
 96            curve_type = type(public_key.curve)
 97            key_data["kid"] = key.kid
 98            key_data["kty"] = "EC"
 99            key_data["use"] = use
100            key_data["x"] = to_base64url_uint(public_numbers.x, min_length_map[curve_type]).decode()
101            key_data["y"] = to_base64url_uint(public_numbers.y, min_length_map[curve_type]).decode()
102            key_data["crv"] = ec_crv_map.get(curve_type, public_key.curve.name)
103        else:
104            return key_data
105        key_data["x5c"] = [b64encode(key.certificate.public_bytes(Encoding.DER)).decode("utf-8")]
106        key_data["x5t"] = (
107            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA1()))  # nosec
108            .decode("utf-8")
109            .rstrip("=")
110        )
111        key_data["x5t#S256"] = (
112            urlsafe_b64encode(key.certificate.fingerprint(hashes.SHA256()))
113            .decode("utf-8")
114            .rstrip("=")
115        )
116        return key_data

Convert a certificate-key pair into JWK

def get_keys(self) -> Generator[dict | None]:
118    def get_keys(self) -> Generator[dict | None]:
119        provider_ids = Application.objects.filter(
120            slug=self.kwargs["application_slug"],
121        ).values_list(
122            "provider_id",
123            flat=True,
124        )
125        provider = (
126            OAuth2Provider.objects.select_related("signing_key", "encryption_key")
127            .filter(pk__in=provider_ids)
128            .first()
129        )
130
131        if provider is None:
132            raise Http404()
133
134        if signing_key := provider.signing_key:
135            yield JWKSView.get_jwk_for_key(signing_key, "sig")
136        if encryption_key := provider.encryption_key:
137            yield JWKSView.get_jwk_for_key(encryption_key, "enc")
def get( self, request: django.http.request.HttpRequest, *args, **kwargs) -> django.http.response.HttpResponse:
139    def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
140        """Show JWK Key data for Provider"""
141        response_data = {}
142        for jwk in self.get_keys():
143            if jwk:
144                response_data.setdefault("keys", [])
145                response_data["keys"].append(jwk)
146
147        response = JsonResponse(response_data)
148        response["Access-Control-Allow-Origin"] = "*"
149
150        return response

Show JWK Key data for Provider