authentik.providers.saml.processors.metadata_parser

SAML ServiceProvider Metadata Parser and dataclass

  1"""SAML ServiceProvider Metadata Parser and dataclass"""
  2
  3from dataclasses import dataclass
  4
  5import xmlsec
  6from cryptography.hazmat.backends import default_backend
  7from cryptography.x509 import load_pem_x509_certificate
  8from defusedxml.lxml import fromstring
  9from lxml import etree  # nosec
 10from structlog.stdlib import get_logger
 11
 12from authentik.common.saml.constants import (
 13    NS_MAP,
 14    NS_SAML_METADATA,
 15    SAML_BINDING_POST,
 16    SAML_BINDING_REDIRECT,
 17)
 18from authentik.crypto.models import CertificateKeyPair, format_cert
 19from authentik.flows.models import Flow
 20from authentik.providers.saml.models import SAMLBindings, SAMLPropertyMapping, SAMLProvider
 21from authentik.sources.saml.models import SAMLNameIDPolicy
 22
 23LOGGER = get_logger()
 24
 25
 26@dataclass(slots=True)
 27class ServiceProviderMetadata:
 28    """SP Metadata Dataclass"""
 29
 30    entity_id: str
 31
 32    acs_binding: str
 33    acs_location: str
 34
 35    auth_n_request_signed: bool
 36    assertion_signed: bool
 37    name_id_policy: SAMLNameIDPolicy
 38
 39    signing_keypair: CertificateKeyPair | None = None
 40    encryption_keypair: CertificateKeyPair | None = None
 41
 42    # Single Logout Service (optional)
 43    sls_binding: str | None = None
 44    sls_location: str | None = None
 45
 46    def to_provider(
 47        self, name: str, authorization_flow: Flow, invalidation_flow: Flow
 48    ) -> SAMLProvider:
 49        """Create a SAMLProvider instance from the details. `name` is required,
 50        as depending on the metadata CertificateKeypairs might have to be created."""
 51        provider = SAMLProvider.objects.create(
 52            name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
 53        )
 54        provider.issuer = self.entity_id
 55        provider.sp_binding = self.acs_binding
 56        provider.acs_url = self.acs_location
 57        provider.default_name_id_policy = self.name_id_policy
 58        # Single Logout Service
 59        if self.sls_location:
 60            provider.sls_url = self.sls_location
 61        if self.sls_binding:
 62            provider.sls_binding = self.sls_binding
 63        if self.signing_keypair and self.auth_n_request_signed:
 64            self.signing_keypair.name = f"Provider {name} - SAML Signing Certificate"
 65            self.signing_keypair.save()
 66            provider.verification_kp = self.signing_keypair
 67        if self.encryption_keypair:
 68            self.encryption_keypair.name = f"Provider {name} - SAML Encryption Certificate"
 69            self.encryption_keypair.save()
 70            provider.encryption_kp = self.encryption_keypair
 71        if self.assertion_signed:
 72            provider.signing_kp = CertificateKeyPair.objects.exclude(key_data__iexact="").first()
 73        # Set all auto-generated Property-mappings as defaults
 74        # They should provide a sane default for most applications:
 75        provider.property_mappings.set(SAMLPropertyMapping.objects.exclude(managed__isnull=True))
 76        provider.save()
 77        return provider
 78
 79
 80class ServiceProviderMetadataParser:
 81    """Service-Provider Metadata Parser"""
 82
 83    def get_signing_cert(self, root: etree.Element) -> CertificateKeyPair | None:
 84        """Extract signing X509Certificate from metadata, when given."""
 85        signing_certs = root.xpath(
 86            '//md:SPSSODescriptor/md:KeyDescriptor[@use="signing"]//ds:X509Certificate/text()',
 87            namespaces=NS_MAP,
 88        )
 89        if len(signing_certs) < 1:
 90            return None
 91        raw_cert = format_cert(signing_certs[0])
 92        # sanity check, make sure the certificate is valid.
 93        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
 94        return CertificateKeyPair(
 95            certificate_data=raw_cert,
 96        )
 97
 98    def get_encryption_cert(self, root: etree.Element) -> CertificateKeyPair | None:
 99        """Extract encryption X509Certificate from metadata, when given."""
