authentik.enterprise.providers.scim.tests.test_token

SCIM OAuth tests

  1"""SCIM OAuth tests"""
  2
  3from base64 import b64encode
  4from datetime import timedelta
  5from urllib.parse import parse_qs, urlencode, urlparse
  6
  7from django.urls import reverse
  8from django.utils.timezone import now
  9from requests_mock import Mocker
 10from rest_framework.test import APITestCase
 11
 12from authentik.blueprints.tests import apply_blueprint
 13from authentik.core.models import Application, Group, User
 14from authentik.lib.generators import generate_id
 15from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
 16from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
 17from authentik.tenants.models import Tenant
 18from tests.live import create_test_admin_user
 19
 20
 21class TestSCIMOAuthToken(APITestCase):
 22    """SCIM User tests"""
 23
 24    @apply_blueprint("system/providers-scim.yaml")
 25    def setUp(self) -> None:
 26        # Delete all users and groups as the mocked HTTP responses only return one ID
 27        # which will cause errors with multiple users
 28        Tenant.objects.update(avatars="none")
 29        User.objects.all().exclude_anonymous().delete()
 30        Group.objects.all().delete()
 31        self.source = OAuthSource.objects.create(
 32            name=generate_id(),
 33            slug=generate_id(),
 34            access_token_url="http://localhost/token",  # nosec
 35            consumer_key=generate_id(),
 36            consumer_secret=generate_id(),
 37            provider_type="openidconnect",
 38        )
 39        self.provider = SCIMProvider.objects.create(
 40            name=generate_id(),
 41            url="https://localhost",
 42            auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
 43            auth_oauth=self.source,
 44            auth_oauth_params={
 45                "foo": "bar",
 46            },
 47            exclude_users_service_account=True,
 48        )
 49        self.app: Application = Application.objects.create(
 50            name=generate_id(),
 51            slug=generate_id(),
 52        )
 53        self.app.backchannel_providers.add(self.provider)
 54        self.provider.property_mappings.add(
 55            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
 56        )
 57        self.provider.property_mappings_group.add(
 58            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
 59        )
 60        self.admin = create_test_admin_user()
 61
 62    def test_retrieve_token_silent(self):
 63        """Test token retrieval"""
 64        with Mocker() as mocker:
 65            token = generate_id()
 66            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
 67            self.provider.scim_auth()
 68        conn = UserOAuthSourceConnection.objects.filter(
 69            source=self.source,
 70            user=self.provider.auth_oauth_user,
 71        ).first()
 72        self.assertIsNotNone(conn)
 73        self.assertTrue(conn.is_valid)
 74        auth = (
 75            b64encode(
 76                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
 77            )
 78            .strip()
 79            .decode()
 80        )
 81        self.assertEqual(
 82            mocker.request_history[0].headers["Authorization"],
 83            f"Basic {auth}",
 84        )
 85        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")
 86
 87    def test_retrieve_token_interactive(self):
 88        """Test token retrieval"""
 89        self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE
 90        self.provider.save()
 91        refresh_token = generate_id()
 92        access_token = generate_id()
 93        UserOAuthSourceConnection.objects.create(
 94            user=self.provider.auth_oauth_user,
 95            source=self.source,
 96            refresh_token=refresh_token,
 97            access_token=access_token,
 98        )
 99        with Mocker() as mocker:
100            token = generate_id()
101            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
102            self.provider.scim_auth()
103        conn = UserOAuthSourceConnection.objects.filter(
104            source=self.source,
105            user=self.provider.auth_oauth_user,
106        ).first()
107        self.assertIsNotNone(conn)
108        self.assertTrue(conn.is_valid)
109        auth = (
110            b64encode(
111                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
112            )
113            .strip()
114            .decode()
115        )
116        self.assertEqual(
117            mocker.request_history[0].headers["Authorization"],
118            f"Basic {auth}",
119        )
120        self.assertEqual(
121            mocker.request_history[0].body,
122            f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar",
123        )
124
125    def test_existing_token(self):
126        """Test existing token"""
127        UserOAuthSourceConnection.objects.create(
128            source=self.source,
129            user=self.provider.auth_oauth_user,
130            access_token=generate_id(),
131            expires=now() + timedelta(hours=3),
132        )
133        with Mocker() as mocker:
134            self.provider.scim_auth()
135            self.assertEqual(len(mocker.request_history), 0)
136
137    def test_interactive_start(self):
138        self.client.force_login(self.admin)
139        res = self.client.get(
140            reverse(
141                "authentik_enterprise_providers_scim:start",
142                kwargs={
143                    "application_slug": self.app.slug,
144                },
145            )
146        )
147        self.assertEqual(res.status_code, 302)
148        query = parse_qs(urlparse(res.url).query)
149        self.assertEqual(query["client_id"], [self.source.consumer_key])
150        self.assertEqual(
151            query["redirect_uri"],
152            [f"http://testserver/application/scim/{self.app.slug}/oauth2/callback/"],
153        )
154        self.assertEqual(query["response_type"], ["code"])
155
156    def test_interactive_callback(self):
157        self.client.force_login(self.admin)
158        res = self.client.get(
159            reverse(
160                "authentik_enterprise_providers_scim:start",
161                kwargs={
162                    "application_slug": self.app.slug,
163                },
164            )
165        )
166        self.assertEqual(res.status_code, 302)
167        query = parse_qs(urlparse(res.url).query)
168
169        with Mocker() as mock:
170            token = generate_id()
171            mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
172
173            res = self.client.get(
174                reverse(
175                    "authentik_enterprise_providers_scim:callback",
176                    kwargs={
177                        "application_slug": self.app.slug,
178                    },
179                )
180                + "?"
181                + urlencode({"state": query["state"][0], "code": generate_id()})
182            )
183            self.assertEqual(res.status_code, 302)
184
185        conn = UserOAuthSourceConnection.objects.filter(source=self.source).first()
186        self.assertIsNotNone(conn)
187        self.assertTrue(conn.is_valid)
class TestSCIMOAuthToken(rest_framework.test.APITestCase):
 22class TestSCIMOAuthToken(APITestCase):
 23    """SCIM User tests"""
 24
 25    @apply_blueprint("system/providers-scim.yaml")
 26    def setUp(self) -> None:
 27        # Delete all users and groups as the mocked HTTP responses only return one ID
 28        # which will cause errors with multiple users
 29        Tenant.objects.update(avatars="none")
 30        User.objects.all().exclude_anonymous().delete()
 31        Group.objects.all().delete()
 32        self.source = OAuthSource.objects.create(
 33            name=generate_id(),
 34            slug=generate_id(),
 35            access_token_url="http://localhost/token",  # nosec
 36            consumer_key=generate_id(),
 37            consumer_secret=generate_id(),
 38            provider_type="openidconnect",
 39        )
 40        self.provider = SCIMProvider.objects.create(
 41            name=generate_id(),
 42            url="https://localhost",
 43            auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
 44            auth_oauth=self.source,
 45            auth_oauth_params={
 46                "foo": "bar",
 47            },
 48            exclude_users_service_account=True,
 49        )
 50        self.app: Application = Application.objects.create(
 51            name=generate_id(),
 52            slug=generate_id(),
 53        )
 54        self.app.backchannel_providers.add(self.provider)
 55        self.provider.property_mappings.add(
 56            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
 57        )
 58        self.provider.property_mappings_group.add(
 59            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
 60        )
 61        self.admin = create_test_admin_user()
 62
 63    def test_retrieve_token_silent(self):
 64        """Test token retrieval"""
 65        with Mocker() as mocker:
 66            token = generate_id()
 67            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
 68            self.provider.scim_auth()
 69        conn = UserOAuthSourceConnection.objects.filter(
 70            source=self.source,
 71            user=self.provider.auth_oauth_user,
 72        ).first()
 73        self.assertIsNotNone(conn)
 74        self.assertTrue(conn.is_valid)
 75        auth = (
 76            b64encode(
 77                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
 78            )
 79            .strip()
 80            .decode()
 81        )
 82        self.assertEqual(
 83            mocker.request_history[0].headers["Authorization"],
 84            f"Basic {auth}",
 85        )
 86        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")
 87
 88    def test_retrieve_token_interactive(self):
 89        """Test token retrieval"""
 90        self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE
 91        self.provider.save()
 92        refresh_token = generate_id()
 93        access_token = generate_id()
 94        UserOAuthSourceConnection.objects.create(
 95            user=self.provider.auth_oauth_user,
 96            source=self.source,
 97            refresh_token=refresh_token,
 98            access_token=access_token,
 99        )
