authentik.providers.oauth2.tests.test_device_backchannel

Device backchannel tests

  1"""Device backchannel tests"""
  2
  3from base64 import b64encode
  4from json import loads
  5from urllib.parse import quote
  6
  7from django.urls import reverse
  8
  9from authentik.blueprints.tests import apply_blueprint
 10from authentik.core.models import Application
 11from authentik.core.tests.utils import create_test_flow
 12from authentik.lib.generators import generate_id
 13from authentik.providers.oauth2.models import DeviceToken, GrantType, OAuth2Provider, ScopeMapping
 14from authentik.providers.oauth2.tests.utils import OAuthTestCase
 15
 16
 17class TesOAuth2DeviceBackchannel(OAuthTestCase):
 18    """Test device back channel"""
 19
 20    def setUp(self) -> None:
 21        self.provider = OAuth2Provider.objects.create(
 22            name=generate_id(),
 23            client_id="test",
 24            authorization_flow=create_test_flow(),
 25            grant_types=[GrantType.DEVICE_CODE],
 26        )
 27        self.application = Application.objects.create(
 28            name=generate_id(),
 29            slug=generate_id(),
 30            provider=self.provider,
 31        )
 32
 33    def test_backchannel_invalid_client_id_via_post_body(self):
 34        """Test backchannel"""
 35        res = self.client.post(
 36            reverse("authentik_providers_oauth2:device"),
 37            data={
 38                "client_id": "foo",
 39            },
 40        )
 41        self.assertEqual(res.status_code, 400)
 42        res = self.client.post(
 43            reverse("authentik_providers_oauth2:device"),
 44        )
 45        self.assertEqual(res.status_code, 400)
 46
 47    def test_backchannel_invalid_no_grant(self):
 48        """Test backchannel"""
 49        self.provider.grant_types = []
 50        self.provider.save()
 51        res = self.client.post(
 52            reverse("authentik_providers_oauth2:device"),
 53            data={
 54                "client_id": "test",
 55            },
 56        )
 57        self.assertEqual(res.status_code, 400)
 58
 59    def test_backchannel_invalid_no_app(self):
 60        """Test backchannel"""
 61        # test without application
 62        self.application.provider = None
 63        self.application.save()
 64        res = self.client.post(
 65            reverse("authentik_providers_oauth2:device"),
 66            data={
 67                "client_id": "test",
 68            },
 69        )
 70        self.assertEqual(res.status_code, 400)
 71
 72    def test_backchannel_client_id_via_post_body(self):
 73        """Test backchannel"""
 74        res = self.client.post(
 75            reverse("authentik_providers_oauth2:device"),
 76            data={
 77                "client_id": self.provider.client_id,
 78            },
 79        )
 80        self.assertEqual(res.status_code, 200)
 81        body = loads(res.content.decode())
 82        self.assertEqual(body["expires_in"], 60)
 83
 84    def test_backchannel_invalid_client_id_via_auth_header(self):
 85        """Test backchannel"""
 86        creds = b64encode(b"foo:").decode()
 87        res = self.client.post(
 88            reverse("authentik_providers_oauth2:device"),
 89            HTTP_AUTHORIZATION=f"Basic {creds}",
 90        )
 91        self.assertEqual(res.status_code, 400)
 92        res = self.client.post(
 93            reverse("authentik_providers_oauth2:device"),
 94        )
 95        self.assertEqual(res.status_code, 400)
 96        # test without application
 97        self.application.provider = None
 98        self.application.save()
 99        res = self.client.post(
100            reverse("authentik_providers_oauth2:device"),
101            data={
102                "client_id": "test",
103            },
104        )
105        self.assertEqual(res.status_code, 400)
106
107    def test_backchannel_client_id_via_auth_header(self):
108        """Test backchannel"""
109        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
110        res = self.client.post(
111            reverse("authentik_providers_oauth2:device"),
112            HTTP_AUTHORIZATION=f"Basic {creds}",
113        )
114        self.assertEqual(res.status_code, 200)
115        body = loads(res.content.decode())
116        self.assertEqual(body["expires_in"], 60)
117
118    def test_backchannel_client_id_via_auth_header_urlencoded(self):
119        """Test URL-encoded client IDs in Basic auth"""
120        self.provider.client_id = "test/client+id"
121        self.provider.save()
122        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
123        res = self.client.post(
124            reverse("authentik_providers_oauth2:device"),
125            HTTP_AUTHORIZATION=f"Basic {creds}",
126        )
127        self.assertEqual(res.status_code, 200)
128        body = loads(res.content.decode())
129        self.assertEqual(body["expires_in"], 60)
130
131    @apply_blueprint("system/providers-oauth2.yaml")
132    def test_backchannel_scopes(self):
133        """Test backchannel"""
134        self.provider.property_mappings.set(
135            ScopeMapping.objects.filter(
136                managed__in=[
137                    "goauthentik.io/providers/oauth2/scope-openid",
138                    "goauthentik.io/providers/oauth2/scope-email",
139                    "goauthentik.io/providers/oauth2/scope-profile",
140                ]
141            )
142        )
143        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
144        res = self.client.post(
145            reverse("authentik_providers_oauth2:device"),
146            HTTP_AUTHORIZATION=f"Basic {creds}",
147            data={"scope": "openid email"},
148        )
149        self.assertEqual(res.status_code, 200)
150        body = loads(res.content.decode())
151        self.assertEqual(body["expires_in"], 60)
152        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
153        self.assertIsNotNone(token)
154        self.assertEqual(len(token.scope), 2)
155        self.assertIn("openid", token.scope)
156        self.assertIn("email", token.scope)
157
158    @apply_blueprint("system/providers-oauth2.yaml")
159    def test_backchannel_scopes_extra(self):
160        """Test backchannel"""
161        self.provider.property_mappings.set(
162            ScopeMapping.objects.filter(
163                managed__in=[
164                    "goauthentik.io/providers/oauth2/scope-openid",
165                    "goauthentik.io/providers/oauth2/scope-email",
166                    "goauthentik.io/providers/oauth2/scope-profile",
167                ]
168            )
169        )
170        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
171        res = self.client.post(
172            reverse("authentik_providers_oauth2:device"),
173            HTTP_AUTHORIZATION=f"Basic {creds}",
174            data={"scope": "openid email foo"},
175        )
176        self.assertEqual(res.status_code, 200)
177        body = loads(res.content.decode())
178        self.assertEqual(body["expires_in"], 60)
179        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
180        self.assertIsNotNone(token)
181        self.assertEqual(len(token.scope), 2)
182        self.assertIn("openid", token.scope)
183        self.assertIn("email", token.scope)
class TesOAuth2DeviceBackchannel(authentik.providers.oauth2.tests.utils.OAuthTestCase):
 18class TesOAuth2DeviceBackchannel(OAuthTestCase):
 19    """Test device back channel"""
 20
 21    def setUp(self) -> None:
 22        self.provider = OAuth2Provider.objects.create(
 23            name=generate_id(),
 24            client_id="test",
 25            authorization_flow=create_test_flow(),
 26            grant_types=[GrantType.DEVICE_CODE],
 27        )
 28        self.application = Application.objects.create(
 29            name=generate_id(),
 30            slug=generate_id(),
 31            provider=self.provider,
 32        )
 33
 34    def test_backchannel_invalid_client_id_via_post_body(self):
 35        """Test backchannel"""
 36        res = self.client.post(
 37            reverse("authentik_providers_oauth2:device"),
 38            data={
 39                "client_id": "foo",
 40            },
 41        )
 42        self.assertEqual(res.status_code, 400)
 43        res = self.client.post(
 44            reverse("authentik_providers_oauth2:device"),
 45        )
 46        self.assertEqual(res.status_code, 400)
 47
 48    def test_backchannel_invalid_no_grant(self):
 49        """Test backchannel"""
 50        self.provider.grant_types = []
 51        self.provider.save()
 52        res = self.client.post(
 53            reverse("authentik_providers_oauth2:device"),
 54            data={
 55                "client_id": "test",
 56            },
 57        )
 58        self.assertEqual(res.status_code, 400)
 59
 60    def test_backchannel_invalid_no_app(self):
 61        """Test backchannel"""
 62        # test without application
 63        self.application.provider = None
 64        self.application.save()
 65        res = self.client.post(
 66            reverse("authentik_providers_oauth2:device"),
 67            data={
 68                "client_id": "test",
 69            },
 70        )
 71        self.assertEqual(res.status_code, 400)
 72
 73    def test_backchannel_client_id_via_post_body(self):
 74        """Test backchannel"""
 75        res = self.client.post(
 76            reverse("authentik_providers_oauth2:device"),
 77            data={
 78                "client_id": self.provider.client_id,
 79            },
 80        )
 81        self.assertEqual(res.status_code, 200)
 82        body = loads(res.content.decode())
 83        self.assertEqual(body["expires_in"], 60)
 84
 85    def test_backchannel_invalid_client_id_via_auth_header(self):
 86        """Test backchannel"""
 87        creds = b64encode(b"foo:").decode()
 88        res = self.client.post(
 89            reverse("authentik_providers_oauth2:device"),
 90            HTTP_AUTHORIZATION=f"Basic {creds}",
 91        )
 92        self.assertEqual(res.status_code, 400)
 93        res = self.client.post(
 94            reverse("authentik_providers_oauth2:device"),
 95        )
 96        self.assertEqual(res.status_code, 400)
 97        # test without application
 98        self.application.provider = None
 99        self.application.save()
