authentik.enterprise.providers.ssf.views.stream

  1from uuid import uuid4
  2
  3from django.http import HttpRequest
  4from django.urls import reverse
  5from rest_framework.exceptions import PermissionDenied, ValidationError
  6from rest_framework.fields import CharField, ChoiceField, ListField, SerializerMethodField
  7from rest_framework.request import Request
  8from rest_framework.response import Response
  9from structlog.stdlib import get_logger
 10
 11from authentik.api.validation import validate
 12from authentik.core.api.utils import ModelSerializer, PassiveSerializer
 13from authentik.enterprise.providers.ssf.models import (
 14    DeliveryMethods,
 15    EventTypes,
 16    SSFProvider,
 17    Stream,
 18    StreamStatus,
 19)
 20from authentik.enterprise.providers.ssf.tasks import send_ssf_events
 21from authentik.enterprise.providers.ssf.views.base import SSFStreamView
 22
 23LOGGER = get_logger()
 24
 25
 26class StreamDeliverySerializer(PassiveSerializer):
 27    method = ChoiceField(choices=[(x.value, x.value) for x in DeliveryMethods])
 28    endpoint_url = CharField(required=False)
 29    authorization_header = CharField(required=False)
 30
 31    def validate_method(self, method: DeliveryMethods):
 32        """Currently only push is supported"""
 33        if method == DeliveryMethods.RISC_POLL:
 34            raise ValidationError("Polling for SSF events is not currently supported.")
 35        return method
 36
 37    def validate(self, attrs: dict) -> dict:
 38        if attrs.get("method") in [DeliveryMethods.RISC_PUSH, DeliveryMethods.RFC_PUSH]:
 39            if not attrs.get("endpoint_url"):
 40                raise ValidationError("Endpoint URL is required when using push.")
 41        return attrs
 42
 43
 44class StreamSerializer(ModelSerializer):
 45    delivery = StreamDeliverySerializer()
 46    events_requested = ListField(
 47        child=ChoiceField(choices=[(x.value, x.value) for x in EventTypes])
 48    )
 49    format = CharField(default="iss_sub")
 50    aud = ListField(child=CharField(), allow_empty=True, default=list)
 51
 52    def create(self, validated_data):
 53        provider: SSFProvider = validated_data["provider"]
 54        request: HttpRequest = self.context["request"]
 55        iss = request.build_absolute_uri(
 56            reverse(
 57                "authentik_providers_ssf:configuration",
 58                kwargs={
 59                    "application_slug": provider.backchannel_application.slug,
 60                },
 61            )
 62        )
 63        # Ensure that streams always get SET verification events sent to them
 64        validated_data["events_requested"].append(EventTypes.SET_VERIFICATION)
 65        stream_id = uuid4()
 66        default_aud = f"goauthentik.io/providers/ssf/{str(stream_id)}"
 67        return super().create(
 68            {
 69                "delivery_method": validated_data["delivery"]["method"],
 70                "endpoint_url": validated_data["delivery"].get("endpoint_url"),
 71                "authorization_header": validated_data["delivery"].get("authorization_header"),
 72                "format": validated_data["format"],
 73                "provider": validated_data["provider"],
 74                "events_requested": validated_data["events_requested"],
 75                "aud": validated_data["aud"] or [default_aud],
 76                "iss": iss,
 77                "pk": stream_id,
 78            }
 79        )
 80
 81    class Meta:
 82        model = Stream
 83        fields = [
 84            "delivery",
 85            "events_requested",
 86            "format",
 87            "aud",
 88        ]
 89
 90
 91class StreamResponseSerializer(PassiveSerializer):
 92    stream_id = CharField(source="pk")
 93    iss = CharField()
 94    aud = ListField(child=CharField())
 95    delivery = SerializerMethodField()
 96    format = CharField()
 97
 98    events_requested = ListField(child=CharField())
 99    events_supported = SerializerMethodField()
