authentik.policies.engine

authentik policy engine

  1"""authentik policy engine"""
  2
  3from collections.abc import Iterable
  4from multiprocessing import Pipe, current_process
  5from multiprocessing.connection import Connection
  6
  7from django.core.cache import cache
  8from django.db.models import Count, Q, QuerySet
  9from django.http import HttpRequest
 10from sentry_sdk import start_span
 11from sentry_sdk.tracing import Span
 12from structlog.stdlib import BoundLogger, get_logger
 13
 14from authentik.core.models import User
 15from authentik.lib.utils.reflection import class_to_path
 16from authentik.policies.apps import HIST_POLICIES_ENGINE_TOTAL_TIME, HIST_POLICIES_EXECUTION_TIME
 17from authentik.policies.exceptions import PolicyEngineException
 18from authentik.policies.models import Policy, PolicyBinding, PolicyBindingModel, PolicyEngineMode
 19from authentik.policies.process import PolicyProcess, cache_key
 20from authentik.policies.types import PolicyRequest, PolicyResult
 21
 22CURRENT_PROCESS = current_process()
 23
 24
 25class PolicyProcessInfo:
 26    """Dataclass to hold all information and communication channels to a process"""
 27
 28    process: PolicyProcess
 29    connection: Connection
 30    result: PolicyResult | None
 31    binding: PolicyBinding
 32
 33    def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
 34        self.process = process
 35        self.connection = connection
 36        self.binding = binding
 37        self.result = None
 38
 39
 40class PolicyEngine:
 41    """Orchestrate policy checking, launch tasks and return result"""
 42
 43    use_cache: bool
 44    request: PolicyRequest
 45
 46    logger: BoundLogger
 47    mode: PolicyEngineMode
 48    # Allow objects with no policies attached to pass
 49    empty_result: bool
 50
 51    def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None):
 52        self.logger = get_logger().bind()
 53        self.mode = pbm.policy_engine_mode
 54        # For backwards compatibility, set empty_result to true
 55        # objects with no policies attached will pass.
 56        self.empty_result = True
 57        if not isinstance(pbm, PolicyBindingModel):  # pragma: no cover
 58            raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel")
 59        if not user:
 60            raise PolicyEngineException("User must be set")
 61        self.__pbm = pbm
 62        self.request = PolicyRequest(user)
 63        self.request.obj = pbm
 64        if request:
 65            self.request.set_http_request(request)
 66        self.__cached_policies: list[PolicyResult] = []
 67        self.__processes: list[PolicyProcessInfo] = []
 68        self.use_cache = True
 69        self.__expected_result_count = 0
 70        self.__static_result: PolicyResult | None = None
 71
 72    def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
 73        """Make sure all Policies are their respective classes"""
 74        return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
 75
 76    def _check_policy_type(self, binding: PolicyBinding):
 77        """Check policy type, make sure it's not the root class as that has no logic implemented"""
 78        if binding.policy is not None and binding.policy.__class__ == Policy:
 79            raise PolicyEngineException(f"Policy '{binding.policy}' is root type")
 80
 81    def _check_cache(self, binding: PolicyBinding):
 82        if not self.use_cache:
 83            return False
 84        # It's a bit silly to time this, but
 85        with HIST_POLICIES_EXECUTION_TIME.labels(
 86            binding_order=binding.order,
 87            binding_target_type=binding.target_type,
 88            binding_target_name=binding.target_name,
 89            object_type=class_to_path(self.request.obj.__class__),
 90            mode="cache_retrieve",
 91        ).time():
 92            key = cache_key(binding, self.request)
 93            cached_policy = cache.get(key, None)
 94            if not cached_policy:
 95                return False
 96        self.logger.debug(
 97            "P_ENG: Taking result from cache",
 98            binding=binding,
 99            cache_key=key,
100            request=self.request,
101        )
102        self.__cached_policies.append(cached_policy)
103        return True
104
105    def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
106        """Check static bindings if possible"""
107        aggrs = {
108            "total": Count(
109                "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
110            ),
111        }
112        if self.request.user.pk:
113            all_groups = self.request.user.all_groups()
114            aggrs["passing"] = Count(
115                "pk",
116                filter=Q(
117                    Q(
118                        Q(user=self.request.user) | Q(group__in=all_groups),
119                        negate=False,
120                    )
121                    | Q(
122                        Q(~Q(user=self.request.user), user__isnull=False)
123                        | Q(~Q(group__in=all_groups), group__isnull=False),
124                        negate=True,
125                    ),
126                    enabled=True,
127                ),
128            )
129        matched_bindings = bindings.aggregate(**aggrs)
130        passing = False
131        if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
132            # If we didn't find any static bindings, do nothing
133            return
134        self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
135        if self.mode == PolicyEngineMode.MODE_ANY:
136            if matched_bindings.get("passing", 0) > 0:
137                # Any passing static binding -> passing
138                passing = True
139        elif self.mode == PolicyEngineMode.MODE_ALL:
140            if matched_bindings.get("passing", 0) == matched_bindings["total"]:
141                # All static bindings are passing -> passing
142                passing = True
143        elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
144            # No matching static bindings but at least one is configured -> not passing
145            passing = False
146        self.__static_result = PolicyResult(passing)
147
148    def build(self) -> PolicyEngine:
149        """Build wrapper which monitors performance"""
150        with (
151            start_span(
152                op="authentik.policy.engine.build",
153                name=self.__pbm,
154            ) as span,
155            HIST_POLICIES_ENGINE_TOTAL_TIME.labels(
156                obj_type=class_to_path(self.__pbm.__class__),
157                obj_pk=str(self.__pbm.pk),
158            ).time(),
159        ):
160            span: Span
161            span.set_data("pbm", self.__pbm)
162            span.set_data("request", self.request)
163            bindings = self.bindings()
164            policy_bindings = bindings
165            if isinstance(bindings, QuerySet):
166                self.compute_static_bindings(bindings)
167                policy_bindings = [x for x in bindings if x.policy]
168            for binding in policy_bindings:
169                self.__expected_result_count += 1
170
171                self._check_policy_type(binding)
172                if self._check_cache(binding):
173                    continue
174                self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request)
175                our_end, task_end = Pipe(False)
176                task = PolicyProcess(binding, self.request, task_end)
177                task.daemon = False
178                self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request)
179                if not CURRENT_PROCESS._config.get("daemon"):
180                    task.run()
181                else:
182                    task.start()
183                self.__processes.append(
184                    PolicyProcessInfo(process=task, connection=our_end, binding=binding)
185                )
186            # If all policies are cached, we have an empty list here.
187            for proc_info in self.__processes:
188                if proc_info.process.is_alive():
189                    proc_info.process.join(proc_info.binding.timeout)
190                # Only call .recv() if no result is saved, otherwise we just deadlock here
191                if not proc_info.result:
192                    proc_info.result = proc_info.connection.recv()
193                if proc_info.result and proc_info.result._exec_time:
194                    HIST_POLICIES_EXECUTION_TIME.labels(
195                        binding_order=proc_info.binding.order,
196                        binding_target_type=proc_info.binding.target_type,
197                        binding_target_name=proc_info.binding.target_name,
198                        object_type=(
199                            class_to_path(self.request.obj.__class__) if self.request.obj else ""
200                        ),
201                        mode="execute_process",
202                    ).observe(proc_info.result._exec_time)
203            return self
204
205    @property
206    def result(self) -> PolicyResult:
207        """Get policy-checking result"""
208        self.__processes.sort(key=lambda x: x.binding.order)
209        process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
210        all_results = list(process_results + self.__cached_policies)
211        if len(all_results) < self.__expected_result_count:  # pragma: no cover
212            raise AssertionError("Got less results than polices")
213        if self.__static_result:
214            all_results.append(self.__static_result)
215        # No results, no policies attached -> passing
216        if len(all_results) == 0:
217            return PolicyResult(self.empty_result)
218        passing = False
219        if self.mode == PolicyEngineMode.MODE_ALL:
220            passing = all(x.passing for x in all_results)
221        if self.mode == PolicyEngineMode.MODE_ANY:
222            passing = any(x.passing for x in all_results)
223        result = PolicyResult(passing)
224        result.source_results = all_results
225        result.messages = tuple(y for x in all_results for y in x.messages)
226        return result
227
228    @property
229    def passing(self) -> bool:
230        """Only get true/false if user passes"""
231        return self.result.passing
CURRENT_PROCESS = <_MainProcess name='MainProcess' parent=None started>
class PolicyProcessInfo:
26class PolicyProcessInfo:
27    """Dataclass to hold all information and communication channels to a process"""
28
29    process: PolicyProcess
30    connection: Connection
31    result: PolicyResult | None
32    binding: PolicyBinding
33
34    def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
35        self.process = process
36        self.connection = connection
37        self.binding = binding
38        self.result = None

