authentik.stages.authenticator.tests

Base authenticator tests

  1"""Base authenticator tests"""
  2
  3from datetime import timedelta
  4from threading import Thread
  5
  6from django.contrib.auth.models import AnonymousUser
  7from django.db import connection
  8from django.test import TestCase, TransactionTestCase
  9from django.test.utils import override_settings
 10from django.utils import timezone
 11from freezegun import freeze_time
 12
 13from authentik.core.tests.utils import create_test_admin_user
 14from authentik.lib.generators import generate_id
 15from authentik.stages.authenticator import match_token, user_has_device, verify_token
 16from authentik.stages.authenticator.models import Device, VerifyNotAllowed
 17
 18
 19class TestThread(Thread):
 20    "Django testing quirk: threads have to close their DB connections."
 21
 22    __test__ = False
 23
 24    def run(self):
 25        super().run()
 26        connection.close()
 27
 28
 29class ThrottlingTestMixin:
 30    """
 31    Generic tests for throttled devices.
 32
 33    Any concrete device implementation that uses throttling should define a
 34    TestCase subclass that includes this as a base class. This will help verify
 35    a correct integration of ThrottlingMixin.
 36
 37    Subclasses are responsible for populating self.device with a device to test
 38    as well as implementing methods to generate tokens to test with.
 39
 40    """
 41
 42    device: Device
 43
 44    def valid_token(self):
 45        """Returns a valid token to pass to our device under test."""
 46        raise NotImplementedError()
 47
 48    def invalid_token(self):
 49        """Returns an invalid token to pass to our device under test."""
 50        raise NotImplementedError()
 51
 52    #
 53    # Tests
 54    #
 55
 56    def test_delay_imposed_after_fail(self):
 57        """Test delay imposed after fail"""
 58        verified1 = self.device.verify_token(self.invalid_token())
 59        self.assertFalse(verified1)
 60        verified2 = self.device.verify_token(self.valid_token())
 61        self.assertFalse(verified2)
 62
 63    def test_delay_after_fail_expires(self):
 64        """Test delay after fail expires"""
 65        verified1 = self.device.verify_token(self.invalid_token())
 66        self.assertFalse(verified1)
 67        with freeze_time() as frozen_time:
 68            # With default settings initial delay is 1 second
 69            frozen_time.tick(delta=timedelta(seconds=1.1))
 70            verified2 = self.device.verify_token(self.valid_token())
 71            self.assertTrue(verified2)
 72
 73    def test_throttling_failure_count(self):
 74        """Test throttling failure count"""
 75        self.assertEqual(self.device.throttling_failure_count, 0)
 76        for _ in range(0, 5):
 77            self.device.verify_token(self.invalid_token())
 78            # Only the first attempt will increase throttling_failure_count,
 79            # the others will all be within 1 second of first
 80            # and therefore not count as attempts.
 81            self.assertEqual(self.device.throttling_failure_count, 1)
 82
 83    def test_verify_is_allowed(self):
 84        """Test verify allowed"""
 85        # Initially should be allowed
 86        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 87        self.assertEqual(verify_is_allowed1, True)
 88        self.assertEqual(data1, None)
 89
 90        # After failure, verify is not allowed
 91        with freeze_time():
 92            self.device.verify_token(self.invalid_token())
 93            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 94            self.assertEqual(verify_is_allowed2, False)
 95            self.assertEqual(
 96                data2,
 97                {
 98                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
 99                    "failure_count": 1,
100                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
101                },
102            )
103
104        # After a successful attempt, should be allowed again
105        with freeze_time() as frozen_time:
106            frozen_time.tick(delta=timedelta(seconds=1.1))
107            self.device.verify_token(self.valid_token())
108
109            verify_is_allowed3, data3 = self.device.verify_is_allowed()
110            self.assertEqual(verify_is_allowed3, True)
111            self.assertEqual(data3, None)
112
113
114@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
115class APITestCase(TestCase):
116    """Test API"""
117
118    def setUp(self):
119        self.alice = create_test_admin_user("alice")
120        self.bob = create_test_admin_user("bob")
121        device = self.alice.staticdevice_set.create()
122        self.valid = generate_id(length=16)
123        device.token_set.create(token=self.valid)
124
125    def test_user_has_device(self):
126        """Test user_has_device"""
127        with self.subTest(user="anonymous"):
128            self.assertFalse(user_has_device(AnonymousUser()))
129        with self.subTest(user="alice"):
130            self.assertTrue(user_has_device(self.alice))
131        with self.subTest(user="bob"):
132            self.assertFalse(user_has_device(self.bob))
133
134    def test_verify_token(self):
135        """Test verify_token"""
136        device = self.alice.staticdevice_set.first()
137
138        verified = verify_token(self.alice, device.persistent_id, "bogus")
139        self.assertIsNone(verified)
140
141        verified = verify_token(self.alice, device.persistent_id, self.valid)
142        self.assertIsNotNone(verified)
143
144    def test_match_token(self):
145        """Test match_token"""
146        verified = match_token(self.alice, "bogus")
147        self.assertIsNone(verified)
148
149        verified = match_token(self.alice, self.valid)
150        self.assertEqual(verified, self.alice.staticdevice_set.first())
151
152
153@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
154class ConcurrencyTestCase(TransactionTestCase):
155    """Test concurrent verifications"""
156
157    def setUp(self):
158        self.alice = create_test_admin_user("alice")
159        self.bob = create_test_admin_user("bob")
160        self.valid = generate_id(length=16)
161        for user in [self.alice, self.bob]:
162            device = user.staticdevice_set.create()
163            device.token_set.create(token=self.valid)
164
165    def test_verify_token(self):
166        """Test verify_token in a thread"""
167
168        class VerifyThread(Thread):
169            """Verifier thread"""
170
171            __test__ = False
172
173            def __init__(self, user, device_id, token):
174                super().__init__()
175
176                self.user = user
177                self.device_id = device_id
178                self.token = token
179
180                self.verified = None
181
182            def run(self):
183                self.verified = verify_token(self.user, self.device_id, self.token)
184                connection.close()
185
186        device = self.alice.staticdevice_set.get()
187        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
188        for thread in threads:
189            thread.start()
190        for thread in threads:
191            thread.join()
192
193        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
194
195    def test_match_token(self):
196        """Test match_token in a thread"""
197
198        class VerifyThread(Thread):
199            """Verifier thread"""
200
201            __test__ = False
202
203            def __init__(self, user, token):
204                super().__init__()
205
206                self.user = user
207                self.token = token
208
209                self.verified = None
210
211            def run(self):
212                self.verified = match_token(self.user, self.token)
213                connection.close()
214
215        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
216        for thread in threads:
217            thread.start()
218        for thread in threads:
219            thread.join()
220
221        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
class TestThread(threading.Thread):
20class TestThread(Thread):
21    "Django testing quirk: threads have to close their DB connections."
22
23    __test__ = False
24
25    def run(self):
26        super().run()
27        connection.close()

