authentik.enterprise.providers.scim.tests

SCIM OAuth tests

  1"""SCIM OAuth tests"""
  2
  3from base64 import b64encode
  4from datetime import timedelta
  5from unittest.mock import MagicMock, PropertyMock, patch
  6
  7from django.urls import reverse
  8from django.utils.timezone import now
  9from requests_mock import Mocker
 10from rest_framework.test import APITestCase
 11
 12from authentik.blueprints.tests import apply_blueprint
 13from authentik.core.models import Application, Group, User
 14from authentik.core.tests.utils import create_test_admin_user
 15from authentik.enterprise.license import LicenseKey
 16from authentik.enterprise.models import License
 17from authentik.enterprise.tests.test_license import expiry_valid
 18from authentik.lib.generators import generate_id
 19from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
 20from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
 21from authentik.tenants.models import Tenant
 22
 23
 24class SCIMOAuthTests(APITestCase):
 25    """SCIM User tests"""
 26
 27    @apply_blueprint("system/providers-scim.yaml")
 28    def setUp(self) -> None:
 29        # Delete all users and groups as the mocked HTTP responses only return one ID
 30        # which will cause errors with multiple users
 31        Tenant.objects.update(avatars="none")
 32        User.objects.all().exclude_anonymous().delete()
 33        Group.objects.all().delete()
 34        self.source = OAuthSource.objects.create(
 35            name=generate_id(),
 36            slug=generate_id(),
 37            access_token_url="http://localhost/token",  # nosec
 38            consumer_key=generate_id(),
 39            consumer_secret=generate_id(),
 40            provider_type="openidconnect",
 41        )
 42        self.provider = SCIMProvider.objects.create(
 43            name=generate_id(),
 44            url="https://localhost",
 45            auth_mode=SCIMAuthenticationMode.OAUTH,
 46            auth_oauth=self.source,
 47            auth_oauth_params={
 48                "foo": "bar",
 49            },
 50            exclude_users_service_account=True,
 51        )
 52        self.app: Application = Application.objects.create(
 53            name=generate_id(),
 54            slug=generate_id(),
 55        )
 56        self.app.backchannel_providers.add(self.provider)
 57        self.provider.property_mappings.add(
 58            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
 59        )
 60        self.provider.property_mappings_group.add(
 61            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
 62        )
 63
 64    def test_retrieve_token(self):
 65        """Test token retrieval"""
 66        with Mocker() as mocker:
 67            token = generate_id()
 68            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
 69            self.provider.scim_auth()
 70        conn = UserOAuthSourceConnection.objects.filter(
 71            source=self.source,
 72            user=self.provider.auth_oauth_user,
 73        ).first()
 74        self.assertIsNotNone(conn)
 75        self.assertTrue(conn.is_valid)
 76        auth = (
 77            b64encode(
 78                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
 79            )
 80            .strip()
 81            .decode()
 82        )
 83        self.assertEqual(
 84            mocker.request_history[0].headers["Authorization"],
 85            f"Basic {auth}",
 86        )
 87        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")
 88
 89    def test_existing_token(self):
 90        """Test existing token"""
 91        UserOAuthSourceConnection.objects.create(
 92            source=self.source,
 93            user=self.provider.auth_oauth_user,
 94            access_token=generate_id(),
 95            expires=now() + timedelta(hours=3),
 96        )
 97        with Mocker() as mocker:
 98            self.provider.scim_auth()
 99            self.assertEqual(len(mocker.request_history), 0)