100        encryption_certs = root.xpath(
101            '//md:SPSSODescriptor/md:KeyDescriptor[@use="encryption"]//ds:X509Certificate/text()',
102            namespaces=NS_MAP,
103        )
104        if len(encryption_certs) < 1:
105            return None
106        raw_cert = format_cert(encryption_certs[0])
107        # sanity check, make sure the certificate is valid.
108        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
109        return CertificateKeyPair(
110            certificate_data=raw_cert,
111        )
112
113    def check_signature(self, root: etree.Element, keypair: CertificateKeyPair):
114        """If Metadata is signed, check validity of signature"""
115        xmlsec.tree.add_ids(root, ["ID"])
116        signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
117        if len(signature_nodes) != 1:
118            # No Signature
119            return
120
121        signature_node = signature_nodes[0]
122
123        if signature_node is not None:
124            try:
125                ctx = xmlsec.SignatureContext()
126                key = xmlsec.Key.from_memory(
127                    keypair.certificate_data,
128                    xmlsec.constants.KeyDataFormatCertPem,
129                    None,
130                )
131                ctx.key = key
132                ctx.verify(signature_node)
133            except xmlsec.Error as exc:
134                raise ValueError("Failed to verify Metadata signature") from exc
135
136    def parse(self, raw_xml: str) -> ServiceProviderMetadata:
137        """Parse raw XML to ServiceProviderMetadata"""
138        root = fromstring(raw_xml.encode())
139
140        entity_id = root.attrib["entityID"]
141        sp_sso_descriptors = root.findall(f"{{{NS_SAML_METADATA}}}SPSSODescriptor")
142        if len(sp_sso_descriptors) < 1:
143            raise ValueError("no SPSSODescriptor objects found.")
144        # For now we'll only look at the first descriptor.
145        # Even if multiple descriptors exist, we can only configure one
146        descriptor = sp_sso_descriptors[0]
147        auth_n_request_signed = False
148        if "AuthnRequestsSigned" in descriptor.attrib:
149            auth_n_request_signed = descriptor.attrib["AuthnRequestsSigned"].lower() == "true"
150
151        assertion_signed = False
152        if "WantAssertionsSigned" in descriptor.attrib:
153            assertion_signed = descriptor.attrib["WantAssertionsSigned"].lower() == "true"
154
155        acs_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}AssertionConsumerService")
156        if len(acs_services) < 1:
157            raise ValueError("No AssertionConsumerService found.")
158
159        acs_service = acs_services[0]
160        acs_binding = {
161            SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
162            SAML_BINDING_POST: SAMLBindings.POST,
163        }[acs_service.attrib["Binding"]]
164        acs_location = acs_service.attrib["Location"]
165
166        signing_keypair = self.get_signing_cert(root)
167        if signing_keypair:
168            self.check_signature(root, signing_keypair)
169        encryption_keypair = self.get_encryption_cert(root)
170
171        name_id_format = descriptor.findall(f"{{{NS_SAML_METADATA}}}NameIDFormat")
172        name_id_policy = SAMLNameIDPolicy.UNSPECIFIED
173        if len(name_id_format) > 0:
174            name_id_policy = SAMLNameIDPolicy(name_id_format[0].text)
175
176        # Parse SingleLogoutService (optional)
177        sls_binding = None
178        sls_location = None
179        sls_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}SingleLogoutService")
180        if len(sls_services) > 0:
181            sls_service = sls_services[0]
182            sls_binding = {
183                SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
184                SAML_BINDING_POST: SAMLBindings.POST,
185            }.get(sls_service.attrib.get("Binding"))
186            sls_location = sls_service.attrib.get("Location")
187
188        return ServiceProviderMetadata(
189            entity_id=entity_id,
190            acs_binding=acs_binding,
191            acs_location=acs_location,
192            auth_n_request_signed=auth_n_request_signed,
193            assertion_signed=assertion_signed,
194            signing_keypair=signing_keypair,
195            encryption_keypair=encryption_keypair,
196            name_id_policy=name_id_policy,
197            sls_binding=sls_binding,
198            sls_location=sls_location,
199        )
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
@dataclass(slots=True)
class ServiceProviderMetadata:
27@dataclass(slots=True)
28class ServiceProviderMetadata:
29    """SP Metadata Dataclass"""
30
31    entity_id: str
32
33    acs_binding: str
34    acs_location: str
35
36    auth_n_request_signed: bool
37    assertion_signed: bool
38    name_id_policy: SAMLNameIDPolicy
39
40    signing_keypair: CertificateKeyPair | None = None
41    encryption_keypair: CertificateKeyPair | None = None
42
43    # Single Logout Service (optional)
44    sls_binding: str | None = None
45    sls_location: str | None = None
46
47    def to_provider(
48        self, name: str, authorization_flow: Flow, invalidation_flow: Flow
49    ) -> SAMLProvider:
50        """Create a SAMLProvider instance from the details. `name` is required,
51        as depending on the metadata CertificateKeypairs might have to be created."""
52        provider = SAMLProvider.objects.create(
53            name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
54        )
55        provider.issuer = self.entity_id
56        provider.sp_binding = self.acs_binding
57        provider.acs_url = self.acs_location
58        provider.default_name_id_policy = self.name_id_policy
59        # Single Logout Service
60        if self.sls_location:
61            provider.sls_url = self.sls_location
62        if self.sls_binding:
63            provider.sls_binding = self.sls_binding
64        if self.signing_keypair and self.auth_n_request_signed:
65            self.signing_keypair.name = f"Provider {name} - SAML Signing Certificate"
66            self.signing_keypair.save()
67            provider.verification_kp = self.signing_keypair
68        if self.encryption_keypair:
69            self.encryption_keypair.name = f"Provider {name} - SAML Encryption Certificate"
70            self.encryption_keypair.save()
71            provider.encryption_kp = self.encryption_keypair
72        if self.assertion_signed:
73            provider.signing_kp = CertificateKeyPair.objects.exclude(key_data__iexact="").first()
74        # Set all auto-generated Property-mappings as defaults
75        # They should provide a sane default for most applications:
76        provider.property_mappings.set(SAMLPropertyMapping.objects.exclude(managed__isnull=True))
77        provider.save()
78        return provider

