authentik.root.test_runner

Integrate ./manage.py test with pytest

  1"""Integrate ./manage.py test with pytest"""
  2
  3import os
  4from argparse import ArgumentParser
  5from unittest import TestCase
  6from unittest.mock import patch
  7
  8import pytest
  9from django.conf import settings
 10from django.contrib.contenttypes.models import ContentType
 11from django.test.runner import DiscoverRunner
 12from structlog.stdlib import get_logger
 13
 14from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
 15from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
 16from authentik.lib.config import CONFIG
 17from authentik.lib.sentry import sentry_init
 18from authentik.root.signals import post_startup, pre_startup, startup
 19from authentik.tasks.test import use_test_broker
 20
 21# globally set maxDiff to none to show full assert error
 22TestCase.maxDiff = None
 23
 24
 25def get_docker_tag() -> str:
 26    """Get docker-tag based off of CI variables"""
 27    env_pr_branch = "GITHUB_HEAD_REF"
 28    default_branch = "GITHUB_REF"
 29    branch_name = os.environ.get(default_branch, "main")
 30    if os.environ.get(env_pr_branch, "") != "":
 31        branch_name = os.environ[env_pr_branch]
 32    branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
 33    return f"gh-{branch_name}"
 34
 35
 36def patched__get_ct_cached(app_label, codename):
 37    """Caches `ContentType` instances like its `QuerySet` does."""
 38    return ContentType.objects.get(app_label=app_label, permission__codename=codename)
 39
 40
 41class PytestTestRunner(DiscoverRunner):  # pragma: no cover
 42    """Runs pytest to discover and run tests."""
 43
 44    def __init__(self, **kwargs):
 45        super().__init__(**kwargs)
 46        self.logger = get_logger().bind(runner="pytest")
 47
 48        self.args = []
 49        if self.failfast:
 50            self.args.append("--exitfirst")
 51        if self.keepdb:
 52            self.args.append("--reuse-db")
 53
 54        if kwargs.get("randomly_seed", None):
 55            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
 56        if kwargs.get("no_capture", False):
 57            self.args.append("--capture=no")
 58
 59        if kwargs.get("count", None):
 60            self.args.append("--flake-finder")
 61            self.args.append(f"--flake-runs={kwargs['count']}")
 62
 63        self._setup_test_environment()
 64
 65    def _setup_test_environment(self):
 66        """Configure test environment settings"""
 67        settings.TEST = True
 68        settings.DRAMATIQ["test"] = True
 69
 70        # Test-specific configuration
 71        test_config = {
 72            "events.context_processors.geoip": "tests/GeoLite2-City-Test.mmdb",
 73            "events.context_processors.asn": "tests/GeoLite2-ASN-Test.mmdb",
 74            "blueprints_dir": "./blueprints",
 75            "outposts.container_image_base": f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
 76            "tenants.enabled": False,
 77            "outposts.disable_embedded_outpost": False,
 78            "error_reporting.sample_rate": 0,
 79            "error_reporting.environment": "testing",
 80            "error_reporting.send_pii": True,
 81        }
 82
 83        for key, value in test_config.items():
 84            CONFIG.set(key, value)
 85
 86        ASN_CONTEXT_PROCESSOR.load()
 87        GEOIP_CONTEXT_PROCESSOR.load()
 88
 89        sentry_init()
 90        self.logger.debug("Test environment configured")
 91
 92        self.task_broker = use_test_broker()
 93
 94        # Send startup signals
 95        pre_startup.send(sender=self, mode="test")
 96        startup.send(sender=self, mode="test")
 97        post_startup.send(sender=self, mode="test")
 98
 99    @classmethod
