authentik.tasks.middleware

  1import socket
  2from collections.abc import Callable
  3from http.server import BaseHTTPRequestHandler
  4from threading import Event as TEvent
  5from threading import Thread, current_thread
  6from typing import Any, cast
  7
  8import pglock
  9from django.db import OperationalError, connections, transaction
 10from django.utils.timezone import now
 11from django_dramatiq_postgres.middleware import (
 12    CurrentTask as BaseCurrentTask,
 13)
 14from django_dramatiq_postgres.middleware import (
 15    HTTPServer,
 16    HTTPServerThread,
 17)
 18from django_dramatiq_postgres.middleware import (
 19    MetricsMiddleware as BaseMetricsMiddleware,
 20)
 21from django_dramatiq_postgres.middleware import (
 22    _MetricsHandler as BaseMetricsHandler,
 23)
 24from dramatiq import Worker
 25from dramatiq.broker import Broker
 26from dramatiq.message import Message
 27from dramatiq.middleware import Middleware
 28from psycopg.errors import Error
 29from setproctitle import setthreadtitle
 30from structlog.stdlib import get_logger
 31
 32from authentik import authentik_full_version
 33from authentik.events.models import Event, EventAction
 34from authentik.lib.config import CONFIG
 35from authentik.lib.sentry import should_ignore_exception
 36from authentik.lib.utils.reflection import class_to_path
 37from authentik.root.monitoring import monitoring_set
 38from authentik.root.signals import post_startup, pre_startup, startup
 39from authentik.tasks.models import Task, TaskLog, TaskStatus, WorkerStatus
 40from authentik.tenants.models import Tenant
 41from authentik.tenants.utils import get_current_tenant
 42
 43LOGGER = get_logger()
 44HEALTHCHECK_LOGGER = get_logger("authentik.worker").bind()
 45DB_ERRORS = (OperationalError, Error)
 46
 47
 48class StartupSignalsMiddleware(Middleware):
 49    def after_process_boot(self, broker: Broker):
 50        _startup_sender = type("WorkerStartup", (object,), {})
 51        pre_startup.send(sender=_startup_sender)
 52        startup.send(sender=_startup_sender)
 53        post_startup.send(sender=_startup_sender)
 54
 55
 56class CurrentTask(BaseCurrentTask):
 57    @classmethod
 58    def get_task(cls) -> Task:
 59        return cast(Task, super().get_task())
 60
 61
 62class TenantMiddleware(Middleware):
 63    def before_enqueue(self, broker: Broker, message: Message, delay: int):
 64        message.options["model_create_defaults"]["tenant"] = get_current_tenant()
 65
 66    def before_process_message(self, broker: Broker, message: Message):
 67        task: Task = message.options["task"]
 68        task.tenant.activate()
 69
 70    def after_process_message(self, *args, **kwargs):
 71        Tenant.deactivate()
 72
 73    after_skip_message = after_process_message
 74
 75
 76class ModelDataMiddleware(Middleware):
 77    @property
 78    def actor_options(self):
 79        return {"rel_obj", "uid"}
 80
 81    def before_enqueue(self, broker: Broker, message: Message, delay: int):
 82        if "rel_obj" in message.options:
 83            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
 84        if "uid" in message.options:
 85            message.options["model_defaults"]["_uid"] = message.options.pop("uid")
 86
 87
 88class TaskLogMiddleware(Middleware):
 89    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
 90        task: Task = message.options["task"]
 91        task_created: bool = message.options["task_created"]
 92        if task_created:
 93            TaskLog.create_from_log_event(
 94                task,
 95                Task._make_log(
 96                    class_to_path(type(self)),
 97                    TaskStatus.INFO,
 98                    "Task has been queued",
 99                    delay=delay,
100                ),
101            )
102        else:
103            TaskLog.objects.filter(task=task).update(previous=True)
104            TaskLog.create_from_log_event(
105                task,
106                Task._make_log(
107                    class_to_path(type(self)),
108                    TaskStatus.INFO,
109                    "Task will be retried",
110                    delay=delay,
111                ),
112            )
113
114    def before_process_message(self, broker: Broker, message: Message):
115        task: Task = message.options["task"]
116        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")
117
118    def after_process_message(
119        self,
120        broker: Broker,
121        message: Message,
122        *,
123        result: Any | None = None,
124        exception: Exception | None = None,
125    ):
126        task: Task = message.options["task"]
127        if exception is None:
128            task.log(
129                class_to_path(type(self)),
130                TaskStatus.INFO,
131                "Task finished processing without errors",
132            )
133            return
134        task.log(
135            class_to_path(type(self)),
136            TaskStatus.ERROR,
137            exception,
138        )
139        if should_ignore_exception(exception):
140            return
141        event_kwargs = {
142            "actor": task.actor_name,
143        }
144        if task.rel_obj:
145            event_kwargs["rel_obj"] = task.rel_obj
146        Event.new(
147            EventAction.SYSTEM_TASK_EXCEPTION,
148            message=f"Task {task.actor_name} encountered an error",
149            **event_kwargs,
150        ).with_exception(exception).save()
151
152    def after_skip_message(self, broker: Broker, message: Message):
153        task: Task = message.options["task"]
154        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")
155
156
157class LoggingMiddleware(Middleware):
158    def __init__(self):
159        self.logger = get_logger()
160
161    def after_enqueue(self, broker: Broker, message: Message, delay: int):
162        self.logger.info(
163            "Task enqueued",
164            task_id=message.message_id,
165            task_name=message.actor_name,
166        )
167
168    def before_process_message(self, broker: Broker, message: Message):
169        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)
170
171    def after_process_message(
172        self,
173        broker: Broker,
174        message: Message,
175        *,
176        result: Any | None = None,
177        exception: Exception | None = None,
178    ):
179        self.logger.info(
180            "Task finished",
181            task_id=message.message_id,
182            task_name=message.actor_name,
183            exc=exception,
184        )
185
186    def after_skip_message(self, broker: Broker, message: Message):
187        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)
188
189
190class DescriptionMiddleware(Middleware):
191    @property
192    def actor_options(self):
193        return {"description"}
194
195
196class _healthcheck_handler(BaseHTTPRequestHandler):
197    def log_request(self, code="-", size="-"):
198        HEALTHCHECK_LOGGER.info(
199            self.path,
200            method=self.command,
201            status=code,
202        )
203
204    def log_error(self, format, *args):
205        HEALTHCHECK_LOGGER.warning(format, *args)
206
207    def do_HEAD(self):
208        try:
209            for db_conn in connections.all():
210                # Force connection reload
211                db_conn.connect()
212                _ = db_conn.cursor()
213            self.send_response(200)
214        except DB_ERRORS:  # pragma: no cover
215            self.send_response(503)
216        self.send_header("Content-Type", "text/plain; charset=utf-8")
217        self.send_header("Content-Length", "0")
218        self.end_headers()
219
220    do_GET = do_HEAD
221
222
223class WorkerHealthcheckMiddleware(Middleware):
224    thread: HTTPServerThread | None
225
226    def __init__(self):
227        listen = CONFIG.get("listen.http", ["[::]:9000"])
228        if isinstance(listen, str):
229            listen = listen.split(",")
230        host, _, port = listen[0].rpartition(":")
231
232        try:
233            port = int(port)
234        except ValueError:
235            LOGGER.error(f"Invalid port entered: {port}")
236
237        self.host, self.port = host, port
238
239    def after_worker_boot(self, broker: Broker, worker: Worker):
240        self.thread = HTTPServerThread(
241            target=WorkerHealthcheckMiddleware.run, args=(self.host, self.port)
242        )
243        self.thread.start()
244
245    def before_worker_shutdown(self, broker: Broker, worker: Worker):
246        server = self.thread.server
247        if server:
248            server.shutdown()
249        LOGGER.debug("Stopping WorkerHealthcheckMiddleware")
250        self.thread.join()
251
252    @staticmethod
253    def run(addr: str, port: int):
254        setthreadtitle("authentik Worker Healthcheck server")
255        try:
256            server = HTTPServer((addr, port), _healthcheck_handler)
257            thread = cast(HTTPServerThread, current_thread())
258            thread.server = server
259            server.serve_forever()
260        except OSError as exc:
261            get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning(
262                "Port is already in use, not starting healthcheck server",
263                exc=exc,
264            )
265
266
267class WorkerStatusMiddleware(Middleware):
268    thread: Thread | None
269    thread_event: TEvent | None
270
271    def after_worker_boot(self, broker: Broker, worker: Worker):
272        self.thread_event = TEvent()
273        self.thread = Thread(target=WorkerStatusMiddleware.run, args=(self.thread_event,))
274        self.thread.start()
275
276    def before_worker_shutdown(self, broker: Broker, worker: Worker):
277        self.thread_event.set()
278        LOGGER.debug("Stopping WorkerStatusMiddleware")
279        self.thread.join()
280
281    @staticmethod
282    def run(event: TEvent):
283        setthreadtitle("authentik Worker status")
284        with transaction.atomic():
285            hostname = socket.gethostname()
286            WorkerStatus.objects.filter(hostname=hostname).delete()
287            status, _ = WorkerStatus.objects.update_or_create(
288                hostname=hostname,
289                version=authentik_full_version(),
290            )
291        while not event.is_set():
292            try:
293                WorkerStatusMiddleware.keep(event, status)
294            except DB_ERRORS:  # pragma: no cover
295                event.wait(10)
296                try:
297                    connections.close_all()
298                except DB_ERRORS:
299                    pass
300
301    @staticmethod
302    def keep(event: TEvent, status: WorkerStatus):
303        lock_id = f"goauthentik.io/worker/status/{status.pk}"
304        with pglock.advisory(lock_id, side_effect=pglock.Raise):
305            while not event.is_set():
306                status.refresh_from_db()
307                old_last_seen = status.last_seen
308                status.last_seen = now()
309                if old_last_seen != status.last_seen:
310                    status.save(update_fields=("last_seen",))
311                event.wait(30)
312
313
314class _MetricsHandler(BaseMetricsHandler):
315    def do_GET(self) -> None:
316        monitoring_set.send_robust(self)
317        return super().do_GET()
318
319
320class MetricsMiddleware(BaseMetricsMiddleware):
321    thread: HTTPServerThread | None
322    handler_class = _MetricsHandler
323
324    @property
325    def forks(self) -> list[Callable[[], None]]:
326        return []
327
328    def after_worker_boot(self, broker: Broker, worker: Worker):
329        listen = CONFIG.get("listen.metrics", ["[::]:9300"])
330        if isinstance(listen, str):
331            listen = listen.split(",")
332        addr, _, port = listen[0].rpartition(":")
333
334        try:
335            port = int(port)
336        except ValueError:
337            LOGGER.error(f"Invalid port entered: {port}")
338        self.thread = HTTPServerThread(target=MetricsMiddleware.run, args=(addr, port))
339        self.thread.start()
340
341    def before_worker_shutdown(self, broker: Broker, worker: Worker):
342        server = self.thread.server
343        if server:
344            server.shutdown()
345        LOGGER.debug("Stopping MetricsMiddleware")
346        self.thread.join()
LOGGER = <BoundLoggerLazyProxy(logger=None, wrapper_class=None, processors=None, context_class=None, initial_values={}, logger_factory_args=())>
HEALTHCHECK_LOGGER = <BoundLoggerFilteringAtDebug(context={}, processors=[<function add_log_level>, <function add_logger_name>, <function merge_contextvars>, <function add_process_id>, <function add_tenant_information>, <structlog.stdlib.PositionalArgumentsFormatter object>, <structlog.processors.TimeStamper object>, <structlog.processors.StackInfoRenderer object>, <structlog.processors.ExceptionRenderer object>, <function ProcessorFormatter.wrap_for_formatter>])>
DB_ERRORS = (<class 'django.db.utils.OperationalError'>, <class 'psycopg.Error'>)
class StartupSignalsMiddleware(dramatiq.middleware.middleware.Middleware):
49class StartupSignalsMiddleware(Middleware):
50    def after_process_boot(self, broker: Broker):
51        _startup_sender = type("WorkerStartup", (object,), {})
52        pre_startup.send(sender=_startup_sender)
53        startup.send(sender=_startup_sender)
54        post_startup.send(sender=_startup_sender)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def after_process_boot(self, broker: dramatiq.broker.Broker):
50    def after_process_boot(self, broker: Broker):
51        _startup_sender = type("WorkerStartup", (object,), {})
52        pre_startup.send(sender=_startup_sender)
53        startup.send(sender=_startup_sender)
54        post_startup.send(sender=_startup_sender)

