authentik.providers.rac.consumer_client

RAC Client consumer

  1"""RAC Client consumer"""
  2
  3from hashlib import sha256
  4
  5from asgiref.sync import async_to_sync
  6from channels.db import database_sync_to_async
  7from channels.exceptions import ChannelFull, DenyConnection
  8from channels.generic.websocket import AsyncWebsocketConsumer
  9from django.db import connection
 10from django.http.request import QueryDict
 11from structlog.stdlib import BoundLogger, get_logger
 12
 13from authentik.outposts.consumer import build_outpost_group_instance
 14from authentik.outposts.models import Outpost, OutpostState, OutpostType
 15from authentik.providers.rac.models import ConnectionToken, RACProvider
 16
 17
 18def build_rac_client_group() -> str:
 19    """
 20    Global broadcast group, which messages are sent to when the outpost connects back
 21    to authentik for a specific connection
 22    The `RACClientConsumer` consumer adds itself to this group on connection,
 23    and removes itself once it has been assigned a specific outpost channel
 24    """
 25    return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest()
 26
 27
 28def build_rac_client_group_session(session_key: str) -> str:
 29    """
 30    A group for all connections in a given authentik session ID
 31    A disconnect message is sent to this group when the session expires/is deleted
 32    """
 33    return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest()
 34
 35
 36def build_rac_client_group_token(token: str) -> str:
 37    """
 38    A group for all connections with a specific token, which in almost all cases
 39    is just one connection, however this is used to disconnect the connection
 40    when the token is deleted
 41    """
 42    return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest()
 43
 44
 45# Step 1: Client connects to this websocket endpoint
 46# Step 2: We prepare all the connection args for Guac
 47# Step 3: Send a websocket message to a single outpost that has this provider assigned
 48#         (Currently sending to all of them)
 49#         (Should probably do different load balancing algorithms)
 50# Step 4: Outpost creates a websocket connection back to authentik
 51#         with /ws/outpost_rac/<our_channel_id>/
 52# Step 5: This consumer transfers data between the two channels
 53
 54
 55class RACClientConsumer(AsyncWebsocketConsumer):
 56    """RAC client consumer the browser connects to"""
 57
 58    dest_channel_id: str = ""
 59    provider: RACProvider
 60    token: ConnectionToken
 61    logger: BoundLogger
 62
 63    async def connect(self):
 64        self.logger = get_logger()
 65        await self.accept("guacamole")
 66        await self.channel_layer.group_add(build_rac_client_group(), self.channel_name)
 67        await self.channel_layer.group_add(
 68            build_rac_client_group_session(self.scope["session"].session_key),
 69            self.channel_name,
 70        )
 71        await self.init_outpost_connection()
 72
 73    async def disconnect(self, code):
 74        self.logger.debug("Disconnecting")
 75        if self.dest_channel_id:
 76            # Tell the outpost we're disconnecting
 77            await self.channel_layer.send(
 78                self.dest_channel_id,
 79                {
 80                    "type": "event.disconnect",
 81                },
 82            )
 83
 84    @database_sync_to_async
 85    def init_outpost_connection(self):
 86        """Initialize guac connection settings"""
 87        self.token = (
 88            ConnectionToken.objects.filter(
 89                token=self.scope["url_route"]["kwargs"]["token"],
 90                session__session__session_key=self.scope["session"].session_key,
 91            )
 92            .select_related("endpoint", "provider", "session", "session__user")
 93            .first()
 94        )
 95        if not self.token:
 96            raise DenyConnection()
 97        self.provider = self.token.provider
 98        params = self.token.get_settings()
 99        self.logger = get_logger().bind(
100            endpoint=self.token.endpoint.name, user=self.scope["user"].username
101        )
102        msg = {
103            "type": "event.provider.specific",
104            "sub_type": "init_connection",
105            "dest_channel_id": self.channel_name,
106            "params": params,
107            "protocol": self.token.endpoint.protocol,
108        }
109        query = QueryDict(self.scope["query_string"].decode())
110        for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
111            value = query.get(key, None)
112            if not value:
113                continue
114            msg[key] = str(value)
115        outposts = Outpost.objects.filter(
116            type=OutpostType.RAC,
117            providers__in=[self.provider],
118        )
119        if not outposts.exists():
120            self.logger.warning("Provider has no outpost")
121            raise DenyConnection()
122        for outpost in outposts:
123            # Sort all states for the outpost by connection count
124            states = sorted(
125                OutpostState.for_outpost(outpost),
126                key=lambda state: int(state.args.get("active_connections", 0)),
127            )
128            if len(states) < 1:
129                continue
130            self.logger.debug("Sending out connection broadcast")
131            group = build_outpost_group_instance(outpost.pk, states[0].uid)
132            async_to_sync(self.channel_layer.group_send)(group, msg)
133        if self.provider and self.provider.delete_token_on_disconnect:
134            self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
135            self.token.delete()
136
137    async def receive(self, text_data=None, bytes_data=None):
138        """Mirror data received from client to the dest_channel_id
139        which is the channel talking to guacd"""
140        if self.dest_channel_id == "":
141            return
142        if self.token.is_expired:
143            await self.event_disconnect({"reason": "token_expiry"})
144            return
145        try:
146            await self.channel_layer.send(
147                self.dest_channel_id,
148                {
149                    "type": "event.send",
150                    "text_data": text_data,
151                    "bytes_data": bytes_data,
152                },
153            )
154        except ChannelFull:
155            pass
156
157    async def event_outpost_connected(self, event: dict):
158        """Handle event broadcasted from outpost consumer, and check if they
159        created a connection for us"""
160        outpost_channel = event.get("outpost_channel")
161        if event.get("client_channel") != self.channel_name:
162            return
163        if self.dest_channel_id != "":
164            # We've already selected an outpost channel, so tell the other channel to disconnect
165            # This should never happen since we remove ourselves from the broadcast group
166            await self.channel_layer.send(
167                outpost_channel,
168                {
169                    "type": "event.disconnect",
170                },
171            )
172            return
173        self.logger.debug("Connected to a single outpost instance")
174        self.dest_channel_id = outpost_channel
175        # Since we have a specific outpost channel now, we can remove
176        # ourselves from the global broadcast group
177        await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)
178
179    async def event_send(self, event: dict):
180        """Handler called by outpost websocket that sends data to this specific
181        client connection"""
182        if self.token.is_expired:
183            await self.event_disconnect({"reason": "token_expiry"})
184            return
185        await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
186
187    async def event_disconnect(self, event: dict):
188        """Disconnect when the session ends"""
189        self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
190        await self.close()
def build_rac_client_group() -> str:
19def build_rac_client_group() -> str:
20    """
21    Global broadcast group, which messages are sent to when the outpost connects back
22    to authentik for a specific connection
23    The `RACClientConsumer` consumer adds itself to this group on connection,
24    and removes itself once it has been assigned a specific outpost channel
25    """
26    return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest()