Django testing quirk: threads have to close their DB connections.

def run(self):
25    def run(self):
26        super().run()
27        connection.close()

Method representing the thread's activity.

You may override this method in a subclass. The standard run() method invokes the callable object passed to the object's constructor as the target argument, if any, with sequential and keyword arguments taken from the args and kwargs arguments, respectively.

class ThrottlingTestMixin:
 30class ThrottlingTestMixin:
 31    """
 32    Generic tests for throttled devices.
 33
 34    Any concrete device implementation that uses throttling should define a
 35    TestCase subclass that includes this as a base class. This will help verify
 36    a correct integration of ThrottlingMixin.
 37
 38    Subclasses are responsible for populating self.device with a device to test
 39    as well as implementing methods to generate tokens to test with.
 40
 41    """
 42
 43    device: Device
 44
 45    def valid_token(self):
 46        """Returns a valid token to pass to our device under test."""
 47        raise NotImplementedError()
 48
 49    def invalid_token(self):
 50        """Returns an invalid token to pass to our device under test."""
 51        raise NotImplementedError()
 52
 53    #
 54    # Tests
 55    #
 56
 57    def test_delay_imposed_after_fail(self):
 58        """Test delay imposed after fail"""
 59        verified1 = self.device.verify_token(self.invalid_token())
 60        self.assertFalse(verified1)
 61        verified2 = self.device.verify_token(self.valid_token())
 62        self.assertFalse(verified2)
 63
 64    def test_delay_after_fail_expires(self):
 65        """Test delay after fail expires"""
 66        verified1 = self.device.verify_token(self.invalid_token())
 67        self.assertFalse(verified1)
 68        with freeze_time() as frozen_time:
 69            # With default settings initial delay is 1 second
 70            frozen_time.tick(delta=timedelta(seconds=1.1))
 71            verified2 = self.device.verify_token(self.valid_token())
 72            self.assertTrue(verified2)
 73
 74    def test_throttling_failure_count(self):
 75        """Test throttling failure count"""
 76        self.assertEqual(self.device.throttling_failure_count, 0)
 77        for _ in range(0, 5):
 78            self.device.verify_token(self.invalid_token())
 79            # Only the first attempt will increase throttling_failure_count,
 80            # the others will all be within 1 second of first
 81            # and therefore not count as attempts.
 82            self.assertEqual(self.device.throttling_failure_count, 1)
 83
 84    def test_verify_is_allowed(self):
 85        """Test verify allowed"""
 86        # Initially should be allowed
 87        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 88        self.assertEqual(verify_is_allowed1, True)
 89        self.assertEqual(data1, None)
 90
 91        # After failure, verify is not allowed
 92        with freeze_time():
 93            self.device.verify_token(self.invalid_token())
 94            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 95            self.assertEqual(verify_is_allowed2, False)
 96            self.assertEqual(
 97                data2,
 98                {
 99                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
100                    "failure_count": 1,
101                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
102                },
103            )
104
105        # After a successful attempt, should be allowed again
106        with freeze_time() as frozen_time:
107            frozen_time.tick(delta=timedelta(seconds=1.1))
108            self.device.verify_token(self.valid_token())
109
110            verify_is_allowed3, data3 = self.device.verify_is_allowed()
111            self.assertEqual(verify_is_allowed3, True)
112            self.assertEqual(data3, None)