Called immediately after subprocess start up.

class CurrentTask(django_dramatiq_postgres.middleware.CurrentTask):
57class CurrentTask(BaseCurrentTask):
58    @classmethod
59    def get_task(cls) -> Task:
60        return cast(Task, super().get_task())

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

@classmethod
def get_task(cls) -> authentik.tasks.models.Task:
58    @classmethod
59    def get_task(cls) -> Task:
60        return cast(Task, super().get_task())
class TenantMiddleware(dramatiq.middleware.middleware.Middleware):
63class TenantMiddleware(Middleware):
64    def before_enqueue(self, broker: Broker, message: Message, delay: int):
65        message.options["model_create_defaults"]["tenant"] = get_current_tenant()
66
67    def before_process_message(self, broker: Broker, message: Message):
68        task: Task = message.options["task"]
69        task.tenant.activate()
70
71    def after_process_message(self, *args, **kwargs):
72        Tenant.deactivate()
73
74    after_skip_message = after_process_message

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def before_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
64    def before_enqueue(self, broker: Broker, message: Message, delay: int):
65        message.options["model_create_defaults"]["tenant"] = get_current_tenant()

Called before a message is enqueued.

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
67    def before_process_message(self, broker: Broker, message: Message):
68        task: Task = message.options["task"]
69        task.tenant.activate()

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message(self, *args, **kwargs):
71    def after_process_message(self, *args, **kwargs):
72        Tenant.deactivate()