100        with Mocker() as mocker:
101            token = generate_id()
102            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
103            self.provider.scim_auth()
104        conn = UserOAuthSourceConnection.objects.filter(
105            source=self.source,
106            user=self.provider.auth_oauth_user,
107        ).first()
108        self.assertIsNotNone(conn)
109        self.assertTrue(conn.is_valid)
110        auth = (
111            b64encode(
112                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
113            )
114            .strip()
115            .decode()
116        )
117        self.assertEqual(
118            mocker.request_history[0].headers["Authorization"],
119            f"Basic {auth}",
120        )
121        self.assertEqual(
122            mocker.request_history[0].body,
123            f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar",
124        )
125
126    def test_existing_token(self):
127        """Test existing token"""
128        UserOAuthSourceConnection.objects.create(
129            source=self.source,
130            user=self.provider.auth_oauth_user,
131            access_token=generate_id(),
132            expires=now() + timedelta(hours=3),
133        )
134        with Mocker() as mocker:
135            self.provider.scim_auth()
136            self.assertEqual(len(mocker.request_history), 0)
137
138    def test_interactive_start(self):
139        self.client.force_login(self.admin)
140        res = self.client.get(
141            reverse(
142                "authentik_enterprise_providers_scim:start",
143                kwargs={
144                    "application_slug": self.app.slug,
145                },
146            )
147        )
148        self.assertEqual(res.status_code, 302)
149        query = parse_qs(urlparse(res.url).query)
150        self.assertEqual(query["client_id"], [self.source.consumer_key])
151        self.assertEqual(
152            query["redirect_uri"],
153            [f"http://testserver/application/scim/{self.app.slug}/oauth2/callback/"],
154        )
155        self.assertEqual(query["response_type"], ["code"])
156
157    def test_interactive_callback(self):
158        self.client.force_login(self.admin)
159        res = self.client.get(
160            reverse(
161                "authentik_enterprise_providers_scim:start",
162                kwargs={
163                    "application_slug": self.app.slug,
164                },
165            )
166        )
167        self.assertEqual(res.status_code, 302)
168        query = parse_qs(urlparse(res.url).query)
169
170        with Mocker() as mock:
171            token = generate_id()
172            mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
173
174            res = self.client.get(
175                reverse(
176                    "authentik_enterprise_providers_scim:callback",
177                    kwargs={
178                        "application_slug": self.app.slug,
179                    },
180                )
181                + "?"
182                + urlencode({"state": query["state"][0], "code": generate_id()})
183            )
184            self.assertEqual(res.status_code, 302)
185
186        conn = UserOAuthSourceConnection.objects.filter(source=self.source).first()
187        self.assertIsNotNone(conn)
188        self.assertTrue(conn.is_valid)

SCIM User tests

@apply_blueprint('system/providers-scim.yaml')
def setUp(self) -> None:
25    @apply_blueprint("system/providers-scim.yaml")
26    def setUp(self) -> None:
27        # Delete all users and groups as the mocked HTTP responses only return one ID
28        # which will cause errors with multiple users
29        Tenant.objects.update(avatars="none")
30        User.objects.all().exclude_anonymous().delete()
31        Group.objects.all().delete()
32        self.source = OAuthSource.objects.create(
33            name=generate_id(),
34            slug=generate_id(),
35            access_token_url="http://localhost/token",  # nosec
36            consumer_key=generate_id(),
37            consumer_secret=generate_id(),
38            provider_type="openidconnect",
39        )
40        self.provider = SCIMProvider.objects.create(
41            name=generate_id(),
42            url="https://localhost",
43            auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
44            auth_oauth=self.source,
45            auth_oauth_params={
46                "foo": "bar",
47            },
48            exclude_users_service_account=True,
49        )
50        self.app: Application = Application.objects.create(
51            name=generate_id(),
52            slug=generate_id(),
53        )
54        self.app.backchannel_providers.add(self.provider)
55        self.provider.property_mappings.add(
56            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
57        )
58        self.provider.property_mappings_group.add(
59            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
60        )
61        self.admin = create_test_admin_user()