Generic tests for throttled devices.

Any concrete device implementation that uses throttling should define a TestCase subclass that includes this as a base class. This will help verify a correct integration of ThrottlingMixin.

Subclasses are responsible for populating self.device with a device to test as well as implementing methods to generate tokens to test with.

def valid_token(self):
45    def valid_token(self):
46        """Returns a valid token to pass to our device under test."""
47        raise NotImplementedError()

Returns a valid token to pass to our device under test.

def invalid_token(self):
49    def invalid_token(self):
50        """Returns an invalid token to pass to our device under test."""
51        raise NotImplementedError()

Returns an invalid token to pass to our device under test.

def test_delay_imposed_after_fail(self):
57    def test_delay_imposed_after_fail(self):
58        """Test delay imposed after fail"""
59        verified1 = self.device.verify_token(self.invalid_token())
60        self.assertFalse(verified1)
61        verified2 = self.device.verify_token(self.valid_token())
62        self.assertFalse(verified2)

Test delay imposed after fail

def test_delay_after_fail_expires(self):
64    def test_delay_after_fail_expires(self):
65        """Test delay after fail expires"""
66        verified1 = self.device.verify_token(self.invalid_token())
67        self.assertFalse(verified1)
68        with freeze_time() as frozen_time:
69            # With default settings initial delay is 1 second
70            frozen_time.tick(delta=timedelta(seconds=1.1))
71            verified2 = self.device.verify_token(self.valid_token())
72            self.assertTrue(verified2)

Test delay after fail expires