Called after a message has been processed.

def after_skip_message(self, *args, **kwargs):
71    def after_process_message(self, *args, **kwargs):
72        Tenant.deactivate()

Called after a message has been processed.

class ModelDataMiddleware(dramatiq.middleware.middleware.Middleware):
77class ModelDataMiddleware(Middleware):
78    @property
79    def actor_options(self):
80        return {"rel_obj", "uid"}
81
82    def before_enqueue(self, broker: Broker, message: Message, delay: int):
83        if "rel_obj" in message.options:
84            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
85        if "uid" in message.options:
86            message.options["model_defaults"]["_uid"] = message.options.pop("uid")

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

actor_options
78    @property
79    def actor_options(self):
80        return {"rel_obj", "uid"}

The set of options that may be configured on each actor.

def before_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
82    def before_enqueue(self, broker: Broker, message: Message, delay: int):
83        if "rel_obj" in message.options:
84            message.options["model_defaults"]["rel_obj"] = message.options.pop("rel_obj")
85        if "uid" in message.options:
86            message.options["model_defaults"]["_uid"] = message.options.pop("uid")

Called before a message is enqueued.

class TaskLogMiddleware(dramatiq.middleware.middleware.Middleware):
 89class TaskLogMiddleware(Middleware):
 90    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
 91        task: Task = message.options["task"]
 92        task_created: bool = message.options["task_created"]
 93        if task_created:
 94            TaskLog.create_from_log_event(
 95                task,
 96                Task._make_log(
 97                    class_to_path(type(self)),
 98                    TaskStatus.INFO,
 99                    "Task has been queued",
100                    delay=delay,
101                ),
102            )
103        else:
104            TaskLog.objects.filter(task=task).update(previous=True)
105            TaskLog.create_from_log_event(
106                task,
107                Task._make_log(
108                    class_to_path(type(self)),
109                    TaskStatus.INFO,
110                    "Task will be retried",
111                    delay=delay,
112                ),
113            )
114
115    def before_process_message(self, broker: Broker, message: Message):
116        task: Task = message.options["task"]
117        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")
118
119    def after_process_message(
120        self,
121        broker: Broker,
122        message: Message,
123        *,
124        result: Any | None = None,
125        exception: Exception | None = None,
126    ):
127        task: Task = message.options["task"]
128        if exception is None:
129            task.log(
130                class_to_path(type(self)),
131                TaskStatus.INFO,
132                "Task finished processing without errors",
133            )
134            return
135        task.log(
136            class_to_path(type(self)),
137            TaskStatus.ERROR,
138            exception,
139        )
140        if should_ignore_exception(exception):
141            return
142        event_kwargs = {
143            "actor": task.actor_name,
144        }
145        if task.rel_obj:
146            event_kwargs["rel_obj"] = task.rel_obj
147        Event.new(
148            EventAction.SYSTEM_TASK_EXCEPTION,
149            message=f"Task {task.actor_name} encountered an error",
150            **event_kwargs,
151        ).with_exception(exception).save()
152
153    def after_skip_message(self, broker: Broker, message: Message):
154        task: Task = message.options["task"]
155        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

