authentik.tasks.middleware

  1from collections.abc import Callable
  2from typing import Any, cast
  3
  4from django.conf import settings
  5from django.db import OperationalError
  6from django_dramatiq_postgres.middleware import (
  7    CurrentTask as BaseCurrentTask,
  8)
  9from django_dramatiq_postgres.middleware import (
 10    MetricsMiddleware as BaseMetricsMiddleware,
 11)
 12from dramatiq.broker import Broker
 13from dramatiq.message import Message
 14from dramatiq.middleware import Middleware
 15from psycopg.errors import Error
 16from structlog.stdlib import get_logger
 17
 18from authentik.events.models import Event, EventAction
 19from authentik.lib.sentry import should_ignore_exception
 20from authentik.lib.utils.reflection import class_to_path
 21from authentik.root.signals import post_startup, pre_startup, startup
 22from authentik.tasks.models import Task, TaskLog, TaskStatus
 23from authentik.tenants.models import Tenant
 24from authentik.tenants.utils import get_current_tenant
 25
 26LOGGER = get_logger()
 27HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind()
 28DB_ERRORS = (OperationalError, Error)
 29
 30
 31class StartupSignalsMiddleware(Middleware):
 32    def after_process_boot(self, broker: Broker):
 33        _startup_sender = type("WorkerStartup", (object,), {})
 34        pre_startup.send(sender=_startup_sender)
 35        startup.send(sender=_startup_sender)
 36        post_startup.send(sender=_startup_sender)
 37
 38
 39class CurrentTask(BaseCurrentTask):
 40    @classmethod
 41    def get_task(cls) -> Task:
 42        return cast(Task, super().get_task())
 43
 44
 45class TenantMiddleware(Middleware):
 46    def before_enqueue(self, broker: Broker, message: Message, delay: int):
 47        message.options["model_create_defaults"]["tenant"] = get_current_tenant()
 48
 49    def before_process_message(self, broker: Broker, message: Message):
 50        task: Task = message.options["task"]
 51        task.tenant.activate()
 52
 53    def after_process_message(self, *args, **kwargs):
 54        Tenant.deactivate()
 55
 56    after_skip_message = after_process_message
 57
 58
 59class ModelDataMiddleware(Middleware):
 60    @property
 61    def actor_options(self):
 62        return {"rel_obj", "uid"}
 63
 64    def before_enqueue(self, broker: Broker, message: Message, delay: int):
 65        if "rel_obj" in message.options:
 66            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
 67        if "uid" in message.options:
 68            message.options["model_defaults"]["_uid"] = message.options.pop("uid")
 69
 70
 71class TaskLogMiddleware(Middleware):
 72    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
 73        task: Task = message.options["task"]
 74        task_created: bool = message.options["task_created"]
 75        if task_created:
 76            TaskLog.create_from_log_event(
 77                task,
 78                Task._make_log(
 79                    class_to_path(type(self)),
 80                    TaskStatus.INFO,
 81                    "Task has been queued",
 82                    delay=delay,
 83                ),
 84            )
 85        else:
 86            TaskLog.objects.filter(task=task).update(previous=True)
 87            TaskLog.create_from_log_event(
 88                task,
 89                Task._make_log(
 90                    class_to_path(type(self)),
 91                    TaskStatus.INFO,
 92                    "Task will be retried",
 93                    delay=delay,
 94                ),
 95            )
 96
 97    def before_process_message(self, broker: Broker, message: Message):
 98        task: Task = message.options["task"]
 99        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")
