authentik.stages.prompt.stage
Prompt Stage Logic
1"""Prompt Stage Logic""" 2 3from collections.abc import Callable 4from email.policy import Policy 5from types import MethodType 6from typing import Any 7 8from django.contrib.messages import INFO, add_message 9from django.db.models.query import QuerySet 10from django.http import HttpRequest, HttpResponse 11from django.http.request import QueryDict 12from django.utils.translation import gettext_lazy as _ 13from rest_framework.fields import ( 14 BooleanField, 15 CharField, 16 ChoiceField, 17 IntegerField, 18 ListField, 19 empty, 20) 21from rest_framework.serializers import ValidationError 22 23from authentik.core.api.utils import PassiveSerializer 24from authentik.core.models import User 25from authentik.flows.challenge import Challenge, ChallengeResponse 26from authentik.flows.planner import FlowPlan 27from authentik.flows.stage import ChallengeStageView 28from authentik.policies.engine import PolicyEngine 29from authentik.policies.models import PolicyBinding, PolicyBindingModel, PolicyEngineMode 30from authentik.stages.prompt.models import FieldTypes, Prompt, PromptStage 31from authentik.stages.prompt.signals import password_validate 32 33PLAN_CONTEXT_PROMPT = "prompt_data" 34 35 36class PromptChoiceSerializer(PassiveSerializer): 37 """Serializer for a single Choice field""" 38 39 value = CharField(required=True) 40 label = CharField(required=True) 41 42 43class StagePromptSerializer(PassiveSerializer): 44 """Serializer for a single Prompt field""" 45 46 field_key = CharField() 47 label = CharField(allow_blank=True) 48 type = ChoiceField(choices=FieldTypes.choices) 49 required = BooleanField() 50 placeholder = CharField(allow_blank=True) 51 initial_value = CharField(allow_blank=True) 52 order = IntegerField() 53 sub_text = CharField(allow_blank=True) 54 choices = ListField(child=PromptChoiceSerializer(), allow_empty=True, allow_null=True) 55 56 57class PromptChallenge(Challenge): 58 """Initial challenge being sent, define fields""" 59 60 fields = StagePromptSerializer(many=True) 61 component = CharField(default="ak-stage-prompt") 62 63 64class PromptChallengeResponse(ChallengeResponse): 65 """Validate response, fields are dynamically created based 66 on the stage""" 67 68 stage_instance: PromptStage 69 70 component = CharField(default="ak-stage-prompt") 71 72 def __init__(self, *args, **kwargs): 73 stage: PromptStage = kwargs.pop("stage_instance", None) 74 plan: FlowPlan = kwargs.pop("plan", None) 75 request: HttpRequest = kwargs.pop("request", None) 76 user: User = kwargs.pop("user", None) 77 super().__init__(*args, **kwargs) 78 self.stage_instance = stage 79 self.plan = plan 80 self.request = request 81 if not self.stage_instance: 82 return 83 # list() is called so we only load the fields once 84 fields = list(self.stage_instance.fields.all()) 85 for field in fields: 86 field: Prompt 87 choices = field.get_choices( 88 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 89 ) 90 current = field.get_initial_value( 91 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 92 ) 93 self.fields[field.field_key] = field.field(current, choices) 94 # Special handling for fields with username type 95 # these check for existing users with the same username 96 if field.type == FieldTypes.USERNAME: 97 setattr( 98 self, 99 f"validate_{field.field_key}", 100 MethodType(username_field_validator_factory(), self), 101 ) 102 # Check if we have a password field, add a handler that sends a signal 103 # to validate it 104 if field.type == FieldTypes.PASSWORD: 105 setattr( 106 self, 107 f"validate_{field.field_key}", 108 MethodType(password_single_validator_factory(), self), 109 ) 110 111 self.field_order = sorted(fields, key=lambda x: x.order) 112 113 def _validate_password_fields(self, *field_names): 114 """Check if the value of all password fields match by merging them into a set 115 and checking the length""" 116 all_passwords = {self.initial_data[x] for x in field_names} 117 if len(all_passwords) > 1: 118 raise ValidationError(_("Passwords don't match.")) 119 120 def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: 121 # Check if we have any static or hidden fields, and ensure they 122 # still have the same value 123 static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 124 type__in=[ 125 FieldTypes.HIDDEN, 126 FieldTypes.STATIC, 127 FieldTypes.TEXT_READ_ONLY, 128 FieldTypes.TEXT_AREA_READ_ONLY, 129 ] 130 ) 131 for static_hidden in static_hidden_fields: 132 field = self.fields[static_hidden.field_key] 133 default = field.default 134 # Prevent rest_framework.fields.empty from ending up in policies and events 135 if default == empty: 136 default = "" 137 attrs[static_hidden.field_key] = default 138 139 # Check if we have two password fields, and make sure they are the same 140 password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 141 type=FieldTypes.PASSWORD 142 ) 143 if password_fields.exists() and password_fields.count() == 2: # noqa: PLR2004 144 self._validate_password_fields(*[field.field_key for field in password_fields]) 145 146 engine = ListPolicyEngine( 147 self.stage_instance.validation_policies.all(), 148 self.stage.get_pending_user(), 149 self.request, 150 ) 151 engine.mode = PolicyEngineMode.MODE_ALL 152 engine.request.context[PLAN_CONTEXT_PROMPT] = attrs 153 engine.use_cache = False 154 engine.build() 155 result = engine.result 156 if not result.passing: 157 raise ValidationError(list(result.messages)) 158 else: 159 for msg in result.messages: 160 add_message(self.request, INFO, msg) 161 return attrs 162 163 164def username_field_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]: 165 """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" 166 167 def username_field_validator(self: PromptChallengeResponse, value: str) -> Any: 168 """Check for duplicate usernames""" 169 pending_user = self.stage.get_pending_user() 170 query = User.objects.all() 171 if pending_user.pk: 172 query = query.exclude(username=pending_user.username) 173 if query.filter(username=value).exists(): 174 raise ValidationError("Username is already taken.") 175 return value 176 177 return username_field_validator 178 179 180def password_single_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]: 181 """Return a `clean_` method for `field`. Clean method checks if the password meets configured 182 PasswordPolicy.""" 183 184 def password_single_clean(self: PromptChallengeResponse, value: str) -> Any: 185 """Send password validation signals for e.g. LDAP Source""" 186 password_validate.send(sender=self, password=value, plan_context=self.plan.context) 187 return value 188 189 return password_single_clean 190 191 192class ListPolicyEngine(PolicyEngine): 193 """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" 194 195 def __init__(self, policies: list[Policy], user: User, request: HttpRequest = None) -> None: 196 super().__init__(PolicyBindingModel(), user, request) 197 self.__list = policies 198 self.use_cache = False 199 200 def bindings(self): 201 for idx, policy in enumerate(self.__list): 202 yield PolicyBinding( 203 policy=policy, 204 order=idx, 205 ) 206 207 208class PromptStageView(ChallengeStageView): 209 """Prompt Stage, save form data in plan context.""" 210 211 response_class = PromptChallengeResponse 212 213 def get_prompt_challenge_fields(self, fields: list[Prompt], context: dict, dry_run=False): 214 """Get serializers for all fields in `fields`, using the context `context`. 215 If `dry_run` is set, property mapping expression errors are raised, otherwise they 216 are logged and events are created""" 217 serializers = [] 218 for field in fields: 219 data = StagePromptSerializer(field).data 220 # Ensure all placeholders and initial values are str, as 221 # otherwise further in we can fail serializer validation if we return 222 # some types such as bool 223 # choices can be a dict with value and label 224 choices = field.get_choices(context, self.get_pending_user(), self.request, dry_run) 225 if choices: 226 data["choices"] = list(self.clean_choices(choices)) 227 else: 228 data["choices"] = None 229 data["placeholder"] = str( 230 field.get_placeholder(context, self.get_pending_user(), self.request, dry_run) 231 ) 232 data["initial_value"] = str( 233 field.get_initial_value(context, self.get_pending_user(), self.request, dry_run) 234 ) 235 serializers.append(data) 236 return serializers 237 238 def clean_choices(self, choices): 239 for choice in choices: 240 label, value = choice, choice 241 if isinstance(choice, dict): 242 label = choice.get("label", "") 243 value = choice.get("value", "") 244 yield {"label": str(label), "value": str(value)} 245 246 def get_challenge(self, *args, **kwargs) -> Challenge: 247 fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order")) 248 context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) 249 serializers = self.get_prompt_challenge_fields(fields, context_prompt) 250 challenge = PromptChallenge( 251 data={ 252 "fields": serializers, 253 }, 254 ) 255 return challenge 256 257 def get_response_instance(self, data: QueryDict) -> ChallengeResponse: 258 if not self.executor.plan: # pragma: no cover 259 raise ValueError 260 return PromptChallengeResponse( 261 instance=None, 262 data=data, 263 request=self.request, 264 stage_instance=self.executor.current_stage, 265 stage=self, 266 plan=self.executor.plan, 267 user=self.get_pending_user(), 268 ) 269 270 def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: 271 if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: 272 self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {} 273 self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data) 274 return self.executor.stage_ok()
37class PromptChoiceSerializer(PassiveSerializer): 38 """Serializer for a single Choice field""" 39 40 value = CharField(required=True) 41 label = CharField(required=True)
Serializer for a single Choice field
Inherited Members
44class StagePromptSerializer(PassiveSerializer): 45 """Serializer for a single Prompt field""" 46 47 field_key = CharField() 48 label = CharField(allow_blank=True) 49 type = ChoiceField(choices=FieldTypes.choices) 50 required = BooleanField() 51 placeholder = CharField(allow_blank=True) 52 initial_value = CharField(allow_blank=True) 53 order = IntegerField() 54 sub_text = CharField(allow_blank=True) 55 choices = ListField(child=PromptChoiceSerializer(), allow_empty=True, allow_null=True)
Serializer for a single Prompt field
Inherited Members
58class PromptChallenge(Challenge): 59 """Initial challenge being sent, define fields""" 60 61 fields = StagePromptSerializer(many=True) 62 component = CharField(default="ak-stage-prompt")
Initial challenge being sent, define fields
65class PromptChallengeResponse(ChallengeResponse): 66 """Validate response, fields are dynamically created based 67 on the stage""" 68 69 stage_instance: PromptStage 70 71 component = CharField(default="ak-stage-prompt") 72 73 def __init__(self, *args, **kwargs): 74 stage: PromptStage = kwargs.pop("stage_instance", None) 75 plan: FlowPlan = kwargs.pop("plan", None) 76 request: HttpRequest = kwargs.pop("request", None) 77 user: User = kwargs.pop("user", None) 78 super().__init__(*args, **kwargs) 79 self.stage_instance = stage 80 self.plan = plan 81 self.request = request 82 if not self.stage_instance: 83 return 84 # list() is called so we only load the fields once 85 fields = list(self.stage_instance.fields.all()) 86 for field in fields: 87 field: Prompt 88 choices = field.get_choices( 89 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 90 ) 91 current = field.get_initial_value( 92 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 93 ) 94 self.fields[field.field_key] = field.field(current, choices) 95 # Special handling for fields with username type 96 # these check for existing users with the same username 97 if field.type == FieldTypes.USERNAME: 98 setattr( 99 self, 100 f"validate_{field.field_key}", 101 MethodType(username_field_validator_factory(), self), 102 ) 103 # Check if we have a password field, add a handler that sends a signal 104 # to validate it 105 if field.type == FieldTypes.PASSWORD: 106 setattr( 107 self, 108 f"validate_{field.field_key}", 109 MethodType(password_single_validator_factory(), self), 110 ) 111 112 self.field_order = sorted(fields, key=lambda x: x.order) 113 114 def _validate_password_fields(self, *field_names): 115 """Check if the value of all password fields match by merging them into a set 116 and checking the length""" 117 all_passwords = {self.initial_data[x] for x in field_names} 118 if len(all_passwords) > 1: 119 raise ValidationError(_("Passwords don't match.")) 120 121 def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: 122 # Check if we have any static or hidden fields, and ensure they 123 # still have the same value 124 static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 125 type__in=[ 126 FieldTypes.HIDDEN, 127 FieldTypes.STATIC, 128 FieldTypes.TEXT_READ_ONLY, 129 FieldTypes.TEXT_AREA_READ_ONLY, 130 ] 131 ) 132 for static_hidden in static_hidden_fields: 133 field = self.fields[static_hidden.field_key] 134 default = field.default 135 # Prevent rest_framework.fields.empty from ending up in policies and events 136 if default == empty: 137 default = "" 138 attrs[static_hidden.field_key] = default 139 140 # Check if we have two password fields, and make sure they are the same 141 password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 142 type=FieldTypes.PASSWORD 143 ) 144 if password_fields.exists() and password_fields.count() == 2: # noqa: PLR2004 145 self._validate_password_fields(*[field.field_key for field in password_fields]) 146 147 engine = ListPolicyEngine( 148 self.stage_instance.validation_policies.all(), 149 self.stage.get_pending_user(), 150 self.request, 151 ) 152 engine.mode = PolicyEngineMode.MODE_ALL 153 engine.request.context[PLAN_CONTEXT_PROMPT] = attrs 154 engine.use_cache = False 155 engine.build() 156 result = engine.result 157 if not result.passing: 158 raise ValidationError(list(result.messages)) 159 else: 160 for msg in result.messages: 161 add_message(self.request, INFO, msg) 162 return attrs
Validate response, fields are dynamically created based on the stage
73 def __init__(self, *args, **kwargs): 74 stage: PromptStage = kwargs.pop("stage_instance", None) 75 plan: FlowPlan = kwargs.pop("plan", None) 76 request: HttpRequest = kwargs.pop("request", None) 77 user: User = kwargs.pop("user", None) 78 super().__init__(*args, **kwargs) 79 self.stage_instance = stage 80 self.plan = plan 81 self.request = request 82 if not self.stage_instance: 83 return 84 # list() is called so we only load the fields once 85 fields = list(self.stage_instance.fields.all()) 86 for field in fields: 87 field: Prompt 88 choices = field.get_choices( 89 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 90 ) 91 current = field.get_initial_value( 92 plan.context.get(PLAN_CONTEXT_PROMPT, {}), user, self.request 93 ) 94 self.fields[field.field_key] = field.field(current, choices) 95 # Special handling for fields with username type 96 # these check for existing users with the same username 97 if field.type == FieldTypes.USERNAME: 98 setattr( 99 self, 100 f"validate_{field.field_key}", 101 MethodType(username_field_validator_factory(), self), 102 ) 103 # Check if we have a password field, add a handler that sends a signal 104 # to validate it 105 if field.type == FieldTypes.PASSWORD: 106 setattr( 107 self, 108 f"validate_{field.field_key}", 109 MethodType(password_single_validator_factory(), self), 110 ) 111 112 self.field_order = sorted(fields, key=lambda x: x.order)
121 def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: 122 # Check if we have any static or hidden fields, and ensure they 123 # still have the same value 124 static_hidden_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 125 type__in=[ 126 FieldTypes.HIDDEN, 127 FieldTypes.STATIC, 128 FieldTypes.TEXT_READ_ONLY, 129 FieldTypes.TEXT_AREA_READ_ONLY, 130 ] 131 ) 132 for static_hidden in static_hidden_fields: 133 field = self.fields[static_hidden.field_key] 134 default = field.default 135 # Prevent rest_framework.fields.empty from ending up in policies and events 136 if default == empty: 137 default = "" 138 attrs[static_hidden.field_key] = default 139 140 # Check if we have two password fields, and make sure they are the same 141 password_fields: QuerySet[Prompt] = self.stage_instance.fields.filter( 142 type=FieldTypes.PASSWORD 143 ) 144 if password_fields.exists() and password_fields.count() == 2: # noqa: PLR2004 145 self._validate_password_fields(*[field.field_key for field in password_fields]) 146 147 engine = ListPolicyEngine( 148 self.stage_instance.validation_policies.all(), 149 self.stage.get_pending_user(), 150 self.request, 151 ) 152 engine.mode = PolicyEngineMode.MODE_ALL 153 engine.request.context[PLAN_CONTEXT_PROMPT] = attrs 154 engine.use_cache = False 155 engine.build() 156 result = engine.result 157 if not result.passing: 158 raise ValidationError(list(result.messages)) 159 else: 160 for msg in result.messages: 161 add_message(self.request, INFO, msg) 162 return attrs
165def username_field_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]: 166 """Return a `clean_` method for `field`. Clean method checks if username is taken already.""" 167 168 def username_field_validator(self: PromptChallengeResponse, value: str) -> Any: 169 """Check for duplicate usernames""" 170 pending_user = self.stage.get_pending_user() 171 query = User.objects.all() 172 if pending_user.pk: 173 query = query.exclude(username=pending_user.username) 174 if query.filter(username=value).exists(): 175 raise ValidationError("Username is already taken.") 176 return value 177 178 return username_field_validator
Return a clean_ method for field. Clean method checks if username is taken already.
181def password_single_validator_factory() -> Callable[[PromptChallengeResponse, str], Any]: 182 """Return a `clean_` method for `field`. Clean method checks if the password meets configured 183 PasswordPolicy.""" 184 185 def password_single_clean(self: PromptChallengeResponse, value: str) -> Any: 186 """Send password validation signals for e.g. LDAP Source""" 187 password_validate.send(sender=self, password=value, plan_context=self.plan.context) 188 return value 189 190 return password_single_clean
Return a clean_ method for field. Clean method checks if the password meets configured
PasswordPolicy.
193class ListPolicyEngine(PolicyEngine): 194 """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" 195 196 def __init__(self, policies: list[Policy], user: User, request: HttpRequest = None) -> None: 197 super().__init__(PolicyBindingModel(), user, request) 198 self.__list = policies 199 self.use_cache = False 200 201 def bindings(self): 202 for idx, policy in enumerate(self.__list): 203 yield PolicyBinding( 204 policy=policy, 205 order=idx, 206 )
Slightly modified policy engine, which uses a list instead of a PolicyBindingModel
209class PromptStageView(ChallengeStageView): 210 """Prompt Stage, save form data in plan context.""" 211 212 response_class = PromptChallengeResponse 213 214 def get_prompt_challenge_fields(self, fields: list[Prompt], context: dict, dry_run=False): 215 """Get serializers for all fields in `fields`, using the context `context`. 216 If `dry_run` is set, property mapping expression errors are raised, otherwise they 217 are logged and events are created""" 218 serializers = [] 219 for field in fields: 220 data = StagePromptSerializer(field).data 221 # Ensure all placeholders and initial values are str, as 222 # otherwise further in we can fail serializer validation if we return 223 # some types such as bool 224 # choices can be a dict with value and label 225 choices = field.get_choices(context, self.get_pending_user(), self.request, dry_run) 226 if choices: 227 data["choices"] = list(self.clean_choices(choices)) 228 else: 229 data["choices"] = None 230 data["placeholder"] = str( 231 field.get_placeholder(context, self.get_pending_user(), self.request, dry_run) 232 ) 233 data["initial_value"] = str( 234 field.get_initial_value(context, self.get_pending_user(), self.request, dry_run) 235 ) 236 serializers.append(data) 237 return serializers 238 239 def clean_choices(self, choices): 240 for choice in choices: 241 label, value = choice, choice 242 if isinstance(choice, dict): 243 label = choice.get("label", "") 244 value = choice.get("value", "") 245 yield {"label": str(label), "value": str(value)} 246 247 def get_challenge(self, *args, **kwargs) -> Challenge: 248 fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order")) 249 context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) 250 serializers = self.get_prompt_challenge_fields(fields, context_prompt) 251 challenge = PromptChallenge( 252 data={ 253 "fields": serializers, 254 }, 255 ) 256 return challenge 257 258 def get_response_instance(self, data: QueryDict) -> ChallengeResponse: 259 if not self.executor.plan: # pragma: no cover 260 raise ValueError 261 return PromptChallengeResponse( 262 instance=None, 263 data=data, 264 request=self.request, 265 stage_instance=self.executor.current_stage, 266 stage=self, 267 plan=self.executor.plan, 268 user=self.get_pending_user(), 269 ) 270 271 def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: 272 if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: 273 self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {} 274 self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data) 275 return self.executor.stage_ok()
Prompt Stage, save form data in plan context.
214 def get_prompt_challenge_fields(self, fields: list[Prompt], context: dict, dry_run=False): 215 """Get serializers for all fields in `fields`, using the context `context`. 216 If `dry_run` is set, property mapping expression errors are raised, otherwise they 217 are logged and events are created""" 218 serializers = [] 219 for field in fields: 220 data = StagePromptSerializer(field).data 221 # Ensure all placeholders and initial values are str, as 222 # otherwise further in we can fail serializer validation if we return 223 # some types such as bool 224 # choices can be a dict with value and label 225 choices = field.get_choices(context, self.get_pending_user(), self.request, dry_run) 226 if choices: 227 data["choices"] = list(self.clean_choices(choices)) 228 else: 229 data["choices"] = None 230 data["placeholder"] = str( 231 field.get_placeholder(context, self.get_pending_user(), self.request, dry_run) 232 ) 233 data["initial_value"] = str( 234 field.get_initial_value(context, self.get_pending_user(), self.request, dry_run) 235 ) 236 serializers.append(data) 237 return serializers
Get serializers for all fields in fields, using the context context.
If dry_run is set, property mapping expression errors are raised, otherwise they
are logged and events are created
247 def get_challenge(self, *args, **kwargs) -> Challenge: 248 fields: list[Prompt] = list(self.executor.current_stage.fields.all().order_by("order")) 249 context_prompt = self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}) 250 serializers = self.get_prompt_challenge_fields(fields, context_prompt) 251 challenge = PromptChallenge( 252 data={ 253 "fields": serializers, 254 }, 255 ) 256 return challenge
Return the challenge that the client should solve
258 def get_response_instance(self, data: QueryDict) -> ChallengeResponse: 259 if not self.executor.plan: # pragma: no cover 260 raise ValueError 261 return PromptChallengeResponse( 262 instance=None, 263 data=data, 264 request=self.request, 265 stage_instance=self.executor.current_stage, 266 stage=self, 267 plan=self.executor.plan, 268 user=self.get_pending_user(), 269 )
Return the response class type
271 def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: 272 if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: 273 self.executor.plan.context[PLAN_CONTEXT_PROMPT] = {} 274 self.executor.plan.context[PLAN_CONTEXT_PROMPT].update(response.validated_data) 275 return self.executor.stage_ok()
Callback when the challenge has the correct format