Hook method for setting up the test fixture before exercising it.

def test_retrieve_token_silent(self):
63    def test_retrieve_token_silent(self):
64        """Test token retrieval"""
65        with Mocker() as mocker:
66            token = generate_id()
67            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
68            self.provider.scim_auth()
69        conn = UserOAuthSourceConnection.objects.filter(
70            source=self.source,
71            user=self.provider.auth_oauth_user,
72        ).first()
73        self.assertIsNotNone(conn)
74        self.assertTrue(conn.is_valid)
75        auth = (
76            b64encode(
77                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
78            )
79            .strip()
80            .decode()
81        )
82        self.assertEqual(
83            mocker.request_history[0].headers["Authorization"],
84            f"Basic {auth}",
85        )
86        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")

Test token retrieval

def test_retrieve_token_interactive(self):
 88    def test_retrieve_token_interactive(self):
 89        """Test token retrieval"""
 90        self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE
 91        self.provider.save()
 92        refresh_token = generate_id()
 93        access_token = generate_id()
 94        UserOAuthSourceConnection.objects.create(
 95            user=self.provider.auth_oauth_user,
 96            source=self.source,
 97            refresh_token=refresh_token,
 98            access_token=access_token,
 99        )
100        with Mocker() as mocker:
101            token = generate_id()
102            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
103            self.provider.scim_auth()
104        conn = UserOAuthSourceConnection.objects.filter(
105            source=self.source,
106            user=self.provider.auth_oauth_user,
107        ).first()
108        self.assertIsNotNone(conn)
109        self.assertTrue(conn.is_valid)
110        auth = (
111            b64encode(
112                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
113            )
114            .strip()
115            .decode()
116        )
117        self.assertEqual(
118            mocker.request_history[0].headers["Authorization"],
119            f"Basic {auth}",
120        )
121        self.assertEqual(
122            mocker.request_history[0].body,
123            f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar",
124        )

Test token retrieval

def test_existing_token(self):
126    def test_existing_token(self):
127        """Test existing token"""
128        UserOAuthSourceConnection.objects.create(
129            source=self.source,
130            user=self.provider.auth_oauth_user,
131            access_token=generate_id(),
132            expires=now() + timedelta(hours=3),
133        )
134        with Mocker() as mocker:
135            self.provider.scim_auth()
136            self.assertEqual(len(mocker.request_history), 0)

Test existing token

def test_interactive_start(self):
138    def test_interactive_start(self):
139        self.client.force_login(self.admin)
140        res = self.client.get(
141            reverse(
142                "authentik_enterprise_providers_scim:start",
143                kwargs={
144                    "application_slug": self.app.slug,
145                },
146            )
147        )
148        self.assertEqual(res.status_code, 302)
149        query = parse_qs(urlparse(res.url).query)
150        self.assertEqual(query["client_id"], [self.source.consumer_key])
151        self.assertEqual(
152            query["redirect_uri"],
153            [f"http://testserver/application/scim/{self.app.slug}/oauth2/callback/"],
154        )
155        self.assertEqual(query["response_type"], ["code"])
def test_interactive_callback(self):
157    def test_interactive_callback(self):
158        self.client.force_login(self.admin)
159        res = self.client.get(
160            reverse(
161                "authentik_enterprise_providers_scim:start",
162                kwargs={
163                    "application_slug": self.app.slug,
164                },
165            )
166        )
167        self.assertEqual(res.status_code, 302)
168        query = parse_qs(urlparse(res.url).query)
169
170        with Mocker() as mock:
171            token = generate_id()
172            mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
173
174            res = self.client.get(
175                reverse(
176                    "authentik_enterprise_providers_scim:callback",
177                    kwargs={
178                        "application_slug": self.app.slug,
179                    },
180                )
181                + "?"
182                + urlencode({"state": query["state"][0], "code": generate_id()})
183            )
184            self.assertEqual(res.status_code, 302)
185
186        conn = UserOAuthSourceConnection.objects.filter(source=self.source).first()
187        self.assertIsNotNone(conn)
188        self.assertTrue(conn.is_valid)