authentik.enterprise.audit.middleware
Enterprise audit middleware
1"""Enterprise audit middleware""" 2 3from copy import deepcopy 4from functools import partial 5from typing import Any 6 7from django.apps.registry import apps 8from django.core.files import File 9from django.db import connection 10from django.db.models import ManyToManyRel, Model 11from django.db.models.expressions import BaseExpression, Combinable 12from django.db.models.signals import post_init 13from django.http import HttpRequest 14 15from authentik.enterprise.audit.apps import AuditIncludeExpandedDiff 16from authentik.events.middleware import AuditMiddleware, should_log_model 17from authentik.events.utils import cleanse_dict, sanitize_item 18 19 20class EnterpriseAuditMiddleware(AuditMiddleware): 21 """Enterprise audit middleware""" 22 23 @property 24 def enabled(self): 25 """Check if audit logging is enabled""" 26 return apps.get_app_config("authentik_enterprise").enabled() 27 28 def connect(self, request: HttpRequest): 29 super().connect(request) 30 if not self.enabled: 31 return 32 if not hasattr(request, "request_id"): 33 return 34 post_init.connect( 35 partial(self.post_init_handler, request=request), 36 dispatch_uid=request.request_id, 37 weak=False, 38 ) 39 40 def disconnect(self, request: HttpRequest): 41 super().disconnect(request) 42 if not self.enabled: 43 return 44 if not hasattr(request, "request_id"): 45 return 46 post_init.disconnect(dispatch_uid=request.request_id) 47 48 def serialize_simple(self, model: Model) -> dict: 49 """Serialize a model in a very simple way. No ForeignKeys or other relationships are 50 resolved""" 51 data = {} 52 deferred_fields = model.get_deferred_fields() 53 for field in model._meta.concrete_fields: 54 value = None 55 if field.get_attname() in deferred_fields: 56 continue 57 58 field_value = getattr(model, field.attname) 59 if isinstance(value, File): 60 field_value = value.name 61 62 # If current field value is an expression, we are not evaluating it 63 if isinstance(field_value, BaseExpression | Combinable): 64 continue 65 field_value = field.to_python(field_value) 66 data[field.name] = deepcopy(field_value) 67 return cleanse_dict(data) 68 69 def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict: 70 """Generate diff between dicts""" 71 diff = {} 72 for key, value in before.items(): 73 if update_fields and key not in update_fields: 74 continue 75 if after.get(key) != value: 76 diff[key] = {"previous_value": value, "new_value": after.get(key)} 77 for key, value in after.items(): 78 if update_fields and key not in update_fields: 79 continue 80 if key not in before and key not in diff and before.get(key) != value: 81 diff[key] = {"previous_value": before.get(key), "new_value": value} 82 return sanitize_item(diff) 83 84 def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): 85 """post_init django model handler""" 86 if not should_log_model(instance): 87 return 88 if hasattr(instance, "_previous_state"): 89 return 90 before = len(connection.queries) 91 instance._previous_state = self.serialize_simple(instance) 92 after = len(connection.queries) 93 if after > before: 94 raise AssertionError("More queries generated by serialize_simple") 95 96 def post_save_handler( 97 self, 98 request: HttpRequest, 99 sender, 100 instance: Model, 101 created: bool, 102 thread_kwargs: dict | None = None, 103 update_fields: list[str] | None = None, 104 **_, 105 ): 106 if not self.enabled: 107 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) 108 if not should_log_model(instance): 109 return None 110 thread_kwargs = {} 111 if hasattr(instance, "_previous_state") or created: 112 prev_state = getattr(instance, "_previous_state", {}) 113 if created: 114 prev_state = {} 115 # Get current state 116 new_state = self.serialize_simple(instance) 117 diff = self.diff(prev_state, new_state, update_fields) 118 thread_kwargs["diff"] = diff 119 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) 120 121 def m2m_changed_handler( # noqa: PLR0913 122 self, 123 request: HttpRequest, 124 sender, 125 instance: Model, 126 action: str, 127 pk_set: set[Any], 128 thread_kwargs: dict | None = None, 129 **_, 130 ): 131 thread_kwargs = {} 132 m2m_field = None 133 if not self.enabled: 134 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) 135 # For the audit log we don't care about `pre_` or `post_` so we trim that part off 136 _, _, action_direction = action.partition("_") 137 # resolve the "through" model to an actual field 138 for field in instance._meta.get_fields(): 139 if not isinstance(field, ManyToManyRel): 140 continue 141 if field.through == sender: 142 m2m_field = field 143 if m2m_field: 144 # If we're clearing we just set the "flag" to True 145 if action_direction == "clear": 146 pk_set = True 147 elif AuditIncludeExpandedDiff.get(): 148 related_model: type[Model] = m2m_field.related_model 149 instances = related_model.objects.filter(pk__in=pk_set) 150 pk_set = [self.serialize_simple(instance) for instance in instances] 151 thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} 152 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
21class EnterpriseAuditMiddleware(AuditMiddleware): 22 """Enterprise audit middleware""" 23 24 @property 25 def enabled(self): 26 """Check if audit logging is enabled""" 27 return apps.get_app_config("authentik_enterprise").enabled() 28 29 def connect(self, request: HttpRequest): 30 super().connect(request) 31 if not self.enabled: 32 return 33 if not hasattr(request, "request_id"): 34 return 35 post_init.connect( 36 partial(self.post_init_handler, request=request), 37 dispatch_uid=request.request_id, 38 weak=False, 39 ) 40 41 def disconnect(self, request: HttpRequest): 42 super().disconnect(request) 43 if not self.enabled: 44 return 45 if not hasattr(request, "request_id"): 46 return 47 post_init.disconnect(dispatch_uid=request.request_id) 48 49 def serialize_simple(self, model: Model) -> dict: 50 """Serialize a model in a very simple way. No ForeignKeys or other relationships are 51 resolved""" 52 data = {} 53 deferred_fields = model.get_deferred_fields() 54 for field in model._meta.concrete_fields: 55 value = None 56 if field.get_attname() in deferred_fields: 57 continue 58 59 field_value = getattr(model, field.attname) 60 if isinstance(value, File): 61 field_value = value.name 62 63 # If current field value is an expression, we are not evaluating it 64 if isinstance(field_value, BaseExpression | Combinable): 65 continue 66 field_value = field.to_python(field_value) 67 data[field.name] = deepcopy(field_value) 68 return cleanse_dict(data) 69 70 def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict: 71 """Generate diff between dicts""" 72 diff = {} 73 for key, value in before.items(): 74 if update_fields and key not in update_fields: 75 continue 76 if after.get(key) != value: 77 diff[key] = {"previous_value": value, "new_value": after.get(key)} 78 for key, value in after.items(): 79 if update_fields and key not in update_fields: 80 continue 81 if key not in before and key not in diff and before.get(key) != value: 82 diff[key] = {"previous_value": before.get(key), "new_value": value} 83 return sanitize_item(diff) 84 85 def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): 86 """post_init django model handler""" 87 if not should_log_model(instance): 88 return 89 if hasattr(instance, "_previous_state"): 90 return 91 before = len(connection.queries) 92 instance._previous_state = self.serialize_simple(instance) 93 after = len(connection.queries) 94 if after > before: 95 raise AssertionError("More queries generated by serialize_simple") 96 97 def post_save_handler( 98 self, 99 request: HttpRequest, 100 sender, 101 instance: Model, 102 created: bool, 103 thread_kwargs: dict | None = None, 104 update_fields: list[str] | None = None, 105 **_, 106 ): 107 if not self.enabled: 108 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) 109 if not should_log_model(instance): 110 return None 111 thread_kwargs = {} 112 if hasattr(instance, "_previous_state") or created: 113 prev_state = getattr(instance, "_previous_state", {}) 114 if created: 115 prev_state = {} 116 # Get current state 117 new_state = self.serialize_simple(instance) 118 diff = self.diff(prev_state, new_state, update_fields) 119 thread_kwargs["diff"] = diff 120 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) 121 122 def m2m_changed_handler( # noqa: PLR0913 123 self, 124 request: HttpRequest, 125 sender, 126 instance: Model, 127 action: str, 128 pk_set: set[Any], 129 thread_kwargs: dict | None = None, 130 **_, 131 ): 132 thread_kwargs = {} 133 m2m_field = None 134 if not self.enabled: 135 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) 136 # For the audit log we don't care about `pre_` or `post_` so we trim that part off 137 _, _, action_direction = action.partition("_") 138 # resolve the "through" model to an actual field 139 for field in instance._meta.get_fields(): 140 if not isinstance(field, ManyToManyRel): 141 continue 142 if field.through == sender: 143 m2m_field = field 144 if m2m_field: 145 # If we're clearing we just set the "flag" to True 146 if action_direction == "clear": 147 pk_set = True 148 elif AuditIncludeExpandedDiff.get(): 149 related_model: type[Model] = m2m_field.related_model 150 instances = related_model.objects.filter(pk__in=pk_set) 151 pk_set = [self.serialize_simple(instance) for instance in instances] 152 thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} 153 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
Enterprise audit middleware
enabled
24 @property 25 def enabled(self): 26 """Check if audit logging is enabled""" 27 return apps.get_app_config("authentik_enterprise").enabled()
Check if audit logging is enabled
def
connect(self, request: django.http.request.HttpRequest):
29 def connect(self, request: HttpRequest): 30 super().connect(request) 31 if not self.enabled: 32 return 33 if not hasattr(request, "request_id"): 34 return 35 post_init.connect( 36 partial(self.post_init_handler, request=request), 37 dispatch_uid=request.request_id, 38 weak=False, 39 )
Connect signal for automatic logging
def
disconnect(self, request: django.http.request.HttpRequest):
41 def disconnect(self, request: HttpRequest): 42 super().disconnect(request) 43 if not self.enabled: 44 return 45 if not hasattr(request, "request_id"): 46 return 47 post_init.disconnect(dispatch_uid=request.request_id)
Disconnect signals
def
serialize_simple(self, model: django.db.models.base.Model) -> dict:
49 def serialize_simple(self, model: Model) -> dict: 50 """Serialize a model in a very simple way. No ForeignKeys or other relationships are 51 resolved""" 52 data = {} 53 deferred_fields = model.get_deferred_fields() 54 for field in model._meta.concrete_fields: 55 value = None 56 if field.get_attname() in deferred_fields: 57 continue 58 59 field_value = getattr(model, field.attname) 60 if isinstance(value, File): 61 field_value = value.name 62 63 # If current field value is an expression, we are not evaluating it 64 if isinstance(field_value, BaseExpression | Combinable): 65 continue 66 field_value = field.to_python(field_value) 67 data[field.name] = deepcopy(field_value) 68 return cleanse_dict(data)
Serialize a model in a very simple way. No ForeignKeys or other relationships are resolved
def
diff( self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict:
70 def diff(self, before: dict, after: dict, update_fields: list[str] | None = None) -> dict: 71 """Generate diff between dicts""" 72 diff = {} 73 for key, value in before.items(): 74 if update_fields and key not in update_fields: 75 continue 76 if after.get(key) != value: 77 diff[key] = {"previous_value": value, "new_value": after.get(key)} 78 for key, value in after.items(): 79 if update_fields and key not in update_fields: 80 continue 81 if key not in before and key not in diff and before.get(key) != value: 82 diff[key] = {"previous_value": before.get(key), "new_value": value} 83 return sanitize_item(diff)
Generate diff between dicts
def
post_init_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, **_):
85 def post_init_handler(self, request: HttpRequest, sender, instance: Model, **_): 86 """post_init django model handler""" 87 if not should_log_model(instance): 88 return 89 if hasattr(instance, "_previous_state"): 90 return 91 before = len(connection.queries) 92 instance._previous_state = self.serialize_simple(instance) 93 after = len(connection.queries) 94 if after > before: 95 raise AssertionError("More queries generated by serialize_simple")
post_init django model handler
def
post_save_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, created: bool, thread_kwargs: dict | None = None, update_fields: list[str] | None = None, **_):
97 def post_save_handler( 98 self, 99 request: HttpRequest, 100 sender, 101 instance: Model, 102 created: bool, 103 thread_kwargs: dict | None = None, 104 update_fields: list[str] | None = None, 105 **_, 106 ): 107 if not self.enabled: 108 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_) 109 if not should_log_model(instance): 110 return None 111 thread_kwargs = {} 112 if hasattr(instance, "_previous_state") or created: 113 prev_state = getattr(instance, "_previous_state", {}) 114 if created: 115 prev_state = {} 116 # Get current state 117 new_state = self.serialize_simple(instance) 118 diff = self.diff(prev_state, new_state, update_fields) 119 thread_kwargs["diff"] = diff 120 return super().post_save_handler(request, sender, instance, created, thread_kwargs, **_)
Signal handler for all object's post_save
def
m2m_changed_handler( self, request: django.http.request.HttpRequest, sender, instance: django.db.models.base.Model, action: str, pk_set: set[typing.Any], thread_kwargs: dict | None = None, **_):
122 def m2m_changed_handler( # noqa: PLR0913 123 self, 124 request: HttpRequest, 125 sender, 126 instance: Model, 127 action: str, 128 pk_set: set[Any], 129 thread_kwargs: dict | None = None, 130 **_, 131 ): 132 thread_kwargs = {} 133 m2m_field = None 134 if not self.enabled: 135 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs) 136 # For the audit log we don't care about `pre_` or `post_` so we trim that part off 137 _, _, action_direction = action.partition("_") 138 # resolve the "through" model to an actual field 139 for field in instance._meta.get_fields(): 140 if not isinstance(field, ManyToManyRel): 141 continue 142 if field.through == sender: 143 m2m_field = field 144 if m2m_field: 145 # If we're clearing we just set the "flag" to True 146 if action_direction == "clear": 147 pk_set = True 148 elif AuditIncludeExpandedDiff.get(): 149 related_model: type[Model] = m2m_field.related_model 150 instances = related_model.objects.filter(pk__in=pk_set) 151 pk_set = [self.serialize_simple(instance) for instance in instances] 152 thread_kwargs["diff"] = {m2m_field.related_name: {action_direction: pk_set}} 153 return super().m2m_changed_handler(request, sender, instance, action, thread_kwargs)
Signal handler for all object's m2m_changed