authentik.policies.engine
authentik policy engine
1"""authentik policy engine""" 2 3from collections.abc import Iterable 4from multiprocessing import Pipe, current_process 5from multiprocessing.connection import Connection 6 7from django.core.cache import cache 8from django.db.models import Count, Q, QuerySet 9from django.http import HttpRequest 10from sentry_sdk import start_span 11from sentry_sdk.tracing import Span 12from structlog.stdlib import BoundLogger, get_logger 13 14from authentik.core.models import User 15from authentik.lib.utils.reflection import class_to_path 16from authentik.policies.apps import HIST_POLICIES_ENGINE_TOTAL_TIME, HIST_POLICIES_EXECUTION_TIME 17from authentik.policies.exceptions import PolicyEngineException 18from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode 19from authentik.policies.process import PolicyProcess, cache_key 20from authentik.policies.types import PolicyRequest, PolicyResult 21 22CURRENT_PROCESS = current_process() 23 24 25class PolicyProcessInfo: 26 """Dataclass to hold all information and communication channels to a process""" 27 28 process: PolicyProcess 29 connection: Connection 30 result: PolicyResult | None 31 binding: PolicyBinding 32 33 def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding): 34 self.process = process 35 self.connection = connection 36 self.binding = binding 37 self.result = None 38 39 40class PolicyEngine: 41 """Orchestrate policy checking, launch tasks and return result""" 42 43 use_cache: bool 44 request: PolicyRequest 45 46 logger: BoundLogger 47 mode: PolicyEngineMode 48 # Allow objects with no policies attached to pass 49 empty_result: bool 50 51 def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None): 52 self.logger = get_logger().bind() 53 self.mode = pbm.policy_engine_mode 54 # For backwards compatibility, set empty_result to true 55 # objects with no policies attached will pass. 56 self.empty_result = True 57 if not isinstance(pbm, PolicyBindingModel): # pragma: no cover 58 raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel") 59 if not user: 60 raise PolicyEngineException("User must be set") 61 self.__pbm = pbm 62 self.request = PolicyRequest(user) 63 self.request.obj = pbm 64 if request: 65 self.request.set_http_request(request) 66 self.__cached_policies: list[PolicyResult] = [] 67 self.__processes: list[PolicyProcessInfo] = [] 68 self.use_cache = True 69 self.__expected_result_count = 0 70 self.__static_result: PolicyResult | None = None 71 72 def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]: 73 """Make sure all Policies are their respective classes""" 74 return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order") 75 76 def _check_policy_type(self, binding: PolicyBinding): 77 """Check policy type, make sure it's not the root class as that has no logic implemented""" 78 if binding.policy is not None and binding.policy.__class__ == Policy: 79 raise PolicyEngineException(f"Policy '{binding.policy}' is root type") 80 81 def _check_cache(self, binding: PolicyBinding): 82 if not self.use_cache: 83 return False 84 # It's a bit silly to time this, but 85 with HIST_POLICIES_EXECUTION_TIME.labels( 86 binding_order=binding.order, 87 binding_target_type=binding.target_type, 88 binding_target_name=binding.target_name, 89 object_type=class_to_path(self.request.obj.__class__), 90 mode="cache_retrieve", 91 ).time(): 92 key = cache_key(binding, self.request) 93 cached_policy = cache.get(key, None) 94 if not cached_policy: 95 return False 96 self.logger.debug( 97 "P_ENG: Taking result from cache", 98 binding=binding, 99 cache_key=key, 100 request=self.request, 101 ) 102 self.__cached_policies.append(cached_policy) 103 return True 104 105 def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): 106 """Check static bindings if possible""" 107 aggrs = { 108 "total": Count( 109 "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None) 110 ), 111 } 112 if self.request.user.pk: 113 all_groups = self.request.user.all_groups() 114 aggrs["passing"] = Count( 115 "pk", 116 filter=Q( 117 Q( 118 Q(user=self.request.user) | Q(group__in=all_groups), 119 negate=False, 120 ) 121 | Q( 122 Q(~Q(user=self.request.user), user__isnull=False) 123 | Q(~Q(group__in=all_groups), group__isnull=False), 124 negate=True, 125 ), 126 enabled=True, 127 ), 128 ) 129 matched_bindings = bindings.aggregate(**aggrs) 130 passing = False 131 if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0: 132 # If we didn't find any static bindings, do nothing 133 return 134 self.logger.debug("P_ENG: Found static bindings", **matched_bindings) 135 if self.mode == PolicyEngineMode.MODE_ANY: 136 if matched_bindings.get("passing", 0) > 0: 137 # Any passing static binding -> passing 138 passing = True 139 elif self.mode == PolicyEngineMode.MODE_ALL: 140 if matched_bindings.get("passing", 0) == matched_bindings["total"]: 141 # All static bindings are passing -> passing 142 passing = True 143 elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1: 144 # No matching static bindings but at least one is configured -> not passing 145 passing = False 146 self.__static_result = PolicyResult(passing) 147 148 def build(self) -> PolicyEngine: 149 """Build wrapper which monitors performance""" 150 with ( 151 start_span( 152 op="authentik.policy.engine.build", 153 name=self.__pbm, 154 ) as span, 155 HIST_POLICIES_ENGINE_TOTAL_TIME.labels( 156 obj_type=class_to_path(self.__pbm.__class__), 157 obj_pk=str(self.__pbm.pk), 158 ).time(), 159 ): 160 span: Span 161 span.set_data("pbm", self.__pbm) 162 span.set_data("request", self.request) 163 bindings = self.bindings() 164 policy_bindings = bindings 165 if isinstance(bindings, QuerySet): 166 self.compute_static_bindings(bindings) 167 policy_bindings = [x for x in bindings if x.policy] 168 for binding in policy_bindings: 169 self.__expected_result_count += 1 170 171 self._check_policy_type(binding) 172 if self._check_cache(binding): 173 continue 174 self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request) 175 our_end, task_end = Pipe(False) 176 task = PolicyProcess(binding, self.request, task_end) 177 task.daemon = False 178 self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request) 179 if not CURRENT_PROCESS._config.get("daemon"): 180 task.run() 181 else: 182 task.start() 183 self.__processes.append( 184 PolicyProcessInfo(process=task, connection=our_end, binding=binding) 185 ) 186 # If all policies are cached, we have an empty list here. 187 for proc_info in self.__processes: 188 if proc_info.process.is_alive(): 189 proc_info.process.join(proc_info.binding.timeout) 190 # Only call .recv() if no result is saved, otherwise we just deadlock here 191 if not proc_info.result: 192 proc_info.result = proc_info.connection.recv() 193 if proc_info.result and proc_info.result._exec_time: 194 HIST_POLICIES_EXECUTION_TIME.labels( 195 binding_order=proc_info.binding.order, 196 binding_target_type=proc_info.binding.target_type, 197 binding_target_name=proc_info.binding.target_name, 198 object_type=( 199 class_to_path(self.request.obj.__class__) if self.request.obj else "" 200 ), 201 mode="execute_process", 202 ).observe(proc_info.result._exec_time) 203 return self 204 205 @property 206 def result(self) -> PolicyResult: 207 """Get policy-checking result""" 208 self.__processes.sort(key=lambda x: x.binding.order) 209 process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] 210 all_results = list(process_results + self.__cached_policies) 211 if len(all_results) < self.__expected_result_count: # pragma: no cover 212 raise AssertionError("Got less results than polices") 213 if self.__static_result: 214 all_results.append(self.__static_result) 215 # No results, no policies attached -> passing 216 if len(all_results) == 0: 217 return PolicyResult(self.empty_result) 218 passing = False 219 if self.mode == PolicyEngineMode.MODE_ALL: 220 passing = all(x.passing for x in all_results) 221 if self.mode == PolicyEngineMode.MODE_ANY: 222 passing = any(x.passing for x in all_results) 223 result = PolicyResult(passing) 224 result.source_results = all_results 225 result.messages = tuple(y for x in all_results for y in x.messages) 226 return result 227 228 @property 229 def passing(self) -> bool: 230 """Only get true/false if user passes""" 231 return self.result.passing
CURRENT_PROCESS =
<_MainProcess name='MainProcess' parent=None started>
class
PolicyProcessInfo:
26class PolicyProcessInfo: 27 """Dataclass to hold all information and communication channels to a process""" 28 29 process: PolicyProcess 30 connection: Connection 31 result: PolicyResult | None 32 binding: PolicyBinding 33 34 def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding): 35 self.process = process 36 self.connection = connection 37 self.binding = binding 38 self.result = None
Dataclass to hold all information and communication channels to a process
PolicyProcessInfo( process: authentik.policies.process.PolicyProcess, connection: multiprocessing.connection.Connection, binding: authentik.policies.models.PolicyBinding)
result: authentik.policies.types.PolicyResult | None
class
PolicyEngine:
41class PolicyEngine: 42 """Orchestrate policy checking, launch tasks and return result""" 43 44 use_cache: bool 45 request: PolicyRequest 46 47 logger: BoundLogger 48 mode: PolicyEngineMode 49 # Allow objects with no policies attached to pass 50 empty_result: bool 51 52 def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None): 53 self.logger = get_logger().bind() 54 self.mode = pbm.policy_engine_mode 55 # For backwards compatibility, set empty_result to true 56 # objects with no policies attached will pass. 57 self.empty_result = True 58 if not isinstance(pbm, PolicyBindingModel): # pragma: no cover 59 raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel") 60 if not user: 61 raise PolicyEngineException("User must be set") 62 self.__pbm = pbm 63 self.request = PolicyRequest(user) 64 self.request.obj = pbm 65 if request: 66 self.request.set_http_request(request) 67 self.__cached_policies: list[PolicyResult] = [] 68 self.__processes: list[PolicyProcessInfo] = [] 69 self.use_cache = True 70 self.__expected_result_count = 0 71 self.__static_result: PolicyResult | None = None 72 73 def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]: 74 """Make sure all Policies are their respective classes""" 75 return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order") 76 77 def _check_policy_type(self, binding: PolicyBinding): 78 """Check policy type, make sure it's not the root class as that has no logic implemented""" 79 if binding.policy is not None and binding.policy.__class__ == Policy: 80 raise PolicyEngineException(f"Policy '{binding.policy}' is root type") 81 82 def _check_cache(self, binding: PolicyBinding): 83 if not self.use_cache: 84 return False 85 # It's a bit silly to time this, but 86 with HIST_POLICIES_EXECUTION_TIME.labels( 87 binding_order=binding.order, 88 binding_target_type=binding.target_type, 89 binding_target_name=binding.target_name, 90 object_type=class_to_path(self.request.obj.__class__), 91 mode="cache_retrieve", 92 ).time(): 93 key = cache_key(binding, self.request) 94 cached_policy = cache.get(key, None) 95 if not cached_policy: 96 return False 97 self.logger.debug( 98 "P_ENG: Taking result from cache", 99 binding=binding, 100 cache_key=key, 101 request=self.request, 102 ) 103 self.__cached_policies.append(cached_policy) 104 return True 105 106 def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): 107 """Check static bindings if possible""" 108 aggrs = { 109 "total": Count( 110 "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None) 111 ), 112 } 113 if self.request.user.pk: 114 all_groups = self.request.user.all_groups() 115 aggrs["passing"] = Count( 116 "pk", 117 filter=Q( 118 Q( 119 Q(user=self.request.user) | Q(group__in=all_groups), 120 negate=False, 121 ) 122 | Q( 123 Q(~Q(user=self.request.user), user__isnull=False) 124 | Q(~Q(group__in=all_groups), group__isnull=False), 125 negate=True, 126 ), 127 enabled=True, 128 ), 129 ) 130 matched_bindings = bindings.aggregate(**aggrs) 131 passing = False 132 if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0: 133 # If we didn't find any static bindings, do nothing 134 return 135 self.logger.debug("P_ENG: Found static bindings", **matched_bindings) 136 if self.mode == PolicyEngineMode.MODE_ANY: 137 if matched_bindings.get("passing", 0) > 0: 138 # Any passing static binding -> passing 139 passing = True 140 elif self.mode == PolicyEngineMode.MODE_ALL: 141 if matched_bindings.get("passing", 0) == matched_bindings["total"]: 142 # All static bindings are passing -> passing 143 passing = True 144 elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1: 145 # No matching static bindings but at least one is configured -> not passing 146 passing = False 147 self.__static_result = PolicyResult(passing) 148 149 def build(self) -> PolicyEngine: 150 """Build wrapper which monitors performance""" 151 with ( 152 start_span( 153 op="authentik.policy.engine.build", 154 name=self.__pbm, 155 ) as span, 156 HIST_POLICIES_ENGINE_TOTAL_TIME.labels( 157 obj_type=class_to_path(self.__pbm.__class__), 158 obj_pk=str(self.__pbm.pk), 159 ).time(), 160 ): 161 span: Span 162 span.set_data("pbm", self.__pbm) 163 span.set_data("request", self.request) 164 bindings = self.bindings() 165 policy_bindings = bindings 166 if isinstance(bindings, QuerySet): 167 self.compute_static_bindings(bindings) 168 policy_bindings = [x for x in bindings if x.policy] 169 for binding in policy_bindings: 170 self.__expected_result_count += 1 171 172 self._check_policy_type(binding) 173 if self._check_cache(binding): 174 continue 175 self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request) 176 our_end, task_end = Pipe(False) 177 task = PolicyProcess(binding, self.request, task_end) 178 task.daemon = False 179 self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request) 180 if not CURRENT_PROCESS._config.get("daemon"): 181 task.run() 182 else: 183 task.start() 184 self.__processes.append( 185 PolicyProcessInfo(process=task, connection=our_end, binding=binding) 186 ) 187 # If all policies are cached, we have an empty list here. 188 for proc_info in self.__processes: 189 if proc_info.process.is_alive(): 190 proc_info.process.join(proc_info.binding.timeout) 191 # Only call .recv() if no result is saved, otherwise we just deadlock here 192 if not proc_info.result: 193 proc_info.result = proc_info.connection.recv() 194 if proc_info.result and proc_info.result._exec_time: 195 HIST_POLICIES_EXECUTION_TIME.labels( 196 binding_order=proc_info.binding.order, 197 binding_target_type=proc_info.binding.target_type, 198 binding_target_name=proc_info.binding.target_name, 199 object_type=( 200 class_to_path(self.request.obj.__class__) if self.request.obj else "" 201 ), 202 mode="execute_process", 203 ).observe(proc_info.result._exec_time) 204 return self 205 206 @property 207 def result(self) -> PolicyResult: 208 """Get policy-checking result""" 209 self.__processes.sort(key=lambda x: x.binding.order) 210 process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] 211 all_results = list(process_results + self.__cached_policies) 212 if len(all_results) < self.__expected_result_count: # pragma: no cover 213 raise AssertionError("Got less results than polices") 214 if self.__static_result: 215 all_results.append(self.__static_result) 216 # No results, no policies attached -> passing 217 if len(all_results) == 0: 218 return PolicyResult(self.empty_result) 219 passing = False 220 if self.mode == PolicyEngineMode.MODE_ALL: 221 passing = all(x.passing for x in all_results) 222 if self.mode == PolicyEngineMode.MODE_ANY: 223 passing = any(x.passing for x in all_results) 224 result = PolicyResult(passing) 225 result.source_results = all_results 226 result.messages = tuple(y for x in all_results for y in x.messages) 227 return result 228 229 @property 230 def passing(self) -> bool: 231 """Only get true/false if user passes""" 232 return self.result.passing
Orchestrate policy checking, launch tasks and return result
PolicyEngine( pbm: authentik.policies.models.PolicyBindingModel, user: authentik.core.models.User, request: django.http.request.HttpRequest = None)
52 def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None): 53 self.logger = get_logger().bind() 54 self.mode = pbm.policy_engine_mode 55 # For backwards compatibility, set empty_result to true 56 # objects with no policies attached will pass. 57 self.empty_result = True 58 if not isinstance(pbm, PolicyBindingModel): # pragma: no cover 59 raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel") 60 if not user: 61 raise PolicyEngineException("User must be set") 62 self.__pbm = pbm 63 self.request = PolicyRequest(user) 64 self.request.obj = pbm 65 if request: 66 self.request.set_http_request(request) 67 self.__cached_policies: list[PolicyResult] = [] 68 self.__processes: list[PolicyProcessInfo] = [] 69 self.use_cache = True 70 self.__expected_result_count = 0 71 self.__static_result: PolicyResult | None = None
def
bindings( self) -> django.db.models.query.QuerySet | Iterable[authentik.policies.models.PolicyBinding]:
73 def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]: 74 """Make sure all Policies are their respective classes""" 75 return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
Make sure all Policies are their respective classes
def
compute_static_bindings(self, bindings: django.db.models.query.QuerySet):
106 def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]): 107 """Check static bindings if possible""" 108 aggrs = { 109 "total": Count( 110 "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None) 111 ), 112 } 113 if self.request.user.pk: 114 all_groups = self.request.user.all_groups() 115 aggrs["passing"] = Count( 116 "pk", 117 filter=Q( 118 Q( 119 Q(user=self.request.user) | Q(group__in=all_groups), 120 negate=False, 121 ) 122 | Q( 123 Q(~Q(user=self.request.user), user__isnull=False) 124 | Q(~Q(group__in=all_groups), group__isnull=False), 125 negate=True, 126 ), 127 enabled=True, 128 ), 129 ) 130 matched_bindings = bindings.aggregate(**aggrs) 131 passing = False 132 if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0: 133 # If we didn't find any static bindings, do nothing 134 return 135 self.logger.debug("P_ENG: Found static bindings", **matched_bindings) 136 if self.mode == PolicyEngineMode.MODE_ANY: 137 if matched_bindings.get("passing", 0) > 0: 138 # Any passing static binding -> passing 139 passing = True 140 elif self.mode == PolicyEngineMode.MODE_ALL: 141 if matched_bindings.get("passing", 0) == matched_bindings["total"]: 142 # All static bindings are passing -> passing 143 passing = True 144 elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1: 145 # No matching static bindings but at least one is configured -> not passing 146 passing = False 147 self.__static_result = PolicyResult(passing)
Check static bindings if possible
149 def build(self) -> PolicyEngine: 150 """Build wrapper which monitors performance""" 151 with ( 152 start_span( 153 op="authentik.policy.engine.build", 154 name=self.__pbm, 155 ) as span, 156 HIST_POLICIES_ENGINE_TOTAL_TIME.labels( 157 obj_type=class_to_path(self.__pbm.__class__), 158 obj_pk=str(self.__pbm.pk), 159 ).time(), 160 ): 161 span: Span 162 span.set_data("pbm", self.__pbm) 163 span.set_data("request", self.request) 164 bindings = self.bindings() 165 policy_bindings = bindings 166 if isinstance(bindings, QuerySet): 167 self.compute_static_bindings(bindings) 168 policy_bindings = [x for x in bindings if x.policy] 169 for binding in policy_bindings: 170 self.__expected_result_count += 1 171 172 self._check_policy_type(binding) 173 if self._check_cache(binding): 174 continue 175 self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request) 176 our_end, task_end = Pipe(False) 177 task = PolicyProcess(binding, self.request, task_end) 178 task.daemon = False 179 self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request) 180 if not CURRENT_PROCESS._config.get("daemon"): 181 task.run() 182 else: 183 task.start() 184 self.__processes.append( 185 PolicyProcessInfo(process=task, connection=our_end, binding=binding) 186 ) 187 # If all policies are cached, we have an empty list here. 188 for proc_info in self.__processes: 189 if proc_info.process.is_alive(): 190 proc_info.process.join(proc_info.binding.timeout) 191 # Only call .recv() if no result is saved, otherwise we just deadlock here 192 if not proc_info.result: 193 proc_info.result = proc_info.connection.recv() 194 if proc_info.result and proc_info.result._exec_time: 195 HIST_POLICIES_EXECUTION_TIME.labels( 196 binding_order=proc_info.binding.order, 197 binding_target_type=proc_info.binding.target_type, 198 binding_target_name=proc_info.binding.target_name, 199 object_type=( 200 class_to_path(self.request.obj.__class__) if self.request.obj else "" 201 ), 202 mode="execute_process", 203 ).observe(proc_info.result._exec_time) 204 return self
Build wrapper which monitors performance
206 @property 207 def result(self) -> PolicyResult: 208 """Get policy-checking result""" 209 self.__processes.sort(key=lambda x: x.binding.order) 210 process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result] 211 all_results = list(process_results + self.__cached_policies) 212 if len(all_results) < self.__expected_result_count: # pragma: no cover 213 raise AssertionError("Got less results than polices") 214 if self.__static_result: 215 all_results.append(self.__static_result) 216 # No results, no policies attached -> passing 217 if len(all_results) == 0: 218 return PolicyResult(self.empty_result) 219 passing = False 220 if self.mode == PolicyEngineMode.MODE_ALL: 221 passing = all(x.passing for x in all_results) 222 if self.mode == PolicyEngineMode.MODE_ANY: 223 passing = any(x.passing for x in all_results) 224 result = PolicyResult(passing) 225 result.source_results = all_results 226 result.messages = tuple(y for x in all_results for y in x.messages) 227 return result
Get policy-checking result