authentik.outposts.consumer

Outpost websocket handler

  1"""Outpost websocket handler"""
  2
  3from dataclasses import asdict, dataclass, field
  4from datetime import datetime
  5from enum import IntEnum
  6from hashlib import sha256
  7from typing import Any
  8from uuid import UUID
  9
 10from asgiref.sync import async_to_sync
 11from channels.exceptions import DenyConnection
 12from channels.generic.websocket import JsonWebsocketConsumer
 13from dacite.core import from_dict
 14from dacite.data import Data
 15from django.db import connection
 16from django.http.request import QueryDict
 17from guardian.shortcuts import get_objects_for_user
 18from structlog.stdlib import BoundLogger, get_logger
 19
 20from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
 21from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
 22
 23
 24def build_outpost_group(outpost_pk: str | UUID) -> str:
 25    return sha256(f"{connection.schema_name}/group_outpost_{str(outpost_pk)}".encode()).hexdigest()
 26
 27
 28def build_outpost_group_instance(outpost_pk: str | UUID, instance: str) -> str:
 29    return sha256(
 30        f"{connection.schema_name}/group_outpost_{str(outpost_pk)}_{instance}".encode()
 31    ).hexdigest()
 32
 33
 34class WebsocketMessageInstruction(IntEnum):
 35    """Commands which can be triggered over Websocket"""
 36
 37    # Simple message used by either side when a message is acknowledged
 38    ACK = 0
 39
 40    # Message used by outposts to report their alive status
 41    HELLO = 1
 42
 43    # Message sent by us to trigger an Update
 44    TRIGGER_UPDATE = 2
 45
 46    # Provider specific message
 47    PROVIDER_SPECIFIC = 3
 48
 49    # Session ended
 50    SESSION_END = 4
 51
 52
 53@dataclass(slots=True)
 54class WebsocketMessage:
 55    """Complete Websocket Message that is being sent"""
 56
 57    instruction: int
 58    args: dict[str, Any] = field(default_factory=dict)
 59
 60
 61class OutpostConsumer(JsonWebsocketConsumer):
 62    """Handler for Outposts that connect over websockets for health checks and live updates"""
 63
 64    outpost: Outpost | None = None
 65    logger: BoundLogger
 66
 67    instance_uid: str | None = None
 68
 69    def __init__(self, *args, **kwargs):
 70        super().__init__(*args, **kwargs)
 71        self.logger = get_logger()
 72
 73    def connect(self):
 74        uuid = self.scope["url_route"]["kwargs"]["pk"]
 75        user = self.scope["user"]
 76        self.outpost: Outpost | None = (
 77            get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
 78        )
 79        if self.outpost is None:
 80            raise DenyConnection()
 81        self.logger = self.logger.bind(outpost=self.outpost)
 82        try:
 83            self.accept()
 84        except RuntimeError as exc:
 85            self.logger.warning("runtime error during accept", exc=exc)
 86            raise DenyConnection() from None
 87        query = QueryDict(self.scope["query_string"].decode())
 88        self.instance_uid = query.get("instance_uuid", self.channel_name)
 89        async_to_sync(self.channel_layer.group_add)(
 90            build_outpost_group(self.outpost.pk), self.channel_name
 91        )
 92        async_to_sync(self.channel_layer.group_add)(
 93            build_outpost_group_instance(self.outpost.pk, self.instance_uid),
 94            self.channel_name,
 95        )
 96        GAUGE_OUTPOSTS_CONNECTED.labels(
 97            tenant=connection.schema_name,
 98            outpost=self.outpost.name,
 99            uid=self.instance_uid,
100            expected=self.outpost.config.kubernetes_replicas,
101        ).inc()
102
103    def disconnect(self, code):
104        if self.outpost:
105            async_to_sync(self.channel_layer.group_discard)(
106                build_outpost_group(self.outpost.pk), self.channel_name
107            )
108            if self.instance_uid:
109                async_to_sync(self.channel_layer.group_discard)(
110                    build_outpost_group_instance(self.outpost.pk, self.instance_uid),
111                    self.channel_name,
112                )
113        if self.outpost and self.instance_uid:
114            GAUGE_OUTPOSTS_CONNECTED.labels(
115                tenant=connection.schema_name,
116                outpost=self.outpost.name,
117                uid=self.instance_uid,
118                expected=self.outpost.config.kubernetes_replicas,
119            ).dec()
120
121    def receive_json(self, content: Data, **kwargs):
122        msg = from_dict(WebsocketMessage, content)
123        if not self.outpost:
124            raise DenyConnection()
125
126        state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
127        state.last_seen = datetime.now()
128        state.hostname = msg.args.pop("hostname", "")
129
130        if msg.instruction == WebsocketMessageInstruction.HELLO:
131            state.version = msg.args.pop("version", None)
132            state.build_hash = msg.args.pop("buildHash", "")
133            state.golang_version = msg.args.pop("golangVersion", "")
134            state.openssl_enabled = msg.args.pop("opensslEnabled", False)
135            state.openssl_version = msg.args.pop("opensslVersion", "")
136            state.fips_enabled = msg.args.pop("fipsEnabled", False)
137            state.args.update(msg.args)
138        elif msg.instruction == WebsocketMessageInstruction.ACK:
139            return
140        GAUGE_OUTPOSTS_LAST_UPDATE.labels(
141            tenant=connection.schema_name,
142            outpost=self.outpost.name,
143            uid=self.instance_uid or "",
144            version=state.version or "",
145        ).set_to_current_time()
146        state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
147
148        response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
149        self.send_json(asdict(response))
150
151    def event_update(self, event):  # pragma: no cover
152        """Event handler which is called by post_save signals, Send update instruction"""
153        self.send_json(
154            asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
155        )
156
157    def event_session_end(self, event):
158        """Event handler which is called when a session is ended"""
159        self.send_json(
160            asdict(
161                WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
162            )
163        )
164
165    def event_provider_specific(self, event):
166        """Event handler which can be called by provider-specific
167        implementations to send specific messages to the outpost"""
168        self.send_json(
169            asdict(
170                WebsocketMessage(
171                    instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
172                )
173            )
174        )
def build_outpost_group(outpost_pk: str | uuid.UUID) -> str:
25def build_outpost_group(outpost_pk: str | UUID) -> str:
26    return sha256(f"{connection.schema_name}/group_outpost_{str(outpost_pk)}".encode()).hexdigest()
def build_outpost_group_instance(outpost_pk: str | uuid.UUID, instance: str) -> str:
29def build_outpost_group_instance(outpost_pk: str | UUID, instance: str) -> str:
30    return sha256(
31        f"{connection.schema_name}/group_outpost_{str(outpost_pk)}_{instance}".encode()
32    ).hexdigest()
class WebsocketMessageInstruction(enum.IntEnum):
35class WebsocketMessageInstruction(IntEnum):
36    """Commands which can be triggered over Websocket"""
37
38    # Simple message used by either side when a message is acknowledged
39    ACK = 0
40
41    # Message used by outposts to report their alive status
42    HELLO = 1
43
44    # Message sent by us to trigger an Update
45    TRIGGER_UPDATE = 2
46
47    # Provider specific message
48    PROVIDER_SPECIFIC = 3
49
50    # Session ended
51    SESSION_END = 4

