authentik.stages.authenticator.tests

Base authenticator tests

  1"""Base authenticator tests"""
  2
  3from datetime import timedelta
  4from threading import Thread
  5
  6from django.contrib.auth.models import AnonymousUser
  7from django.db import connection
  8from django.test import TestCase, TransactionTestCase
  9from django.utils import timezone
 10from freezegun import freeze_time
 11
 12from authentik.core.tests.utils import create_test_admin_user
 13from authentik.lib.generators import generate_id
 14from authentik.stages.authenticator import match_token, user_has_device, verify_token
 15from authentik.stages.authenticator.models import Device, VerifyNotAllowed
 16
 17
 18class TestThread(Thread):
 19    "Django testing quirk: threads have to close their DB connections."
 20
 21    __test__ = False
 22
 23    def run(self):
 24        super().run()
 25        connection.close()
 26
 27
 28class ThrottlingTestMixin:
 29    """
 30    Generic tests for throttled devices.
 31
 32    Any concrete device implementation that uses throttling should define a
 33    TestCase subclass that includes this as a base class. This will help verify
 34    a correct integration of ThrottlingMixin.
 35
 36    Subclasses are responsible for populating self.device with a device to test
 37    as well as implementing methods to generate tokens to test with.
 38
 39    """
 40
 41    device: Device
 42
 43    def valid_token(self):
 44        """Returns a valid token to pass to our device under test."""
 45        raise NotImplementedError()
 46
 47    def invalid_token(self):
 48        """Returns an invalid token to pass to our device under test."""
 49        raise NotImplementedError()
 50
 51    #
 52    # Tests
 53    #
 54
 55    def test_delay_imposed_after_fail(self):
 56        """Test delay imposed after fail"""
 57        verified1 = self.device.verify_token(self.invalid_token())
 58        self.assertFalse(verified1)
 59        verified2 = self.device.verify_token(self.valid_token())
 60        self.assertFalse(verified2)
 61
 62    def test_delay_after_fail_expires(self):
 63        """Test delay after fail expires"""
 64        verified1 = self.device.verify_token(self.invalid_token())
 65        self.assertFalse(verified1)
 66        with freeze_time() as frozen_time:
 67            # With default settings initial delay is 1 second
 68            frozen_time.tick(delta=timedelta(seconds=1.1))
 69            verified2 = self.device.verify_token(self.valid_token())
 70            self.assertTrue(verified2)
 71
 72    def test_throttling_failure_count(self):
 73        """Test throttling failure count"""
 74        self.assertEqual(self.device.throttling_failure_count, 0)
 75        for _ in range(0, 5):
 76            self.device.verify_token(self.invalid_token())
 77            # Only the first attempt will increase throttling_failure_count,
 78            # the others will all be within 1 second of first
 79            # and therefore not count as attempts.
 80            self.assertEqual(self.device.throttling_failure_count, 1)
 81
 82    def test_verify_is_allowed(self):
 83        """Test verify allowed"""
 84        # Initially should be allowed
 85        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 86        self.assertEqual(verify_is_allowed1, True)
 87        self.assertEqual(data1, None)
 88
 89        # After failure, verify is not allowed
 90        with freeze_time():
 91            self.device.verify_token(self.invalid_token())
 92            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 93            self.assertEqual(verify_is_allowed2, False)
 94            self.assertEqual(
 95                data2,
 96                {
 97                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
 98                    "failure_count": 1,
 99                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
100                },
101            )
102
103        # After a successful attempt, should be allowed again
104        with freeze_time() as frozen_time:
105            frozen_time.tick(delta=timedelta(seconds=1.1))
106            self.device.verify_token(self.valid_token())
107
108            verify_is_allowed3, data3 = self.device.verify_is_allowed()
109            self.assertEqual(verify_is_allowed3, True)
110            self.assertEqual(data3, None)
111
112    def test_set_throttle_factor_is_reflected(self):
113        """`set_throttle_factor` must drive `get_throttle_factor`."""
114        self.device.set_throttle_factor(5.5)
115        self.assertEqual(self.device.get_throttle_factor(), 5.5)
116        self.device.set_throttle_factor(0)
117        self.assertEqual(self.device.get_throttle_factor(), 0)
118
119    def test_throttling_disabled_by_factor_zero(self):
120        """Setting the throttle factor to 0 must actually disable throttling.
121
122        A failed attempt followed by a successful one must succeed. The lockout
123        path must not kick in when the factor is 0.
124        """
125        self.device.set_throttle_factor(0)
126        self.assertFalse(self.device.verify_token(self.invalid_token()))
127        self.assertTrue(self.device.verify_token(self.valid_token()))
128
129
130class APITestCase(TestCase):
131    """Test API"""
132
133    def setUp(self):
134        self.alice = create_test_admin_user("alice")
135        self.bob = create_test_admin_user("bob")
136        device = self.alice.staticdevice_set.create()
137        device.set_throttle_factor(0)
138        self.valid = generate_id(length=16)
139        device.token_set.create(token=self.valid)
140
141    def test_user_has_device(self):
142        """Test user_has_device"""
143        with self.subTest(user="anonymous"):
144            self.assertFalse(user_has_device(AnonymousUser()))
145        with self.subTest(user="alice"):
146            self.assertTrue(user_has_device(self.alice))
147        with self.subTest(user="bob"):
148            self.assertFalse(user_has_device(self.bob))
149
150    def test_verify_token(self):
151        """Test verify_token"""
152        device = self.alice.staticdevice_set.first()
153
154        verified = verify_token(self.alice, device.persistent_id, "bogus")
155        self.assertIsNone(verified)
156
157        self.alice.staticdevice_set.get().throttle_reset()
158
159        verified = verify_token(self.alice, device.persistent_id, self.valid)
160        self.assertIsNotNone(verified)
161
162    def test_match_token(self):
163        """Test match_token"""
164        verified = match_token(self.alice, "bogus")
165        self.assertIsNone(verified)
166
167        self.alice.staticdevice_set.get().throttle_reset()
168
169        verified = match_token(self.alice, self.valid)
170        self.assertEqual(verified, self.alice.staticdevice_set.first())
171
172
173class ConcurrencyTestCase(TransactionTestCase):
174    """Test concurrent verifications"""
175
176    def setUp(self):
177        self.alice = create_test_admin_user("alice")
178        self.bob = create_test_admin_user("bob")
179        self.valid = generate_id(length=16)
180        for user in [self.alice, self.bob]:
181            device = user.staticdevice_set.create()
182            device.token_set.create(token=self.valid)
183
184    def test_verify_token(self):
185        """Test verify_token in a thread"""
186
187        class VerifyThread(Thread):
188            """Verifier thread"""
189
190            __test__ = False
191
192            def __init__(self, user, device_id, token):
193                super().__init__()
194
195                self.user = user
196                self.device_id = device_id
197                self.token = token
198
199                self.verified = None
200
201            def run(self):
202                self.verified = verify_token(self.user, self.device_id, self.token)
203                connection.close()
204
205        device = self.alice.staticdevice_set.get()
206        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
207        for thread in threads:
208            thread.start()
209        for thread in threads:
210            thread.join()
211
212        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
213
214    def test_match_token(self):
215        """Test match_token in a thread"""
216
217        class VerifyThread(Thread):
218            """Verifier thread"""
219
220            __test__ = False
221
222            def __init__(self, user, token):
223                super().__init__()
224
225                self.user = user
226                self.token = token
227
228                self.verified = None
229
230            def run(self):
231                self.verified = match_token(self.user, self.token)
232                connection.close()
233
234        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
235        for thread in threads:
236            thread.start()
237        for thread in threads:
238            thread.join()
239
240        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
class TestThread(threading.Thread):
19class TestThread(Thread):
20    "Django testing quirk: threads have to close their DB connections."
21
22    __test__ = False
23
24    def run(self):
25        super().run()
26        connection.close()