Dataclass to hold all information and communication channels to a process

PolicyProcessInfo( process: authentik.policies.process.PolicyProcess, connection: multiprocessing.connection.Connection, binding: authentik.policies.models.PolicyBinding)
34    def __init__(self, process: PolicyProcess, connection: Connection, binding: PolicyBinding):
35        self.process = process
36        self.connection = connection
37        self.binding = binding
38        self.result = None
connection: multiprocessing.connection.Connection
class PolicyEngine:
 41class PolicyEngine:
 42    """Orchestrate policy checking, launch tasks and return result"""
 43
 44    use_cache: bool
 45    request: PolicyRequest
 46
 47    logger: BoundLogger
 48    mode: PolicyEngineMode
 49    # Allow objects with no policies attached to pass
 50    empty_result: bool
 51
 52    def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None):
 53        self.logger = get_logger().bind()
 54        self.mode = pbm.policy_engine_mode
 55        # For backwards compatibility, set empty_result to true
 56        # objects with no policies attached will pass.
 57        self.empty_result = True
 58        if not isinstance(pbm, PolicyBindingModel):  # pragma: no cover
 59            raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel")
 60        if not user:
 61            raise PolicyEngineException("User must be set")
 62        self.__pbm = pbm
 63        self.request = PolicyRequest(user)
 64        self.request.obj = pbm
 65        if request:
 66            self.request.set_http_request(request)
 67        self.__cached_policies: list[PolicyResult] = []
 68        self.__processes: list[PolicyProcessInfo] = []
 69        self.use_cache = True
 70        self.__expected_result_count = 0
 71        self.__static_result: PolicyResult | None = None
 72
 73    def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
 74        """Make sure all Policies are their respective classes"""
 75        return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")
 76
 77    def _check_policy_type(self, binding: PolicyBinding):
 78        """Check policy type, make sure it's not the root class as that has no logic implemented"""
 79        if binding.policy is not None and binding.policy.__class__ == Policy:
 80            raise PolicyEngineException(f"Policy '{binding.policy}' is root type")
 81
 82    def _check_cache(self, binding: PolicyBinding):
 83        if not self.use_cache:
 84            return False
 85        # It's a bit silly to time this, but
 86        with HIST_POLICIES_EXECUTION_TIME.labels(
 87            binding_order=binding.order,
 88            binding_target_type=binding.target_type,
 89            binding_target_name=binding.target_name,
 90            object_type=class_to_path(self.request.obj.__class__),
 91            mode="cache_retrieve",
 92        ).time():
 93            key = cache_key(binding, self.request)
 94            cached_policy = cache.get(key, None)
 95            if not cached_policy:
 96                return False
 97        self.logger.debug(
 98            "P_ENG: Taking result from cache",
 99            binding=binding,
100            cache_key=key,
101            request=self.request,
102        )
103        self.__cached_policies.append(cached_policy)
104        return True
105
106    def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
107        """Check static bindings if possible"""
108        aggrs = {
109            "total": Count(
110                "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
111            ),
112        }
113        if self.request.user.pk:
114            all_groups = self.request.user.all_groups()
115            aggrs["passing"] = Count(
116                "pk",
117                filter=Q(
118                    Q(
119                        Q(user=self.request.user) | Q(group__in=all_groups),
120                        negate=False,
121                    )
122                    | Q(
123                        Q(~Q(user=self.request.user), user__isnull=False)
124                        | Q(~Q(group__in=all_groups), group__isnull=False),
125                        negate=True,
126                    ),
127                    enabled=True,
128                ),
129            )
130        matched_bindings = bindings.aggregate(**aggrs)
131        passing = False
132        if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
133            # If we didn't find any static bindings, do nothing
134            return
135        self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
136        if self.mode == PolicyEngineMode.MODE_ANY:
137            if matched_bindings.get("passing", 0) > 0:
138                # Any passing static binding -> passing
139                passing = True
140        elif self.mode == PolicyEngineMode.MODE_ALL:
141            if matched_bindings.get("passing", 0) == matched_bindings["total"]:
142                # All static bindings are passing -> passing
143                passing = True
144        elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
145            # No matching static bindings but at least one is configured -> not passing
146            passing = False
147        self.__static_result = PolicyResult(passing)
148
149    def build(self) -> PolicyEngine:
150        """Build wrapper which monitors performance"""
151        with (
152            start_span(
153                op="authentik.policy.engine.build",
154                name=self.__pbm,
155            ) as span,
156            HIST_POLICIES_ENGINE_TOTAL_TIME.labels(
157                obj_type=class_to_path(self.__pbm.__class__),
158                obj_pk=str(self.__pbm.pk),
159            ).time(),
160        ):
161            span: Span
162            span.set_data("pbm", self.__pbm)
163            span.set_data("request", self.request)
164            bindings = self.bindings()
165            policy_bindings = bindings
166            if isinstance(bindings, QuerySet):
167                self.compute_static_bindings(bindings)
168                policy_bindings = [x for x in bindings if x.policy]
169            for binding in policy_bindings:
170                self.__expected_result_count += 1
171
172                self._check_policy_type(binding)
173                if self._check_cache(binding):
174                    continue
175                self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request)
176                our_end, task_end = Pipe(False)
177                task = PolicyProcess(binding, self.request, task_end)
178                task.daemon = False
179                self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request)
180                if not CURRENT_PROCESS._config.get("daemon"):
181                    task.run()
182                else:
183                    task.start()
184                self.__processes.append(
185                    PolicyProcessInfo(process=task, connection=our_end, binding=binding)
186                )
187            # If all policies are cached, we have an empty list here.
188            for proc_info in self.__processes:
189                if proc_info.process.is_alive():
190                    proc_info.process.join(proc_info.binding.timeout)
191                # Only call .recv() if no result is saved, otherwise we just deadlock here
192                if not proc_info.result:
193                    proc_info.result = proc_info.connection.recv()
194                if proc_info.result and proc_info.result._exec_time:
195                    HIST_POLICIES_EXECUTION_TIME.labels(
196                        binding_order=proc_info.binding.order,
197                        binding_target_type=proc_info.binding.target_type,
198                        binding_target_name=proc_info.binding.target_name,
199                        object_type=(
200                            class_to_path(self.request.obj.__class__) if self.request.obj else ""
201                        ),
202                        mode="execute_process",
203                    ).observe(proc_info.result._exec_time)
204            return self
205
206    @property
207    def result(self) -> PolicyResult:
208        """Get policy-checking result"""
209        self.__processes.sort(key=lambda x: x.binding.order)
210        process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
211        all_results = list(process_results + self.__cached_policies)
212        if len(all_results) < self.__expected_result_count:  # pragma: no cover
213            raise AssertionError("Got less results than polices")
214        if self.__static_result:
215            all_results.append(self.__static_result)
216        # No results, no policies attached -> passing
217        if len(all_results) == 0:
218            return PolicyResult(self.empty_result)
219        passing = False
220        if self.mode == PolicyEngineMode.MODE_ALL:
221            passing = all(x.passing for x in all_results)
222        if self.mode == PolicyEngineMode.MODE_ANY:
223            passing = any(x.passing for x in all_results)
224        result = PolicyResult(passing)
225        result.source_results = all_results
226        result.messages = tuple(y for x in all_results for y in x.messages)
227        return result
228
229    @property
230    def passing(self) -> bool:
231        """Only get true/false if user passes"""
232        return self.result.passing

