authentik.root.db.base

authentik database backend

 1"""authentik database backend"""
 2
 3from django.core.checks import Warning
 4from django.db.backends.base.validation import BaseDatabaseValidation
 5from django_tenants.postgresql_backend.base import DatabaseWrapper as BaseDatabaseWrapper
 6
 7from authentik.lib.config import CONFIG
 8
 9
10class DatabaseValidation(BaseDatabaseValidation):
11
12    def check(self, **kwargs):
13        return self._check_encoding()
14
15    def _check_encoding(self):
16        """Throw a warning when the server_encoding is not UTF-8 or
17        server_encoding and client_encoding are mismatched"""
18        messages = []
19        with self.connection.cursor() as cursor:
20            cursor.execute("SHOW server_encoding;")
21            server_encoding = cursor.fetchone()[0]
22            cursor.execute("SHOW client_encoding;")
23            client_encoding = cursor.fetchone()[0]
24            if server_encoding != client_encoding:
25                messages.append(
26                    Warning(
27                        "PostgreSQL Server and Client encoding are mismatched: Server: "
28                        f"{server_encoding}, Client: {client_encoding}",
29                        id="ak.db.W001",
30                    )
31                )
32            if server_encoding != "UTF8":
33                messages.append(
34                    Warning(
35                        f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
36                        id="ak.db.W002",
37                    )
38                )
39        return messages
40
41
42class DatabaseWrapper(BaseDatabaseWrapper):
43    """database backend which supports rotating credentials"""
44
45    validation_class = DatabaseValidation
46
47    def get_connection_params(self):
48        """Refresh DB credentials before getting connection params"""
49        conn_params = super().get_connection_params()
50
51        prefix = "postgresql"
52        if self.alias.startswith("replica_"):
53            prefix = f"postgresql.read_replicas.{self.alias.removeprefix('replica_')}"
54
55        for setting in ("host", "port", "user", "password"):
56            conn_params[setting] = CONFIG.refresh(f"{prefix}.{setting}")
57            if conn_params[setting] is None and self.alias.startswith("replica_"):
58                conn_params[setting] = CONFIG.refresh(f"postgresql.{setting}")
59
60        return conn_params
class DatabaseValidation(django.db.backends.base.validation.BaseDatabaseValidation):
11class DatabaseValidation(BaseDatabaseValidation):
12
13    def check(self, **kwargs):
14        return self._check_encoding()
15
16    def _check_encoding(self):
17        """Throw a warning when the server_encoding is not UTF-8 or
18        server_encoding and client_encoding are mismatched"""
19        messages = []
20        with self.connection.cursor() as cursor:
21            cursor.execute("SHOW server_encoding;")
22            server_encoding = cursor.fetchone()[0]
23            cursor.execute("SHOW client_encoding;")
24            client_encoding = cursor.fetchone()[0]
25            if server_encoding != client_encoding:
26                messages.append(
27                    Warning(
28                        "PostgreSQL Server and Client encoding are mismatched: Server: "
29                        f"{server_encoding}, Client: {client_encoding}",
30                        id="ak.db.W001",
31                    )
32                )
33            if server_encoding != "UTF8":
34                messages.append(
35                    Warning(
36                        f"PostgreSQL Server encoding is not UTF8: {server_encoding}",
37                        id="ak.db.W002",
38                    )
39                )
40        return messages

Encapsulate backend-specific validation.

def check(self, **kwargs):
13    def check(self, **kwargs):
14        return self._check_encoding()
class DatabaseWrapper(django_tenants.postgresql_backend.base.DatabaseWrapper):
43class DatabaseWrapper(BaseDatabaseWrapper):
44    """database backend which supports rotating credentials"""
45
46    validation_class = DatabaseValidation
47
48    def get_connection_params(self):
49        """Refresh DB credentials before getting connection params"""
50        conn_params = super().get_connection_params()
51
52        prefix = "postgresql"
53        if self.alias.startswith("replica_"):
54            prefix = f"postgresql.read_replicas.{self.alias.removeprefix('replica_')}"
55
56        for setting in ("host", "port", "user", "password"):
57            conn_params[setting] = CONFIG.refresh(f"{prefix}.{setting}")
58            if conn_params[setting] is None and self.alias.startswith("replica_"):
59                conn_params[setting] = CONFIG.refresh(f"postgresql.{setting}")
60
61        return conn_params

database backend which supports rotating credentials

validation_class = <class 'DatabaseValidation'>
def get_connection_params(self):
48    def get_connection_params(self):
49        """Refresh DB credentials before getting connection params"""
50        conn_params = super().get_connection_params()
51
52        prefix = "postgresql"
53        if self.alias.startswith("replica_"):
54            prefix = f"postgresql.read_replicas.{self.alias.removeprefix('replica_')}"
55
56        for setting in ("host", "port", "user", "password"):
57            conn_params[setting] = CONFIG.refresh(f"{prefix}.{setting}")
58            if conn_params[setting] is None and self.alias.startswith("replica_"):
59                conn_params[setting] = CONFIG.refresh(f"postgresql.{setting}")
60
61        return conn_params

Refresh DB credentials before getting connection params