def after_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int | None):
 90    def after_enqueue(self, broker: Broker, message: Message, delay: int | None):
 91        task: Task = message.options["task"]
 92        task_created: bool = message.options["task_created"]
 93        if task_created:
 94            TaskLog.create_from_log_event(
 95                task,
 96                Task._make_log(
 97                    class_to_path(type(self)),
 98                    TaskStatus.INFO,
 99                    "Task has been queued",
100                    delay=delay,
101                ),
102            )
103        else:
104            TaskLog.objects.filter(task=task).update(previous=True)
105            TaskLog.create_from_log_event(
106                task,
107                Task._make_log(
108                    class_to_path(type(self)),
109                    TaskStatus.INFO,
110                    "Task will be retried",
111                    delay=delay,
112                ),
113            )

Called after a message has been enqueued.

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
115    def before_process_message(self, broker: Broker, message: Message):
116        task: Task = message.options["task"]
117        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task is being processed")

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, *, result: Any | None = None, exception: Exception | None = None):
119    def after_process_message(
120        self,
121        broker: Broker,
122        message: Message,
123        *,
124        result: Any | None = None,
125        exception: Exception | None = None,
126    ):
127        task: Task = message.options["task"]
128        if exception is None:
129            task.log(
130                class_to_path(type(self)),
131                TaskStatus.INFO,
132                "Task finished processing without errors",
133            )
134            return
135        task.log(
136            class_to_path(type(self)),
137            TaskStatus.ERROR,
138            exception,
139        )
140        if should_ignore_exception(exception):
141            return
142        event_kwargs = {
143            "actor": task.actor_name,
144        }
145        if task.rel_obj:
146            event_kwargs["rel_obj"] = task.rel_obj
147        Event.new(
148            EventAction.SYSTEM_TASK_EXCEPTION,
149            message=f"Task {task.actor_name} encountered an error",
150            **event_kwargs,
151        ).with_exception(exception).save()