100
101    @Mocker()
102    def test_user_create(self, mock: Mocker):
103        """Test user creation"""
104        scim_id = generate_id()
105        token = generate_id()
106        mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
107        mock.get(
108            "https://localhost/ServiceProviderConfig",
109            json={},
110        )
111        mock.post(
112            "https://localhost/Users",
113            json={
114                "id": scim_id,
115            },
116        )
117        uid = generate_id()
118        user = User.objects.create(
119            username=uid,
120            name=f"{uid} {uid}",
121            email=f"{uid}@goauthentik.io",
122        )
123        self.assertEqual(mock.call_count, 3)
124        self.assertEqual(mock.request_history[1].method, "GET")
125        self.assertEqual(mock.request_history[2].method, "POST")
126        self.assertJSONEqual(
127            mock.request_history[2].body,
128            {
129                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
130                "active": True,
131                "emails": [
132                    {
133                        "primary": True,
134                        "type": "other",
135                        "value": f"{uid}@goauthentik.io",
136                    }
137                ],
138                "externalId": user.uid,
139                "name": {
140                    "familyName": uid,
141                    "formatted": f"{uid} {uid}",
142                    "givenName": uid,
143                },
144                "displayName": f"{uid} {uid}",
145                "userName": uid,
146            },
147        )
148
149    @patch(
150        "authentik.enterprise.license.LicenseKey.validate",
151        MagicMock(
152            return_value=LicenseKey(
153                aud="",
154                exp=expiry_valid,
155                name=generate_id(),
156                internal_users=100,
157                external_users=100,
158            )
159        ),
160    )
161    def test_api_create(self):
162        License.objects.create(key=generate_id())
163        self.client.force_login(create_test_admin_user())
164        res = self.client.post(
165            reverse("authentik_api:scimprovider-list"),
166            {
167                "name": generate_id(),
168                "url": "http://localhost",
169                "auth_mode": "oauth",
170                "auth_oauth": str(self.source.pk),
171            },
172        )
173        self.assertEqual(res.status_code, 201)
174
175    @patch(
176        "authentik.enterprise.models.LicenseUsageStatus.is_valid",
177        PropertyMock(return_value=False),
178    )
179    def test_api_create_no_license(self):
180        self.client.force_login(create_test_admin_user())
181        res = self.client.post(
182            reverse("authentik_api:scimprovider-list"),
183            {
184                "name": generate_id(),
185                "url": "http://localhost",
186                "auth_mode": "oauth",
187                "auth_oauth": str(self.source.pk),
188            },
189        )
190        self.assertEqual(res.status_code, 400)
191        self.assertJSONEqual(
192            res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
193        )
class SCIMOAuthTests(rest_framework.test.APITestCase):
 25class SCIMOAuthTests(APITestCase):
 26    """SCIM User tests"""
 27
 28    @apply_blueprint("system/providers-scim.yaml")
 29    def setUp(self) -> None:
 30        # Delete all users and groups as the mocked HTTP responses only return one ID
 31        # which will cause errors with multiple users
 32        Tenant.objects.update(avatars="none")
 33        User.objects.all().exclude_anonymous().delete()
 34        Group.objects.all().delete()
 35        self.source = OAuthSource.objects.create(
 36            name=generate_id(),
 37            slug=generate_id(),
 38            access_token_url="http://localhost/token",  # nosec
 39            consumer_key=generate_id(),
 40            consumer_secret=generate_id(),
 41            provider_type="openidconnect",
 42        )
 43        self.provider = SCIMProvider.objects.create(
 44            name=generate_id(),
 45            url="https://localhost",
 46            auth_mode=SCIMAuthenticationMode.OAUTH,
 47            auth_oauth=self.source,
 48            auth_oauth_params={
 49                "foo": "bar",
 50            },
 51            exclude_users_service_account=True,
 52        )
 53        self.app: Application = Application.objects.create(
 54            name=generate_id(),
 55            slug=generate_id(),
 56        )
 57        self.app.backchannel_providers.add(self.provider)
 58        self.provider.property_mappings.add(
 59            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
 60        )
 61        self.provider.property_mappings_group.add(
 62            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
 63        )
 64
 65    def test_retrieve_token(self):
 66        """Test token retrieval"""
 67        with Mocker() as mocker:
 68            token = generate_id()
 69            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
 70            self.provider.scim_auth()
 71        conn = UserOAuthSourceConnection.objects.filter(
 72            source=self.source,
 73            user=self.provider.auth_oauth_user,
 74        ).first()
 75        self.assertIsNotNone(conn)
 76        self.assertTrue(conn.is_valid)
 77        auth = (
 78            b64encode(
 79                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
 80            )
 81            .strip()
 82            .decode()
 83        )
 84        self.assertEqual(
 85            mocker.request_history[0].headers["Authorization"],
 86            f"Basic {auth}",
 87        )
 88        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")
 89
 90    def test_existing_token(self):
 91        """Test existing token"""
 92        UserOAuthSourceConnection.objects.create(
 93            source=self.source,
 94            user=self.provider.auth_oauth_user,
 95            access_token=generate_id(),
 96            expires=now() + timedelta(hours=3),
 97        )
 98        with Mocker() as mocker:
 99            self.provider.scim_auth()