100        res = self.client.post(
101            reverse("authentik_providers_oauth2:device"),
102            data={
103                "client_id": "test",
104            },
105        )
106        self.assertEqual(res.status_code, 400)
107
108    def test_backchannel_client_id_via_auth_header(self):
109        """Test backchannel"""
110        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
111        res = self.client.post(
112            reverse("authentik_providers_oauth2:device"),
113            HTTP_AUTHORIZATION=f"Basic {creds}",
114        )
115        self.assertEqual(res.status_code, 200)
116        body = loads(res.content.decode())
117        self.assertEqual(body["expires_in"], 60)
118
119    def test_backchannel_client_id_via_auth_header_urlencoded(self):
120        """Test URL-encoded client IDs in Basic auth"""
121        self.provider.client_id = "test/client+id"
122        self.provider.save()
123        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
124        res = self.client.post(
125            reverse("authentik_providers_oauth2:device"),
126            HTTP_AUTHORIZATION=f"Basic {creds}",
127        )
128        self.assertEqual(res.status_code, 200)
129        body = loads(res.content.decode())
130        self.assertEqual(body["expires_in"], 60)
131
132    @apply_blueprint("system/providers-oauth2.yaml")
133    def test_backchannel_scopes(self):
134        """Test backchannel"""
135        self.provider.property_mappings.set(
136            ScopeMapping.objects.filter(
137                managed__in=[
138                    "goauthentik.io/providers/oauth2/scope-openid",
139                    "goauthentik.io/providers/oauth2/scope-email",
140                    "goauthentik.io/providers/oauth2/scope-profile",
141                ]
142            )
143        )
144        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
145        res = self.client.post(
146            reverse("authentik_providers_oauth2:device"),
147            HTTP_AUTHORIZATION=f"Basic {creds}",
148            data={"scope": "openid email"},
149        )
150        self.assertEqual(res.status_code, 200)
151        body = loads(res.content.decode())
152        self.assertEqual(body["expires_in"], 60)
153        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
154        self.assertIsNotNone(token)
155        self.assertEqual(len(token.scope), 2)
156        self.assertIn("openid", token.scope)
157        self.assertIn("email", token.scope)
158
159    @apply_blueprint("system/providers-oauth2.yaml")
160    def test_backchannel_scopes_extra(self):
161        """Test backchannel"""
162        self.provider.property_mappings.set(
163            ScopeMapping.objects.filter(
164                managed__in=[
165                    "goauthentik.io/providers/oauth2/scope-openid",
166                    "goauthentik.io/providers/oauth2/scope-email",
167                    "goauthentik.io/providers/oauth2/scope-profile",
168                ]
169            )
170        )
171        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
172        res = self.client.post(
173            reverse("authentik_providers_oauth2:device"),
174            HTTP_AUTHORIZATION=f"Basic {creds}",
175            data={"scope": "openid email foo"},
176        )
177        self.assertEqual(res.status_code, 200)
178        body = loads(res.content.decode())
179        self.assertEqual(body["expires_in"], 60)
180        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
181        self.assertIsNotNone(token)
182        self.assertEqual(len(token.scope), 2)
183        self.assertIn("openid", token.scope)
184        self.assertIn("email", token.scope)

