authentik.enterprise.audit.middleware

Enterprise audit middleware

  1"""Enterprise audit middleware"""
  2
  3from copy import deepcopy
  4from functools import partial
  5from typing import Any
  6
  7from django.apps.registry import apps
  8from django.core.files import File
  9from django.db import connection
 10from django.db.models import ManyToManyRel, Model
 11from django.db.models.expressions import BaseExpression, Combinable
 12from django.db.models.signals import post_init
 13from django.http import HttpRequest
 14
 15from authentik.enterprise.audit.apps import AuditIncludeExpandedDiff
 16from authentik.events.middleware import AuditMiddleware, should_log_model
 17from authentik.events.utils import cleanse_dict, sanitize_item
 18
 19
 20class EnterpriseAuditMiddleware(AuditMiddleware):
 21    """Enterprise audit middleware"""
 22
 23    @property
 24    def enabled(self):
 25        """Check if audit logging is enabled"""
 26        return apps.get_app_config("authentik_enterprise").enabled()
 27
 28    def connect(self, request: HttpRequest):
 29        super().connect(request)
 30        if not self.enabled:
 31            return
 32        if not hasattr(request, "request_id"):
 33            return
 34        post_init.connect(
 35            partial(self.post_init_handler, request=request),
 36            dispatch_uid=request.request_id,
 37            weak=False,
 38        )
 39
 40    def disconnect(self, request: HttpRequest):
 41        super().disconnect(request)
 42        if not self.enabled:
 43            return
 44        if not hasattr(request, "request_id"):
 45            return
 46        post_init.disconnect(dispatch_uid=request.request_id)
 47
 48    def serialize_simple(self, model: Model) -> dict:
 49        """Serialize a model in a very simple way. No ForeignKeys or other relationships are
 50        resolved"""
 51        data = {}
 52        deferred_fields = model.get_deferred_fields()
 53        for field in model._meta.concrete_fields:
 54            value = None
 55            if field.get_attname() in deferred_fields:
 56                continue
 57
 58            field_value = getattr(model, field.attname)
 59            if isinstance(value, File):
 60                field_value = value.name
 61
 62            # If current field value is an expression, we are not evaluating it
 63            if isinstance(field_value, BaseExpression | Combinable):
 64                continue
 65            field_value = field.to_python(field_value)
 66            data[field.name] = deepcopy(field_value)
 67        return cleanse_dict(data)
 68
 69    def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
 70        """Generate diff between dicts"""
 71        diff = {}
 72        for key, value in before.items():
 73            if update_fields and key not in update_fields:
 74                continue
 75            if after.get(key) != value:
 76                diff[key] = {"previous_value": value, "new_value": after.get(key)}
 77        for key, value in after.items():
 78            if update_fields and key not in update_fields:
 79                continue
 80            if key not in before and key not in diff and before.get(key) != value:
 81                diff[key] = {"previous_value": before.get(key), "new_value": value}
 82        return sanitize_item(diff)
 83
 84    def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
 85        """post_init django model handler"""
 86        if not should_log_model(instance):
 87            return
 88        if hasattr(instance, "_previous_state"):
 89            return
 90        before = len(connection.queries)
 91        instance._previous_state = self.serialize_simple(instance)
 92        after = len(connection.queries)
 93        if after > before:
 94            raise AssertionError("More queries generated by serialize_simple")
 95
 96    def post_save_handler(
 97        self,
 98        request: HttpRequest,
 99        sender,
100        instance: Model,
101        created: bool,
102        thread_kwargs: dict | None = None,
103        update_fields: list[str] | None = None,
104        **_,
105    ):
106        if not self.enabled:
107            return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
108        if not should_log_model(instance):
109            return None
110        thread_kwargs = {}
111        if hasattr(instance, "_previous_state") or created:
112            prev_state = getattr(instance, "_previous_state", {})
113            if created:
114                prev_state = {}
115            # Get current state
116            new_state = self.serialize_simple(instance)
117            diff = self.diff(prev_state, new_state, update_fields)
118            thread_kwargs["diff"] = diff
119        return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
120
121    def m2m_changed_handler(  # noqa: PLR0913
122        self,
123        request: HttpRequest,
124        sender,
125        instance: Model,
126        action: str,
127        pk_set: set[Any],
128        thread_kwargs: dict | None = None,
129        **_,
130    ):
131        thread_kwargs = {}
132        m2m_field = None
133        if not self.enabled:
134            return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
135        # For the audit log we don't care about `pre_` or `post_` so we trim that part off
136        _, _, action_direction = action.partition("_")
137        # resolve the "through" model to an actual field
138        for field in instance._meta.get_fields():
139            if not isinstance(field, ManyToManyRel):
140                continue
141            if field.through == sender:
142                m2m_field = field
143        if m2m_field:
144            # If we're clearing we just set the "flag" to True
145            if action_direction == "clear":
146                pk_set = True
147            elif AuditIncludeExpandedDiff.get():
148                related_model: type[Model] = m2m_field.related_model
149                instances = related_model.objects.filter(pk__in=pk_set)
150                pk_set = [self.serialize_simple(instance) for instance in instances]
151            thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
152        return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
class EnterpriseAuditMiddleware(authentik.events.middleware.AuditMiddleware):
 21class EnterpriseAuditMiddleware(AuditMiddleware):
 22    """Enterprise audit middleware"""
 23
 24    @property
 25    def enabled(self):
 26        """Check if audit logging is enabled"""
 27        return apps.get_app_config("authentik_enterprise").enabled()
 28
 29    def connect(self, request: HttpRequest):
 30        super().connect(request)
 31        if not self.enabled:
 32            return
 33        if not hasattr(request, "request_id"):
 34            return
 35        post_init.connect(
 36            partial(self.post_init_handler, request=request),
 37            dispatch_uid=request.request_id,
 38            weak=False,
 39        )
 40
 41    def disconnect(self, request: HttpRequest):
 42        super().disconnect(request)
 43        if not self.enabled:
 44            return
 45        if not hasattr(request, "request_id"):
 46            return
 47        post_init.disconnect(dispatch_uid=request.request_id)
 48
 49    def serialize_simple(self, model: Model) -> dict:
 50        """Serialize a model in a very simple way. No ForeignKeys or other relationships are
 51        resolved"""
 52        data = {}
 53        deferred_fields = model.get_deferred_fields()
 54        for field in model._meta.concrete_fields:
 55            value = None
 56            if field.get_attname() in deferred_fields:
 57                continue
 58
 59            field_value = getattr(model, field.attname)
 60            if isinstance(value, File):
 61                field_value = value.name
 62
 63            # If current field value is an expression, we are not evaluating it
 64            if isinstance(field_value, BaseExpression | Combinable):
 65                continue
 66            field_value = field.to_python(field_value)
 67            data[field.name] = deepcopy(field_value)
 68        return cleanse_dict(data)
 69
 70    def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
 71        """Generate diff between dicts"""
 72        diff = {}
 73        for key, value in before.items():
 74            if update_fields and key not in update_fields:
 75                continue
 76            if after.get(key) != value:
 77                diff[key] = {"previous_value": value, "new_value": after.get(key)}
 78        for key, value in after.items():
 79            if update_fields and key not in update_fields:
 80                continue
 81            if key not in before and key not in diff and before.get(key) != value:
 82                diff[key] = {"previous_value": before.get(key), "new_value": value}
 83        return sanitize_item(diff)
 84
 85    def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
 86        """post_init django model handler"""
 87        if not should_log_model(instance):
 88            return
 89        if hasattr(instance, "_previous_state"):
 90            return
 91        before = len(connection.queries)
 92        instance._previous_state = self.serialize_simple(instance)
 93        after = len(connection.queries)
 94        if after > before:
 95            raise AssertionError("More queries generated by serialize_simple")
 96
 97    def post_save_handler(
 98        self,
 99        request: HttpRequest,
100        sender,
101        instance: Model,
102        created: bool,
103        thread_kwargs: dict | None = None,
104        update_fields: list[str] | None = None,
105        **_,
106    ):
107        if not self.enabled:
108            return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
109        if not should_log_model(instance):
110            return None
111        thread_kwargs = {}
112        if hasattr(instance, "_previous_state") or created:
113            prev_state = getattr(instance, "_previous_state", {})
114            if created:
115                prev_state = {}
116            # Get current state
117            new_state = self.serialize_simple(instance)
118            diff = self.diff(prev_state, new_state, update_fields)
119            thread_kwargs["diff"] = diff
120        return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
121
122    def m2m_changed_handler(  # noqa: PLR0913
123        self,
124        request: HttpRequest,
125        sender,
126        instance: Model,
127        action: str,
128        pk_set: set[Any],
129        thread_kwargs: dict | None = None,
130        **_,
131    ):
132        thread_kwargs = {}
133        m2m_field = None
134        if not self.enabled:
135            return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
136        # For the audit log we don't care about `pre_` or `post_` so we trim that part off
137        _, _, action_direction = action.partition("_")
138        # resolve the "through" model to an actual field
139        for field in instance._meta.get_fields():
140            if not isinstance(field, ManyToManyRel):
141                continue
142            if field.through == sender:
143                m2m_field = field
144        if m2m_field:
145            # If we're clearing we just set the "flag" to True
146            if action_direction == "clear":
147                pk_set = True
148            elif AuditIncludeExpandedDiff.get():
149                related_model: type[Model] = m2m_field.related_model
150                instances = related_model.objects.filter(pk__in=pk_set)
151                pk_set = [self.serialize_simple(instance) for instance in instances]
152            thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
153        return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)