Django testing quirk: threads have to close their DB connections.

def run(self):
24    def run(self):
25        super().run()
26        connection.close()

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

class ThrottlingTestMixin:
 29class ThrottlingTestMixin:
 30    """
 31    Generic tests for throttled devices.
 32
 33    Any concrete device implementation that uses throttling should define a
 34    TestCase subclass that includes this as a base class. This will help verify
 35    a correct integration of ThrottlingMixin.
 36
 37    Subclasses are responsible for populating self.device with a device to test
 38    as well as implementing methods to generate tokens to test with.
 39
 40    """
 41
 42    device: Device
 43
 44    def valid_token(self):
 45        """Returns a valid token to pass to our device under test."""
 46        raise NotImplementedError()
 47
 48    def invalid_token(self):
 49        """Returns an invalid token to pass to our device under test."""
 50        raise NotImplementedError()
 51
 52    #
 53    # Tests
 54    #
 55
 56    def test_delay_imposed_after_fail(self):
 57        """Test delay imposed after fail"""
 58        verified1 = self.device.verify_token(self.invalid_token())
 59        self.assertFalse(verified1)
 60        verified2 = self.device.verify_token(self.valid_token())
 61        self.assertFalse(verified2)
 62
 63    def test_delay_after_fail_expires(self):
 64        """Test delay after fail expires"""
 65        verified1 = self.device.verify_token(self.invalid_token())
 66        self.assertFalse(verified1)
 67        with freeze_time() as frozen_time:
 68            # With default settings initial delay is 1 second
 69            frozen_time.tick(delta=timedelta(seconds=1.1))
 70            verified2 = self.device.verify_token(self.valid_token())
 71            self.assertTrue(verified2)
 72
 73    def test_throttling_failure_count(self):
 74        """Test throttling failure count"""
 75        self.assertEqual(self.device.throttling_failure_count, 0)
 76        for _ in range(0, 5):
 77            self.device.verify_token(self.invalid_token())
 78            # Only the first attempt will increase throttling_failure_count,
 79            # the others will all be within 1 second of first
 80            # and therefore not count as attempts.
 81            self.assertEqual(self.device.throttling_failure_count, 1)
 82
 83    def test_verify_is_allowed(self):
 84        """Test verify allowed"""
 85        # Initially should be allowed
 86        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 87        self.assertEqual(verify_is_allowed1, True)
 88        self.assertEqual(data1, None)
 89
 90        # After failure, verify is not allowed
 91        with freeze_time():
 92            self.device.verify_token(self.invalid_token())
 93            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 94            self.assertEqual(verify_is_allowed2, False)
 95            self.assertEqual(
 96                data2,
 97                {
 98                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
 99                    "failure_count": 1,
100                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
101                },
102            )
103
104        # After a successful attempt, should be allowed again
105        with freeze_time() as frozen_time:
106            frozen_time.tick(delta=timedelta(seconds=1.1))
107            self.device.verify_token(self.valid_token())
108
109            verify_is_allowed3, data3 = self.device.verify_is_allowed()
110            self.assertEqual(verify_is_allowed3, True)
111            self.assertEqual(data3, None)
112
113    def test_set_throttle_factor_is_reflected(self):
114        """`set_throttle_factor` must drive `get_throttle_factor`."""
115        self.device.set_throttle_factor(5.5)
116        self.assertEqual(self.device.get_throttle_factor(), 5.5)
117        self.device.set_throttle_factor(0)
118        self.assertEqual(self.device.get_throttle_factor(), 0)
119
120    def test_throttling_disabled_by_factor_zero(self):
121        """Setting the throttle factor to 0 must actually disable throttling.
122
123        A failed attempt followed by a successful one must succeed. The lockout
124        path must not kick in when the factor is 0.
125        """
126        self.device.set_throttle_factor(0)
127        self.assertFalse(self.device.verify_token(self.invalid_token()))
128        self.assertTrue(self.device.verify_token(self.valid_token()))