100    def add_arguments(cls, parser: ArgumentParser):
101        """Add more pytest-specific arguments"""
102        DiscoverRunner.add_arguments(parser)
103        default_seed = None
104        if seed := os.getenv("CI_TEST_SEED"):
105            default_seed = int(seed)
106        parser.add_argument(
107            "--randomly-seed",
108            type=int,
109            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
110            "to reuse the seed from the previous run."
111            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
112            "different on each run.",
113            default=default_seed,
114        )
115        parser.add_argument(
116            "--no-capture",
117            action="store_true",
118            help="Disable any capturing of stdout/stderr during tests.",
119        )
120        parser.add_argument("--count", type=int, help="Re-run selected tests n times")
121
122    def _validate_test_label(self, label: str) -> bool:
123        """Validate test label format"""
124        if not label:
125            return False
126
127        # Check for invalid characters, but allow forward slashes and colons
128        # for paths and pytest markers
129        invalid_chars = set('\\*?"<>|')
130        if any(c in label for c in invalid_chars):
131            self.logger.error("Invalid characters in test label", label=label)
132            return False
133
134        return True
135
136    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
137        """Run pytest and return the exitcode.
138
139        It translates some of Django's test command option to pytest's.
140        It is supported to only run specific classes and methods using
141        a dotted module name i.e. foo.bar[.Class[.method]]
142
143        The extra_tests argument has been deprecated since Django 5.x
144        It is kept for compatibility with PyCharm's Django test runner.
145        """
146        if not test_labels:
147            self.logger.error("No test files specified")
148            return 1
149
150        for label in test_labels:
151            if not self._validate_test_label(label):
152                return 1
153
154            valid_label_found = False
155            label_as_path = os.path.abspath(label)
156
157            # File path has been specified
158            if os.path.exists(label_as_path):
159                self.args.append(label_as_path)
160                valid_label_found = True
161            elif "::" in label:
162                self.args.append(label)
163                valid_label_found = True
164            else:
165                # Check if the label is a dotted module path
166                path_pieces = label.split(".")
167                for i in range(-1, -3, -1):
168                    try:
169                        path = os.path.join(*path_pieces[:i]) + ".py"
170                        if os.path.exists(path):
171                            if i < -1:
172                                path_method = path + "::" + "::".join(path_pieces[i:])
173                                self.args.append(path_method)
174                            else:
175                                self.args.append(path)
176                            valid_label_found = True
177                            break
178                    except TypeError, IndexError:
179                        continue
180
181            if not valid_label_found:
182                self.logger.error("Test file not found", label=label)
183                return 1
184
185        self.logger.info("Running tests", test_files=self.args)
186        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
187            try:
188                ret = pytest.main(self.args)
189                self.task_broker.close()
190                return ret
191            except Exception as exc:  # noqa
192                self.logger.error("Error running tests", exc=exc, test_files=self.args)
193                return 1
def get_docker_tag() -> str:
26def get_docker_tag() -> str:
27    """Get docker-tag based off of CI variables"""
28    env_pr_branch = "GITHUB_HEAD_REF"
29    default_branch = "GITHUB_REF"
30    branch_name = os.environ.get(default_branch, "main")
31    if os.environ.get(env_pr_branch, "") != "":
32        branch_name = os.environ[env_pr_branch]
33    branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
34    return f"gh-{branch_name}"

Get docker-tag based off of CI variables

def patched__get_ct_cached(app_label, codename):
37def patched__get_ct_cached(app_label, codename):
38    """Caches `ContentType` instances like its `QuerySet` does."""
39    return ContentType.objects.get(app_label=app_label, permission__codename=codename)

Caches ContentType instances like its QuerySet does.