Enterprise audit middleware

enabled
24    @property
25    def enabled(self):
26        """Check if audit logging is enabled"""
27        return apps.get_app_config("authentik_enterprise").enabled()

Check if audit logging is enabled

def connect(self, request: django.http.request.HttpRequest):
29    def connect(self, request: HttpRequest):
30        super().connect(request)
31        if not self.enabled:
32            return
33        if not hasattr(request, "request_id"):
34            return
35        post_init.connect(
36            partial(self.post_init_handler, request=request),
37            dispatch_uid=request.request_id,
38            weak=False,
39        )

Connect signal for automatic logging

def disconnect(self, request: django.http.request.HttpRequest):
41    def disconnect(self, request: HttpRequest):
42        super().disconnect(request)
43        if not self.enabled:
44            return
45        if not hasattr(request, "request_id"):
46            return
47        post_init.disconnect(dispatch_uid=request.request_id)

Disconnect signals

def serialize_simple(self, model: django.db.models.base.Model) -> dict:
49    def serialize_simple(self, model: Model) -> dict:
50        """Serialize a model in a very simple way. No ForeignKeys or other relationships are
51        resolved"""
52        data = {}
53        deferred_fields = model.get_deferred_fields()
54        for field in model._meta.concrete_fields:
55            value = None
56            if field.get_attname() in deferred_fields:
57                continue
58
59            field_value = getattr(model, field.attname)
60            if isinstance(value, File):
61                field_value = value.name
62
63            # If current field value is an expression, we are not evaluating it
64            if isinstance(field_value, BaseExpression | Combinable):
65                continue
66            field_value = field.to_python(field_value)
67            data[field.name] = deepcopy(field_value)
68        return cleanse_dict(data)

