Skip to content

Commit

Permalink
add ids (filter) to seed_all method
Browse files Browse the repository at this point in the history
  • Loading branch information
suliman-99 committed Apr 21, 2024
1 parent d0bdec7 commit 15989df
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
22 changes: 13 additions & 9 deletions django_seeding/seeder_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ def register(cls, seeder):
""" Method and decorator to register the seeder-class in the seeders list to be seeded when the server is run """
if not issubclass(seeder, Seeder):
raise TypeError('Only subclasses of Seeder class can be registered with SeederRegistry.register')

if seeder().get_id() in [obj.get_id() for obj in cls.seeders]:
return

for cur_seeder in cls.seeders:
if cur_seeder._get_id() == seeder()._get_id():
return

cls.seeders.append(seeder())

Expand All @@ -43,24 +44,27 @@ def import_all(cls):
spec.loader.exec_module(module)

@classmethod
def seed_all(cls, debug=None):
def seed_all(cls, debug=None, ids=None):
"""
Method that call seed methods for all registered seeders
sort the seeders depending on the `priority` (less is applied earlier)
"""
seeders = cls.seeders
if ids is not None:
seeders = filter(lambda seeder: seeder._get_id() in ids, seeders)

if AppliedSeeder.objects.filter(id__in=[seeder._get_id() for seeder in cls.seeders]).count() != len(cls.seeders):
if AppliedSeeder.objects.filter(id__in=[seeder._get_id() for seeder in seeders]).count() != len(seeders):
BLUE_COLOR = "\033[94m"
WHITE_COLOR = "\033[0m"
print(BLUE_COLOR + "Running Seeders: " + WHITE_COLOR)

cls.seeders.sort(key=lambda seeder: seeder._get_priority())
for seeder in cls.seeders:
seeders.sort(key=lambda seeder: seeder._get_priority())
for seeder in seeders:
seeder._seed(debug=debug)

@classmethod
def import_all_then_seed_all(cls, debug=None):
def import_all_then_seed_all(cls, debug=None, ids=None):
"""
Note: the decorator `@SeederRegistry.register` will be applied when the file is imported
Expand All @@ -75,7 +79,7 @@ def import_all_then_seed_all(cls, debug=None):
cls.import_all()

# call the `seed_all()` method to apply all the registered seeders
cls.seed_all(debug=debug)
cls.seed_all(debug=debug, ids=ids)

@classmethod
def on_run(cls):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = 'django-seeding'
version = '1.1.2'
version = '1.2.1'
description = 'Simple Django Package that helps developer to seed data from files and codes into the database automatically'
readme = "README.md"
authors = [
Expand Down

0 comments on commit 15989df

Please sign in to comment.