Orchestrate policy checking, launch tasks and return result

PolicyEngine( pbm: authentik.policies.models.PolicyBindingModel, user: authentik.core.models.User, request: django.http.request.HttpRequest = None)
52    def __init__(self, pbm: PolicyBindingModel, user: User, request: HttpRequest = None):
53        self.logger = get_logger().bind()
54        self.mode = pbm.policy_engine_mode
55        # For backwards compatibility, set empty_result to true
56        # objects with no policies attached will pass.
57        self.empty_result = True
58        if not isinstance(pbm, PolicyBindingModel):  # pragma: no cover
59            raise PolicyEngineException(f"{pbm} is not instance of PolicyBindingModel")
60        if not user:
61            raise PolicyEngineException("User must be set")
62        self.__pbm = pbm
63        self.request = PolicyRequest(user)
64        self.request.obj = pbm
65        if request:
66            self.request.set_http_request(request)
67        self.__cached_policies: list[PolicyResult] = []
68        self.__processes: list[PolicyProcessInfo] = []
69        self.use_cache = True
70        self.__expected_result_count = 0
71        self.__static_result: PolicyResult | None = None
use_cache: bool
logger: structlog.stdlib.BoundLogger
empty_result: bool
def bindings( self) -> django.db.models.query.QuerySet | Iterable[authentik.policies.models.PolicyBinding]:
73    def bindings(self) -> QuerySet[PolicyBinding] | Iterable[PolicyBinding]:
74        """Make sure all Policies are their respective classes"""
75        return PolicyBinding.objects.filter(target=self.__pbm, enabled=True).order_by("order")

