authentik.root.test_runner

Integrate ./manage.py test with pytest

  1"""Integrate ./manage.py test with pytest"""
  2
  3import os
  4from argparse import ArgumentParser
  5from unittest import TestCase
  6from unittest.mock import patch
  7
  8import freezegun
  9import pytest
 10from django.conf import settings
 11from django.contrib.contenttypes.models import ContentType
 12from django.test.runner import DiscoverRunner
 13from structlog.stdlib import get_logger
 14
 15from authentik.events.context_processors.asn import ASN_CONTEXT_PROCESSOR
 16from authentik.events.context_processors.geoip import GEOIP_CONTEXT_PROCESSOR
 17from authentik.lib.config import CONFIG
 18from authentik.lib.sentry import sentry_init
 19from authentik.root.signals import post_startup, pre_startup, startup
 20from authentik.tasks.test import use_test_broker
 21
 22# globally set maxDiff to none to show full assert error
 23TestCase.maxDiff = None
 24
 25
 26def get_docker_tag() -> str:
 27    """Get docker-tag based off of CI variables"""
 28    env_pr_branch = "GITHUB_HEAD_REF"
 29    default_branch = "GITHUB_REF"
 30    branch_name = os.environ.get(default_branch, "main")
 31    if os.environ.get(env_pr_branch, "") != "":
 32        branch_name = os.environ[env_pr_branch]
 33    branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
 34    return f"gh-{branch_name}"
 35
 36
 37def patched__get_ct_cached(app_label, codename):
 38    """Caches `ContentType` instances like its `QuerySet` does."""
 39    return ContentType.objects.get(app_label=app_label, permission__codename=codename)
 40
 41
 42class PytestTestRunner(DiscoverRunner):  # pragma: no cover
 43    """Runs pytest to discover and run tests."""
 44
 45    def __init__(self, **kwargs):
 46        super().__init__(**kwargs)
 47        self.logger = get_logger().bind(runner="pytest")
 48
 49        self.args = []
 50        if self.failfast:
 51            self.args.append("--exitfirst")
 52        if self.keepdb:
 53            self.args.append("--reuse-db")
 54
 55        if kwargs.get("randomly_seed", None):
 56            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
 57        if kwargs.get("no_capture", False):
 58            self.args.append("--capture=no")
 59
 60        if kwargs.get("count", None):
 61            self.args.append("--flake-finder")
 62            self.args.append(f"--flake-runs={kwargs['count']}")
 63
 64        self._setup_test_environment()
 65
 66    def _setup_test_environment(self):
 67        """Configure test environment settings"""
 68        settings.TEST = True
 69        settings.DRAMATIQ["test"] = True
 70
 71        # Test-specific configuration
 72        test_config = {
 73            "events.context_processors.geoip": "tests/geoip/GeoLite2-City-Test.mmdb",
 74            "events.context_processors.asn": "tests/geoip/GeoLite2-ASN-Test.mmdb",
 75            "blueprints_dir": "./blueprints",
 76            "outposts.container_image_base": f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
 77            "tenants.enabled": False,
 78            "outposts.disable_embedded_outpost": False,
 79            "error_reporting.sample_rate": 0,
 80            "error_reporting.environment": "testing",
 81            "error_reporting.send_pii": True,
 82        }
 83
 84        for key, value in test_config.items():
 85            CONFIG.set(key, value)
 86
 87        ASN_CONTEXT_PROCESSOR.load()
 88        GEOIP_CONTEXT_PROCESSOR.load()
 89
 90        sentry_init()
 91        self.logger.debug("Test environment configured")
 92
 93        self.task_broker = use_test_broker()
 94
 95        freezegun.configure(extend_ignore_list=["cryptography"])
 96
 97        # Send startup signals
 98        pre_startup.send(sender=self, mode="test")
 99        startup.send(sender=self, mode="test")