Global broadcast group, which messages are sent to when the outpost connects back to authentik for a specific connection The RACClientConsumer consumer adds itself to this group on connection, and removes itself once it has been assigned a specific outpost channel

def build_rac_client_group_session(session_key: str) -> str:
29def build_rac_client_group_session(session_key: str) -> str:
30    """
31    A group for all connections in a given authentik session ID
32    A disconnect message is sent to this group when the session expires/is deleted
33    """
34    return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest()

A group for all connections in a given authentik session ID A disconnect message is sent to this group when the session expires/is deleted

def build_rac_client_group_token(token: str) -> str:
37def build_rac_client_group_token(token: str) -> str:
38    """
39    A group for all connections with a specific token, which in almost all cases
40    is just one connection, however this is used to disconnect the connection
41    when the token is deleted
42    """
43    return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest()

A group for all connections with a specific token, which in almost all cases is just one connection, however this is used to disconnect the connection when the token is deleted

class RACClientConsumer(channels.generic.websocket.AsyncWebsocketConsumer):
 56class RACClientConsumer(AsyncWebsocketConsumer):
 57    """RAC client consumer the browser connects to"""
 58
 59    dest_channel_id: str = ""
 60    provider: RACProvider
 61    token: ConnectionToken
 62    logger: BoundLogger
 63
 64    async def connect(self):
 65        self.logger = get_logger()
 66        await self.accept("guacamole")
 67        await self.channel_layer.group_add(build_rac_client_group(), self.channel_name)
 68        await self.channel_layer.group_add(
 69            build_rac_client_group_session(self.scope["session"].session_key),
 70            self.channel_name,
 71        )
 72        await self.init_outpost_connection()
 73
 74    async def disconnect(self, code):
 75        self.logger.debug("Disconnecting")
 76        if self.dest_channel_id:
 77            # Tell the outpost we're disconnecting
 78            await self.channel_layer.send(
 79                self.dest_channel_id,
 80                {
 81                    "type": "event.disconnect",
 82                },
 83            )
 84
 85    @database_sync_to_async
 86    def init_outpost_connection(self):
 87        """Initialize guac connection settings"""
 88        self.token = (
 89            ConnectionToken.objects.filter(
 90                token=self.scope["url_route"]["kwargs"]["token"],
 91                session__session__session_key=self.scope["session"].session_key,
 92            )
 93            .select_related("endpoint", "provider", "session", "session__user")
 94            .first()
 95        )
 96        if not self.token:
 97            raise DenyConnection()
 98        self.provider = self.token.provider
 99        params = self.token.get_settings()