100            self.assertEqual(len(mocker.request_history), 0)
101
102    @Mocker()
103    def test_user_create(self, mock: Mocker):
104        """Test user creation"""
105        scim_id = generate_id()
106        token = generate_id()
107        mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
108        mock.get(
109            "https://localhost/ServiceProviderConfig",
110            json={},
111        )
112        mock.post(
113            "https://localhost/Users",
114            json={
115                "id": scim_id,
116            },
117        )
118        uid = generate_id()
119        user = User.objects.create(
120            username=uid,
121            name=f"{uid} {uid}",
122            email=f"{uid}@goauthentik.io",
123        )
124        self.assertEqual(mock.call_count, 3)
125        self.assertEqual(mock.request_history[1].method, "GET")
126        self.assertEqual(mock.request_history[2].method, "POST")
127        self.assertJSONEqual(
128            mock.request_history[2].body,
129            {
130                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
131                "active": True,
132                "emails": [
133                    {
134                        "primary": True,
135                        "type": "other",
136                        "value": f"{uid}@goauthentik.io",
137                    }
138                ],
139                "externalId": user.uid,
140                "name": {
141                    "familyName": uid,
142                    "formatted": f"{uid} {uid}",
143                    "givenName": uid,
144                },
145                "displayName": f"{uid} {uid}",
146                "userName": uid,
147            },
148        )
149
150    @patch(
151        "authentik.enterprise.license.LicenseKey.validate",
152        MagicMock(
153            return_value=LicenseKey(
154                aud="",
155                exp=expiry_valid,
156                name=generate_id(),
157                internal_users=100,
158                external_users=100,
159            )
160        ),
161    )
162    def test_api_create(self):
163        License.objects.create(key=generate_id())
164        self.client.force_login(create_test_admin_user())
165        res = self.client.post(
166            reverse("authentik_api:scimprovider-list"),
167            {
168                "name": generate_id(),
169                "url": "http://localhost",
170                "auth_mode": "oauth",
171                "auth_oauth": str(self.source.pk),
172            },
173        )
174        self.assertEqual(res.status_code, 201)
175
176    @patch(
177        "authentik.enterprise.models.LicenseUsageStatus.is_valid",
178        PropertyMock(return_value=False),
179    )
180    def test_api_create_no_license(self):
181        self.client.force_login(create_test_admin_user())
182        res = self.client.post(
183            reverse("authentik_api:scimprovider-list"),
184            {
185                "name": generate_id(),
186                "url": "http://localhost",
187                "auth_mode": "oauth",
188                "auth_oauth": str(self.source.pk),
189            },
190        )
191        self.assertEqual(res.status_code, 400)
192        self.assertJSONEqual(
193            res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
194        )

SCIM User tests

@apply_blueprint('system/providers-scim.yaml')
def setUp(self) -> None:
28    @apply_blueprint("system/providers-scim.yaml")
29    def setUp(self) -> None:
30        # Delete all users and groups as the mocked HTTP responses only return one ID
31        # which will cause errors with multiple users
32        Tenant.objects.update(avatars="none")
33        User.objects.all().exclude_anonymous().delete()
34        Group.objects.all().delete()
35        self.source = OAuthSource.objects.create(
36            name=generate_id(),
37            slug=generate_id(),
38            access_token_url="http://localhost/token",  # nosec
39            consumer_key=generate_id(),
40            consumer_secret=generate_id(),
41            provider_type="openidconnect",
42        )
43        self.provider = SCIMProvider.objects.create(
44            name=generate_id(),
45            url="https://localhost",
46            auth_mode=SCIMAuthenticationMode.OAUTH,
47            auth_oauth=self.source,
48            auth_oauth_params={
49                "foo": "bar",
50            },
51            exclude_users_service_account=True,
52        )
53        self.app: Application = Application.objects.create(
54            name=generate_id(),
55            slug=generate_id(),
56        )
57        self.app.backchannel_providers.add(self.provider)
58        self.provider.property_mappings.add(
59            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
60        )
61        self.provider.property_mappings_group.add(
62            SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
63        )

Hook method for setting up the test fixture before exercising it.

