From 16a8a407105b72b5b35776546d9ff5936d0adea9 Mon Sep 17 00:00:00 2001 From: mski_iksm <9000000000000000091e@gmail.com> Date: Thu, 9 Nov 2023 23:36:40 +0900 Subject: [PATCH] add --- gokart/task.py | 6 ++++ gokart/task_complete_check.py | 9 ++++++ test/test_task_on_kart.py | 60 +++++++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+) create mode 100644 gokart/task_complete_check.py diff --git a/gokart/task.py b/gokart/task.py index ac5d7752..88416ba1 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -16,6 +16,7 @@ from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.redis_lock import make_redis_params from gokart.target import TargetOnKart +from gokart.task_complete_check import task_complete_check_wrapper logger = getLogger(__name__) @@ -76,6 +77,8 @@ class TaskOnKart(luigi.Task): description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \ Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.', significant=False) + complete_check_at_run: bool = ExplicitBoolParameter( + default=False, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False) def __init__(self, *args, **kwargs): self._add_configuration(kwargs, 'TaskOnKart') @@ -86,6 +89,9 @@ def __init__(self, *args, **kwargs): self._rerun_state = self.rerun self._lock_at_dump = True + if self.complete_check_at_run: + self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) + def output(self): return self.make_target() diff --git a/gokart/task_complete_check.py b/gokart/task_complete_check.py new file mode 100644 index 00000000..f7d7fd78 --- /dev/null +++ b/gokart/task_complete_check.py @@ -0,0 +1,9 @@ +from typing import Callable + + +def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable): + def wrapper(*args, **kwargs): + if not complete_check_func(): + run_func(*args, **kwargs) + + return wrapper diff --git a/test/test_task_on_kart.py b/test/test_task_on_kart.py index e31c5d73..623c6874 100644 --- a/test/test_task_on_kart.py +++ b/test/test_task_on_kart.py @@ -561,5 +561,65 @@ def test_serialize_and_deserialize_default_values(self): self.assertDictEqual(task.to_str_params(), deserialized.to_str_params()) +class _DummyTaskWithNonCompleted(gokart.TaskOnKart): + + def dump(self, obj): + # overrive dump() to do nothing. + pass + + def run(self): + self.dump('hello') + + def complete(self): + return False + + +class _DummyTaskWithCompleted(gokart.TaskOnKart): + + def dump(self, obj): + # overrive dump() to do nothing. + pass + + def run(self): + self.dump('hello') + + def complete(self): + return True + + +class TestCompleteCheckAtRun(unittest.TestCase): + def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self): + task = _DummyTaskWithNonCompleted(complete_check_at_run=False) + task.dump = MagicMock() + task.run() + + # since run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self): + task = _DummyTaskWithCompleted(complete_check_at_run=False) + task.dump = MagicMock() + task.run() + + # even task is completed, since run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self): + task = _DummyTaskWithNonCompleted(complete_check_at_run=True) + task.dump = MagicMock() + task.run() + + # since task is not completed, when run() is called, dump() should be called. + task.dump.assert_called_once() + + def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self): + task = _DummyTaskWithCompleted(complete_check_at_run=True) + task.dump = MagicMock() + task.run() + + # since task is completed, even when run() is called, dump() should not be called. + task.dump.assert_not_called() + + if __name__ == '__main__': unittest.main()