100    events_delivered = ListField(child=CharField(), source="events_requested")
101
102    def get_delivery(self, instance: Stream) -> StreamDeliverySerializer:
103        return {
104            "method": instance.delivery_method,
105            "endpoint_url": instance.endpoint_url,
106        }
107
108    def get_events_supported(self, instance: Stream) -> list[str]:
109        return [x.value for x in EventTypes]
110
111
112class StreamView(SSFStreamView):
113
114    def get(self, request: Request, *args, **kwargs):
115        stream = self.get_object()
116        return Response(
117            StreamResponseSerializer(instance=stream, context={"request": request}).data
118        )
119
120    @validate(StreamSerializer)
121    def post(self, request: Request, *args, body: StreamSerializer, **kwargs) -> Response:
122        if not request.user.has_perm("authentik_providers_ssf.add_stream", self.provider):
123            raise PermissionDenied(
124                "User does not have permission to create stream for this provider."
125            )
126        instance: Stream = body.save(provider=self.provider)
127
128        LOGGER.info("Sending verification event", stream=instance)
129        send_ssf_events(
130            EventTypes.SET_VERIFICATION,
131            {
132                "state": None,
133            },
134            stream_filter={"pk": instance.uuid},
135            sub_id={"format": "opaque", "id": str(instance.uuid)},
136        )
137        response = StreamResponseSerializer(instance=instance, context={"request": request}).data
138        return Response(response, status=201)
139
140    def patch(self, request: Request, *args, **kwargs) -> Response:
141        stream = self.get_object()
142        serializer = StreamSerializer(stream, data=request.data, partial=True)
143        serializer.is_valid(raise_exception=True)
144        serializer.save()
145        response = StreamResponseSerializer(
146            instance=serializer.instance, context={"request": request}
147        ).data
148        return Response(response, status=200)
149
150    def put(self, request: Request, *args, **kwargs) -> Response:
151        stream = self.get_object()
152        serializer = StreamSerializer(stream, data=request.data)
153        serializer.is_valid(raise_exception=True)
154        serializer.save()
155        response = StreamResponseSerializer(
156            instance=serializer.instance, context={"request": request}
157        ).data
158        return Response(response, status=200)
159
160    def delete(self, request: Request, *args, **kwargs) -> Response:
161        stream = self.get_object()
162        stream.status = StreamStatus.DISABLED
163        stream.save()
164        return Response(status=204)
165
166
167class StreamVerifyView(SSFStreamView):
168
169    def post(self, request: Request, *args, **kwargs):
170        stream = self.get_object()
171        state = request.data.get("state", None)
172        send_ssf_events(
173            EventTypes.SET_VERIFICATION,
174            {
175                "state": state,
176            },
177            stream_filter={"pk": stream.uuid},
178            sub_id={"format": "opaque", "id": str(stream.uuid)},
179        )
180        return Response(status=204)
181
182
183class StreamStatusView(SSFStreamView):
184
185    def get(self, request: Request, *args, **kwargs):
186        stream = self.get_object(any_status=True)
187        return Response(
188            {
189                "stream_id": str(stream.pk),
190                "status": str(stream.status),
191            }
192        )
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
class StreamDeliverySerializer(authentik.core.api.utils.PassiveSerializer):
27class StreamDeliverySerializer(PassiveSerializer):
28    method = ChoiceField(choices=[(x.value, x.value) for x in DeliveryMethods])
29    endpoint_url = CharField(required=False)
30    authorization_header = CharField(required=False)
31
32    def validate_method(self, method: DeliveryMethods):
33        """Currently only push is supported"""
34        if method == DeliveryMethods.RISC_POLL:
35            raise ValidationError("Polling for SSF events is not currently supported.")
36        return method
37
38    def validate(self, attrs: dict) -> dict:
39        if attrs.get("method") in [DeliveryMethods.RISC_PUSH, DeliveryMethods.RFC_PUSH]:
40            if not attrs.get("endpoint_url"):
41                raise ValidationError("Endpoint URL is required when using push.")
42        return attrs

Base serializer class which doesn't implement create/update methods

method
endpoint_url
authorization_header
def validate_method( self, method: authentik.enterprise.providers.ssf.models.DeliveryMethods):
32    def validate_method(self, method: DeliveryMethods):
33        """Currently only push is supported"""
34        if method == DeliveryMethods.RISC_POLL:
35            raise ValidationError("Polling for SSF events is not currently supported.")
36        return method

