authentik.tasks.middleware
1from collections.abc import Callable 2from typing import Any, cast 3 4from django.conf import settings 5from django.db import OperationalError 6from django_dramatiq_postgres.middleware import ( 7 CurrentTask as BaseCurrentTask, 8) 9from django_dramatiq_postgres.middleware import ( 10 MetricsMiddleware as BaseMetricsMiddleware, 11) 12from dramatiq.broker import Broker 13from dramatiq.message import Message 14from dramatiq.middleware import Middleware 15from psycopg.errors import Error 16from structlog.stdlib import get_logger 17 18from authentik.events.models import Event, EventAction 19from authentik.lib.sentry import should_ignore_exception 20from authentik.lib.utils.reflection import class_to_path 21from authentik.root.signals import post_startup, pre_startup, startup 22from authentik.tasks.models import Task, TaskLog, TaskStatus 23from authentik.tenants.models import Tenant 24from authentik.tenants.utils import get_current_tenant 25 26LOGGER = get_logger() 27HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind() 28DB_ERRORS = (OperationalError, Error) 29 30 31class StartupSignalsMiddleware(Middleware): 32 def after_process_boot(self, broker: Broker): 33 _startup_sender = type("WorkerStartup", (object,), {}) 34 pre_startup.send(sender=_startup_sender) 35 startup.send(sender=_startup_sender) 36 post_startup.send(sender=_startup_sender) 37 38 39class CurrentTask(BaseCurrentTask): 40 @classmethod 41 def get_task(cls) -> Task: 42 return cast(Task, super().get_task()) 43 44 45class TenantMiddleware(Middleware): 46 def before_enqueue(self, broker: Broker, message: Message, delay: int): 47 message.options["model_create_defaults"]["tenant"] = get_current_tenant() 48 49 def before_process_message(self, broker: Broker, message: Message): 50 task: Task = message.options["task"] 51 task.tenant.activate() 52 53 def after_process_message(self, *args, **kwargs): 54 Tenant.deactivate() 55 56 after_skip_message = after_process_message 57 58 59class ModelDataMiddleware(Middleware): 60 @property 61 def actor_options(self): 62 return {"rel_obj", "uid"} 63 64 def before_enqueue(self, broker: Broker, message: Message, delay: int): 65 if "rel_obj" in message.options: 66 message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj") 67 if "uid" in message.options: 68 message.options["model_defaults"]["_uid"] = message.options.pop("uid") 69 70 71class TaskLogMiddleware(Middleware): 72 def after_enqueue(self, broker: Broker, message: Message, delay: int | None): 73 task: Task = message.options["task"] 74 task_created: bool = message.options["task_created"] 75 if task_created: 76 TaskLog.create_from_log_event( 77 task, 78 Task._make_log( 79 class_to_path(type(self)), 80 TaskStatus.INFO, 81 "Task has been queued", 82 delay=delay, 83 ), 84 ) 85 else: 86 TaskLog.objects.filter(task=task).update(previous=True) 87 TaskLog.create_from_log_event( 88 task, 89 Task._make_log( 90 class_to_path(type(self)), 91 TaskStatus.INFO, 92 "Task will be retried", 93 delay=delay, 94 ), 95 ) 96 97 def before_process_message(self, broker: Broker, message: Message): 98 task: Task = message.options["task"] 99 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed") 100 101 def after_process_message( 102 self, 103 broker: Broker, 104 message: Message, 105 *, 106 result: Any | None = None, 107 exception: Exception | None = None, 108 ): 109 task: Task = message.options["task"] 110 if exception is None: 111 task.log( 112 class_to_path(type(self)), 113 TaskStatus.INFO, 114 "Task finished processing without errors", 115 ) 116 return 117 task.log( 118 class_to_path(type(self)), 119 TaskStatus.ERROR, 120 exception, 121 ) 122 if should_ignore_exception(exception): 123 return 124 event_kwargs = { 125 "actor": task.actor_name, 126 } 127 if task.rel_obj: 128 event_kwargs["rel_obj"] = task.rel_obj 129 Event.new( 130 EventAction.SYSTEM_TASK_EXCEPTION, 131 message=f"Task {task.actor_name} encountered an error", 132 **event_kwargs, 133 ).with_exception(exception).save() 134 135 def after_skip_message(self, broker: Broker, message: Message): 136 task: Task = message.options["task"] 137 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped") 138 139 140class LoggingMiddleware(Middleware): 141 def __init__(self): 142 self.logger = get_logger() 143 144 def after_enqueue(self, broker: Broker, message: Message, delay: int): 145 self.logger.info( 146 "Task enqueued", 147 task_id=message.message_id, 148 task_name=message.actor_name, 149 ) 150 151 def before_process_message(self, broker: Broker, message: Message): 152 self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name) 153 154 def after_process_message( 155 self, 156 broker: Broker, 157 message: Message, 158 *, 159 result: Any | None = None, 160 exception: Exception | None = None, 161 ): 162 self.logger.info( 163 "Task finished", 164 task_id=message.message_id, 165 task_name=message.actor_name, 166 exc=exception, 167 ) 168 169 def after_skip_message(self, broker: Broker, message: Message): 170 self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name) 171 172 173class DescriptionMiddleware(Middleware): 174 @property 175 def actor_options(self): 176 return {"description"} 177 178 179class MetricsMiddleware(BaseMetricsMiddleware): 180 @property 181 def forks(self) -> list[Callable[[], None]]: 182 return [] 183 184 def before_worker_boot(self, broker: Broker, worker: Any) -> None: 185 if settings.TEST: 186 return super().before_worker_boot(broker, worker) 187 188 from prometheus_client import values 189 from prometheus_client.values import MultiProcessValue 190 191 values.ValueClass = MultiProcessValue(lambda: worker.worker_id) 192 193 return super().before_worker_boot(broker, worker) 194 195 def after_worker_shutdown(self, broker: Broker, worker: Any) -> None: 196 if settings.TEST: 197 return 198 199 from prometheus_client import multiprocess 200 201 multiprocess.mark_process_dead(worker.worker_id)
32class StartupSignalsMiddleware(Middleware): 33 def after_process_boot(self, broker: Broker): 34 _startup_sender = type("WorkerStartup", (object,), {}) 35 pre_startup.send(sender=_startup_sender) 36 startup.send(sender=_startup_sender) 37 post_startup.send(sender=_startup_sender)
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
33 def after_process_boot(self, broker: Broker): 34 _startup_sender = type("WorkerStartup", (object,), {}) 35 pre_startup.send(sender=_startup_sender) 36 startup.send(sender=_startup_sender) 37 post_startup.send(sender=_startup_sender)
Called immediately after subprocess start up.
40class CurrentTask(BaseCurrentTask): 41 @classmethod 42 def get_task(cls) -> Task: 43 return cast(Task, super().get_task())
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
46class TenantMiddleware(Middleware): 47 def before_enqueue(self, broker: Broker, message: Message, delay: int): 48 message.options["model_create_defaults"]["tenant"] = get_current_tenant() 49 50 def before_process_message(self, broker: Broker, message: Message): 51 task: Task = message.options["task"] 52 task.tenant.activate() 53 54 def after_process_message(self, *args, **kwargs): 55 Tenant.deactivate() 56 57 after_skip_message = after_process_message
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
47 def before_enqueue(self, broker: Broker, message: Message, delay: int): 48 message.options["model_create_defaults"]["tenant"] = get_current_tenant()
Called before a message is enqueued (including retries).
50 def before_process_message(self, broker: Broker, message: Message): 51 task: Task = message.options["task"] 52 task.tenant.activate()
Called before a message is processed.
Raises:
SkipMessage: If the current message should be skipped. When
this is raised, after_skip_message is emitted instead
of after_process_message.
60class ModelDataMiddleware(Middleware): 61 @property 62 def actor_options(self): 63 return {"rel_obj", "uid"} 64 65 def before_enqueue(self, broker: Broker, message: Message, delay: int): 66 if "rel_obj" in message.options: 67 message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj") 68 if "uid" in message.options: 69 message.options["model_defaults"]["_uid"] = message.options.pop("uid")
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
65 def before_enqueue(self, broker: Broker, message: Message, delay: int): 66 if "rel_obj" in message.options: 67 message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj") 68 if "uid" in message.options: 69 message.options["model_defaults"]["_uid"] = message.options.pop("uid")
Called before a message is enqueued (including retries).
72class TaskLogMiddleware(Middleware): 73 def after_enqueue(self, broker: Broker, message: Message, delay: int | None): 74 task: Task = message.options["task"] 75 task_created: bool = message.options["task_created"] 76 if task_created: 77 TaskLog.create_from_log_event( 78 task, 79 Task._make_log( 80 class_to_path(type(self)), 81 TaskStatus.INFO, 82 "Task has been queued", 83 delay=delay, 84 ), 85 ) 86 else: 87 TaskLog.objects.filter(task=task).update(previous=True) 88 TaskLog.create_from_log_event( 89 task, 90 Task._make_log( 91 class_to_path(type(self)), 92 TaskStatus.INFO, 93 "Task will be retried", 94 delay=delay, 95 ), 96 ) 97 98 def before_process_message(self, broker: Broker, message: Message): 99 task: Task = message.options["task"] 100 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed") 101 102 def after_process_message( 103 self, 104 broker: Broker, 105 message: Message, 106 *, 107 result: Any | None = None, 108 exception: Exception | None = None, 109 ): 110 task: Task = message.options["task"] 111 if exception is None: 112 task.log( 113 class_to_path(type(self)), 114 TaskStatus.INFO, 115 "Task finished processing without errors", 116 ) 117 return 118 task.log( 119 class_to_path(type(self)), 120 TaskStatus.ERROR, 121 exception, 122 ) 123 if should_ignore_exception(exception): 124 return 125 event_kwargs = { 126 "actor": task.actor_name, 127 } 128 if task.rel_obj: 129 event_kwargs["rel_obj"] = task.rel_obj 130 Event.new( 131 EventAction.SYSTEM_TASK_EXCEPTION, 132 message=f"Task {task.actor_name} encountered an error", 133 **event_kwargs, 134 ).with_exception(exception).save() 135 136 def after_skip_message(self, broker: Broker, message: Message): 137 task: Task = message.options["task"] 138 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
73 def after_enqueue(self, broker: Broker, message: Message, delay: int | None): 74 task: Task = message.options["task"] 75 task_created: bool = message.options["task_created"] 76 if task_created: 77 TaskLog.create_from_log_event( 78 task, 79 Task._make_log( 80 class_to_path(type(self)), 81 TaskStatus.INFO, 82 "Task has been queued", 83 delay=delay, 84 ), 85 ) 86 else: 87 TaskLog.objects.filter(task=task).update(previous=True) 88 TaskLog.create_from_log_event( 89 task, 90 Task._make_log( 91 class_to_path(type(self)), 92 TaskStatus.INFO, 93 "Task will be retried", 94 delay=delay, 95 ), 96 )
Called after a message has been enqueued (including retries).
98 def before_process_message(self, broker: Broker, message: Message): 99 task: Task = message.options["task"] 100 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")
Called before a message is processed.
Raises:
SkipMessage: If the current message should be skipped. When
this is raised, after_skip_message is emitted instead
of after_process_message.
102 def after_process_message( 103 self, 104 broker: Broker, 105 message: Message, 106 *, 107 result: Any | None = None, 108 exception: Exception | None = None, 109 ): 110 task: Task = message.options["task"] 111 if exception is None: 112 task.log( 113 class_to_path(type(self)), 114 TaskStatus.INFO, 115 "Task finished processing without errors", 116 ) 117 return 118 task.log( 119 class_to_path(type(self)), 120 TaskStatus.ERROR, 121 exception, 122 ) 123 if should_ignore_exception(exception): 124 return 125 event_kwargs = { 126 "actor": task.actor_name, 127 } 128 if task.rel_obj: 129 event_kwargs["rel_obj"] = task.rel_obj 130 Event.new( 131 EventAction.SYSTEM_TASK_EXCEPTION, 132 message=f"Task {task.actor_name} encountered an error", 133 **event_kwargs, 134 ).with_exception(exception).save()
Called after a message has been processed.
136 def after_skip_message(self, broker: Broker, message: Message): 137 task: Task = message.options["task"] 138 task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")
Called instead of after_process_message after a message
has been skipped.
141class LoggingMiddleware(Middleware): 142 def __init__(self): 143 self.logger = get_logger() 144 145 def after_enqueue(self, broker: Broker, message: Message, delay: int): 146 self.logger.info( 147 "Task enqueued", 148 task_id=message.message_id, 149 task_name=message.actor_name, 150 ) 151 152 def before_process_message(self, broker: Broker, message: Message): 153 self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name) 154 155 def after_process_message( 156 self, 157 broker: Broker, 158 message: Message, 159 *, 160 result: Any | None = None, 161 exception: Exception | None = None, 162 ): 163 self.logger.info( 164 "Task finished", 165 task_id=message.message_id, 166 task_name=message.actor_name, 167 exc=exception, 168 ) 169 170 def after_skip_message(self, broker: Broker, message: Message): 171 self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
145 def after_enqueue(self, broker: Broker, message: Message, delay: int): 146 self.logger.info( 147 "Task enqueued", 148 task_id=message.message_id, 149 task_name=message.actor_name, 150 )
Called after a message has been enqueued (including retries).
152 def before_process_message(self, broker: Broker, message: Message): 153 self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)
Called before a message is processed.
Raises:
SkipMessage: If the current message should be skipped. When
this is raised, after_skip_message is emitted instead
of after_process_message.
155 def after_process_message( 156 self, 157 broker: Broker, 158 message: Message, 159 *, 160 result: Any | None = None, 161 exception: Exception | None = None, 162 ): 163 self.logger.info( 164 "Task finished", 165 task_id=message.message_id, 166 task_name=message.actor_name, 167 exc=exception, 168 )
Called after a message has been processed.
170 def after_skip_message(self, broker: Broker, message: Message): 171 self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)
Called instead of after_process_message after a message
has been skipped.
174class DescriptionMiddleware(Middleware): 175 @property 176 def actor_options(self): 177 return {"description"}
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
180class MetricsMiddleware(BaseMetricsMiddleware): 181 @property 182 def forks(self) -> list[Callable[[], None]]: 183 return [] 184 185 def before_worker_boot(self, broker: Broker, worker: Any) -> None: 186 if settings.TEST: 187 return super().before_worker_boot(broker, worker) 188 189 from prometheus_client import values 190 from prometheus_client.values import MultiProcessValue 191 192 values.ValueClass = MultiProcessValue(lambda: worker.worker_id) 193 194 return super().before_worker_boot(broker, worker) 195 196 def after_worker_shutdown(self, broker: Broker, worker: Any) -> None: 197 if settings.TEST: 198 return 199 200 from prometheus_client import multiprocess 201 202 multiprocess.mark_process_dead(worker.worker_id)
Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.
185 def before_worker_boot(self, broker: Broker, worker: Any) -> None: 186 if settings.TEST: 187 return super().before_worker_boot(broker, worker) 188 189 from prometheus_client import values 190 from prometheus_client.values import MultiProcessValue 191 192 values.ValueClass = MultiProcessValue(lambda: worker.worker_id) 193 194 return super().before_worker_boot(broker, worker)
Called before the worker process starts up.
196 def after_worker_shutdown(self, broker: Broker, worker: Any) -> None: 197 if settings.TEST: 198 return 199 200 from prometheus_client import multiprocess 201 202 multiprocess.mark_process_dead(worker.worker_id)
Called after the worker process shuts down.