authentik.lib.migrations

Migration helpers

 1"""Migration helpers"""
 2
 3from collections.abc import Iterable
 4from typing import TYPE_CHECKING
 5
 6from django.apps.registry import Apps
 7from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 8
 9from authentik.events.utils import cleanse_dict, sanitize_dict
10
11if TYPE_CHECKING:
12    from authentik.events.models import EventAction
13
14
15def fallback_names(app: str, model: str, field: str):
16    """Factory function that checks all instances of `app`.`model` instance's `field`
17    to prevent any duplicates"""
18
19    def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
20        db_alias = schema_editor.connection.alias
21
22        klass = apps.get_model(app, model)
23        seen_names = []
24        for obj in klass.objects.using(db_alias).all():
25            value = getattr(obj, field)
26            if value not in seen_names:
27                seen_names.append(value)
28                continue
29            separator = "_"
30            suffix_index = 2
31            while (
32                klass.objects.using(db_alias)
33                .filter(**{field: f"{value}{separator}{suffix_index}"})
34                .exists()
35            ):
36                suffix_index += 1
37            new_value = f"{value}{separator}{suffix_index}"
38            setattr(obj, field, new_value)
39            obj.save()
40
41    return migrator
42
43
44def progress_bar(iterable: Iterable):
45    """Call in a loop to create terminal progress bar
46    https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
47
48    prefix = "Writing: "
49    suffix = " finished"
50    decimals = 1
51    length = 100
52    fill = "█"
53    print_end = "\r"
54
55    total = len(iterable)
56    if total < 1:
57        return
58
59    def print_progress_bar(iteration):
60        """Progress Bar Printing Function"""
61        percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
62        filled_length = int(length * iteration // total)
63        bar = fill * filled_length + "-" * (length - filled_length)
64        print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
65
66    # Initial Call
67    print_progress_bar(0)
68    # Update Progress Bar
69    for i, item in enumerate(iterable):
70        yield item
71        print_progress_bar(i + 1)
72    # Print New Line on Complete
73    print()
74
75
76def migration_event(
77    apps: Apps, schema_editor: BaseDatabaseSchemaEditor, action: EventAction, **kwargs
78):
79    db_alias = schema_editor.connection.alias
80    Event = apps.get_model("authentik_events", "Event")
81    event = Event(action=action, app="authentik", context=cleanse_dict(sanitize_dict(kwargs)))
82    event.save(using=db_alias)
def fallback_names(app: str, model: str, field: str):
16def fallback_names(app: str, model: str, field: str):
17    """Factory function that checks all instances of `app`.`model` instance's `field`
18    to prevent any duplicates"""
19
20    def migrator(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
21        db_alias = schema_editor.connection.alias
22
23        klass = apps.get_model(app, model)
24        seen_names = []
25        for obj in klass.objects.using(db_alias).all():
26            value = getattr(obj, field)
27            if value not in seen_names:
28                seen_names.append(value)
29                continue
30            separator = "_"
31            suffix_index = 2
32            while (
33                klass.objects.using(db_alias)
34                .filter(**{field: f"{value}{separator}{suffix_index}"})
35                .exists()
36            ):
37                suffix_index += 1
38            new_value = f"{value}{separator}{suffix_index}"
39            setattr(obj, field, new_value)
40            obj.save()
41
42    return migrator

Factory function that checks all instances of app.model instance's field to prevent any duplicates

def progress_bar(iterable: Iterable):
45def progress_bar(iterable: Iterable):
46    """Call in a loop to create terminal progress bar
47    https://stackoverflow.com/questions/3173320/text-progress-bar-in-the-console"""
48
49    prefix = "Writing: "
50    suffix = " finished"
51    decimals = 1
52    length = 100
53    fill = "█"
54    print_end = "\r"
55
56    total = len(iterable)
57    if total < 1:
58        return
59
60    def print_progress_bar(iteration):
61        """Progress Bar Printing Function"""
62        percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
63        filled_length = int(length * iteration // total)
64        bar = fill * filled_length + "-" * (length - filled_length)
65        print(f"\r{prefix} |{bar}| {percent}% {suffix}", end=print_end)
66
67    # Initial Call
68    print_progress_bar(0)
69    # Update Progress Bar
70    for i, item in enumerate(iterable):
71        yield item
72        print_progress_bar(i + 1)
73    # Print New Line on Complete
74    print()
def migration_event(unknown):
77def migration_event(
78    apps: Apps, schema_editor: BaseDatabaseSchemaEditor, action: EventAction, **kwargs
79):
80    db_alias = schema_editor.connection.alias
81    Event = apps.get_model("authentik_events", "Event")
82    event = Event(action=action, app="authentik", context=cleanse_dict(sanitize_dict(kwargs)))
83    event.save(using=db_alias)