SP Metadata Dataclass

ServiceProviderMetadata( entity_id: str, acs_binding: str, acs_location: str, auth_n_request_signed: bool, assertion_signed: bool, name_id_policy: authentik.sources.saml.models.SAMLNameIDPolicy, signing_keypair: authentik.crypto.models.CertificateKeyPair | None = None, encryption_keypair: authentik.crypto.models.CertificateKeyPair | None = None, sls_binding: str | None = None, sls_location: str | None = None)
entity_id: str
acs_binding: str
acs_location: str
auth_n_request_signed: bool
assertion_signed: bool
encryption_keypair: authentik.crypto.models.CertificateKeyPair | None
sls_binding: str | None
sls_location: str | None
def to_provider( self, name: str, authorization_flow: authentik.flows.models.Flow, invalidation_flow: authentik.flows.models.Flow) -> authentik.providers.saml.models.SAMLProvider:
47    def to_provider(
48        self, name: str, authorization_flow: Flow, invalidation_flow: Flow
49    ) -> SAMLProvider:
50        """Create a SAMLProvider instance from the details. `name` is required,
51        as depending on the metadata CertificateKeypairs might have to be created."""
52        provider = SAMLProvider.objects.create(
53            name=name, authorization_flow=authorization_flow, invalidation_flow=invalidation_flow
54        )
55        provider.issuer = self.entity_id
56        provider.sp_binding = self.acs_binding
57        provider.acs_url = self.acs_location
58        provider.default_name_id_policy = self.name_id_policy
59        # Single Logout Service
60        if self.sls_location:
61            provider.sls_url = self.sls_location
62        if self.sls_binding:
63            provider.sls_binding = self.sls_binding
64        if self.signing_keypair and self.auth_n_request_signed:
65            self.signing_keypair.name = f"Provider {name} - SAML Signing Certificate"
66            self.signing_keypair.save()
67            provider.verification_kp = self.signing_keypair
68        if self.encryption_keypair:
69            self.encryption_keypair.name = f"Provider {name} - SAML Encryption Certificate"
70            self.encryption_keypair.save()
71            provider.encryption_kp = self.encryption_keypair
72        if self.assertion_signed:
73            provider.signing_kp = CertificateKeyPair.objects.exclude(key_data__iexact="").first()
74        # Set all auto-generated Property-mappings as defaults
75        # They should provide a sane default for most applications:
76        provider.property_mappings.set(SAMLPropertyMapping.objects.exclude(managed__isnull=True))
77        provider.save()
78        return provider