Test device back channel

def setUp(self) -> None:
21    def setUp(self) -> None:
22        self.provider = OAuth2Provider.objects.create(
23            name=generate_id(),
24            client_id="test",
25            authorization_flow=create_test_flow(),
26            grant_types=[GrantType.DEVICE_CODE],
27        )
28        self.application = Application.objects.create(
29            name=generate_id(),
30            slug=generate_id(),
31            provider=self.provider,
32        )

Hook method for setting up the test fixture before exercising it.

def test_backchannel_invalid_client_id_via_post_body(self):
34    def test_backchannel_invalid_client_id_via_post_body(self):
35        """Test backchannel"""
36        res = self.client.post(
37            reverse("authentik_providers_oauth2:device"),
38            data={
39                "client_id": "foo",
40            },
41        )
42        self.assertEqual(res.status_code, 400)
43        res = self.client.post(
44            reverse("authentik_providers_oauth2:device"),
45        )
46        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_invalid_no_grant(self):
48    def test_backchannel_invalid_no_grant(self):
49        """Test backchannel"""
50        self.provider.grant_types = []
51        self.provider.save()
52        res = self.client.post(
53            reverse("authentik_providers_oauth2:device"),
54            data={
55                "client_id": "test",
56            },
57        )
58        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_invalid_no_app(self):
60    def test_backchannel_invalid_no_app(self):
61        """Test backchannel"""
62        # test without application
63        self.application.provider = None
64        self.application.save()
65        res = self.client.post(
66            reverse("authentik_providers_oauth2:device"),
67            data={
68                "client_id": "test",
69            },
70        )
71        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_client_id_via_post_body(self):
73    def test_backchannel_client_id_via_post_body(self):
74        """Test backchannel"""
75        res = self.client.post(
76            reverse("authentik_providers_oauth2:device"),
77            data={
78                "client_id": self.provider.client_id,
79            },
80        )
81        self.assertEqual(res.status_code, 200)
82        body = loads(res.content.decode())
83        self.assertEqual(body["expires_in"], 60)

