Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TaskOnKart.fix_random_seed_value cannnot handle None #287

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class TaskOnKart(luigi.Task):
description='If this is false, this task is not treated as a part of dependent tasks for the unique id.',
significant=False)
fix_random_seed_methods = luigi.ListParameter(default=['random.seed', 'numpy.random.seed'], description='Fix random seed method list.', significant=False)
fix_random_seed_value = luigi.IntParameter(default=None, description='Fix random seed method value.', significant=False)
FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER = -42497368
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just random

fix_random_seed_value = luigi.IntParameter(
default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.',
significant=False) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released

redis_host = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False)
redis_port = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False)
Expand Down Expand Up @@ -389,7 +392,7 @@ def try_set_seed(methods: List[str], random_seed: int) -> List[str]:
return success_methods

def _get_random_seed(self):
if self.fix_random_seed_value:
if self.fix_random_seed_value and (not self.fix_random_seed_value == self.FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER):
return self.fix_random_seed_value
return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed

Expand Down
5 changes: 5 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,11 @@ def test_is_task_on_kart(self):
self.assertEqual(False, gokart.TaskOnKart.is_task_on_kart(list()))
self.assertEqual(True, gokart.TaskOnKart.is_task_on_kart((gokart.TaskOnKart(), gokart.TaskOnKart())))

def test_serialize_and_deserialize_default_values(self):
task = gokart.TaskOnKart()
deserialized: gokart.TaskOnKart = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params())
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())


if __name__ == '__main__':
unittest.main()