Create a SAMLProvider instance from the details. name is required, as depending on the metadata CertificateKeypairs might have to be created.

class ServiceProviderMetadataParser:
 81class ServiceProviderMetadataParser:
 82    """Service-Provider Metadata Parser"""
 83
 84    def get_signing_cert(self, root: etree.Element) -> CertificateKeyPair | None:
 85        """Extract signing X509Certificate from metadata, when given."""
 86        signing_certs = root.xpath(
 87            '//md:SPSSODescriptor/md:KeyDescriptor[@use="signing"]//ds:X509Certificate/text()',
 88            namespaces=NS_MAP,
 89        )
 90        if len(signing_certs) < 1:
 91            return None
 92        raw_cert = format_cert(signing_certs[0])
 93        # sanity check, make sure the certificate is valid.
 94        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
 95        return CertificateKeyPair(
 96            certificate_data=raw_cert,
 97        )
 98
 99    def get_encryption_cert(self, root: etree.Element) -> CertificateKeyPair | None:
100        """Extract encryption X509Certificate from metadata, when given."""
101        encryption_certs = root.xpath(
102            '//md:SPSSODescriptor/md:KeyDescriptor[@use="encryption"]//ds:X509Certificate/text()',
103            namespaces=NS_MAP,
104        )
105        if len(encryption_certs) < 1:
106            return None
107        raw_cert = format_cert(encryption_certs[0])
108        # sanity check, make sure the certificate is valid.
109        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
110        return CertificateKeyPair(
111            certificate_data=raw_cert,
112        )
113
114    def check_signature(self, root: etree.Element, keypair: CertificateKeyPair):
115        """If Metadata is signed, check validity of signature"""
116        xmlsec.tree.add_ids(root, ["ID"])
117        signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
118        if len(signature_nodes) != 1:
119            # No Signature
120            return
121
122        signature_node = signature_nodes[0]
123
124        if signature_node is not None:
125            try:
126                ctx = xmlsec.SignatureContext()
127                key = xmlsec.Key.from_memory(
128                    keypair.certificate_data,
129                    xmlsec.constants.KeyDataFormatCertPem,
130                    None,
131                )
132                ctx.key = key
133                ctx.verify(signature_node)
134            except xmlsec.Error as exc:
135                raise ValueError("Failed to verify Metadata signature") from exc
136
137    def parse(self, raw_xml: str) -> ServiceProviderMetadata:
138        """Parse raw XML to ServiceProviderMetadata"""
139        root = fromstring(raw_xml.encode())
140
141        entity_id = root.attrib["entityID"]
142        sp_sso_descriptors = root.findall(f"{{{NS_SAML_METADATA}}}SPSSODescriptor")
143        if len(sp_sso_descriptors) < 1:
144            raise ValueError("no SPSSODescriptor objects found.")
145        # For now we'll only look at the first descriptor.
146        # Even if multiple descriptors exist, we can only configure one
147        descriptor = sp_sso_descriptors[0]
148        auth_n_request_signed = False
149        if "AuthnRequestsSigned" in descriptor.attrib:
150            auth_n_request_signed = descriptor.attrib["AuthnRequestsSigned"].lower() == "true"
151
152        assertion_signed = False
153        if "WantAssertionsSigned" in descriptor.attrib:
154            assertion_signed = descriptor.attrib["WantAssertionsSigned"].lower() == "true"
155
156        acs_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}AssertionConsumerService")
157        if len(acs_services) < 1:
158            raise ValueError("No AssertionConsumerService found.")
159
160        acs_service = acs_services[0]
161        acs_binding = {
162            SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
163            SAML_BINDING_POST: SAMLBindings.POST,
164        }[acs_service.attrib["Binding"]]
165        acs_location = acs_service.attrib["Location"]
166
167        signing_keypair = self.get_signing_cert(root)
168        if signing_keypair:
169            self.check_signature(root, signing_keypair)
170        encryption_keypair = self.get_encryption_cert(root)
171
172        name_id_format = descriptor.findall(f"{{{NS_SAML_METADATA}}}NameIDFormat")
173        name_id_policy = SAMLNameIDPolicy.UNSPECIFIED
174        if len(name_id_format) > 0:
175            name_id_policy = SAMLNameIDPolicy(name_id_format[0].text)
176
177        # Parse SingleLogoutService (optional)
178        sls_binding = None
179        sls_location = None
180        sls_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}SingleLogoutService")
181        if len(sls_services) > 0:
182            sls_service = sls_services[0]
183            sls_binding = {
184                SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
185                SAML_BINDING_POST: SAMLBindings.POST,
186            }.get(sls_service.attrib.get("Binding"))
187            sls_location = sls_service.attrib.get("Location")
188
189        return ServiceProviderMetadata(
190            entity_id=entity_id,
191            acs_binding=acs_binding,
192            acs_location=acs_location,
193            auth_n_request_signed=auth_n_request_signed,
194            assertion_signed=assertion_signed,
195            signing_keypair=signing_keypair,
196            encryption_keypair=encryption_keypair,
197            name_id_policy=name_id_policy,
198            sls_binding=sls_binding,
199            sls_location=sls_location,
200        )