Currently only push is supported

def validate(self, attrs: dict) -> dict:
38    def validate(self, attrs: dict) -> dict:
39        if attrs.get("method") in [DeliveryMethods.RISC_PUSH, DeliveryMethods.RFC_PUSH]:
40            if not attrs.get("endpoint_url"):
41                raise ValidationError("Endpoint URL is required when using push.")
42        return attrs
class StreamSerializer(authentik.core.api.utils.ModelSerializer):
45class StreamSerializer(ModelSerializer):
46    delivery = StreamDeliverySerializer()
47    events_requested = ListField(
48        child=ChoiceField(choices=[(x.value, x.value) for x in EventTypes])
49    )
50    format = CharField(default="iss_sub")
51    aud = ListField(child=CharField(), allow_empty=True, default=list)
52
53    def create(self, validated_data):
54        provider: SSFProvider = validated_data["provider"]
55        request: HttpRequest = self.context["request"]
56        iss = request.build_absolute_uri(
57            reverse(
58                "authentik_providers_ssf:configuration",
59                kwargs={
60                    "application_slug": provider.backchannel_application.slug,
61                },
62            )
63        )
64        # Ensure that streams always get SET verification events sent to them
65        validated_data["events_requested"].append(EventTypes.SET_VERIFICATION)
66        stream_id = uuid4()
67        default_aud = f"goauthentik.io/providers/ssf/{str(stream_id)}"
68        return super().create(
69            {
70                "delivery_method": validated_data["delivery"]["method"],
71                "endpoint_url": validated_data["delivery"].get("endpoint_url"),
72                "authorization_header": validated_data["delivery"].get("authorization_header"),
73                "format": validated_data["format"],
74                "provider": validated_data["provider"],
75                "events_requested": validated_data["events_requested"],
76                "aud": validated_data["aud"] or [default_aud],
77                "iss": iss,
78                "pk": stream_id,
79            }
80        )
81
82    class Meta:
83        model = Stream
84        fields = [
85            "delivery",
86            "events_requested",
87            "format",
88            "aud",
89        ]

A ModelSerializer is just a regular Serializer, except that:

  • A set of default fields are automatically populated.
  • A set of default validators are automatically populated.
  • Default .create() and .update() implementations are provided.

The process of automatically determining a set of serializer fields based on the model fields is reasonably complex, but you almost certainly don't need to dig into the implementation.

If the ModelSerializer class doesn't generate the set of fields that you need you should either declare the extra/differing fields explicitly on the serializer class, or simply use a Serializer class.

delivery
events_requested
format
aud
def create(self, validated_data):
53    def create(self, validated_data):
54        provider: SSFProvider = validated_data["provider"]
55        request: HttpRequest = self.context["request"]
56        iss = request.build_absolute_uri(
57            reverse(
58                "authentik_providers_ssf:configuration",
59                kwargs={
60                    "application_slug": provider.backchannel_application.slug,
61                },
62            )
63        )
64        # Ensure that streams always get SET verification events sent to them
65        validated_data["events_requested"].append(EventTypes.SET_VERIFICATION)
66        stream_id = uuid4()
67        default_aud = f"goauthentik.io/providers/ssf/{str(stream_id)}"
68        return super().create(
69            {
70                "delivery_method": validated_data["delivery"]["method"],
71                "endpoint_url": validated_data["delivery"].get("endpoint_url"),
72                "authorization_header": validated_data["delivery"].get("authorization_header"),
73                "format": validated_data["format"],
74                "provider": validated_data["provider"],
75                "events_requested": validated_data["events_requested"],
76                "aud": validated_data["aud"] or [default_aud],
77                "iss": iss,
78                "pk": stream_id,
79            }
80        )

We have a bit of extra checking around this in order to provide descriptive messages when something goes wrong, but this method is essentially just:

return ExampleModel.objects.create(**validated_data)

If there are many to many fields present on the instance then they cannot be set until the model is instantiated, in which case the implementation is like so:

example_relationship = validated_data.pop('example_relationship')
instance = ExampleModel.objects.create(**validated_data)
instance.example_relationship = example_relationship
return instance

The default implementation also does not handle nested relationships. If you want to support writable nested relationships you'll need to write an explicit .create() method.

class StreamSerializer.Meta:
82    class Meta:
83        model = Stream
84        fields = [
85            "delivery",
86            "events_requested",
87            "format",
88            "aud",
89        ]
fields = ['delivery', 'events_requested', 'format', 'aud']
class StreamResponseSerializer(authentik.core.api.utils.PassiveSerializer):
 92class StreamResponseSerializer(PassiveSerializer):
 93    stream_id = CharField(source="pk")
 94    iss = CharField()
 95    aud = ListField(child=CharField())
 96    delivery = SerializerMethodField()
 97    format = CharField()
 98
 99    events_requested = ListField(child=CharField())
100    events_supported = SerializerMethodField()
101    events_delivered = ListField(child=CharField(), source="events_requested")
102
103    def get_delivery(self, instance: Stream) -> StreamDeliverySerializer:
104        return {
105            "method": instance.delivery_method,
106            "endpoint_url": instance.endpoint_url,
107        }
108
109    def get_events_supported(self, instance: Stream) -> list[str]:
110        return [x.value for x in EventTypes]

Base serializer class which doesn't implement create/update methods

stream_id
iss
aud
delivery
format
events_requested
events_supported
events_delivered
def get_delivery( self, instance: authentik.enterprise.providers.ssf.models.Stream) -> StreamDeliverySerializer:
103    def get_delivery(self, instance: Stream) -> StreamDeliverySerializer:
104        return {
105            "method": instance.delivery_method,
106            "endpoint_url": instance.endpoint_url,
107        }
def get_events_supported( self, instance: authentik.enterprise.providers.ssf.models.Stream) -> list[str]:
109    def get_events_supported(self, instance: Stream) -> list[str]:
110        return [x.value for x in EventTypes]
113class StreamView(SSFStreamView):
114
115    def get(self, request: Request, *args, **kwargs):
116        stream = self.get_object()
117        return Response(
118            StreamResponseSerializer(instance=stream, context={"request": request}).data
119        )
120
121    @validate(StreamSerializer)
122    def post(self, request: Request, *args, body: StreamSerializer, **kwargs) -> Response:
123        if not request.user.has_perm("authentik_providers_ssf.add_stream", self.provider):
124            raise PermissionDenied(
125                "User does not have permission to create stream for this provider."
126            )
127        instance: Stream = body.save(provider=self.provider)
128
129        LOGGER.info("Sending verification event", stream=instance)
130        send_ssf_events(
131            EventTypes.SET_VERIFICATION,
132            {
133                "state": None,
134            },
135            stream_filter={"pk": instance.uuid},
136            sub_id={"format": "opaque", "id": str(instance.uuid)},
137        )
138        response = StreamResponseSerializer(instance=instance, context={"request": request}).data
139        return Response(response, status=201)
140
141    def patch(self, request: Request, *args, **kwargs) -> Response:
142        stream = self.get_object()
143        serializer = StreamSerializer(stream, data=request.data, partial=True)
144        serializer.is_valid(raise_exception=True)
145        serializer.save()
146        response = StreamResponseSerializer(
147            instance=serializer.instance, context={"request": request}
148        ).data
149        return Response(response, status=200)
150
151    def put(self, request: Request, *args, **kwargs) -> Response:
152        stream = self.get_object()
153        serializer = StreamSerializer(stream, data=request.data)
154        serializer.is_valid(raise_exception=True)
155        serializer.save()
156        response = StreamResponseSerializer(
157            instance=serializer.instance, context={"request": request}
158        ).data
159        return Response(response, status=200)
160
161    def delete(self, request: Request, *args, **kwargs) -> Response:
162        stream = self.get_object()
163        stream.status = StreamStatus.DISABLED
164        stream.save()
165        return Response(status=204)

Intentionally simple parent class for all views. Only implements dispatch-by-method and simple sanity checking.