Generic tests for throttled devices.

Any concrete device implementation that uses throttling should define a TestCase subclass that includes this as a base class. This will help verify a correct integration of ThrottlingMixin.

Subclasses are responsible for populating self.device with a device to test as well as implementing methods to generate tokens to test with.

def valid_token(self):
44    def valid_token(self):
45        """Returns a valid token to pass to our device under test."""
46        raise NotImplementedError()

Returns a valid token to pass to our device under test.

def invalid_token(self):
48    def invalid_token(self):
49        """Returns an invalid token to pass to our device under test."""
50        raise NotImplementedError()

Returns an invalid token to pass to our device under test.

def test_delay_imposed_after_fail(self):
56    def test_delay_imposed_after_fail(self):
57        """Test delay imposed after fail"""
58        verified1 = self.device.verify_token(self.invalid_token())
59        self.assertFalse(verified1)
60        verified2 = self.device.verify_token(self.valid_token())
61        self.assertFalse(verified2)

Test delay imposed after fail

def test_delay_after_fail_expires(self):
63    def test_delay_after_fail_expires(self):
64        """Test delay after fail expires"""
65        verified1 = self.device.verify_token(self.invalid_token())
66        self.assertFalse(verified1)
67        with freeze_time() as frozen_time:
68            # With default settings initial delay is 1 second
69            frozen_time.tick(delta=timedelta(seconds=1.1))
70            verified2 = self.device.verify_token(self.valid_token())
71            self.assertTrue(verified2)

Test delay after fail expires