Service-Provider Metadata Parser

def get_signing_cert( self, root: lxml.etree.Element) -> authentik.crypto.models.CertificateKeyPair | None:
84    def get_signing_cert(self, root: etree.Element) -> CertificateKeyPair | None:
85        """Extract signing X509Certificate from metadata, when given."""
86        signing_certs = root.xpath(
87            '//md:SPSSODescriptor/md:KeyDescriptor[@use="signing"]//ds:X509Certificate/text()',
88            namespaces=NS_MAP,
89        )
90        if len(signing_certs) < 1:
91            return None
92        raw_cert = format_cert(signing_certs[0])
93        # sanity check, make sure the certificate is valid.
94        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
95        return CertificateKeyPair(
96            certificate_data=raw_cert,
97        )

Extract signing X509Certificate from metadata, when given.

def get_encryption_cert( self, root: lxml.etree.Element) -> authentik.crypto.models.CertificateKeyPair | None:
 99    def get_encryption_cert(self, root: etree.Element) -> CertificateKeyPair | None:
100        """Extract encryption X509Certificate from metadata, when given."""
101        encryption_certs = root.xpath(
102            '//md:SPSSODescriptor/md:KeyDescriptor[@use="encryption"]//ds:X509Certificate/text()',
103            namespaces=NS_MAP,
104        )
105        if len(encryption_certs) < 1:
106            return None
107        raw_cert = format_cert(encryption_certs[0])
108        # sanity check, make sure the certificate is valid.
109        load_pem_x509_certificate(raw_cert.encode("utf-8"), default_backend())
110        return CertificateKeyPair(
111            certificate_data=raw_cert,
112        )

Extract encryption X509Certificate from metadata, when given.