def test_retrieve_token(self):
65    def test_retrieve_token(self):
66        """Test token retrieval"""
67        with Mocker() as mocker:
68            token = generate_id()
69            mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
70            self.provider.scim_auth()
71        conn = UserOAuthSourceConnection.objects.filter(
72            source=self.source,
73            user=self.provider.auth_oauth_user,
74        ).first()
75        self.assertIsNotNone(conn)
76        self.assertTrue(conn.is_valid)
77        auth = (
78            b64encode(
79                b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
80            )
81            .strip()
82            .decode()
83        )
84        self.assertEqual(
85            mocker.request_history[0].headers["Authorization"],
86            f"Basic {auth}",
87        )
88        self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")

Test token retrieval

def test_existing_token(self):
 90    def test_existing_token(self):
 91        """Test existing token"""
 92        UserOAuthSourceConnection.objects.create(
 93            source=self.source,
 94            user=self.provider.auth_oauth_user,
 95            access_token=generate_id(),
 96            expires=now() + timedelta(hours=3),
 97        )
 98        with Mocker() as mocker:
 99            self.provider.scim_auth()
100            self.assertEqual(len(mocker.request_history), 0)

Test existing token

@Mocker()
def test_user_create(self, mock: requests_mock.mocker.Mocker):
102    @Mocker()
103    def test_user_create(self, mock: Mocker):
104        """Test user creation"""
105        scim_id = generate_id()
106        token = generate_id()
107        mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
108        mock.get(
109            "https://localhost/ServiceProviderConfig",
110            json={},
111        )
112        mock.post(
113            "https://localhost/Users",
114            json={
115                "id": scim_id,
116            },
117        )
118        uid = generate_id()
119        user = User.objects.create(
120            username=uid,
121            name=f"{uid} {uid}",
122            email=f"{uid}@goauthentik.io",
123        )
124        self.assertEqual(mock.call_count, 3)
125        self.assertEqual(mock.request_history[1].method, "GET")
126        self.assertEqual(mock.request_history[2].method, "POST")
127        self.assertJSONEqual(
128            mock.request_history[2].body,
129            {
130                "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
131                "active": True,
132                "emails": [
133                    {
134                        "primary": True,
135                        "type": "other",
136                        "value": f"{uid}@goauthentik.io",
137                    }
138                ],
139                "externalId": user.uid,
140                "name": {
141                    "familyName": uid,
142                    "formatted": f"{uid} {uid}",
143                    "givenName": uid,
144                },
145                "displayName": f"{uid} {uid}",
146                "userName": uid,
147            },
148        )

Test user creation

@patch('authentik.enterprise.license.LicenseKey.validate', MagicMock(return_value=LicenseKey(aud='', exp=expiry_valid, name=generate_id(), internal_users=100, external_users=100)))
def test_api_create(self):
150    @patch(
151        "authentik.enterprise.license.LicenseKey.validate",
152        MagicMock(
153            return_value=LicenseKey(
154                aud="",
155                exp=expiry_valid,
156                name=generate_id(),
157                internal_users=100,
158                external_users=100,
159            )
160        ),
161    )
162    def test_api_create(self):
163        License.objects.create(key=generate_id())
164        self.client.force_login(create_test_admin_user())
165        res = self.client.post(
166            reverse("authentik_api:scimprovider-list"),
167            {
168                "name": generate_id(),
169                "url": "http://localhost",
170                "auth_mode": "oauth",
171                "auth_oauth": str(self.source.pk),
172            },
173        )
174        self.assertEqual(res.status_code, 201)
@patch('authentik.enterprise.models.LicenseUsageStatus.is_valid', PropertyMock(return_value=False))
def test_api_create_no_license(self):
176    @patch(
177        "authentik.enterprise.models.LicenseUsageStatus.is_valid",
178        PropertyMock(return_value=False),
179    )
180    def test_api_create_no_license(self):
181        self.client.force_login(create_test_admin_user())
182        res = self.client.post(
183            reverse("authentik_api:scimprovider-list"),
184            {
185                "name": generate_id(),
186                "url": "http://localhost",
187                "auth_mode": "oauth",
188                "auth_oauth": str(self.source.pk),
189            },
190        )
191        self.assertEqual(res.status_code, 400)
192        self.assertJSONEqual(
193            res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
194        )