100        post_startup.send(sender=self, mode="test")
101
102    @classmethod
103    def add_arguments(cls, parser: ArgumentParser):
104        """Add more pytest-specific arguments"""
105        DiscoverRunner.add_arguments(parser)
106        default_seed = None
107        if seed := os.getenv("CI_TEST_SEED"):
108            default_seed = int(seed)
109        parser.add_argument(
110            "--randomly-seed",
111            type=int,
112            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
113            "to reuse the seed from the previous run."
114            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
115            "different on each run.",
116            default=default_seed,
117        )
118        parser.add_argument(
119            "--no-capture",
120            action="store_true",
121            help="Disable any capturing of stdout/stderr during tests.",
122        )
123        parser.add_argument("--count", type=int, help="Re-run selected tests n times")
124
125    def _validate_test_label(self, label: str) -> bool:
126        """Validate test label format"""
127        if not label:
128            return False
129
130        # Check for invalid characters, but allow forward slashes and colons
131        # for paths and pytest markers
132        invalid_chars = set('\\*?"<>|')
133        if any(c in label for c in invalid_chars):
134            self.logger.error("Invalid characters in test label", label=label)
135            return False
136
137        return True
138
139    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
140        """Run pytest and return the exitcode.
141
142        It translates some of Django's test command option to pytest's.
143        It is supported to only run specific classes and methods using
144        a dotted module name i.e. foo.bar[.Class[.method]]
145
146        The extra_tests argument has been deprecated since Django 5.x
147        It is kept for compatibility with PyCharm's Django test runner.
148        """
149        if not test_labels:
150            self.logger.error("No test files specified")
151            return 1
152
153        for label in test_labels:
154            if not self._validate_test_label(label):
155                return 1
156
157            valid_label_found = False
158            label_as_path = os.path.abspath(label)
159
160            # File path has been specified
161            if os.path.exists(label_as_path):
162                self.args.append(label_as_path)
163                valid_label_found = True
164            elif "::" in label:
165                self.args.append(label)
166                valid_label_found = True
167            else:
168                # Check if the label is a dotted module path
169                path_pieces = label.split(".")
170                for i in range(-1, -3, -1):
171                    try:
172                        path = os.path.join(*path_pieces[:i]) + ".py"
173                        if os.path.exists(path):
174                            if i < -1:
175                                path_method = path + "::" + "::".join(path_pieces[i:])
176                                self.args.append(path_method)
177                            else:
178                                self.args.append(path)
179                            valid_label_found = True
180                            break
181                    except TypeError, IndexError:
182                        continue
183
184            if not valid_label_found:
185                self.logger.error("Test file not found", label=label)
186                return 1
187
188        self.logger.info("Running tests", test_files=self.args)
189        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
190            try:
191                ret = pytest.main(self.args)
192                self.task_broker.close()
193                return ret
194            except Exception as exc:  # noqa
195                self.logger.error("Error running tests", exc=exc, test_files=self.args)
196                return 1
def get_docker_tag() -> str:
27def get_docker_tag() -> str:
28    """Get docker-tag based off of CI variables"""
29    env_pr_branch = "GITHUB_HEAD_REF"
30    default_branch = "GITHUB_REF"
31    branch_name = os.environ.get(default_branch, "main")
32    if os.environ.get(env_pr_branch, "") != "":
33        branch_name = os.environ[env_pr_branch]
34    branch_name = branch_name.replace("refs/heads/", "").replace("/", "-")
35    return f"gh-{branch_name}"

Get docker-tag based off of CI variables

def patched__get_ct_cached(app_label, codename):
38def patched__get_ct_cached(app_label, codename):
39    """Caches `ContentType` instances like its `QuerySet` does."""
40    return ContentType.objects.get(app_label=app_label, permission__codename=codename)

Caches ContentType instances like its QuerySet does.