Called after a message has been processed.

def after_skip_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
153    def after_skip_message(self, broker: Broker, message: Message):
154        task: Task = message.options["task"]
155        task.log(class_to_path(type(self)), TaskStatus.INFO, "Task has been skipped")

Called instead of after_process_message after a message has been skipped.

class LoggingMiddleware(dramatiq.middleware.middleware.Middleware):
158class LoggingMiddleware(Middleware):
159    def __init__(self):
160        self.logger = get_logger()
161
162    def after_enqueue(self, broker: Broker, message: Message, delay: int):
163        self.logger.info(
164            "Task enqueued",
165            task_id=message.message_id,
166            task_name=message.actor_name,
167        )
168
169    def before_process_message(self, broker: Broker, message: Message):
170        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)
171
172    def after_process_message(
173        self,
174        broker: Broker,
175        message: Message,
176        *,
177        result: Any | None = None,
178        exception: Exception | None = None,
179    ):
180        self.logger.info(
181            "Task finished",
182            task_id=message.message_id,
183            task_name=message.actor_name,
184            exc=exception,
185        )
186
187    def after_skip_message(self, broker: Broker, message: Message):
188        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

logger
def after_enqueue( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, delay: int):
162    def after_enqueue(self, broker: Broker, message: Message, delay: int):
163        self.logger.info(
164            "Task enqueued",
165            task_id=message.message_id,
166            task_name=message.actor_name,
167        )