def check_signature( self, root: lxml.etree.Element, keypair: authentik.crypto.models.CertificateKeyPair):
114    def check_signature(self, root: etree.Element, keypair: CertificateKeyPair):
115        """If Metadata is signed, check validity of signature"""
116        xmlsec.tree.add_ids(root, ["ID"])
117        signature_nodes = root.xpath("/md:EntityDescriptor/ds:Signature", namespaces=NS_MAP)
118        if len(signature_nodes) != 1:
119            # No Signature
120            return
121
122        signature_node = signature_nodes[0]
123
124        if signature_node is not None:
125            try:
126                ctx = xmlsec.SignatureContext()
127                key = xmlsec.Key.from_memory(
128                    keypair.certificate_data,
129                    xmlsec.constants.KeyDataFormatCertPem,
130                    None,
131                )
132                ctx.key = key
133                ctx.verify(signature_node)
134            except xmlsec.Error as exc:
135                raise ValueError("Failed to verify Metadata signature") from exc

If Metadata is signed, check validity of signature

def parse( self, raw_xml: str) -> ServiceProviderMetadata:
137    def parse(self, raw_xml: str) -> ServiceProviderMetadata:
138        """Parse raw XML to ServiceProviderMetadata"""
139        root = fromstring(raw_xml.encode())
140
141        entity_id = root.attrib["entityID"]
142        sp_sso_descriptors = root.findall(f"{{{NS_SAML_METADATA}}}SPSSODescriptor")
143        if len(sp_sso_descriptors) < 1:
144            raise ValueError("no SPSSODescriptor objects found.")
145        # For now we'll only look at the first descriptor.
146        # Even if multiple descriptors exist, we can only configure one
147        descriptor = sp_sso_descriptors[0]
148        auth_n_request_signed = False
149        if "AuthnRequestsSigned" in descriptor.attrib:
150            auth_n_request_signed = descriptor.attrib["AuthnRequestsSigned"].lower() == "true"
151
152        assertion_signed = False
153        if "WantAssertionsSigned" in descriptor.attrib:
154            assertion_signed = descriptor.attrib["WantAssertionsSigned"].lower() == "true"
155
156        acs_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}AssertionConsumerService")
157        if len(acs_services) < 1:
158            raise ValueError("No AssertionConsumerService found.")
159
160        acs_service = acs_services[0]
161        acs_binding = {
162            SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
163            SAML_BINDING_POST: SAMLBindings.POST,
164        }[acs_service.attrib["Binding"]]
165        acs_location = acs_service.attrib["Location"]
166
167        signing_keypair = self.get_signing_cert(root)
168        if signing_keypair:
169            self.check_signature(root, signing_keypair)
170        encryption_keypair = self.get_encryption_cert(root)
171
172        name_id_format = descriptor.findall(f"{{{NS_SAML_METADATA}}}NameIDFormat")
173        name_id_policy = SAMLNameIDPolicy.UNSPECIFIED
174        if len(name_id_format) > 0:
175            name_id_policy = SAMLNameIDPolicy(name_id_format[0].text)
176
177        # Parse SingleLogoutService (optional)
178        sls_binding = None
179        sls_location = None
180        sls_services = descriptor.findall(f"{{{NS_SAML_METADATA}}}SingleLogoutService")
181        if len(sls_services) > 0:
182            sls_service = sls_services[0]
183            sls_binding = {
184                SAML_BINDING_REDIRECT: SAMLBindings.REDIRECT,
185                SAML_BINDING_POST: SAMLBindings.POST,
186            }.get(sls_service.attrib.get("Binding"))
187            sls_location = sls_service.attrib.get("Location")
188
189        return ServiceProviderMetadata(
190            entity_id=entity_id,
191            acs_binding=acs_binding,
192            acs_location=acs_location,
193            auth_n_request_signed=auth_n_request_signed,
194            assertion_signed=assertion_signed,
195            signing_keypair=signing_keypair,
196            encryption_keypair=encryption_keypair,
197            name_id_policy=name_id_policy,
198            sls_binding=sls_binding,
199            sls_location=sls_location,
200        )

Parse raw XML to ServiceProviderMetadata