Commands which can be triggered over Websocket

@dataclass(slots=True)
class WebsocketMessage:
54@dataclass(slots=True)
55class WebsocketMessage:
56    """Complete Websocket Message that is being sent"""
57
58    instruction: int
59    args: dict[str, Any] = field(default_factory=dict)

Complete Websocket Message that is being sent

WebsocketMessage(instruction: int, args: dict[str, typing.Any] = <factory>)
instruction: int
args: dict[str, typing.Any]
class OutpostConsumer(channels.generic.websocket.JsonWebsocketConsumer):
 62class OutpostConsumer(JsonWebsocketConsumer):
 63    """Handler for Outposts that connect over websockets for health checks and live updates"""
 64
 65    outpost: Outpost | None = None
 66    logger: BoundLogger
 67
 68    instance_uid: str | None = None
 69
 70    def __init__(self, *args, **kwargs):
 71        super().__init__(*args, **kwargs)
 72        self.logger = get_logger()
 73
 74    def connect(self):
 75        uuid = self.scope["url_route"]["kwargs"]["pk"]
 76        user = self.scope["user"]
 77        self.outpost: Outpost | None = (
 78            get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
 79        )
 80        if self.outpost is None:
 81            raise DenyConnection()
 82        self.logger = self.logger.bind(outpost=self.outpost)
 83        try:
 84            self.accept()
 85        except RuntimeError as exc:
 86            self.logger.warning("runtime error during accept", exc=exc)
 87            raise DenyConnection() from None
 88        query = QueryDict(self.scope["query_string"].decode())
 89        self.instance_uid = query.get("instance_uuid", self.channel_name)
 90        async_to_sync(self.channel_layer.group_add)(
 91            build_outpost_group(self.outpost.pk), self.channel_name
 92        )
 93        async_to_sync(self.channel_layer.group_add)(
 94            build_outpost_group_instance(self.outpost.pk, self.instance_uid),
 95            self.channel_name,
 96        )
 97        GAUGE_OUTPOSTS_CONNECTED.labels(
 98            tenant=connection.schema_name,
 99            outpost=self.outpost.name,
100            uid=self.instance_uid,
101            expected=self.outpost.config.kubernetes_replicas,
102        ).inc()
103
104    def disconnect(self, code):
105        if self.outpost:
106            async_to_sync(self.channel_layer.group_discard)(
107                build_outpost_group(self.outpost.pk), self.channel_name
108            )
109            if self.instance_uid:
110                async_to_sync(self.channel_layer.group_discard)(
111                    build_outpost_group_instance(self.outpost.pk, self.instance_uid),
112                    self.channel_name,
113                )
114        if self.outpost and self.instance_uid:
115            GAUGE_OUTPOSTS_CONNECTED.labels(
116                tenant=connection.schema_name,
117                outpost=self.outpost.name,
118                uid=self.instance_uid,
119                expected=self.outpost.config.kubernetes_replicas,
120            ).dec()
121
122    def receive_json(self, content: Data, **kwargs):
123        msg = from_dict(WebsocketMessage, content)
124        if not self.outpost:
125            raise DenyConnection()
126
127        state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
128        state.last_seen = datetime.now()
129        state.hostname = msg.args.pop("hostname", "")
130
131        if msg.instruction == WebsocketMessageInstruction.HELLO:
132            state.version = msg.args.pop("version", None)
133            state.build_hash = msg.args.pop("buildHash", "")
134            state.golang_version = msg.args.pop("golangVersion", "")
135            state.openssl_enabled = msg.args.pop("opensslEnabled", False)
136            state.openssl_version = msg.args.pop("opensslVersion", "")
137            state.fips_enabled = msg.args.pop("fipsEnabled", False)
138            state.args.update(msg.args)
139        elif msg.instruction == WebsocketMessageInstruction.ACK:
140            return
141        GAUGE_OUTPOSTS_LAST_UPDATE.labels(
142            tenant=connection.schema_name,
143            outpost=self.outpost.name,
144            uid=self.instance_uid or "",
145            version=state.version or "",
146        ).set_to_current_time()
147        state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
148
149        response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
150        self.send_json(asdict(response))
151
152    def event_update(self, event):  # pragma: no cover
153        """Event handler which is called by post_save signals, Send update instruction"""
154        self.send_json(
155            asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
156        )
157
158    def event_session_end(self, event):
159        """Event handler which is called when a session is ended"""
160        self.send_json(
161            asdict(
162                WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
163            )
164        )
165
166    def event_provider_specific(self, event):
167        """Event handler which can be called by provider-specific
168        implementations to send specific messages to the outpost"""
169        self.send_json(
170            asdict(
171                WebsocketMessage(
172                    instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
173                )
174            )
175        )