Called after a message has been enqueued.

def before_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
169    def before_process_message(self, broker: Broker, message: Message):
170        self.logger.info("Task started", task_id=message.message_id, task_name=message.actor_name)

Called before a message is processed.

Raises: SkipMessage: If the current message should be skipped. When this is raised, after_skip_message is emitted instead of after_process_message.

def after_process_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message, *, result: Any | None = None, exception: Exception | None = None):
172    def after_process_message(
173        self,
174        broker: Broker,
175        message: Message,
176        *,
177        result: Any | None = None,
178        exception: Exception | None = None,
179    ):
180        self.logger.info(
181            "Task finished",
182            task_id=message.message_id,
183            task_name=message.actor_name,
184            exc=exception,
185        )

Called after a message has been processed.

def after_skip_message( self, broker: dramatiq.broker.Broker, message: dramatiq.message.Message):
187    def after_skip_message(self, broker: Broker, message: Message):
188        self.logger.info("Task skipped", task_id=message.message_id, task_name=message.actor_name)

Called instead of after_process_message after a message has been skipped.

class DescriptionMiddleware(dramatiq.middleware.middleware.Middleware):
191class DescriptionMiddleware(Middleware):
192    @property
193    def actor_options(self):
194        return {"description"}

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

actor_options
192    @property
193    def actor_options(self):
194        return {"description"}

The set of options that may be configured on each actor.

class WorkerHealthcheckMiddleware(dramatiq.middleware.middleware.Middleware):
224class WorkerHealthcheckMiddleware(Middleware):
225    thread: HTTPServerThread | None
226
227    def __init__(self):
228        listen = CONFIG.get("listen.http", ["[::]:9000"])
229        if isinstance(listen, str):
230            listen = listen.split(",")
231        host, _, port = listen[0].rpartition(":")
232
233        try:
234            port = int(port)
235        except ValueError:
236            LOGGER.error(f"Invalid port entered: {port}")
237
238        self.host, self.port = host, port
239
240    def after_worker_boot(self, broker: Broker, worker: Worker):
241        self.thread = HTTPServerThread(
242            target=WorkerHealthcheckMiddleware.run, args=(self.host, self.port)
243        )
244        self.thread.start()
245
246    def before_worker_shutdown(self, broker: Broker, worker: Worker):
247        server = self.thread.server
248        if server:
249            server.shutdown()
250        LOGGER.debug("Stopping WorkerHealthcheckMiddleware")
251        self.thread.join()
252
253    @staticmethod
254    def run(addr: str, port: int):
255        setthreadtitle("authentik Worker Healthcheck server")
256        try:
257            server = HTTPServer((addr, port), _healthcheck_handler)
258            thread = cast(HTTPServerThread, current_thread())
259            thread.server = server
260            server.serve_forever()
261        except OSError as exc:
262            get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning(
263                "Port is already in use, not starting healthcheck server",
264                exc=exc,
265            )

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