100
101    def after_process_message(
102        self,
103        broker: Broker,
104        message: Message,
105        *,
106        result: Any | None = None,
107        exception: Exception | None = None,
108    ):
109        task: Task = message.options["task"]
110        if exception is None:
111            task.log(
112                class_to_path(type(self)),
113                TaskStatus.INFO,
114                "Task finished processing without errors",
115            )
116            return
117        task.log(
118            class_to_path(type(self)),
119            TaskStatus.ERROR,
120            exception,
121        )
122        if should_ignore_exception(exception):
123            return
124        event_kwargs = {
125            "actor": task.actor_name,
126        }
127        if task.rel_obj:
128            event_kwargs["rel_obj"] = task.rel_obj
129        Event.new(
130            EventAction.SYSTEM_TASK_EXCEPTION,
131            message=f"Task {task.actor_name} encountered an error",
132            **event_kwargs,
133        ).with_exception(exception).save()
134
135    def after_skip_message(self, broker: Broker, message: Message):
136        task: Task = message.options["task"]
137        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")
138
139
140class LoggingMiddleware(Middleware):
141    def __init__(self):
142        self.logger = get_logger()
143
144    def after_enqueue(self, broker: Broker, message: Message, delay: int):
145        self.logger.info(
146            "Task enqueued",
147            task_id=message.message_id,
148            task_name=message.actor_name,
149        )
150
151    def before_process_message(self, broker: Broker, message: Message):
152        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)
153
154    def after_process_message(
155        self,
156        broker: Broker,
157        message: Message,
158        *,
159        result: Any | None = None,
160        exception: Exception | None = None,
161    ):
162        self.logger.info(
163            "Task finished",
164            task_id=message.message_id,
165            task_name=message.actor_name,
166            exc=exception,
167        )
168
169    def after_skip_message(self, broker: Broker, message: Message):
170        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)
171
172
173class DescriptionMiddleware(Middleware):
174    @property
175    def actor_options(self):
176        return {"description"}
177
178
179class MetricsMiddleware(BaseMetricsMiddleware):
180    @property
181    def forks(self) -> list[Callable[[], None]]:
182        return []
183
184    def before_worker_boot(self, broker: Broker, worker: Any) -> None:
185        if settings.TEST:
186            return super().before_worker_boot(broker, worker)
187
188        from prometheus_client import values
189        from prometheus_client.values import MultiProcessValue
190
191        values.ValueClass = MultiProcessValue(lambda: worker.worker_id)
192
193        return super().before_worker_boot(broker, worker)
194
195    def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
196        if settings.TEST:
197            return
198
199        from prometheus_client import multiprocess
200
201        multiprocess.mark_process_dead(worker.worker_id)
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
HEALTHCHECK_LOGGER = <BoundLoggerFilteringAtDebug(context={}, processors=[<function add_log_level>, <function add_logger_name>, <function merge_contextvars>, <function add_process_id>, <function add_tenant_information>, <structlog.stdlib.PositionalArgumentsFormatter object>, <structlog.processors.TimeStamper object>, <structlog.processors.StackInfoRenderer object>, <structlog.processors.ExceptionRenderer object>, <function ProcessorFormatter.wrap_for_formatter>])>
DB_ERRORS = (<class 'django.db.utils.OperationalError'>, <class 'psycopg.Error'>)
class StartupSignalsMiddleware(dramatiq.middleware.middleware.Middleware):
32class StartupSignalsMiddleware(Middleware):
33    def after_process_boot(self, broker: Broker):
34        _startup_sender = type("WorkerStartup", (object,), {})
35        pre_startup.send(sender=_startup_sender)
36        startup.send(sender=_startup_sender)
37        post_startup.send(sender=_startup_sender)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def after_process_boot(self, broker: dramatiq.broker.Broker):
33    def after_process_boot(self, broker: Broker):
34        _startup_sender = type("WorkerStartup", (object,), {})
35        pre_startup.send(sender=_startup_sender)
36        startup.send(sender=_startup_sender)
37        post_startup.send(sender=_startup_sender)

Called immediately after subprocess start up.

class CurrentTask(django_dramatiq_postgres.middleware.CurrentTask):
40class CurrentTask(BaseCurrentTask):
41    @classmethod
42    def get_task(cls) -> Task:
43        return cast(Task, super().get_task())

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

@classmethod
def get_task(cls) -> authentik.tasks.models.Task:
41    @classmethod
42    def get_task(cls) -> Task:
43        return cast(Task, super().get_task())
class TenantMiddleware(dramatiq.middleware.middleware.Middleware):
46class TenantMiddleware(Middleware):
47    def before_enqueue(self, broker: Broker, message: Message, delay: int):
48        message.options["model_create_defaults"]["tenant"] = get_current_tenant()
49
50    def before_process_message(self, broker: Broker, message: Message):
51        task: Task = message.options["task"]
52        task.tenant.activate()
53
54    def after_process_message(self, *args, **kwargs):
55        Tenant.deactivate()
56
57    after_skip_message = after_process_message

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def before_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
47    def before_enqueue(self, broker: Broker, message: Message, delay: int):
48        message.options["model_create_defaults"]["tenant"] = get_current_tenant()

Called before a message is enqueued (including retries).

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
50    def before_process_message(self, broker: Broker, message: Message):
51        task: Task = message.options["task"]
52        task.tenant.activate()

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message(self, *args, **kwargs):
54    def after_process_message(self, *args, **kwargs):
55        Tenant.deactivate()