100        self.logger = get_logger().bind(
101            endpoint=self.token.endpoint.name, user=self.scope["user"].username
102        )
103        msg = {
104            "type": "event.provider.specific",
105            "sub_type": "init_connection",
106            "dest_channel_id": self.channel_name,
107            "params": params,
108            "protocol": self.token.endpoint.protocol,
109        }
110        query = QueryDict(self.scope["query_string"].decode())
111        for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
112            value = query.get(key, None)
113            if not value:
114                continue
115            msg[key] = str(value)
116        outposts = Outpost.objects.filter(
117            type=OutpostType.RAC,
118            providers__in=[self.provider],
119        )
120        if not outposts.exists():
121            self.logger.warning("Provider has no outpost")
122            raise DenyConnection()
123        for outpost in outposts:
124            # Sort all states for the outpost by connection count
125            states = sorted(
126                OutpostState.for_outpost(outpost),
127                key=lambda state: int(state.args.get("active_connections", 0)),
128            )
129            if len(states) < 1:
130                continue
131            self.logger.debug("Sending out connection broadcast")
132            group = build_outpost_group_instance(outpost.pk, states[0].uid)
133            async_to_sync(self.channel_layer.group_send)(group, msg)
134        if self.provider and self.provider.delete_token_on_disconnect:
135            self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
136            self.token.delete()
137
138    async def receive(self, text_data=None, bytes_data=None):
139        """Mirror data received from client to the dest_channel_id
140        which is the channel talking to guacd"""
141        if self.dest_channel_id == "":
142            return
143        if self.token.is_expired:
144            await self.event_disconnect({"reason": "token_expiry"})
145            return
146        try:
147            await self.channel_layer.send(
148                self.dest_channel_id,
149                {
150                    "type": "event.send",
151                    "text_data": text_data,
152                    "bytes_data": bytes_data,
153                },
154            )
155        except ChannelFull:
156            pass
157
158    async def event_outpost_connected(self, event: dict):
159        """Handle event broadcasted from outpost consumer, and check if they
160        created a connection for us"""
161        outpost_channel = event.get("outpost_channel")
162        if event.get("client_channel") != self.channel_name:
163            return
164        if self.dest_channel_id != "":
165            # We've already selected an outpost channel, so tell the other channel to disconnect
166            # This should never happen since we remove ourselves from the broadcast group
167            await self.channel_layer.send(
168                outpost_channel,
169                {
170                    "type": "event.disconnect",
171                },
172            )
173            return
174        self.logger.debug("Connected to a single outpost instance")
175        self.dest_channel_id = outpost_channel
176        # Since we have a specific outpost channel now, we can remove
177        # ourselves from the global broadcast group
178        await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)
179
180    async def event_send(self, event: dict):
181        """Handler called by outpost websocket that sends data to this specific
182        client connection"""
183        if self.token.is_expired:
184            await self.event_disconnect({"reason": "token_expiry"})
185            return
186        await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
187
188    async def event_disconnect(self, event: dict):
189        """Disconnect when the session ends"""
190        self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
191        await self.close()

RAC client consumer the browser connects to

dest_channel_id: str = ''
logger: structlog.stdlib.BoundLogger
async def connect(self):
64    async def connect(self):
65        self.logger = get_logger()
66        await self.accept("guacamole")
67        await self.channel_layer.group_add(build_rac_client_group(), self.channel_name)
68        await self.channel_layer.group_add(
69            build_rac_client_group_session(self.scope["session"].session_key),
70            self.channel_name,
71        )
72        await self.init_outpost_connection()
async def disconnect(self, code):
74    async def disconnect(self, code):
75        self.logger.debug("Disconnecting")
76        if self.dest_channel_id:
77            # Tell the outpost we're disconnecting
78            await self.channel_layer.send(
79                self.dest_channel_id,
80                {
81                    "type": "event.disconnect",
82                },
83            )