Make sure all Policies are their respective classes

def compute_static_bindings(self, bindings: django.db.models.query.QuerySet):
106    def compute_static_bindings(self, bindings: QuerySet[PolicyBinding]):
107        """Check static bindings if possible"""
108        aggrs = {
109            "total": Count(
110                "pk", filter=Q(Q(group__isnull=False) | Q(user__isnull=False), policy=None)
111            ),
112        }
113        if self.request.user.pk:
114            all_groups = self.request.user.all_groups()
115            aggrs["passing"] = Count(
116                "pk",
117                filter=Q(
118                    Q(
119                        Q(user=self.request.user) | Q(group__in=all_groups),
120                        negate=False,
121                    )
122                    | Q(
123                        Q(~Q(user=self.request.user), user__isnull=False)
124                        | Q(~Q(group__in=all_groups), group__isnull=False),
125                        negate=True,
126                    ),
127                    enabled=True,
128                ),
129            )
130        matched_bindings = bindings.aggregate(**aggrs)
131        passing = False
132        if matched_bindings["total"] == 0 and matched_bindings.get("passing", 0) == 0:
133            # If we didn't find any static bindings, do nothing
134            return
135        self.logger.debug("P_ENG: Found static bindings", **matched_bindings)
136        if self.mode == PolicyEngineMode.MODE_ANY:
137            if matched_bindings.get("passing", 0) > 0:
138                # Any passing static binding -> passing
139                passing = True
140        elif self.mode == PolicyEngineMode.MODE_ALL:
141            if matched_bindings.get("passing", 0) == matched_bindings["total"]:
142                # All static bindings are passing -> passing
143                passing = True
144        elif matched_bindings["total"] > 0 and matched_bindings.get("passing", 0) < 1:
145            # No matching static bindings but at least one is configured -> not passing
146            passing = False
147        self.__static_result = PolicyResult(passing)

