authentik.outposts.consumer
Outpost websocket handler
1"""Outpost websocket handler""" 2 3from dataclasses import asdict, dataclass, field 4from datetime import datetime 5from enum import IntEnum 6from hashlib import sha256 7from typing import Any 8from uuid import UUID 9 10from asgiref.sync import async_to_sync 11from channels.exceptions import DenyConnection 12from channels.generic.websocket import JsonWebsocketConsumer 13from dacite.core import from_dict 14from dacite.data import Data 15from django.db import connection 16from django.http.request import QueryDict 17from guardian.shortcuts import get_objects_for_user 18from structlog.stdlib import BoundLogger, get_logger 19 20from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE 21from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState 22 23 24def build_outpost_group(outpost_pk: str | UUID) -> str: 25 return sha256(f"{connection.schema_name}/group_outpost_{str(outpost_pk)}".encode()).hexdigest() 26 27 28def build_outpost_group_instance(outpost_pk: str | UUID, instance: str) -> str: 29 return sha256( 30 f"{connection.schema_name}/group_outpost_{str(outpost_pk)}_{instance}".encode() 31 ).hexdigest() 32 33 34class WebsocketMessageInstruction(IntEnum): 35 """Commands which can be triggered over Websocket""" 36 37 # Simple message used by either side when a message is acknowledged 38 ACK = 0 39 40 # Message used by outposts to report their alive status 41 HELLO = 1 42 43 # Message sent by us to trigger an Update 44 TRIGGER_UPDATE = 2 45 46 # Provider specific message 47 PROVIDER_SPECIFIC = 3 48 49 # Session ended 50 SESSION_END = 4 51 52 53@dataclass(slots=True) 54class WebsocketMessage: 55 """Complete Websocket Message that is being sent""" 56 57 instruction: int 58 args: dict[str, Any] = field(default_factory=dict) 59 60 61class OutpostConsumer(JsonWebsocketConsumer): 62 """Handler for Outposts that connect over websockets for health checks and live updates""" 63 64 outpost: Outpost | None = None 65 logger: BoundLogger 66 67 instance_uid: str | None = None 68 69 def __init__(self, *args, **kwargs): 70 super().__init__(*args, **kwargs) 71 self.logger = get_logger() 72 73 def connect(self): 74 uuid = self.scope["url_route"]["kwargs"]["pk"] 75 user = self.scope["user"] 76 self.outpost: Outpost | None = ( 77 get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first() 78 ) 79 if self.outpost is None: 80 raise DenyConnection() 81 self.logger = self.logger.bind(outpost=self.outpost) 82 try: 83 self.accept() 84 except RuntimeError as exc: 85 self.logger.warning("runtime error during accept", exc=exc) 86 raise DenyConnection() from None 87 query = QueryDict(self.scope["query_string"].decode()) 88 self.instance_uid = query.get("instance_uuid", self.channel_name) 89 async_to_sync(self.channel_layer.group_add)( 90 build_outpost_group(self.outpost.pk), self.channel_name 91 ) 92 async_to_sync(self.channel_layer.group_add)( 93 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 94 self.channel_name, 95 ) 96 GAUGE_OUTPOSTS_CONNECTED.labels( 97 tenant=connection.schema_name, 98 outpost=self.outpost.name, 99 uid=self.instance_uid, 100 expected=self.outpost.config.kubernetes_replicas, 101 ).inc() 102 103 def disconnect(self, code): 104 if self.outpost: 105 async_to_sync(self.channel_layer.group_discard)( 106 build_outpost_group(self.outpost.pk), self.channel_name 107 ) 108 if self.instance_uid: 109 async_to_sync(self.channel_layer.group_discard)( 110 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 111 self.channel_name, 112 ) 113 if self.outpost and self.instance_uid: 114 GAUGE_OUTPOSTS_CONNECTED.labels( 115 tenant=connection.schema_name, 116 outpost=self.outpost.name, 117 uid=self.instance_uid, 118 expected=self.outpost.config.kubernetes_replicas, 119 ).dec() 120 121 def receive_json(self, content: Data, **kwargs): 122 msg = from_dict(WebsocketMessage, content) 123 if not self.outpost: 124 raise DenyConnection() 125 126 state = OutpostState.for_instance_uid(self.outpost, self.instance_uid) 127 state.last_seen = datetime.now() 128 state.hostname = msg.args.pop("hostname", "") 129 130 if msg.instruction == WebsocketMessageInstruction.HELLO: 131 state.version = msg.args.pop("version", None) 132 state.build_hash = msg.args.pop("buildHash", "") 133 state.golang_version = msg.args.pop("golangVersion", "") 134 state.openssl_enabled = msg.args.pop("opensslEnabled", False) 135 state.openssl_version = msg.args.pop("opensslVersion", "") 136 state.fips_enabled = msg.args.pop("fipsEnabled", False) 137 state.args.update(msg.args) 138 elif msg.instruction == WebsocketMessageInstruction.ACK: 139 return 140 GAUGE_OUTPOSTS_LAST_UPDATE.labels( 141 tenant=connection.schema_name, 142 outpost=self.outpost.name, 143 uid=self.instance_uid or "", 144 version=state.version or "", 145 ).set_to_current_time() 146 state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5) 147 148 response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK) 149 self.send_json(asdict(response)) 150 151 def event_update(self, event): # pragma: no cover 152 """Event handler which is called by post_save signals, Send update instruction""" 153 self.send_json( 154 asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) 155 ) 156 157 def event_session_end(self, event): 158 """Event handler which is called when a session is ended""" 159 self.send_json( 160 asdict( 161 WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event) 162 ) 163 ) 164 165 def event_provider_specific(self, event): 166 """Event handler which can be called by provider-specific 167 implementations to send specific messages to the outpost""" 168 self.send_json( 169 asdict( 170 WebsocketMessage( 171 instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event 172 ) 173 ) 174 )
def
build_outpost_group(outpost_pk: str | uuid.UUID) -> str:
def
build_outpost_group_instance(outpost_pk: str | uuid.UUID, instance: str) -> str:
class
WebsocketMessageInstruction(enum.IntEnum):
35class WebsocketMessageInstruction(IntEnum): 36 """Commands which can be triggered over Websocket""" 37 38 # Simple message used by either side when a message is acknowledged 39 ACK = 0 40 41 # Message used by outposts to report their alive status 42 HELLO = 1 43 44 # Message sent by us to trigger an Update 45 TRIGGER_UPDATE = 2 46 47 # Provider specific message 48 PROVIDER_SPECIFIC = 3 49 50 # Session ended 51 SESSION_END = 4
Commands which can be triggered over Websocket
ACK =
<WebsocketMessageInstruction.ACK: 0>
HELLO =
<WebsocketMessageInstruction.HELLO: 1>
TRIGGER_UPDATE =
<WebsocketMessageInstruction.TRIGGER_UPDATE: 2>
PROVIDER_SPECIFIC =
<WebsocketMessageInstruction.PROVIDER_SPECIFIC: 3>
SESSION_END =
<WebsocketMessageInstruction.SESSION_END: 4>
@dataclass(slots=True)
class
WebsocketMessage:
54@dataclass(slots=True) 55class WebsocketMessage: 56 """Complete Websocket Message that is being sent""" 57 58 instruction: int 59 args: dict[str, Any] = field(default_factory=dict)
Complete Websocket Message that is being sent
class
OutpostConsumer(channels.generic.websocket.JsonWebsocketConsumer):
62class OutpostConsumer(JsonWebsocketConsumer): 63 """Handler for Outposts that connect over websockets for health checks and live updates""" 64 65 outpost: Outpost | None = None 66 logger: BoundLogger 67 68 instance_uid: str | None = None 69 70 def __init__(self, *args, **kwargs): 71 super().__init__(*args, **kwargs) 72 self.logger = get_logger() 73 74 def connect(self): 75 uuid = self.scope["url_route"]["kwargs"]["pk"] 76 user = self.scope["user"] 77 self.outpost: Outpost | None = ( 78 get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first() 79 ) 80 if self.outpost is None: 81 raise DenyConnection() 82 self.logger = self.logger.bind(outpost=self.outpost) 83 try: 84 self.accept() 85 except RuntimeError as exc: 86 self.logger.warning("runtime error during accept", exc=exc) 87 raise DenyConnection() from None 88 query = QueryDict(self.scope["query_string"].decode()) 89 self.instance_uid = query.get("instance_uuid", self.channel_name) 90 async_to_sync(self.channel_layer.group_add)( 91 build_outpost_group(self.outpost.pk), self.channel_name 92 ) 93 async_to_sync(self.channel_layer.group_add)( 94 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 95 self.channel_name, 96 ) 97 GAUGE_OUTPOSTS_CONNECTED.labels( 98 tenant=connection.schema_name, 99 outpost=self.outpost.name, 100 uid=self.instance_uid, 101 expected=self.outpost.config.kubernetes_replicas, 102 ).inc() 103 104 def disconnect(self, code): 105 if self.outpost: 106 async_to_sync(self.channel_layer.group_discard)( 107 build_outpost_group(self.outpost.pk), self.channel_name 108 ) 109 if self.instance_uid: 110 async_to_sync(self.channel_layer.group_discard)( 111 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 112 self.channel_name, 113 ) 114 if self.outpost and self.instance_uid: 115 GAUGE_OUTPOSTS_CONNECTED.labels( 116 tenant=connection.schema_name, 117 outpost=self.outpost.name, 118 uid=self.instance_uid, 119 expected=self.outpost.config.kubernetes_replicas, 120 ).dec() 121 122 def receive_json(self, content: Data, **kwargs): 123 msg = from_dict(WebsocketMessage, content) 124 if not self.outpost: 125 raise DenyConnection() 126 127 state = OutpostState.for_instance_uid(self.outpost, self.instance_uid) 128 state.last_seen = datetime.now() 129 state.hostname = msg.args.pop("hostname", "") 130 131 if msg.instruction == WebsocketMessageInstruction.HELLO: 132 state.version = msg.args.pop("version", None) 133 state.build_hash = msg.args.pop("buildHash", "") 134 state.golang_version = msg.args.pop("golangVersion", "") 135 state.openssl_enabled = msg.args.pop("opensslEnabled", False) 136 state.openssl_version = msg.args.pop("opensslVersion", "") 137 state.fips_enabled = msg.args.pop("fipsEnabled", False) 138 state.args.update(msg.args) 139 elif msg.instruction == WebsocketMessageInstruction.ACK: 140 return 141 GAUGE_OUTPOSTS_LAST_UPDATE.labels( 142 tenant=connection.schema_name, 143 outpost=self.outpost.name, 144 uid=self.instance_uid or "", 145 version=state.version or "", 146 ).set_to_current_time() 147 state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5) 148 149 response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK) 150 self.send_json(asdict(response)) 151 152 def event_update(self, event): # pragma: no cover 153 """Event handler which is called by post_save signals, Send update instruction""" 154 self.send_json( 155 asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) 156 ) 157 158 def event_session_end(self, event): 159 """Event handler which is called when a session is ended""" 160 self.send_json( 161 asdict( 162 WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event) 163 ) 164 ) 165 166 def event_provider_specific(self, event): 167 """Event handler which can be called by provider-specific 168 implementations to send specific messages to the outpost""" 169 self.send_json( 170 asdict( 171 WebsocketMessage( 172 instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event 173 ) 174 ) 175 )
Handler for Outposts that connect over websockets for health checks and live updates
def
connect(self):
74 def connect(self): 75 uuid = self.scope["url_route"]["kwargs"]["pk"] 76 user = self.scope["user"] 77 self.outpost: Outpost | None = ( 78 get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first() 79 ) 80 if self.outpost is None: 81 raise DenyConnection() 82 self.logger = self.logger.bind(outpost=self.outpost) 83 try: 84 self.accept() 85 except RuntimeError as exc: 86 self.logger.warning("runtime error during accept", exc=exc) 87 raise DenyConnection() from None 88 query = QueryDict(self.scope["query_string"].decode()) 89 self.instance_uid = query.get("instance_uuid", self.channel_name) 90 async_to_sync(self.channel_layer.group_add)( 91 build_outpost_group(self.outpost.pk), self.channel_name 92 ) 93 async_to_sync(self.channel_layer.group_add)( 94 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 95 self.channel_name, 96 ) 97 GAUGE_OUTPOSTS_CONNECTED.labels( 98 tenant=connection.schema_name, 99 outpost=self.outpost.name, 100 uid=self.instance_uid, 101 expected=self.outpost.config.kubernetes_replicas, 102 ).inc()
def
disconnect(self, code):
104 def disconnect(self, code): 105 if self.outpost: 106 async_to_sync(self.channel_layer.group_discard)( 107 build_outpost_group(self.outpost.pk), self.channel_name 108 ) 109 if self.instance_uid: 110 async_to_sync(self.channel_layer.group_discard)( 111 build_outpost_group_instance(self.outpost.pk, self.instance_uid), 112 self.channel_name, 113 ) 114 if self.outpost and self.instance_uid: 115 GAUGE_OUTPOSTS_CONNECTED.labels( 116 tenant=connection.schema_name, 117 outpost=self.outpost.name, 118 uid=self.instance_uid, 119 expected=self.outpost.config.kubernetes_replicas, 120 ).dec()
Called when a WebSocket connection is closed.
def
receive_json(self, content: dacite.data.Data, **kwargs):
122 def receive_json(self, content: Data, **kwargs): 123 msg = from_dict(WebsocketMessage, content) 124 if not self.outpost: 125 raise DenyConnection() 126 127 state = OutpostState.for_instance_uid(self.outpost, self.instance_uid) 128 state.last_seen = datetime.now() 129 state.hostname = msg.args.pop("hostname", "") 130 131 if msg.instruction == WebsocketMessageInstruction.HELLO: 132 state.version = msg.args.pop("version", None) 133 state.build_hash = msg.args.pop("buildHash", "") 134 state.golang_version = msg.args.pop("golangVersion", "") 135 state.openssl_enabled = msg.args.pop("opensslEnabled", False) 136 state.openssl_version = msg.args.pop("opensslVersion", "") 137 state.fips_enabled = msg.args.pop("fipsEnabled", False) 138 state.args.update(msg.args) 139 elif msg.instruction == WebsocketMessageInstruction.ACK: 140 return 141 GAUGE_OUTPOSTS_LAST_UPDATE.labels( 142 tenant=connection.schema_name, 143 outpost=self.outpost.name, 144 uid=self.instance_uid or "", 145 version=state.version or "", 146 ).set_to_current_time() 147 state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5) 148 149 response = WebsocketMessage(instruction=WebsocketMessageInstruction.ACK) 150 self.send_json(asdict(response))
Called with decoded JSON content.
def
event_update(self, event):
152 def event_update(self, event): # pragma: no cover 153 """Event handler which is called by post_save signals, Send update instruction""" 154 self.send_json( 155 asdict(WebsocketMessage(instruction=WebsocketMessageInstruction.TRIGGER_UPDATE)) 156 )
Event handler which is called by post_save signals, Send update instruction
def
event_session_end(self, event):
158 def event_session_end(self, event): 159 """Event handler which is called when a session is ended""" 160 self.send_json( 161 asdict( 162 WebsocketMessage(instruction=WebsocketMessageInstruction.SESSION_END, args=event) 163 ) 164 )
Event handler which is called when a session is ended
def
event_provider_specific(self, event):
166 def event_provider_specific(self, event): 167 """Event handler which can be called by provider-specific 168 implementations to send specific messages to the outpost""" 169 self.send_json( 170 asdict( 171 WebsocketMessage( 172 instruction=WebsocketMessageInstruction.PROVIDER_SPECIFIC, args=event 173 ) 174 ) 175 )
Event handler which can be called by provider-specific implementations to send specific messages to the outpost