authentik.providers.rac.consumer_client
RAC Client consumer
1"""RAC Client consumer""" 2 3from hashlib import sha256 4 5from asgiref.sync import async_to_sync 6from channels.db import database_sync_to_async 7from channels.exceptions import ChannelFull, DenyConnection 8from channels.generic.websocket import AsyncWebsocketConsumer 9from django.db import connection 10from django.http.request import QueryDict 11from structlog.stdlib import BoundLogger, get_logger 12 13from authentik.outposts.consumer import build_outpost_group_instance 14from authentik.outposts.models import Outpost, OutpostState, OutpostType 15from authentik.providers.rac.models import ConnectionToken, RACProvider 16 17 18def build_rac_client_group() -> str: 19 """ 20 Global broadcast group, which messages are sent to when the outpost connects back 21 to authentik for a specific connection 22 The `RACClientConsumer` consumer adds itself to this group on connection, 23 and removes itself once it has been assigned a specific outpost channel 24 """ 25 return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest() 26 27 28def build_rac_client_group_session(session_key: str) -> str: 29 """ 30 A group for all connections in a given authentik session ID 31 A disconnect message is sent to this group when the session expires/is deleted 32 """ 33 return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest() 34 35 36def build_rac_client_group_token(token: str) -> str: 37 """ 38 A group for all connections with a specific token, which in almost all cases 39 is just one connection, however this is used to disconnect the connection 40 when the token is deleted 41 """ 42 return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest() 43 44 45# Step 1: Client connects to this websocket endpoint 46# Step 2: We prepare all the connection args for Guac 47# Step 3: Send a websocket message to a single outpost that has this provider assigned 48# (Currently sending to all of them) 49# (Should probably do different load balancing algorithms) 50# Step 4: Outpost creates a websocket connection back to authentik 51# with /ws/outpost_rac/<our_channel_id>/ 52# Step 5: This consumer transfers data between the two channels 53 54 55class RACClientConsumer(AsyncWebsocketConsumer): 56 """RAC client consumer the browser connects to""" 57 58 dest_channel_id: str = "" 59 provider: RACProvider 60 token: ConnectionToken 61 logger: BoundLogger 62 63 async def connect(self): 64 self.logger = get_logger() 65 await self.accept("guacamole") 66 await self.channel_layer.group_add(build_rac_client_group(), self.channel_name) 67 await self.channel_layer.group_add( 68 build_rac_client_group_session(self.scope["session"].session_key), 69 self.channel_name, 70 ) 71 await self.init_outpost_connection() 72 73 async def disconnect(self, code): 74 self.logger.debug("Disconnecting") 75 if self.dest_channel_id: 76 # Tell the outpost we're disconnecting 77 await self.channel_layer.send( 78 self.dest_channel_id, 79 { 80 "type": "event.disconnect", 81 }, 82 ) 83 84 @database_sync_to_async 85 def init_outpost_connection(self): 86 """Initialize guac connection settings""" 87 self.token = ( 88 ConnectionToken.objects.filter( 89 token=self.scope["url_route"]["kwargs"]["token"], 90 session__session__session_key=self.scope["session"].session_key, 91 ) 92 .select_related("endpoint", "provider", "session", "session__user") 93 .first() 94 ) 95 if not self.token: 96 raise DenyConnection() 97 self.provider = self.token.provider 98 params = self.token.get_settings() 99 self.logger = get_logger().bind( 100 endpoint=self.token.endpoint.name, user=self.scope["user"].username 101 ) 102 msg = { 103 "type": "event.provider.specific", 104 "sub_type": "init_connection", 105 "dest_channel_id": self.channel_name, 106 "params": params, 107 "protocol": self.token.endpoint.protocol, 108 } 109 query = QueryDict(self.scope["query_string"].decode()) 110 for key in ["screen_width", "screen_height", "screen_dpi", "audio"]: 111 value = query.get(key, None) 112 if not value: 113 continue 114 msg[key] = str(value) 115 outposts = Outpost.objects.filter( 116 type=OutpostType.RAC, 117 providers__in=[self.provider], 118 ) 119 if not outposts.exists(): 120 self.logger.warning("Provider has no outpost") 121 raise DenyConnection() 122 for outpost in outposts: 123 # Sort all states for the outpost by connection count 124 states = sorted( 125 OutpostState.for_outpost(outpost), 126 key=lambda state: int(state.args.get("active_connections", 0)), 127 ) 128 if len(states) < 1: 129 continue 130 self.logger.debug("Sending out connection broadcast") 131 group = build_outpost_group_instance(outpost.pk, states[0].uid) 132 async_to_sync(self.channel_layer.group_send)(group, msg) 133 if self.provider and self.provider.delete_token_on_disconnect: 134 self.logger.info("Deleting connection token to prevent reconnect", token=self.token) 135 self.token.delete() 136 137 async def receive(self, text_data=None, bytes_data=None): 138 """Mirror data received from client to the dest_channel_id 139 which is the channel talking to guacd""" 140 if self.dest_channel_id == "": 141 return 142 if self.token.is_expired: 143 await self.event_disconnect({"reason": "token_expiry"}) 144 return 145 try: 146 await self.channel_layer.send( 147 self.dest_channel_id, 148 { 149 "type": "event.send", 150 "text_data": text_data, 151 "bytes_data": bytes_data, 152 }, 153 ) 154 except ChannelFull: 155 pass 156 157 async def event_outpost_connected(self, event: dict): 158 """Handle event broadcasted from outpost consumer, and check if they 159 created a connection for us""" 160 outpost_channel = event.get("outpost_channel") 161 if event.get("client_channel") != self.channel_name: 162 return 163 if self.dest_channel_id != "": 164 # We've already selected an outpost channel, so tell the other channel to disconnect 165 # This should never happen since we remove ourselves from the broadcast group 166 await self.channel_layer.send( 167 outpost_channel, 168 { 169 "type": "event.disconnect", 170 }, 171 ) 172 return 173 self.logger.debug("Connected to a single outpost instance") 174 self.dest_channel_id = outpost_channel 175 # Since we have a specific outpost channel now, we can remove 176 # ourselves from the global broadcast group 177 await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name) 178 179 async def event_send(self, event: dict): 180 """Handler called by outpost websocket that sends data to this specific 181 client connection""" 182 if self.token.is_expired: 183 await self.event_disconnect({"reason": "token_expiry"}) 184 return 185 await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data")) 186 187 async def event_disconnect(self, event: dict): 188 """Disconnect when the session ends""" 189 self.logger.info("Disconnecting RAC connection", reason=event.get("reason")) 190 await self.close()
19def build_rac_client_group() -> str: 20 """ 21 Global broadcast group, which messages are sent to when the outpost connects back 22 to authentik for a specific connection 23 The `RACClientConsumer` consumer adds itself to this group on connection, 24 and removes itself once it has been assigned a specific outpost channel 25 """ 26 return sha256(f"{connection.schema_name}/group_rac_client".encode()).hexdigest()
Global broadcast group, which messages are sent to when the outpost connects back
to authentik for a specific connection
The RACClientConsumer consumer adds itself to this group on connection,
and removes itself once it has been assigned a specific outpost channel
29def build_rac_client_group_session(session_key: str) -> str: 30 """ 31 A group for all connections in a given authentik session ID 32 A disconnect message is sent to this group when the session expires/is deleted 33 """ 34 return sha256(f"{connection.schema_name}/group_rac_client_{session_key}".encode()).hexdigest()
A group for all connections in a given authentik session ID A disconnect message is sent to this group when the session expires/is deleted
37def build_rac_client_group_token(token: str) -> str: 38 """ 39 A group for all connections with a specific token, which in almost all cases 40 is just one connection, however this is used to disconnect the connection 41 when the token is deleted 42 """ 43 return sha256(f"{connection.schema_name}/group_rac_token_{token}".encode()).hexdigest()
A group for all connections with a specific token, which in almost all cases is just one connection, however this is used to disconnect the connection when the token is deleted
56class RACClientConsumer(AsyncWebsocketConsumer): 57 """RAC client consumer the browser connects to""" 58 59 dest_channel_id: str = "" 60 provider: RACProvider 61 token: ConnectionToken 62 logger: BoundLogger 63 64 async def connect(self): 65 self.logger = get_logger() 66 await self.accept("guacamole") 67 await self.channel_layer.group_add(build_rac_client_group(), self.channel_name) 68 await self.channel_layer.group_add( 69 build_rac_client_group_session(self.scope["session"].session_key), 70 self.channel_name, 71 ) 72 await self.init_outpost_connection() 73 74 async def disconnect(self, code): 75 self.logger.debug("Disconnecting") 76 if self.dest_channel_id: 77 # Tell the outpost we're disconnecting 78 await self.channel_layer.send( 79 self.dest_channel_id, 80 { 81 "type": "event.disconnect", 82 }, 83 ) 84 85 @database_sync_to_async 86 def init_outpost_connection(self): 87 """Initialize guac connection settings""" 88 self.token = ( 89 ConnectionToken.objects.filter( 90 token=self.scope["url_route"]["kwargs"]["token"], 91 session__session__session_key=self.scope["session"].session_key, 92 ) 93 .select_related("endpoint", "provider", "session", "session__user") 94 .first() 95 ) 96 if not self.token: 97 raise DenyConnection() 98 self.provider = self.token.provider 99 params = self.token.get_settings() 100 self.logger = get_logger().bind( 101 endpoint=self.token.endpoint.name, user=self.scope["user"].username 102 ) 103 msg = { 104 "type": "event.provider.specific", 105 "sub_type": "init_connection", 106 "dest_channel_id": self.channel_name, 107 "params": params, 108 "protocol": self.token.endpoint.protocol, 109 } 110 query = QueryDict(self.scope["query_string"].decode()) 111 for key in ["screen_width", "screen_height", "screen_dpi", "audio"]: 112 value = query.get(key, None) 113 if not value: 114 continue 115 msg[key] = str(value) 116 outposts = Outpost.objects.filter( 117 type=OutpostType.RAC, 118 providers__in=[self.provider], 119 ) 120 if not outposts.exists(): 121 self.logger.warning("Provider has no outpost") 122 raise DenyConnection() 123 for outpost in outposts: 124 # Sort all states for the outpost by connection count 125 states = sorted( 126 OutpostState.for_outpost(outpost), 127 key=lambda state: int(state.args.get("active_connections", 0)), 128 ) 129 if len(states) < 1: 130 continue 131 self.logger.debug("Sending out connection broadcast") 132 group = build_outpost_group_instance(outpost.pk, states[0].uid) 133 async_to_sync(self.channel_layer.group_send)(group, msg) 134 if self.provider and self.provider.delete_token_on_disconnect: 135 self.logger.info("Deleting connection token to prevent reconnect", token=self.token) 136 self.token.delete() 137 138 async def receive(self, text_data=None, bytes_data=None): 139 """Mirror data received from client to the dest_channel_id 140 which is the channel talking to guacd""" 141 if self.dest_channel_id == "": 142 return 143 if self.token.is_expired: 144 await self.event_disconnect({"reason": "token_expiry"}) 145 return 146 try: 147 await self.channel_layer.send( 148 self.dest_channel_id, 149 { 150 "type": "event.send", 151 "text_data": text_data, 152 "bytes_data": bytes_data, 153 }, 154 ) 155 except ChannelFull: 156 pass 157 158 async def event_outpost_connected(self, event: dict): 159 """Handle event broadcasted from outpost consumer, and check if they 160 created a connection for us""" 161 outpost_channel = event.get("outpost_channel") 162 if event.get("client_channel") != self.channel_name: 163 return 164 if self.dest_channel_id != "": 165 # We've already selected an outpost channel, so tell the other channel to disconnect 166 # This should never happen since we remove ourselves from the broadcast group 167 await self.channel_layer.send( 168 outpost_channel, 169 { 170 "type": "event.disconnect", 171 }, 172 ) 173 return 174 self.logger.debug("Connected to a single outpost instance") 175 self.dest_channel_id = outpost_channel 176 # Since we have a specific outpost channel now, we can remove 177 # ourselves from the global broadcast group 178 await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name) 179 180 async def event_send(self, event: dict): 181 """Handler called by outpost websocket that sends data to this specific 182 client connection""" 183 if self.token.is_expired: 184 await self.event_disconnect({"reason": "token_expiry"}) 185 return 186 await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data")) 187 188 async def event_disconnect(self, event: dict): 189 """Disconnect when the session ends""" 190 self.logger.info("Disconnecting RAC connection", reason=event.get("reason")) 191 await self.close()
RAC client consumer the browser connects to
64 async def connect(self): 65 self.logger = get_logger() 66 await self.accept("guacamole") 67 await self.channel_layer.group_add(build_rac_client_group(), self.channel_name) 68 await self.channel_layer.group_add( 69 build_rac_client_group_session(self.scope["session"].session_key), 70 self.channel_name, 71 ) 72 await self.init_outpost_connection()
74 async def disconnect(self, code): 75 self.logger.debug("Disconnecting") 76 if self.dest_channel_id: 77 # Tell the outpost we're disconnecting 78 await self.channel_layer.send( 79 self.dest_channel_id, 80 { 81 "type": "event.disconnect", 82 }, 83 )
Called when a WebSocket connection is closed.
85 @database_sync_to_async 86 def init_outpost_connection(self): 87 """Initialize guac connection settings""" 88 self.token = ( 89 ConnectionToken.objects.filter( 90 token=self.scope["url_route"]["kwargs"]["token"], 91 session__session__session_key=self.scope["session"].session_key, 92 ) 93 .select_related("endpoint", "provider", "session", "session__user") 94 .first() 95 ) 96 if not self.token: 97 raise DenyConnection() 98 self.provider = self.token.provider 99 params = self.token.get_settings() 100 self.logger = get_logger().bind( 101 endpoint=self.token.endpoint.name, user=self.scope["user"].username 102 ) 103 msg = { 104 "type": "event.provider.specific", 105 "sub_type": "init_connection", 106 "dest_channel_id": self.channel_name, 107 "params": params, 108 "protocol": self.token.endpoint.protocol, 109 } 110 query = QueryDict(self.scope["query_string"].decode()) 111 for key in ["screen_width", "screen_height", "screen_dpi", "audio"]: 112 value = query.get(key, None) 113 if not value: 114 continue 115 msg[key] = str(value) 116 outposts = Outpost.objects.filter( 117 type=OutpostType.RAC, 118 providers__in=[self.provider], 119 ) 120 if not outposts.exists(): 121 self.logger.warning("Provider has no outpost") 122 raise DenyConnection() 123 for outpost in outposts: 124 # Sort all states for the outpost by connection count 125 states = sorted( 126 OutpostState.for_outpost(outpost), 127 key=lambda state: int(state.args.get("active_connections", 0)), 128 ) 129 if len(states) < 1: 130 continue 131 self.logger.debug("Sending out connection broadcast") 132 group = build_outpost_group_instance(outpost.pk, states[0].uid) 133 async_to_sync(self.channel_layer.group_send)(group, msg) 134 if self.provider and self.provider.delete_token_on_disconnect: 135 self.logger.info("Deleting connection token to prevent reconnect", token=self.token) 136 self.token.delete()
Initialize guac connection settings
138 async def receive(self, text_data=None, bytes_data=None): 139 """Mirror data received from client to the dest_channel_id 140 which is the channel talking to guacd""" 141 if self.dest_channel_id == "": 142 return 143 if self.token.is_expired: 144 await self.event_disconnect({"reason": "token_expiry"}) 145 return 146 try: 147 await self.channel_layer.send( 148 self.dest_channel_id, 149 { 150 "type": "event.send", 151 "text_data": text_data, 152 "bytes_data": bytes_data, 153 }, 154 ) 155 except ChannelFull: 156 pass
Mirror data received from client to the dest_channel_id which is the channel talking to guacd
158 async def event_outpost_connected(self, event: dict): 159 """Handle event broadcasted from outpost consumer, and check if they 160 created a connection for us""" 161 outpost_channel = event.get("outpost_channel") 162 if event.get("client_channel") != self.channel_name: 163 return 164 if self.dest_channel_id != "": 165 # We've already selected an outpost channel, so tell the other channel to disconnect 166 # This should never happen since we remove ourselves from the broadcast group 167 await self.channel_layer.send( 168 outpost_channel, 169 { 170 "type": "event.disconnect", 171 }, 172 ) 173 return 174 self.logger.debug("Connected to a single outpost instance") 175 self.dest_channel_id = outpost_channel 176 # Since we have a specific outpost channel now, we can remove 177 # ourselves from the global broadcast group 178 await self.channel_layer.group_discard(build_rac_client_group(), self.channel_name)
Handle event broadcasted from outpost consumer, and check if they created a connection for us
180 async def event_send(self, event: dict): 181 """Handler called by outpost websocket that sends data to this specific 182 client connection""" 183 if self.token.is_expired: 184 await self.event_disconnect({"reason": "token_expiry"}) 185 return 186 await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
Handler called by outpost websocket that sends data to this specific client connection