authentik.providers.oauth2.tests.test_device_backchannel

Device backchannel tests

  1"""Device backchannel tests"""
  2
  3from base64 import b64encode
  4from json import loads
  5from urllib.parse import quote
  6
  7from django.urls import reverse
  8
  9from authentik.core.models import Application
 10from authentik.core.tests.utils import create_test_flow
 11from authentik.lib.generators import generate_id
 12from authentik.providers.oauth2.models import OAuth2Provider
 13from authentik.providers.oauth2.tests.utils import OAuthTestCase
 14
 15
 16class TesOAuth2DeviceBackchannel(OAuthTestCase):
 17    """Test device back channel"""
 18
 19    def setUp(self) -> None:
 20        self.provider = OAuth2Provider.objects.create(
 21            name=generate_id(),
 22            client_id="test",
 23            authorization_flow=create_test_flow(),
 24        )
 25        self.application = Application.objects.create(
 26            name=generate_id(),
 27            slug=generate_id(),
 28            provider=self.provider,
 29        )
 30
 31    def test_backchannel_invalid_client_id_via_post_body(self):
 32        """Test backchannel"""
 33        res = self.client.post(
 34            reverse("authentik_providers_oauth2:device"),
 35            data={
 36                "client_id": "foo",
 37            },
 38        )
 39        self.assertEqual(res.status_code, 400)
 40        res = self.client.post(
 41            reverse("authentik_providers_oauth2:device"),
 42        )
 43        self.assertEqual(res.status_code, 400)
 44        # test without application
 45        self.application.provider = None
 46        self.application.save()
 47        res = self.client.post(
 48            reverse("authentik_providers_oauth2:device"),
 49            data={
 50                "client_id": "test",
 51            },
 52        )
 53        self.assertEqual(res.status_code, 400)
 54
 55    def test_backchannel_client_id_via_post_body(self):
 56        """Test backchannel"""
 57        res = self.client.post(
 58            reverse("authentik_providers_oauth2:device"),
 59            data={
 60                "client_id": self.provider.client_id,
 61            },
 62        )
 63        self.assertEqual(res.status_code, 200)
 64        body = loads(res.content.decode())
 65        self.assertEqual(body["expires_in"], 60)
 66
 67    def test_backchannel_invalid_client_id_via_auth_header(self):
 68        """Test backchannel"""
 69        creds = b64encode(b"foo:").decode()
 70        res = self.client.post(
 71            reverse("authentik_providers_oauth2:device"),
 72            HTTP_AUTHORIZATION=f"Basic {creds}",
 73        )
 74        self.assertEqual(res.status_code, 400)
 75        res = self.client.post(
 76            reverse("authentik_providers_oauth2:device"),
 77        )
 78        self.assertEqual(res.status_code, 400)
 79        # test without application
 80        self.application.provider = None
 81        self.application.save()
 82        res = self.client.post(
 83            reverse("authentik_providers_oauth2:device"),
 84            data={
 85                "client_id": "test",
 86            },
 87        )
 88        self.assertEqual(res.status_code, 400)
 89
 90    def test_backchannel_client_id_via_auth_header(self):
 91        """Test backchannel"""
 92        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
 93        res = self.client.post(
 94            reverse("authentik_providers_oauth2:device"),
 95            HTTP_AUTHORIZATION=f"Basic {creds}",
 96        )
 97        self.assertEqual(res.status_code, 200)
 98        body = loads(res.content.decode())
 99        self.assertEqual(body["expires_in"], 60)
100
101    def test_backchannel_client_id_via_auth_header_urlencoded(self):
102        """Test URL-encoded client IDs in Basic auth"""
103        self.provider.client_id = "test/client+id"
104        self.provider.save()
105        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
106        res = self.client.post(
107            reverse("authentik_providers_oauth2:device"),
108            HTTP_AUTHORIZATION=f"Basic {creds}",
109        )
110        self.assertEqual(res.status_code, 200)
111        body = loads(res.content.decode())
112        self.assertEqual(body["expires_in"], 60)
class TesOAuth2DeviceBackchannel(authentik.providers.oauth2.tests.utils.OAuthTestCase):
 17class TesOAuth2DeviceBackchannel(OAuthTestCase):
 18    """Test device back channel"""
 19
 20    def setUp(self) -> None:
 21        self.provider = OAuth2Provider.objects.create(
 22            name=generate_id(),
 23            client_id="test",
 24            authorization_flow=create_test_flow(),
 25        )
 26        self.application = Application.objects.create(
 27            name=generate_id(),
 28            slug=generate_id(),
 29            provider=self.provider,
 30        )
 31
 32    def test_backchannel_invalid_client_id_via_post_body(self):
 33        """Test backchannel"""
 34        res = self.client.post(
 35            reverse("authentik_providers_oauth2:device"),
 36            data={
 37                "client_id": "foo",
 38            },
 39        )
 40        self.assertEqual(res.status_code, 400)
 41        res = self.client.post(
 42            reverse("authentik_providers_oauth2:device"),
 43        )
 44        self.assertEqual(res.status_code, 400)
 45        # test without application
 46        self.application.provider = None
 47        self.application.save()
 48        res = self.client.post(
 49            reverse("authentik_providers_oauth2:device"),
 50            data={
 51                "client_id": "test",
 52            },
 53        )
 54        self.assertEqual(res.status_code, 400)
 55
 56    def test_backchannel_client_id_via_post_body(self):
 57        """Test backchannel"""
 58        res = self.client.post(
 59            reverse("authentik_providers_oauth2:device"),
 60            data={
 61                "client_id": self.provider.client_id,
 62            },
 63        )
 64        self.assertEqual(res.status_code, 200)
 65        body = loads(res.content.decode())
 66        self.assertEqual(body["expires_in"], 60)
 67
 68    def test_backchannel_invalid_client_id_via_auth_header(self):
 69        """Test backchannel"""
 70        creds = b64encode(b"foo:").decode()
 71        res = self.client.post(
 72            reverse("authentik_providers_oauth2:device"),
 73            HTTP_AUTHORIZATION=f"Basic {creds}",
 74        )
 75        self.assertEqual(res.status_code, 400)
 76        res = self.client.post(
 77            reverse("authentik_providers_oauth2:device"),
 78        )
 79        self.assertEqual(res.status_code, 400)
 80        # test without application
 81        self.application.provider = None
 82        self.application.save()
 83        res = self.client.post(
 84            reverse("authentik_providers_oauth2:device"),
 85            data={
 86                "client_id": "test",
 87            },
 88        )
 89        self.assertEqual(res.status_code, 400)
 90
 91    def test_backchannel_client_id_via_auth_header(self):
 92        """Test backchannel"""
 93        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
 94        res = self.client.post(
 95            reverse("authentik_providers_oauth2:device"),
 96            HTTP_AUTHORIZATION=f"Basic {creds}",
 97        )
 98        self.assertEqual(res.status_code, 200)
 99        body = loads(res.content.decode())