Handler for Outposts that connect over websockets for health checks and live updates

OutpostConsumer(*args, **kwargs)
70    def __init__(self, *args, **kwargs):
71        super().__init__(*args, **kwargs)
72        self.logger = get_logger()
outpost: authentik.outposts.models.Outpost | None = None
logger: structlog.stdlib.BoundLogger
instance_uid: str | None = None
def connect(self):
 74    def connect(self):
 75        uuid = self.scope["url_route"]["kwargs"]["pk"]
 76        user = self.scope["user"]
 77        self.outpost: Outpost | None = (
 78            get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
 79        )
 80        if self.outpost is None:
 81            raise DenyConnection()
 82        self.logger = self.logger.bind(outpost=self.outpost)
 83        try:
 84            self.accept()
 85        except RuntimeError as exc:
 86            self.logger.warning("runtime error during accept", exc=exc)
 87            raise DenyConnection() from None
 88        query = QueryDict(self.scope["query_string"].decode())
 89        self.instance_uid = query.get("instance_uuid", self.channel_name)
 90        async_to_sync(self.channel_layer.group_add)(
 91            build_outpost_group(self.outpost.pk), self.channel_name
 92        )
 93        async_to_sync(self.channel_layer.group_add)(
 94            build_outpost_group_instance(self.outpost.pk, self.instance_uid),
 95            self.channel_name,
 96        )
 97        GAUGE_OUTPOSTS_CONNECTED.labels(
 98            tenant=connection.schema_name,
 99            outpost=self.outpost.name,
100            uid=self.instance_uid,
101            expected=self.outpost.config.kubernetes_replicas,
102        ).inc()
def disconnect(self, code):
104    def disconnect(self, code):
105        if self.outpost:
106            async_to_sync(self.channel_layer.group_discard)(
107                build_outpost_group(self.outpost.pk), self.channel_name
108            )
109            if self.instance_uid:
110                async_to_sync(self.channel_layer.group_discard)(
111                    build_outpost_group_instance(self.outpost.pk, self.instance_uid),
112                    self.channel_name,
113                )
114        if self.outpost and self.instance_uid:
115            GAUGE_OUTPOSTS_CONNECTED.labels(
116                tenant=connection.schema_name,
117                outpost=self.outpost.name,
118                uid=self.instance_uid,
119                expected=self.outpost.config.kubernetes_replicas,
120            ).dec()