def get(self, request: rest_framework.request.Request, *args, **kwargs):
115    def get(self, request: Request, *args, **kwargs):
116        stream = self.get_object()
117        return Response(
118            StreamResponseSerializer(instance=stream, context={"request": request}).data
119        )
@validate(StreamSerializer)
def post( self, request: rest_framework.request.Request, *args, body: StreamSerializer, **kwargs) -> rest_framework.response.Response:
121    @validate(StreamSerializer)
122    def post(self, request: Request, *args, body: StreamSerializer, **kwargs) -> Response:
123        if not request.user.has_perm("authentik_providers_ssf.add_stream", self.provider):
124            raise PermissionDenied(
125                "User does not have permission to create stream for this provider."
126            )
127        instance: Stream = body.save(provider=self.provider)
128
129        LOGGER.info("Sending verification event", stream=instance)
130        send_ssf_events(
131            EventTypes.SET_VERIFICATION,
132            {
133                "state": None,
134            },
135            stream_filter={"pk": instance.uuid},
136            sub_id={"format": "opaque", "id": str(instance.uuid)},
137        )
138        response = StreamResponseSerializer(instance=instance, context={"request": request}).data
139        return Response(response, status=201)
def patch( self, request: rest_framework.request.Request, *args, **kwargs) -> rest_framework.response.Response:
141    def patch(self, request: Request, *args, **kwargs) -> Response:
142        stream = self.get_object()
143        serializer = StreamSerializer(stream, data=request.data, partial=True)
144        serializer.is_valid(raise_exception=True)
145        serializer.save()
146        response = StreamResponseSerializer(
147            instance=serializer.instance, context={"request": request}
148        ).data
149        return Response(response, status=200)
def put( self, request: rest_framework.request.Request, *args, **kwargs) -> rest_framework.response.Response:
151    def put(self, request: Request, *args, **kwargs) -> Response:
152        stream = self.get_object()
153        serializer = StreamSerializer(stream, data=request.data)
154        serializer.is_valid(raise_exception=True)
155        serializer.save()
156        response = StreamResponseSerializer(
157            instance=serializer.instance, context={"request": request}
158        ).data
159        return Response(response, status=200)
def delete( self, request: rest_framework.request.Request, *args, **kwargs) -> rest_framework.response.Response:
161    def delete(self, request: Request, *args, **kwargs) -> Response:
162        stream = self.get_object()
163        stream.status = StreamStatus.DISABLED
164        stream.save()
165        return Response(status=204)
168class StreamVerifyView(SSFStreamView):
169
170    def post(self, request: Request, *args, **kwargs):
171        stream = self.get_object()
172        state = request.data.get("state", None)
173        send_ssf_events(
174            EventTypes.SET_VERIFICATION,
175            {
176                "state": state,
177            },
178            stream_filter={"pk": stream.uuid},
179            sub_id={"format": "opaque", "id": str(stream.uuid)},
180        )
181        return Response(status=204)

Intentionally simple parent class for all views. Only implements dispatch-by-method and simple sanity checking.

def post(self, request: rest_framework.request.Request, *args, **kwargs):
170    def post(self, request: Request, *args, **kwargs):
171        stream = self.get_object()
172        state = request.data.get("state", None)
173        send_ssf_events(
174            EventTypes.SET_VERIFICATION,
175            {
176                "state": state,
177            },
178            stream_filter={"pk": stream.uuid},
179            sub_id={"format": "opaque", "id": str(stream.uuid)},
180        )
181        return Response(status=204)
184class StreamStatusView(SSFStreamView):
185
186    def get(self, request: Request, *args, **kwargs):
187        stream = self.get_object(any_status=True)
188        return Response(
189            {
190                "stream_id": str(stream.pk),
191                "status": str(stream.status),
192            }
193        )

Intentionally simple parent class for all views. Only implements dispatch-by-method and simple sanity checking.

def get(self, request: rest_framework.request.Request, *args, **kwargs):
186    def get(self, request: Request, *args, **kwargs):
187        stream = self.get_object(any_status=True)
188        return Response(
189            {
190                "stream_id": str(stream.pk),
191                "status": str(stream.status),
192            }
193        )