def test_throttling_failure_count(self):
74    def test_throttling_failure_count(self):
75        """Test throttling failure count"""
76        self.assertEqual(self.device.throttling_failure_count, 0)
77        for _ in range(0, 5):
78            self.device.verify_token(self.invalid_token())
79            # Only the first attempt will increase throttling_failure_count,
80            # the others will all be within 1 second of first
81            # and therefore not count as attempts.
82            self.assertEqual(self.device.throttling_failure_count, 1)

Test throttling failure count

def test_verify_is_allowed(self):
 84    def test_verify_is_allowed(self):
 85        """Test verify allowed"""
 86        # Initially should be allowed
 87        verify_is_allowed1, data1 = self.device.verify_is_allowed()
 88        self.assertEqual(verify_is_allowed1, True)
 89        self.assertEqual(data1, None)
 90
 91        # After failure, verify is not allowed
 92        with freeze_time():
 93            self.device.verify_token(self.invalid_token())
 94            verify_is_allowed2, data2 = self.device.verify_is_allowed()
 95            self.assertEqual(verify_is_allowed2, False)
 96            self.assertEqual(
 97                data2,
 98                {
 99                    "reason": VerifyNotAllowed.N_FAILED_ATTEMPTS,
100                    "failure_count": 1,
101                    "locked_until": timezone.now() + timezone.timedelta(seconds=1),
102                },
103            )
104
105        # After a successful attempt, should be allowed again
106        with freeze_time() as frozen_time:
107            frozen_time.tick(delta=timedelta(seconds=1.1))
108            self.device.verify_token(self.valid_token())
109
110            verify_is_allowed3, data3 = self.device.verify_is_allowed()
111            self.assertEqual(verify_is_allowed3, True)
112            self.assertEqual(data3, None)

Test verify allowed

@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
class APITestCase(django.test.testcases.TestCase):
115@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
116class APITestCase(TestCase):
117    """Test API"""
118
119    def setUp(self):
120        self.alice = create_test_admin_user("alice")
121        self.bob = create_test_admin_user("bob")
122        device = self.alice.staticdevice_set.create()
123        self.valid = generate_id(length=16)
124        device.token_set.create(token=self.valid)
125
126    def test_user_has_device(self):
127        """Test user_has_device"""
128        with self.subTest(user="anonymous"):
129            self.assertFalse(user_has_device(AnonymousUser()))
130        with self.subTest(user="alice"):
131            self.assertTrue(user_has_device(self.alice))
132        with self.subTest(user="bob"):
133            self.assertFalse(user_has_device(self.bob))
134
135    def test_verify_token(self):
136        """Test verify_token"""
137        device = self.alice.staticdevice_set.first()
138
139        verified = verify_token(self.alice, device.persistent_id, "bogus")
140        self.assertIsNone(verified)
141
142        verified = verify_token(self.alice, device.persistent_id, self.valid)
143        self.assertIsNotNone(verified)
144
145    def test_match_token(self):
146        """Test match_token"""
147        verified = match_token(self.alice, "bogus")
148        self.assertIsNone(verified)
149
150        verified = match_token(self.alice, self.valid)
151        self.assertEqual(verified, self.alice.staticdevice_set.first())

Test API

def setUp(self):
119    def setUp(self):
120        self.alice = create_test_admin_user("alice")
121        self.bob = create_test_admin_user("bob")
122        device = self.alice.staticdevice_set.create()
123        self.valid = generate_id(length=16)
124        device.token_set.create(token=self.valid)

Hook method for setting up the test fixture before exercising it.

def test_user_has_device(self):
126    def test_user_has_device(self):
127        """Test user_has_device"""
128        with self.subTest(user="anonymous"):
129            self.assertFalse(user_has_device(AnonymousUser()))
130        with self.subTest(user="alice"):
131            self.assertTrue(user_has_device(self.alice))
132        with self.subTest(user="bob"):
133            self.assertFalse(user_has_device(self.bob))

Test user_has_device

def test_verify_token(self):
135    def test_verify_token(self):
136        """Test verify_token"""
137        device = self.alice.staticdevice_set.first()
138
139        verified = verify_token(self.alice, device.persistent_id, "bogus")
140        self.assertIsNone(verified)
141
142        verified = verify_token(self.alice, device.persistent_id, self.valid)
143        self.assertIsNotNone(verified)