Called when a WebSocket connection is closed.

def receive_json(self, content: dacite.data.Data, **kwargs):
122    def receive_json(self, content: Data, **kwargs):
123        msg = from_dict(WebsocketMessage, content)
124        if not self.outpost:
125            raise DenyConnection()
126
127        state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
128        state.last_seen = datetime.now()
129        state.hostname = msg.args.pop("hostname", "")
130
131        if msg.instruction == WebsocketMessageInstruction.HELLO:
132            state.version = msg.args.pop("version", None)
133            state.build_hash = msg.args.pop("buildHash", "")
134            state.golang_version = msg.args.pop("golangVersion", "")
135            state.openssl_enabled = msg.args.pop("opensslEnabled", False)
136            state.openssl_version = msg.args.pop("opensslVersion", "")
137            state.fips_enabled = msg.args.pop("fipsEnabled", False)
138            state.args.update(msg.args)
139        elif msg.instruction == WebsocketMessageInstruction.ACK:
140            return
141        GAUGE_OUTPOSTS_LAST_UPDATE.labels(
142            tenant=connection.schema_name,
143            outpost=self.outpost.name,
144            uid=self.instance_uid or "",
145            version=state.version or "",
146        ).set_to_current_time()
147        state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
148
149        response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK)
150        self.send_json(asdict(response))

Called with decoded JSON content.

def event_update(self, event):
152    def event_update(self, event):  # pragma: no cover
153        """Event handler which is called by post_save signals, Send update instruction"""
154        self.send_json(
155            asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE))
156        )

Event handler which is called by post_save signals, Send update instruction

def event_session_end(self, event):
158    def event_session_end(self, event):
159        """Event handler which is called when a session is ended"""
160        self.send_json(
161            asdict(
162                WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event)
163            )
164        )

Event handler which is called when a session is ended

def event_provider_specific(self, event):
166    def event_provider_specific(self, event):
167        """Event handler which can be called by provider-specific
168        implementations to send specific messages to the outpost"""
169        self.send_json(
170            asdict(
171                WebsocketMessage(
172                    instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event
173                )
174            )
175        )

Event handler which can be called by provider-specific implementations to send specific messages to the outpost