Called after a message has been processed.

def after_skip_message(self, *args, **kwargs):
54    def after_process_message(self, *args, **kwargs):
55        Tenant.deactivate()

Called after a message has been processed.

class ModelDataMiddleware(dramatiq.middleware.middleware.Middleware):
60class ModelDataMiddleware(Middleware):
61    @property
62    def actor_options(self):
63        return {"rel_obj", "uid"}
64
65    def before_enqueue(self, broker: Broker, message: Message, delay: int):
66        if "rel_obj" in message.options:
67            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
68        if "uid" in message.options:
69            message.options["model_defaults"]["_uid"] = message.options.pop("uid")

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

actor_options
61    @property
62    def actor_options(self):
63        return {"rel_obj", "uid"}

The set of options that may be configured on each actor.

def before_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
65    def before_enqueue(self, broker: Broker, message: Message, delay: int):
66        if "rel_obj" in message.options:
67            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
68        if "uid" in message.options:
69            message.options["model_defaults"]["_uid"] = message.options.pop("uid")

Called before a message is enqueued (including retries).

class TaskLogMiddleware(dramatiq.middleware.middleware.Middleware):
 72class TaskLogMiddleware(Middleware):
 73    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
 74        task: Task = message.options["task"]
 75        task_created: bool = message.options["task_created"]
 76        if task_created:
 77            TaskLog.create_from_log_event(
 78                task,
 79                Task._make_log(
 80                    class_to_path(type(self)),
 81                    TaskStatus.INFO,
 82                    "Task has been queued",
 83                    delay=delay,
 84                ),
 85            )
 86        else:
 87            TaskLog.objects.filter(task=task).update(previous=True)
 88            TaskLog.create_from_log_event(
 89                task,
 90                Task._make_log(
 91                    class_to_path(type(self)),
 92                    TaskStatus.INFO,
 93                    "Task will be retried",
 94                    delay=delay,
 95                ),
 96            )
 97
 98    def before_process_message(self, broker: Broker, message: Message):
 99        task: Task = message.options["task"]
100        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")
101
102    def after_process_message(
103        self,
104        broker: Broker,
105        message: Message,
106        *,
107        result: Any | None = None,
108        exception: Exception | None = None,
109    ):
110        task: Task = message.options["task"]
111        if exception is None:
112            task.log(
113                class_to_path(type(self)),
114                TaskStatus.INFO,
115                "Task finished processing without errors",
116            )
117            return
118        task.log(
119            class_to_path(type(self)),
120            TaskStatus.ERROR,
121            exception,
122        )
123        if should_ignore_exception(exception):
124            return
125        event_kwargs = {
126            "actor": task.actor_name,
127        }
128        if task.rel_obj:
129            event_kwargs["rel_obj"] = task.rel_obj
130        Event.new(
131            EventAction.SYSTEM_TASK_EXCEPTION,
132            message=f"Task {task.actor_name} encountered an error",
133            **event_kwargs,
134        ).with_exception(exception).save()
135
136    def after_skip_message(self, broker: Broker, message: Message):
137        task: Task = message.options["task"]
138        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def after_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int | None):
73    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
74        task: Task = message.options["task"]
75        task_created: bool = message.options["task_created"]
76        if task_created:
77            TaskLog.create_from_log_event(
78                task,
79                Task._make_log(
80                    class_to_path(type(self)),
81                    TaskStatus.INFO,
82                    "Task has been queued",
83                    delay=delay,
84                ),
85            )
86        else:
87            TaskLog.objects.filter(task=task).update(previous=True)
88            TaskLog.create_from_log_event(
89                task,
90                Task._make_log(
91                    class_to_path(type(self)),
92                    TaskStatus.INFO,
93                    "Task will be retried",
94                    delay=delay,
95                ),
96            )

Called after a message has been enqueued (including retries).

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
 98    def before_process_message(self, broker: Broker, message: Message):
 99        task: Task = message.options["task"]
100        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, *, result: Any | None = None, exception: Exception | None = None):
102    def after_process_message(
103        self,
104        broker: Broker,
105        message: Message,
106        *,
107        result: Any | None = None,
108        exception: Exception | None = None,
109    ):
110        task: Task = message.options["task"]
111        if exception is None:
112            task.log(
113                class_to_path(type(self)),
114                TaskStatus.INFO,
115                "Task finished processing without errors",
116            )
117            return
118        task.log(
119            class_to_path(type(self)),
120            TaskStatus.ERROR,
121            exception,
122        )
123        if should_ignore_exception(exception):
124            return
125        event_kwargs = {
126            "actor": task.actor_name,
127        }
128        if task.rel_obj:
129            event_kwargs["rel_obj"] = task.rel_obj
130        Event.new(
131            EventAction.SYSTEM_TASK_EXCEPTION,
132            message=f"Task {task.actor_name} encountered an error",
133            **event_kwargs,
134        ).with_exception(exception).save()