100        self.assertEqual(body["expires_in"], 60)
101
102    def test_backchannel_client_id_via_auth_header_urlencoded(self):
103        """Test URL-encoded client IDs in Basic auth"""
104        self.provider.client_id = "test/client+id"
105        self.provider.save()
106        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
107        res = self.client.post(
108            reverse("authentik_providers_oauth2:device"),
109            HTTP_AUTHORIZATION=f"Basic {creds}",
110        )
111        self.assertEqual(res.status_code, 200)
112        body = loads(res.content.decode())
113        self.assertEqual(body["expires_in"], 60)

Test device back channel

def setUp(self) -> None:
20    def setUp(self) -> None:
21        self.provider = OAuth2Provider.objects.create(
22            name=generate_id(),
23            client_id="test",
24            authorization_flow=create_test_flow(),
25        )
26        self.application = Application.objects.create(
27            name=generate_id(),
28            slug=generate_id(),
29            provider=self.provider,
30        )

Hook method for setting up the test fixture before exercising it.

def test_backchannel_invalid_client_id_via_post_body(self):
32    def test_backchannel_invalid_client_id_via_post_body(self):
33        """Test backchannel"""
34        res = self.client.post(
35            reverse("authentik_providers_oauth2:device"),
36            data={
37                "client_id": "foo",
38            },
39        )
40        self.assertEqual(res.status_code, 400)
41        res = self.client.post(
42            reverse("authentik_providers_oauth2:device"),
43        )
44        self.assertEqual(res.status_code, 400)
45        # test without application
46        self.application.provider = None
47        self.application.save()
48        res = self.client.post(
49            reverse("authentik_providers_oauth2:device"),
50            data={
51                "client_id": "test",
52            },
53        )
54        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_client_id_via_post_body(self):
56    def test_backchannel_client_id_via_post_body(self):
57        """Test backchannel"""
58        res = self.client.post(
59            reverse("authentik_providers_oauth2:device"),
60            data={
61                "client_id": self.provider.client_id,
62            },
63        )
64        self.assertEqual(res.status_code, 200)
65        body = loads(res.content.decode())
66        self.assertEqual(body["expires_in"], 60)

Test backchannel

def test_backchannel_invalid_client_id_via_auth_header(self):
68    def test_backchannel_invalid_client_id_via_auth_header(self):
69        """Test backchannel"""
70        creds = b64encode(b"foo:").decode()
71        res = self.client.post(
72            reverse("authentik_providers_oauth2:device"),
73            HTTP_AUTHORIZATION=f"Basic {creds}",
74        )
75        self.assertEqual(res.status_code, 400)
76        res = self.client.post(
77            reverse("authentik_providers_oauth2:device"),
78        )
79        self.assertEqual(res.status_code, 400)
80        # test without application
81        self.application.provider = None
82        self.application.save()
83        res = self.client.post(
84            reverse("authentik_providers_oauth2:device"),
85            data={
86                "client_id": "test",
87            },
88        )
89        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_client_id_via_auth_header(self):
 91    def test_backchannel_client_id_via_auth_header(self):
 92        """Test backchannel"""
 93        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
 94        res = self.client.post(
 95            reverse("authentik_providers_oauth2:device"),
 96            HTTP_AUTHORIZATION=f"Basic {creds}",
 97        )
 98        self.assertEqual(res.status_code, 200)
 99        body = loads(res.content.decode())
100        self.assertEqual(body["expires_in"], 60)

Test backchannel

def test_backchannel_client_id_via_auth_header_urlencoded(self):
102    def test_backchannel_client_id_via_auth_header_urlencoded(self):
103        """Test URL-encoded client IDs in Basic auth"""
104        self.provider.client_id = "test/client+id"
105        self.provider.save()
106        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
107        res = self.client.post(
108            reverse("authentik_providers_oauth2:device"),
109            HTTP_AUTHORIZATION=f"Basic {creds}",
110        )
111        self.assertEqual(res.status_code, 200)
112        body = loads(res.content.decode())
113        self.assertEqual(body["expires_in"], 60)

Test URL-encoded client IDs in Basic auth