Serialize a model in a very simple way. No ForeignKeys or other relationships are resolved

def diff( self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
70    def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
71        """Generate diff between dicts"""
72        diff = {}
73        for key, value in before.items():
74            if update_fields and key not in update_fields:
75                continue
76            if after.get(key) != value:
77                diff[key] = {"previous_value": value, "new_value": after.get(key)}
78        for key, value in after.items():
79            if update_fields and key not in update_fields:
80                continue
81            if key not in before and key not in diff and before.get(key) != value:
82                diff[key] = {"previous_value": before.get(key), "new_value": value}
83        return sanitize_item(diff)

Generate diff between dicts

def post_init_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, **_):
85    def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_):
86        """post_init django model handler"""
87        if not should_log_model(instance):
88            return
89        if hasattr(instance, "_previous_state"):
90            return
91        before = len(connection.queries)
92        instance._previous_state = self.serialize_simple(instance)
93        after = len(connection.queries)
94        if after > before:
95            raise AssertionError("More queries generated by serialize_simple")

post_init django model handler

def post_save_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, created: bool, thread_kwargs: dict | None = None, update_fields: list[str] | None = None, **_):
 97    def post_save_handler(
 98        self,
 99        request: HttpRequest,
100        sender,
101        instance: Model,
102        created: bool,
103        thread_kwargs: dict | None = None,
104        update_fields: list[str] | None = None,
105        **_,
106    ):
107        if not self.enabled:
108            return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
109        if not should_log_model(instance):
110            return None
111        thread_kwargs = {}
112        if hasattr(instance, "_previous_state") or created:
113            prev_state = getattr(instance, "_previous_state", {})
114            if created:
115                prev_state = {}
116            # Get current state
117            new_state = self.serialize_simple(instance)
118            diff = self.diff(prev_state, new_state, update_fields)
119            thread_kwargs["diff"] = diff
120        return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)

Signal handler for all object's post_save

def m2m_changed_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, action: str, pk_set: set[typing.Any], thread_kwargs: dict | None = None, **_):
122    def m2m_changed_handler(  # noqa: PLR0913
123        self,
124        request: HttpRequest,
125        sender,
126        instance: Model,
127        action: str,
128        pk_set: set[Any],
129        thread_kwargs: dict | None = None,
130        **_,
131    ):
132        thread_kwargs = {}
133        m2m_field = None
134        if not self.enabled:
135            return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
136        # For the audit log we don't care about `pre_` or `post_` so we trim that part off
137        _, _, action_direction = action.partition("_")
138        # resolve the "through" model to an actual field
139        for field in instance._meta.get_fields():
140            if not isinstance(field, ManyToManyRel):
141                continue
142            if field.through == sender:
143                m2m_field = field
144        if m2m_field:
145            # If we're clearing we just set the "flag" to True
146            if action_direction == "clear":
147                pk_set = True
148            elif AuditIncludeExpandedDiff.get():
149                related_model: type[Model] = m2m_field.related_model
150                instances = related_model.objects.filter(pk__in=pk_set)
151                pk_set = [self.serialize_simple(instance) for instance in instances]
152            thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}}
153        return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)

Signal handler for all object's m2m_changed