class PytestTestRunner(django.test.runner.DiscoverRunner):
 42class PytestTestRunner(DiscoverRunner):  # pragma: no cover
 43    """Runs pytest to discover and run tests."""
 44
 45    def __init__(self, **kwargs):
 46        super().__init__(**kwargs)
 47        self.logger = get_logger().bind(runner="pytest")
 48
 49        self.args = []
 50        if self.failfast:
 51            self.args.append("--exitfirst")
 52        if self.keepdb:
 53            self.args.append("--reuse-db")
 54
 55        if kwargs.get("randomly_seed", None):
 56            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
 57        if kwargs.get("no_capture", False):
 58            self.args.append("--capture=no")
 59
 60        if kwargs.get("count", None):
 61            self.args.append("--flake-finder")
 62            self.args.append(f"--flake-runs={kwargs['count']}")
 63
 64        self._setup_test_environment()
 65
 66    def _setup_test_environment(self):
 67        """Configure test environment settings"""
 68        settings.TEST = True
 69        settings.DRAMATIQ["test"] = True
 70
 71        # Test-specific configuration
 72        test_config = {
 73            "events.context_processors.geoip": "tests/GeoLite2-City-Test.mmdb",
 74            "events.context_processors.asn": "tests/GeoLite2-ASN-Test.mmdb",
 75            "blueprints_dir": "./blueprints",
 76            "outposts.container_image_base": f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
 77            "tenants.enabled": False,
 78            "outposts.disable_embedded_outpost": False,
 79            "error_reporting.sample_rate": 0,
 80            "error_reporting.environment": "testing",
 81            "error_reporting.send_pii": True,
 82        }
 83
 84        for key, value in test_config.items():
 85            CONFIG.set(key, value)
 86
 87        ASN_CONTEXT_PROCESSOR.load()
 88        GEOIP_CONTEXT_PROCESSOR.load()
 89
 90        sentry_init()
 91        self.logger.debug("Test environment configured")
 92
 93        self.task_broker = use_test_broker()
 94
 95        # Send startup signals
 96        pre_startup.send(sender=self, mode="test")
 97        startup.send(sender=self, mode="test")
 98        post_startup.send(sender=self, mode="test")
 99
100    @classmethod
101    def add_arguments(cls, parser: ArgumentParser):
102        """Add more pytest-specific arguments"""
103        DiscoverRunner.add_arguments(parser)
104        default_seed = None
105        if seed := os.getenv("CI_TEST_SEED"):
106            default_seed = int(seed)
107        parser.add_argument(
108            "--randomly-seed",
109            type=int,
110            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
111            "to reuse the seed from the previous run."
112            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
113            "different on each run.",
114            default=default_seed,
115        )
116        parser.add_argument(
117            "--no-capture",
118            action="store_true",
119            help="Disable any capturing of stdout/stderr during tests.",
120        )
121        parser.add_argument("--count", type=int, help="Re-run selected tests n times")
122
123    def _validate_test_label(self, label: str) -> bool:
124        """Validate test label format"""
125        if not label:
126            return False
127
128        # Check for invalid characters, but allow forward slashes and colons
129        # for paths and pytest markers
130        invalid_chars = set('\\*?"<>|')
131        if any(c in label for c in invalid_chars):
132            self.logger.error("Invalid characters in test label", label=label)
133            return False
134
135        return True
136
137    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
138        """Run pytest and return the exitcode.
139
140        It translates some of Django's test command option to pytest's.
141        It is supported to only run specific classes and methods using
142        a dotted module name i.e. foo.bar[.Class[.method]]
143
144        The extra_tests argument has been deprecated since Django 5.x
145        It is kept for compatibility with PyCharm's Django test runner.
146        """
147        if not test_labels:
148            self.logger.error("No test files specified")
149            return 1
150
151        for label in test_labels:
152            if not self._validate_test_label(label):
153                return 1
154
155            valid_label_found = False
156            label_as_path = os.path.abspath(label)
157
158            # File path has been specified
159            if os.path.exists(label_as_path):
160                self.args.append(label_as_path)
161                valid_label_found = True
162            elif "::" in label:
163                self.args.append(label)
164                valid_label_found = True
165            else:
166                # Check if the label is a dotted module path
167                path_pieces = label.split(".")
168                for i in range(-1, -3, -1):
169                    try:
170                        path = os.path.join(*path_pieces[:i]) + ".py"
171                        if os.path.exists(path):
172                            if i < -1:
173                                path_method = path + "::" + "::".join(path_pieces[i:])
174                                self.args.append(path_method)
175                            else:
176                                self.args.append(path)
177                            valid_label_found = True
178                            break
179                    except TypeError, IndexError:
180                        continue
181
182            if not valid_label_found:
183                self.logger.error("Test file not found", label=label)
184                return 1
185
186        self.logger.info("Running tests", test_files=self.args)
187        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
188            try:
189                ret = pytest.main(self.args)
190                self.task_broker.close()
191                return ret
192            except Exception as exc:  # noqa
193                self.logger.error("Error running tests", exc=exc, test_files=self.args)
194                return 1