Test backchannel

def test_backchannel_invalid_client_id_via_auth_header(self):
 85    def test_backchannel_invalid_client_id_via_auth_header(self):
 86        """Test backchannel"""
 87        creds = b64encode(b"foo:").decode()
 88        res = self.client.post(
 89            reverse("authentik_providers_oauth2:device"),
 90            HTTP_AUTHORIZATION=f"Basic {creds}",
 91        )
 92        self.assertEqual(res.status_code, 400)
 93        res = self.client.post(
 94            reverse("authentik_providers_oauth2:device"),
 95        )
 96        self.assertEqual(res.status_code, 400)
 97        # test without application
 98        self.application.provider = None
 99        self.application.save()
100        res = self.client.post(
101            reverse("authentik_providers_oauth2:device"),
102            data={
103                "client_id": "test",
104            },
105        )
106        self.assertEqual(res.status_code, 400)

Test backchannel

def test_backchannel_client_id_via_auth_header(self):
108    def test_backchannel_client_id_via_auth_header(self):
109        """Test backchannel"""
110        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
111        res = self.client.post(
112            reverse("authentik_providers_oauth2:device"),
113            HTTP_AUTHORIZATION=f"Basic {creds}",
114        )
115        self.assertEqual(res.status_code, 200)
116        body = loads(res.content.decode())
117        self.assertEqual(body["expires_in"], 60)