class PytestTestRunner(django.test.runner.DiscoverRunner):
 43class PytestTestRunner(DiscoverRunner):  # pragma: no cover
 44    """Runs pytest to discover and run tests."""
 45
 46    def __init__(self, **kwargs):
 47        super().__init__(**kwargs)
 48        self.logger = get_logger().bind(runner="pytest")
 49
 50        self.args = []
 51        if self.failfast:
 52            self.args.append("--exitfirst")
 53        if self.keepdb:
 54            self.args.append("--reuse-db")
 55
 56        if kwargs.get("randomly_seed", None):
 57            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
 58        if kwargs.get("no_capture", False):
 59            self.args.append("--capture=no")
 60
 61        if kwargs.get("count", None):
 62            self.args.append("--flake-finder")
 63            self.args.append(f"--flake-runs={kwargs['count']}")
 64
 65        self._setup_test_environment()
 66
 67    def _setup_test_environment(self):
 68        """Configure test environment settings"""
 69        settings.TEST = True
 70        settings.DRAMATIQ["test"] = True
 71
 72        # Test-specific configuration
 73        test_config = {
 74            "events.context_processors.geoip": "tests/geoip/GeoLite2-City-Test.mmdb",
 75            "events.context_processors.asn": "tests/geoip/GeoLite2-ASN-Test.mmdb",
 76            "blueprints_dir": "./blueprints",
 77            "outposts.container_image_base": f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
 78            "tenants.enabled": False,
 79            "outposts.disable_embedded_outpost": False,
 80            "error_reporting.sample_rate": 0,
 81            "error_reporting.environment": "testing",
 82            "error_reporting.send_pii": True,
 83        }
 84
 85        for key, value in test_config.items():
 86            CONFIG.set(key, value)
 87
 88        ASN_CONTEXT_PROCESSOR.load()
 89        GEOIP_CONTEXT_PROCESSOR.load()
 90
 91        sentry_init()
 92        self.logger.debug("Test environment configured")
 93
 94        self.task_broker = use_test_broker()
 95
 96        freezegun.configure(extend_ignore_list=["cryptography"])
 97
 98        # Send startup signals
 99        pre_startup.send(sender=self, mode="test")
100        startup.send(sender=self, mode="test")
101        post_startup.send(sender=self, mode="test")
102
103    @classmethod
104    def add_arguments(cls, parser: ArgumentParser):
105        """Add more pytest-specific arguments"""
106        DiscoverRunner.add_arguments(parser)
107        default_seed = None
108        if seed := os.getenv("CI_TEST_SEED"):
109            default_seed = int(seed)
110        parser.add_argument(
111            "--randomly-seed",
112            type=int,
113            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
114            "to reuse the seed from the previous run."
115            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
116            "different on each run.",
117            default=default_seed,
118        )
119        parser.add_argument(
120            "--no-capture",
121            action="store_true",
122            help="Disable any capturing of stdout/stderr during tests.",
123        )
124        parser.add_argument("--count", type=int, help="Re-run selected tests n times")
125
126    def _validate_test_label(self, label: str) -> bool:
127        """Validate test label format"""
128        if not label:
129            return False
130
131        # Check for invalid characters, but allow forward slashes and colons
132        # for paths and pytest markers
133        invalid_chars = set('\\*?"<>|')
134        if any(c in label for c in invalid_chars):
135            self.logger.error("Invalid characters in test label", label=label)
136            return False
137
138        return True
139
140    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
141        """Run pytest and return the exitcode.
142
143        It translates some of Django's test command option to pytest's.
144        It is supported to only run specific classes and methods using
145        a dotted module name i.e. foo.bar[.Class[.method]]
146
147        The extra_tests argument has been deprecated since Django 5.x
148        It is kept for compatibility with PyCharm's Django test runner.
149        """
150        if not test_labels:
151            self.logger.error("No test files specified")
152            return 1
153
154        for label in test_labels:
155            if not self._validate_test_label(label):
156                return 1
157
158            valid_label_found = False
159            label_as_path = os.path.abspath(label)
160
161            # File path has been specified
162            if os.path.exists(label_as_path):
163                self.args.append(label_as_path)
164                valid_label_found = True
165            elif "::" in label:
166                self.args.append(label)
167                valid_label_found = True
168            else:
169                # Check if the label is a dotted module path
170                path_pieces = label.split(".")
171                for i in range(-1, -3, -1):
172                    try:
173                        path = os.path.join(*path_pieces[:i]) + ".py"
174                        if os.path.exists(path):
175                            if i < -1:
176                                path_method = path + "::" + "::".join(path_pieces[i:])
177                                self.args.append(path_method)
178                            else:
179                                self.args.append(path)
180                            valid_label_found = True
181                            break
182                    except TypeError, IndexError:
183                        continue
184
185            if not valid_label_found:
186                self.logger.error("Test file not found", label=label)
187                return 1
188
189        self.logger.info("Running tests", test_files=self.args)
190        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
191            try:
192                ret = pytest.main(self.args)
193                self.task_broker.close()
194                return ret
195            except Exception as exc:  # noqa
196                self.logger.error("Error running tests", exc=exc, test_files=self.args)
197                return 1