Called when a WebSocket connection is closed.

@database_sync_to_async
def init_outpost_connection(self):
 85    @database_sync_to_async
 86    def init_outpost_connection(self):
 87        """Initialize guac connection settings"""
 88        self.token = (
 89            ConnectionToken.objects.filter(
 90                token=self.scope["url_route"]["kwargs"]["token"],
 91                session__session__session_key=self.scope["session"].session_key,
 92            )
 93            .select_related("endpoint", "provider", "session", "session__user")
 94            .first()
 95        )
 96        if not self.token:
 97            raise DenyConnection()
 98        self.provider = self.token.provider
 99        params = self.token.get_settings()
100        self.logger = get_logger().bind(
101            endpoint=self.token.endpoint.name, user=self.scope["user"].username
102        )
103        msg = {
104            "type": "event.provider.specific",
105            "sub_type": "init_connection",
106            "dest_channel_id": self.channel_name,
107            "params": params,
108            "protocol": self.token.endpoint.protocol,
109        }
110        query = QueryDict(self.scope["query_string"].decode())
111        for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
112            value = query.get(key, None)
113            if not value:
114                continue
115            msg[key] = str(value)
116        outposts = Outpost.objects.filter(
117            type=OutpostType.RAC,
118            providers__in=[self.provider],
119        )
120        if not outposts.exists():
121            self.logger.warning("Provider has no outpost")
122            raise DenyConnection()
123        for outpost in outposts:
124            # Sort all states for the outpost by connection count
125            states = sorted(
126                OutpostState.for_outpost(outpost),
127                key=lambda state: int(state.args.get("active_connections", 0)),
128            )
129            if len(states) < 1:
130                continue
131            self.logger.debug("Sending out connection broadcast")
132            group = build_outpost_group_instance(outpost.pk, states[0].uid)
133            async_to_sync(self.channel_layer.group_send)(group, msg)
134        if self.provider and self.provider.delete_token_on_disconnect:
135            self.logger.info("Deleting connection token to prevent reconnect", token=self.token)
136            self.token.delete()

Initialize guac connection settings

async def receive(self, text_data=None, bytes_data=None):
138    async def receive(self, text_data=None, bytes_data=None):
139        """Mirror data received from client to the dest_channel_id
140        which is the channel talking to guacd"""
141        if self.dest_channel_id == "":
142            return
143        if self.token.is_expired:
144            await self.event_disconnect({"reason": "token_expiry"})
145            return
146        try:
147            await self.channel_layer.send(
148                self.dest_channel_id,
149                {
150                    "type": "event.send",
151                    "text_data": text_data,
152                    "bytes_data": bytes_data,
153                },
154            )
155        except ChannelFull:
156            pass

Mirror data received from client to the dest_channel_id which is the channel talking to guacd

async def event_outpost_connected(self, event: dict):
158    async def event_outpost_connected(self, event: dict):
159        """Handle event broadcasted from outpost consumer, and check if they
160        created a connection for us"""
161        outpost_channel = event.get("outpost_channel")
162        if event.get("client_channel") != self.channel_name:
163            return
164        if self.dest_channel_id != "":
165            # We've already selected an outpost channel, so tell the other channel to disconnect
166            # This should never happen since we remove ourselves from the broadcast group
167            await self.channel_layer.send(
168                outpost_channel,
169                {
170                    "type": "event.disconnect",
171                },
172            )
173            return
174        self.logger.debug("Connected to a single outpost instance")
175        self.dest_channel_id = outpost_channel
176        # Since we have a specific outpost channel now, we can remove
177        # ourselves from the global broadcast group
178        await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)

Handle event broadcasted from outpost consumer, and check if they created a connection for us

async def event_send(self, event: dict):
180    async def event_send(self, event: dict):
181        """Handler called by outpost websocket that sends data to this specific
182        client connection"""
183        if self.token.is_expired:
184            await self.event_disconnect({"reason": "token_expiry"})
185            return
186        await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))

Handler called by outpost websocket that sends data to this specific client connection

async def event_disconnect(self, event: dict):
188    async def event_disconnect(self, event: dict):
189        """Disconnect when the session ends"""
190        self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
191        await self.close()

Disconnect when the session ends