Check static bindings if possible

def build(self) -> PolicyEngine:
149    def build(self) -> PolicyEngine:
150        """Build wrapper which monitors performance"""
151        with (
152            start_span(
153                op="authentik.policy.engine.build",
154                name=self.__pbm,
155            ) as span,
156            HIST_POLICIES_ENGINE_TOTAL_TIME.labels(
157                obj_type=class_to_path(self.__pbm.__class__),
158                obj_pk=str(self.__pbm.pk),
159            ).time(),
160        ):
161            span: Span
162            span.set_data("pbm", self.__pbm)
163            span.set_data("request", self.request)
164            bindings = self.bindings()
165            policy_bindings = bindings
166            if isinstance(bindings, QuerySet):
167                self.compute_static_bindings(bindings)
168                policy_bindings = [x for x in bindings if x.policy]
169            for binding in policy_bindings:
170                self.__expected_result_count += 1
171
172                self._check_policy_type(binding)
173                if self._check_cache(binding):
174                    continue
175                self.logger.debug("P_ENG: Evaluating policy", binding=binding, request=self.request)
176                our_end, task_end = Pipe(False)
177                task = PolicyProcess(binding, self.request, task_end)
178                task.daemon = False
179                self.logger.debug("P_ENG: Starting Process", binding=binding, request=self.request)
180                if not CURRENT_PROCESS._config.get("daemon"):
181                    task.run()
182                else:
183                    task.start()
184                self.__processes.append(
185                    PolicyProcessInfo(process=task, connection=our_end, binding=binding)
186                )
187            # If all policies are cached, we have an empty list here.
188            for proc_info in self.__processes:
189                if proc_info.process.is_alive():
190                    proc_info.process.join(proc_info.binding.timeout)
191                # Only call .recv() if no result is saved, otherwise we just deadlock here
192                if not proc_info.result:
193                    proc_info.result = proc_info.connection.recv()
194                if proc_info.result and proc_info.result._exec_time:
195                    HIST_POLICIES_EXECUTION_TIME.labels(
196                        binding_order=proc_info.binding.order,
197                        binding_target_type=proc_info.binding.target_type,
198                        binding_target_name=proc_info.binding.target_name,
199                        object_type=(
200                            class_to_path(self.request.obj.__class__) if self.request.obj else ""
201                        ),
202                        mode="execute_process",
203                    ).observe(proc_info.result._exec_time)
204            return self

Build wrapper which monitors performance

206    @property
207    def result(self) -> PolicyResult:
208        """Get policy-checking result"""
209        self.__processes.sort(key=lambda x: x.binding.order)
210        process_results: list[PolicyResult] = [x.result for x in self.__processes if x.result]
211        all_results = list(process_results + self.__cached_policies)
212        if len(all_results) < self.__expected_result_count:  # pragma: no cover
213            raise AssertionError("Got less results than polices")
214        if self.__static_result:
215            all_results.append(self.__static_result)
216        # No results, no policies attached -> passing
217        if len(all_results) == 0:
218            return PolicyResult(self.empty_result)
219        passing = False
220        if self.mode == PolicyEngineMode.MODE_ALL:
221            passing = all(x.passing for x in all_results)
222        if self.mode == PolicyEngineMode.MODE_ANY:
223            passing = any(x.passing for x in all_results)
224        result = PolicyResult(passing)
225        result.source_results = all_results
226        result.messages = tuple(y for x in all_results for y in x.messages)
227        return result

Get policy-checking result

passing: bool
229    @property
230    def passing(self) -> bool:
231        """Only get true/false if user passes"""
232        return self.result.passing

Only get true/false if user passes