Runs pytest to discover and run tests.

PytestTestRunner(**kwargs)
46    def __init__(self, **kwargs):
47        super().__init__(**kwargs)
48        self.logger = get_logger().bind(runner="pytest")
49
50        self.args = []
51        if self.failfast:
52            self.args.append("--exitfirst")
53        if self.keepdb:
54            self.args.append("--reuse-db")
55
56        if kwargs.get("randomly_seed", None):
57            self.args.append(f"--randomly-seed={kwargs['randomly_seed']}")
58        if kwargs.get("no_capture", False):
59            self.args.append("--capture=no")
60
61        if kwargs.get("count", None):
62            self.args.append("--flake-finder")
63            self.args.append(f"--flake-runs={kwargs['count']}")
64
65        self._setup_test_environment()
logger
args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
103    @classmethod
104    def add_arguments(cls, parser: ArgumentParser):
105        """Add more pytest-specific arguments"""
106        DiscoverRunner.add_arguments(parser)
107        default_seed = None
108        if seed := os.getenv("CI_TEST_SEED"):
109            default_seed = int(seed)
110        parser.add_argument(
111            "--randomly-seed",
112            type=int,
113            help="Set the seed that pytest-randomly uses (int), or pass the special value 'last'"
114            "to reuse the seed from the previous run."
115            "Default behaviour: use random.Random().getrandbits(32), so the seed is"
116            "different on each run.",
117            default=default_seed,
118        )
119        parser.add_argument(
120            "--no-capture",
121            action="store_true",
122            help="Disable any capturing of stdout/stderr during tests.",
123        )
124        parser.add_argument("--count", type=int, help="Re-run selected tests n times")

Add more pytest-specific arguments

def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
140    def run_tests(self, test_labels: list[str], extra_tests=None, **kwargs):
141        """Run pytest and return the exitcode.
142
143        It translates some of Django's test command option to pytest's.
144        It is supported to only run specific classes and methods using
145        a dotted module name i.e. foo.bar[.Class[.method]]
146
147        The extra_tests argument has been deprecated since Django 5.x
148        It is kept for compatibility with PyCharm's Django test runner.
149        """
150        if not test_labels:
151            self.logger.error("No test files specified")
152            return 1
153
154        for label in test_labels:
155            if not self._validate_test_label(label):
156                return 1
157
158            valid_label_found = False
159            label_as_path = os.path.abspath(label)
160
161            # File path has been specified
162            if os.path.exists(label_as_path):
163                self.args.append(label_as_path)
164                valid_label_found = True
165            elif "::" in label:
166                self.args.append(label)
167                valid_label_found = True
168            else:
169                # Check if the label is a dotted module path
170                path_pieces = label.split(".")
171                for i in range(-1, -3, -1):
172                    try:
173                        path = os.path.join(*path_pieces[:i]) + ".py"
174                        if os.path.exists(path):
175                            if i < -1:
176                                path_method = path + "::" + "::".join(path_pieces[i:])
177                                self.args.append(path_method)
178                            else:
179                                self.args.append(path)
180                            valid_label_found = True
181                            break
182                    except TypeError, IndexError:
183                        continue
184
185            if not valid_label_found:
186                self.logger.error("Test file not found", label=label)
187                return 1
188
189        self.logger.info("Running tests", test_files=self.args)
190        with patch("guardian.shortcuts._get_ct_cached", patched__get_ct_cached):
191            try:
192                ret = pytest.main(self.args)
193                self.task_broker.close()
194                return ret
195            except Exception as exc:  # noqa
196                self.logger.error("Error running tests", exc=exc, test_files=self.args)
197                return 1

Run pytest and return the exitcode.

It translates some of Django's test command option to pytest's. It is supported to only run specific classes and methods using a dotted module name i.e. foo.bar[.Class[.method]]

The extra_tests argument has been deprecated since Django 5.x It is kept for compatibility with PyCharm's Django test runner.