def test_throttling_failure_count(self):
73    def test_throttling_failure_count(self):
74        """Test throttling failure count"""
75        self.assertEqual(self.device.throttling_failure_count, 0)
76        for _ in range(0, 5):
77            self.device.verify_token(self.invalid_token())
78            # Only the first attempt will increase throttling_failure_count,
79            # the others will all be within 1 second of first
80            # and therefore not count as attempts.
81            self.assertEqual(self.device.throttling_failure_count, 1)

Test throttling failure count

def test_verify_is_allowed(self):
 83    def test_verify_is_allowed(self):
 84        """Test verify allowed"""
 85        # Initially should be allowed
 86        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 87        self.assertEqual(verify_is_allowed1, True)
 88        self.assertEqual(data1, None)
 89
 90        # After failure, verify is not allowed
 91        with freeze_time():
 92            self.device.verify_token(self.invalid_token())
 93            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 94            self.assertEqual(verify_is_allowed2, False)
 95            self.assertEqual(
 96                data2,
 97                {
 98                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
 99                    "failure_count": 1,
100                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
101                },
102            )
103
104        # After a successful attempt, should be allowed again
105        with freeze_time() as frozen_time:
106            frozen_time.tick(delta=timedelta(seconds=1.1))
107            self.device.verify_token(self.valid_token())
108
109            verify_is_allowed3, data3 = self.device.verify_is_allowed()
110            self.assertEqual(verify_is_allowed3, True)
111            self.assertEqual(data3, None)

Test verify allowed

def test_set_throttle_factor_is_reflected(self):
113    def test_set_throttle_factor_is_reflected(self):
114        """`set_throttle_factor` must drive `get_throttle_factor`."""
115        self.device.set_throttle_factor(5.5)
116        self.assertEqual(self.device.get_throttle_factor(), 5.5)
117        self.device.set_throttle_factor(0)
118        self.assertEqual(self.device.get_throttle_factor(), 0)

set_throttle_factor must drive get_throttle_factor.

def test_throttling_disabled_by_factor_zero(self):
120    def test_throttling_disabled_by_factor_zero(self):
121        """Setting the throttle factor to 0 must actually disable throttling.
122
123        A failed attempt followed by a successful one must succeed. The lockout
124        path must not kick in when the factor is 0.
125        """
126        self.device.set_throttle_factor(0)
127        self.assertFalse(self.device.verify_token(self.invalid_token()))
128        self.assertTrue(self.device.verify_token(self.valid_token()))

Setting the throttle factor to 0 must actually disable throttling.

A failed attempt followed by a successful one must succeed. The lockout path must not kick in when the factor is 0.

class APITestCase(django.test.testcases.TestCase):
131class APITestCase(TestCase):
132    """Test API"""
133
134    def setUp(self):
135        self.alice = create_test_admin_user("alice")
136        self.bob = create_test_admin_user("bob")
137        device = self.alice.staticdevice_set.create()
138        device.set_throttle_factor(0)
139        self.valid = generate_id(length=16)
140        device.token_set.create(token=self.valid)
141
142    def test_user_has_device(self):
143        """Test user_has_device"""
144        with self.subTest(user="anonymous"):
145            self.assertFalse(user_has_device(AnonymousUser()))
146        with self.subTest(user="alice"):
147            self.assertTrue(user_has_device(self.alice))
148        with self.subTest(user="bob"):
149            self.assertFalse(user_has_device(self.bob))
150
151    def test_verify_token(self):
152        """Test verify_token"""
153        device = self.alice.staticdevice_set.first()
154
155        verified = verify_token(self.alice, device.persistent_id, "bogus")
156        self.assertIsNone(verified)
157
158        self.alice.staticdevice_set.get().throttle_reset()
159
160        verified = verify_token(self.alice, device.persistent_id, self.valid)
161        self.assertIsNotNone(verified)
162
163    def test_match_token(self):
164        """Test match_token"""
165        verified = match_token(self.alice, "bogus")
166        self.assertIsNone(verified)
167
168        self.alice.staticdevice_set.get().throttle_reset()
169
170        verified = match_token(self.alice, self.valid)
171        self.assertEqual(verified, self.alice.staticdevice_set.first())

Test API

def setUp(self):
134    def setUp(self):
135        self.alice = create_test_admin_user("alice")
136        self.bob = create_test_admin_user("bob")
137        device = self.alice.staticdevice_set.create()
138        device.set_throttle_factor(0)
139        self.valid = generate_id(length=16)
140        device.token_set.create(token=self.valid)

Hook method for setting up the test fixture before exercising it.