Called after a message has been processed.

def after_skip_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
136    def after_skip_message(self, broker: Broker, message: Message):
137        task: Task = message.options["task"]
138        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")

Called instead of after_process_message after a message has been skipped.

class LoggingMiddleware(dramatiq.middleware.middleware.Middleware):
141class LoggingMiddleware(Middleware):
142    def __init__(self):
143        self.logger = get_logger()
144
145    def after_enqueue(self, broker: Broker, message: Message, delay: int):
146        self.logger.info(
147            "Task enqueued",
148            task_id=message.message_id,
149            task_name=message.actor_name,
150        )
151
152    def before_process_message(self, broker: Broker, message: Message):
153        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)
154
155    def after_process_message(
156        self,
157        broker: Broker,
158        message: Message,
159        *,
160        result: Any | None = None,
161        exception: Exception | None = None,
162    ):
163        self.logger.info(
164            "Task finished",
165            task_id=message.message_id,
166            task_name=message.actor_name,
167            exc=exception,
168        )
169
170    def after_skip_message(self, broker: Broker, message: Message):
171        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

logger
def after_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
145    def after_enqueue(self, broker: Broker, message: Message, delay: int):
146        self.logger.info(
147            "Task enqueued",
148            task_id=message.message_id,
149            task_name=message.actor_name,
150        )

Called after a message has been enqueued (including retries).

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
152    def before_process_message(self, broker: Broker, message: Message):
153        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, *, result: Any | None = None, exception: Exception | None = None):
155    def after_process_message(
156        self,
157        broker: Broker,
158        message: Message,
159        *,
160        result: Any | None = None,
161        exception: Exception | None = None,
162    ):
163        self.logger.info(
164            "Task finished",
165            task_id=message.message_id,
166            task_name=message.actor_name,
167            exc=exception,
168        )

Called after a message has been processed.

def after_skip_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
170    def after_skip_message(self, broker: Broker, message: Message):
171        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)

Called instead of after_process_message after a message has been skipped.

class DescriptionMiddleware(dramatiq.middleware.middleware.Middleware):
174class DescriptionMiddleware(Middleware):
175    @property
176    def actor_options(self):
177        return {"description"}

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

actor_options
175    @property
176    def actor_options(self):
177        return {"description"}

The set of options that may be configured on each actor.

class MetricsMiddleware(django_dramatiq_postgres.middleware.MetricsMiddleware):
180class MetricsMiddleware(BaseMetricsMiddleware):
181    @property
182    def forks(self) -> list[Callable[[], None]]:
183        return []
184
185    def before_worker_boot(self, broker: Broker, worker: Any) -> None:
186        if settings.TEST:
187            return super().before_worker_boot(broker, worker)
188
189        from prometheus_client import values
190        from prometheus_client.values import MultiProcessValue
191
192        values.ValueClass = MultiProcessValue(lambda: worker.worker_id)
193
194        return super().before_worker_boot(broker, worker)
195
196    def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
197        if settings.TEST:
198            return
199
200        from prometheus_client import multiprocess
201
202        multiprocess.mark_process_dead(worker.worker_id)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

forks: list[Callable[[], None]]
181    @property
182    def forks(self) -> list[Callable[[], None]]:
183        return []

A list of functions to run in separate forks of the main process.

def before_worker_boot(self, broker: dramatiq.broker.Broker, worker: Any) -> None:
185    def before_worker_boot(self, broker: Broker, worker: Any) -> None:
186        if settings.TEST:
187            return super().before_worker_boot(broker, worker)
188
189        from prometheus_client import values
190        from prometheus_client.values import MultiProcessValue
191
192        values.ValueClass = MultiProcessValue(lambda: worker.worker_id)
193
194        return super().before_worker_boot(broker, worker)

Called before the worker process starts up.

def after_worker_shutdown(self, broker: dramatiq.broker.Broker, worker: Any) -> None:
196    def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
197        if settings.TEST:
198            return
199
200        from prometheus_client import multiprocess
201
202        multiprocess.mark_process_dead(worker.worker_id)

Called after the worker process shuts down.