Test backchannel

def test_backchannel_client_id_via_auth_header_urlencoded(self):
119    def test_backchannel_client_id_via_auth_header_urlencoded(self):
120        """Test URL-encoded client IDs in Basic auth"""
121        self.provider.client_id = "test/client+id"
122        self.provider.save()
123        creds = b64encode(f"{quote(self.provider.client_id, safe='')}:".encode()).decode()
124        res = self.client.post(
125            reverse("authentik_providers_oauth2:device"),
126            HTTP_AUTHORIZATION=f"Basic {creds}",
127        )
128        self.assertEqual(res.status_code, 200)
129        body = loads(res.content.decode())
130        self.assertEqual(body["expires_in"], 60)

Test URL-encoded client IDs in Basic auth

@apply_blueprint('system/providers-oauth2.yaml')
def test_backchannel_scopes(self):
132    @apply_blueprint("system/providers-oauth2.yaml")
133    def test_backchannel_scopes(self):
134        """Test backchannel"""
135        self.provider.property_mappings.set(
136            ScopeMapping.objects.filter(
137                managed__in=[
138                    "goauthentik.io/providers/oauth2/scope-openid",
139                    "goauthentik.io/providers/oauth2/scope-email",
140                    "goauthentik.io/providers/oauth2/scope-profile",
141                ]
142            )
143        )
144        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
145        res = self.client.post(
146            reverse("authentik_providers_oauth2:device"),
147            HTTP_AUTHORIZATION=f"Basic {creds}",
148            data={"scope": "openid email"},
149        )
150        self.assertEqual(res.status_code, 200)
151        body = loads(res.content.decode())
152        self.assertEqual(body["expires_in"], 60)
153        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
154        self.assertIsNotNone(token)
155        self.assertEqual(len(token.scope), 2)
156        self.assertIn("openid", token.scope)
157        self.assertIn("email", token.scope)

Test backchannel

@apply_blueprint('system/providers-oauth2.yaml')
def test_backchannel_scopes_extra(self):
159    @apply_blueprint("system/providers-oauth2.yaml")
160    def test_backchannel_scopes_extra(self):
161        """Test backchannel"""
162        self.provider.property_mappings.set(
163            ScopeMapping.objects.filter(
164                managed__in=[
165                    "goauthentik.io/providers/oauth2/scope-openid",
166                    "goauthentik.io/providers/oauth2/scope-email",
167                    "goauthentik.io/providers/oauth2/scope-profile",
168                ]
169            )
170        )
171        creds = b64encode(f"{self.provider.client_id}:".encode()).decode()
172        res = self.client.post(
173            reverse("authentik_providers_oauth2:device"),
174            HTTP_AUTHORIZATION=f"Basic {creds}",
175            data={"scope": "openid email foo"},
176        )
177        self.assertEqual(res.status_code, 200)
178        body = loads(res.content.decode())
179        self.assertEqual(body["expires_in"], 60)
180        token = DeviceToken.objects.filter(device_code=body["device_code"]).first()
181        self.assertIsNotNone(token)
182        self.assertEqual(len(token.scope), 2)
183        self.assertIn("openid", token.scope)
184        self.assertIn("email", token.scope)

Test backchannel