def test_user_has_device(self):
142    def test_user_has_device(self):
143        """Test user_has_device"""
144        with self.subTest(user="anonymous"):
145            self.assertFalse(user_has_device(AnonymousUser()))
146        with self.subTest(user="alice"):
147            self.assertTrue(user_has_device(self.alice))
148        with self.subTest(user="bob"):
149            self.assertFalse(user_has_device(self.bob))

Test user_has_device

def test_verify_token(self):
151    def test_verify_token(self):
152        """Test verify_token"""
153        device = self.alice.staticdevice_set.first()
154
155        verified = verify_token(self.alice, device.persistent_id, "bogus")
156        self.assertIsNone(verified)
157
158        self.alice.staticdevice_set.get().throttle_reset()
159
160        verified = verify_token(self.alice, device.persistent_id, self.valid)
161        self.assertIsNotNone(verified)

Test verify_token

def test_match_token(self):
163    def test_match_token(self):
164        """Test match_token"""
165        verified = match_token(self.alice, "bogus")
166        self.assertIsNone(verified)
167
168        self.alice.staticdevice_set.get().throttle_reset()
169
170        verified = match_token(self.alice, self.valid)
171        self.assertEqual(verified, self.alice.staticdevice_set.first())

Test match_token

class ConcurrencyTestCase(django.test.testcases.TransactionTestCase):
174class ConcurrencyTestCase(TransactionTestCase):
175    """Test concurrent verifications"""
176
177    def setUp(self):
178        self.alice = create_test_admin_user("alice")
179        self.bob = create_test_admin_user("bob")
180        self.valid = generate_id(length=16)
181        for user in [self.alice, self.bob]:
182            device = user.staticdevice_set.create()
183            device.token_set.create(token=self.valid)
184
185    def test_verify_token(self):
186        """Test verify_token in a thread"""
187
188        class VerifyThread(Thread):
189            """Verifier thread"""
190
191            __test__ = False
192
193            def __init__(self, user, device_id, token):
194                super().__init__()
195
196                self.user = user
197                self.device_id = device_id
198                self.token = token
199
200                self.verified = None
201
202            def run(self):
203                self.verified = verify_token(self.user, self.device_id, self.token)
204                connection.close()
205
206        device = self.alice.staticdevice_set.get()
207        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
208        for thread in threads:
209            thread.start()
210        for thread in threads:
211            thread.join()
212
213        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
214
215    def test_match_token(self):
216        """Test match_token in a thread"""
217
218        class VerifyThread(Thread):
219            """Verifier thread"""
220
221            __test__ = False
222
223            def __init__(self, user, token):
224                super().__init__()
225
226                self.user = user
227                self.token = token
228
229                self.verified = None
230
231            def run(self):
232                self.verified = match_token(self.user, self.token)
233                connection.close()
234
235        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
236        for thread in threads:
237            thread.start()
238        for thread in threads:
239            thread.join()
240
241        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test concurrent verifications

def setUp(self):
177    def setUp(self):
178        self.alice = create_test_admin_user("alice")
179        self.bob = create_test_admin_user("bob")
180        self.valid = generate_id(length=16)
181        for user in [self.alice, self.bob]:
182            device = user.staticdevice_set.create()
183            device.token_set.create(token=self.valid)

Hook method for setting up the test fixture before exercising it.

def test_verify_token(self):
185    def test_verify_token(self):
186        """Test verify_token in a thread"""
187
188        class VerifyThread(Thread):
189            """Verifier thread"""
190
191            __test__ = False
192
193            def __init__(self, user, device_id, token):
194                super().__init__()
195
196                self.user = user
197                self.device_id = device_id
198                self.token = token
199
200                self.verified = None
201
202            def run(self):
203                self.verified = verify_token(self.user, self.device_id, self.token)
204                connection.close()
205
206        device = self.alice.staticdevice_set.get()
207        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
208        for thread in threads:
209            thread.start()
210        for thread in threads:
211            thread.join()
212
213        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test verify_token in a thread

def test_match_token(self):
215    def test_match_token(self):
216        """Test match_token in a thread"""
217
218        class VerifyThread(Thread):
219            """Verifier thread"""
220
221            __test__ = False
222
223            def __init__(self, user, token):
224                super().__init__()
225
226                self.user = user
227                self.token = token
228
229                self.verified = None
230
231            def run(self):
232                self.verified = match_token(self.user, self.token)
233                connection.close()
234
235        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
236        for thread in threads:
237            thread.start()
238        for thread in threads:
239            thread.join()
240
241        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test match_token in a thread