Runs pytest to discover and run tests.

PytestTestRunner(**kwargs)
45    def __init__(self, **kwargs):
46        super().__init__(**kwargs)
47        self.logger = get_logger().bind(runner="pytest")
48
49        self.args = []
50        if self.failfast:
51            self.args.append("--exitfirst")
52        if self.keepdb:
53            self.args.append("--reuse-db")
54
55        if kwargs.get("randomly_seed", None):
56            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
57        if kwargs.get("no_capture", False):
58            self.args.append("--capture=no")
59
60        if kwargs.get("count", None):
61            self.args.append("--flake-finder")
62            self.args.append(f"--flake-runs={kwargs['count']}")
63
64        self._setup_test_environment()
logger
args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
100    @classmethod
101    def add_arguments(cls, parser: ArgumentParser):
102        """Add more pytest-specific arguments"""
103        DiscoverRunner.add_arguments(parser)
104        default_seed = None
105        if seed := os.getenv("CI_TEST_SEED"):
106            default_seed = int(seed)
107        parser.add_argument(
108            "--randomly-seed",
109            type=int,
110            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
111            "to reuse the seed from the previous run."
112            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
113            "different on each run.",
114            default=default_seed,
115        )
116        parser.add_argument(
117            "--no-capture",
118            action="store_true",
119            help="Disable any capturing of stdout/stderr during tests.",
120        )
121        parser.add_argument("--count", type=int, help="Re-run selected tests n times")

Add more pytest-specific arguments

def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
137    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
138        """Run pytest and return the exitcode.
139
140        It translates some of Django's test command option to pytest's.
141        It is supported to only run specific classes and methods using
142        a dotted module name i.e. foo.bar[.Class[.method]]
143
144        The extra_tests argument has been deprecated since Django 5.x
145        It is kept for compatibility with PyCharm's Django test runner.
146        """
147        if not test_labels:
148            self.logger.error("No test files specified")
149            return 1
150
151        for label in test_labels:
152            if not self._validate_test_label(label):
153                return 1
154
155            valid_label_found = False
156            label_as_path = os.path.abspath(label)
157
158            # File path has been specified
159            if os.path.exists(label_as_path):
160                self.args.append(label_as_path)
161                valid_label_found = True
162            elif "::" in label:
163                self.args.append(label)
164                valid_label_found = True
165            else:
166                # Check if the label is a dotted module path
167                path_pieces = label.split(".")
168                for i in range(-1, -3, -1):
169                    try:
170                        path = os.path.join(*path_pieces[:i]) + ".py"
171                        if os.path.exists(path):
172                            if i < -1:
173                                path_method = path + "::" + "::".join(path_pieces[i:])
174                                self.args.append(path_method)
175                            else:
176                                self.args.append(path)
177                            valid_label_found = True
178                            break
179                    except TypeError, IndexError:
180                        continue
181
182            if not valid_label_found:
183                self.logger.error("Test file not found", label=label)
184                return 1
185
186        self.logger.info("Running tests", test_files=self.args)
187        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
188            try:
189                ret = pytest.main(self.args)
190                self.task_broker.close()
191                return ret
192            except Exception as exc:  # noqa
193                self.logger.error("Error running tests", exc=exc, test_files=self.args)
194                return 1

Run pytest and return the exitcode.

It translates some of Django's test command option to pytest's. It is supported to only run specific classes and methods using a dotted module name i.e. foo.bar[.Class[.method]]

The extra_tests argument has been deprecated since Django 5.x It is kept for compatibility with PyCharm's Django test runner.