thread: django_dramatiq_postgres.middleware.HTTPServerThread | None
def after_worker_boot(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
240    def after_worker_boot(self, broker: Broker, worker: Worker):
241        self.thread = HTTPServerThread(
242            target=WorkerHealthcheckMiddleware.run, args=(self.host, self.port)
243        )
244        self.thread.start()

Called after the worker process has started up.

def before_worker_shutdown(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
246    def before_worker_shutdown(self, broker: Broker, worker: Worker):
247        server = self.thread.server
248        if server:
249            server.shutdown()
250        LOGGER.debug("Stopping WorkerHealthcheckMiddleware")
251        self.thread.join()

Called before the worker process shuts down.

@staticmethod
def run(addr: str, port: int):
253    @staticmethod
254    def run(addr: str, port: int):
255        setthreadtitle("authentik Worker Healthcheck server")
256        try:
257            server = HTTPServer((addr, port), _healthcheck_handler)
258            thread = cast(HTTPServerThread, current_thread())
259            thread.server = server
260            server.serve_forever()
261        except OSError as exc:
262            get_logger(__name__, type(WorkerHealthcheckMiddleware)).warning(
263                "Port is already in use, not starting healthcheck server",
264                exc=exc,
265            )
class WorkerStatusMiddleware(dramatiq.middleware.middleware.Middleware):
268class WorkerStatusMiddleware(Middleware):
269    thread: Thread | None
270    thread_event: TEvent | None
271
272    def after_worker_boot(self, broker: Broker, worker: Worker):
273        self.thread_event = TEvent()
274        self.thread = Thread(target=WorkerStatusMiddleware.run, args=(self.thread_event,))
275        self.thread.start()
276
277    def before_worker_shutdown(self, broker: Broker, worker: Worker):
278        self.thread_event.set()
279        LOGGER.debug("Stopping WorkerStatusMiddleware")
280        self.thread.join()
281
282    @staticmethod
283    def run(event: TEvent):
284        setthreadtitle("authentik Worker status")
285        with transaction.atomic():
286            hostname = socket.gethostname()
287            WorkerStatus.objects.filter(hostname=hostname).delete()
288            status, _ = WorkerStatus.objects.update_or_create(
289                hostname=hostname,
290                version=authentik_full_version(),
291            )
292        while not event.is_set():
293            try:
294                WorkerStatusMiddleware.keep(event, status)
295            except DB_ERRORS:  # pragma: no cover
296                event.wait(10)
297                try:
298                    connections.close_all()
299                except DB_ERRORS:
300                    pass
301
302    @staticmethod
303    def keep(event: TEvent, status: WorkerStatus):
304        lock_id = f"goauthentik.io/worker/status/{status.pk}"
305        with pglock.advisory(lock_id, side_effect=pglock.Raise):
306            while not event.is_set():
307                status.refresh_from_db()
308                old_last_seen = status.last_seen
309                status.last_seen = now()
310                if old_last_seen != status.last_seen:
311                    status.save(update_fields=("last_seen",))
312                event.wait(30)

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

thread: threading.Thread | None
thread_event: threading.Event | None
def after_worker_boot(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
272    def after_worker_boot(self, broker: Broker, worker: Worker):
273        self.thread_event = TEvent()
274        self.thread = Thread(target=WorkerStatusMiddleware.run, args=(self.thread_event,))
275        self.thread.start()

Called after the worker process has started up.

def before_worker_shutdown(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
277    def before_worker_shutdown(self, broker: Broker, worker: Worker):
278        self.thread_event.set()
279        LOGGER.debug("Stopping WorkerStatusMiddleware")
280        self.thread.join()

Called before the worker process shuts down.

@staticmethod
def run(event: threading.Event):
282    @staticmethod
283    def run(event: TEvent):
284        setthreadtitle("authentik Worker status")
285        with transaction.atomic():
286            hostname = socket.gethostname()
287            WorkerStatus.objects.filter(hostname=hostname).delete()
288            status, _ = WorkerStatus.objects.update_or_create(
289                hostname=hostname,
290                version=authentik_full_version(),
291            )
292        while not event.is_set():
293            try:
294                WorkerStatusMiddleware.keep(event, status)
295            except DB_ERRORS:  # pragma: no cover
296                event.wait(10)
297                try:
298                    connections.close_all()
299                except DB_ERRORS:
300                    pass
@staticmethod
def keep(event: threading.Event, status: authentik.tasks.models.WorkerStatus):
302    @staticmethod
303    def keep(event: TEvent, status: WorkerStatus):
304        lock_id = f"goauthentik.io/worker/status/{status.pk}"
305        with pglock.advisory(lock_id, side_effect=pglock.Raise):
306            while not event.is_set():
307                status.refresh_from_db()
308                old_last_seen = status.last_seen
309                status.last_seen = now()
310                if old_last_seen != status.last_seen:
311                    status.save(update_fields=("last_seen",))
312                event.wait(30)
class MetricsMiddleware(django_dramatiq_postgres.middleware.MetricsMiddleware):
321class MetricsMiddleware(BaseMetricsMiddleware):
322    thread: HTTPServerThread | None
323    handler_class = _MetricsHandler
324
325    @property
326    def forks(self) -> list[Callable[[], None]]:
327        return []
328
329    def after_worker_boot(self, broker: Broker, worker: Worker):
330        listen = CONFIG.get("listen.metrics", ["[::]:9300"])
331        if isinstance(listen, str):
332            listen = listen.split(",")
333        addr, _, port = listen[0].rpartition(":")
334
335        try:
336            port = int(port)
337        except ValueError:
338            LOGGER.error(f"Invalid port entered: {port}")
339        self.thread = HTTPServerThread(target=MetricsMiddleware.run, args=(addr, port))
340        self.thread.start()
341
342    def before_worker_shutdown(self, broker: Broker, worker: Worker):
343        server = self.thread.server
344        if server:
345            server.shutdown()
346        LOGGER.debug("Stopping MetricsMiddleware")
347        self.thread.join()

Base class for broker middleware. The default implementations for all hooks are no-ops and subclasses may implement whatever subset of hooks they like.

thread: django_dramatiq_postgres.middleware.HTTPServerThread | None
handler_class = <class 'authentik.tasks.middleware._MetricsHandler'>
forks: list[Callable[[], None]]
325    @property
326    def forks(self) -> list[Callable[[], None]]:
327        return []

A list of functions to run in separate forks of the main process.

def after_worker_boot(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
329    def after_worker_boot(self, broker: Broker, worker: Worker):
330        listen = CONFIG.get("listen.metrics", ["[::]:9300"])
331        if isinstance(listen, str):
332            listen = listen.split(",")
333        addr, _, port = listen[0].rpartition(":")
334
335        try:
336            port = int(port)
337        except ValueError:
338            LOGGER.error(f"Invalid port entered: {port}")
339        self.thread = HTTPServerThread(target=MetricsMiddleware.run, args=(addr, port))
340        self.thread.start()

Called after the worker process has started up.

def before_worker_shutdown(self, broker: dramatiq.broker.Broker, worker: dramatiq.worker.Worker):
342    def before_worker_shutdown(self, broker: Broker, worker: Worker):
343        server = self.thread.server
344        if server:
345            server.shutdown()
346        LOGGER.debug("Stopping MetricsMiddleware")
347        self.thread.join()

Called before the worker process shuts down.