Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
mski-iksm committed Nov 9, 2023
1 parent 96703a4 commit 16a8a40
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 0 deletions.
6 changes: 6 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.redis_lock import make_redis_params
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper

logger = getLogger(__name__)

Expand Down Expand Up @@ -76,6 +77,8 @@ class TaskOnKart(luigi.Task):
description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.',
significant=False)
complete_check_at_run: bool = ExplicitBoolParameter(
default=False, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False)

def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, 'TaskOnKart')
Expand All @@ -86,6 +89,9 @@ def __init__(self, *args, **kwargs):
self._rerun_state = self.rerun
self._lock_at_dump = True

if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete)

def output(self):
return self.make_target()

Expand Down
9 changes: 9 additions & 0 deletions gokart/task_complete_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import Callable


def task_complete_check_wrapper(run_func: Callable, complete_check_func: Callable):
def wrapper(*args, **kwargs):
if not complete_check_func():
run_func(*args, **kwargs)

return wrapper
60 changes: 60 additions & 0 deletions test/test_task_on_kart.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,5 +561,65 @@ def test_serialize_and_deserialize_default_values(self):
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())


class _DummyTaskWithNonCompleted(gokart.TaskOnKart):

def dump(self, obj):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return False


class _DummyTaskWithCompleted(gokart.TaskOnKart):

def dump(self, obj):
# overrive dump() to do nothing.
pass

def run(self):
self.dump('hello')

def complete(self):
return True


class TestCompleteCheckAtRun(unittest.TestCase):
def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=False)
task.dump = MagicMock()
task.run()

# since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=False)
task.dump = MagicMock()
task.run()

# even task is completed, since run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=True)
task.dump = MagicMock()
task.run()

# since task is not completed, when run() is called, dump() should be called.
task.dump.assert_called_once()

def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=True)
task.dump = MagicMock()
task.run()

# since task is completed, even when run() is called, dump() should not be called.
task.dump.assert_not_called()


if __name__ == '__main__':
unittest.main()

0 comments on commit 16a8a40

Please sign in to comment.