Test verify_token

def test_match_token(self):
145    def test_match_token(self):
146        """Test match_token"""
147        verified = match_token(self.alice, "bogus")
148        self.assertIsNone(verified)
149
150        verified = match_token(self.alice, self.valid)
151        self.assertEqual(verified, self.alice.staticdevice_set.first())

Test match_token

@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
class ConcurrencyTestCase(django.test.testcases.TransactionTestCase):
154@override_settings(OTP_STATIC_THROTTLE_FACTOR=0)
155class ConcurrencyTestCase(TransactionTestCase):
156    """Test concurrent verifications"""
157
158    def setUp(self):
159        self.alice = create_test_admin_user("alice")
160        self.bob = create_test_admin_user("bob")
161        self.valid = generate_id(length=16)
162        for user in [self.alice, self.bob]:
163            device = user.staticdevice_set.create()
164            device.token_set.create(token=self.valid)
165
166    def test_verify_token(self):
167        """Test verify_token in a thread"""
168
169        class VerifyThread(Thread):
170            """Verifier thread"""
171
172            __test__ = False
173
174            def __init__(self, user, device_id, token):
175                super().__init__()
176
177                self.user = user
178                self.device_id = device_id
179                self.token = token
180
181                self.verified = None
182
183            def run(self):
184                self.verified = verify_token(self.user, self.device_id, self.token)
185                connection.close()
186
187        device = self.alice.staticdevice_set.get()
188        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
189        for thread in threads:
190            thread.start()
191        for thread in threads:
192            thread.join()
193
194        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)
195
196    def test_match_token(self):
197        """Test match_token in a thread"""
198
199        class VerifyThread(Thread):
200            """Verifier thread"""
201
202            __test__ = False
203
204            def __init__(self, user, token):
205                super().__init__()
206
207                self.user = user
208                self.token = token
209
210                self.verified = None
211
212            def run(self):
213                self.verified = match_token(self.user, self.token)
214                connection.close()
215
216        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
217        for thread in threads:
218            thread.start()
219        for thread in threads:
220            thread.join()
221
222        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test concurrent verifications

def setUp(self):
158    def setUp(self):
159        self.alice = create_test_admin_user("alice")
160        self.bob = create_test_admin_user("bob")
161        self.valid = generate_id(length=16)
162        for user in [self.alice, self.bob]:
163            device = user.staticdevice_set.create()
164            device.token_set.create(token=self.valid)

Hook method for setting up the test fixture before exercising it.

def test_verify_token(self):
166    def test_verify_token(self):
167        """Test verify_token in a thread"""
168
169        class VerifyThread(Thread):
170            """Verifier thread"""
171
172            __test__ = False
173
174            def __init__(self, user, device_id, token):
175                super().__init__()
176
177                self.user = user
178                self.device_id = device_id
179                self.token = token
180
181                self.verified = None
182
183            def run(self):
184                self.verified = verify_token(self.user, self.device_id, self.token)
185                connection.close()
186
187        device = self.alice.staticdevice_set.get()
188        threads = [VerifyThread(device.user, device.persistent_id, self.valid) for _ in range(10)]
189        for thread in threads:
190            thread.start()
191        for thread in threads:
192            thread.join()
193
194        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test verify_token in a thread

def test_match_token(self):
196    def test_match_token(self):
197        """Test match_token in a thread"""
198
199        class VerifyThread(Thread):
200            """Verifier thread"""
201
202            __test__ = False
203
204            def __init__(self, user, token):
205                super().__init__()
206
207                self.user = user
208                self.token = token
209
210                self.verified = None
211
212            def run(self):
213                self.verified = match_token(self.user, self.token)
214                connection.close()
215
216        threads = [VerifyThread(self.alice, self.valid) for _ in range(10)]
217        for thread in threads:
218            thread.start()
219        for thread in threads:
220            thread.join()
221
222        self.assertEqual(sum(1 for t in threads if t.verified is not None), 1)

Test match_token in a thread