diff --git a/doc/configuration.rst b/doc/configuration.rst index 0ceaaabad7..ed9cc334e7 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -442,6 +442,7 @@ marker-table Table in which to store status of table updates. This table will be created if it doesn't already exist. Defaults to "table_updates". +.. _resources-config: [resources] ----------- @@ -763,7 +764,7 @@ Luigi also supports defining retry-policy per task. ... -If none of retry-policy fields is defined per task, the field value will be **default** value which is defined in luigi config file. +If none of retry-policy fields is defined per task, the field value will be **default** value which is defined in luigi config file. To make luigi sticks to the given retry-policy, be sure you run luigi worker with `keep_alive` config. Please check ``keep_alive`` config in :ref:`worker-config` section. @@ -774,4 +775,4 @@ The fields below are in retry-policy and they can be defined per task. * retry_count * disable_hard_timeout -* disable_window_seconds \ No newline at end of file +* disable_window_seconds diff --git a/doc/luigi_patterns.rst b/doc/luigi_patterns.rst index 0d98321945..fc53a57c2e 100644 --- a/doc/luigi_patterns.rst +++ b/doc/luigi_patterns.rst @@ -118,6 +118,91 @@ can have implications in how the scheduler and visualizer handle task instances. luigi RangeDaily --of MyTask --start 2014-10-31 --MyTask-my-param 123 +.. _batch_method: + +Batching multiple parameter values into a single run +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes it'll be faster to run multiple jobs together as a single +batch rather than running them each individually. When this is the case, +you can mark some parameters with a batch_method in their constructor +to tell the worker how to combine multiple values. One common way to do +this is by simply running the maximum value. This is good for tasks that +overwrite older data when a newer one runs. You accomplish this by +setting the batch_method to max, like so: + +.. code-block:: python + + class A(luigi.Task): + date = luigi.DateParameter(batch_method=max) + +What's exciting about this is that if you send multiple As to the +scheduler, it can combine them and return one. So if +``A(date=2016-07-28)``, ``A(date=2016-07-29)`` and +``A(date=2016-07-30)`` are all ready to run, you will start running +``A(date=2016-07-30)``. While this is running, the scheduler will show +``A(date=2016-07-28)``, ``A(date=2016-07-29)`` as batch running while +``A(date=2016-07-30)`` is running. When ``A(date=2016-07-30)`` is done +running and becomes FAILED or DONE, the other two tasks will be updated +to the same status. + +If you want to limit how big a batch can get, simply set max_batch_size. +So if you have + +.. code-block:: python + + class A(luigi.Task): + date = luigi.DateParameter(batch_method=max) + + max_batch_size = 10 + +then the scheduler will batch at most 10 jobs together. You probably do +not want to do this with the max batch method, but it can be helpful if +you use other methods. You can use any method that takes a list of +parameter values and returns a single parameter value. + +If you have two max batch parameters, you'll get the max values for both +of them. If you have parameters that don't have a batch method, they'll +be aggregated separately. So if you have a class like + +.. code-block:: python + + class A(luigi.Task): + p1 = luigi.IntParameter(batch_method=max) + p2 = luigi.IntParameter(batch_method=max) + p3 = luigi.IntParameter() + +and you create tasks ``A(p1=1, p2=2, p3=0)``, ``A(p1=2, p2=3, p3=0)``, +``A(p1=3, p2=4, p3=1)``, you'll get them batched as +``A(p1=2, p2=3, p3=0)`` and ``A(p1=3, p2=4, p3=1)``. + +Note that batched tasks do not take up :ref:`resources-config`, only the +task that ends up running will use resources. The scheduler only checks +that there are sufficient resources for each task individually before +batching them all together. + +Tasks that regularly overwrite the same data source +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +If you are overwriting of the same data source with every run, you'll +need to ensure that two batches can't run at the same time. You can do +this pretty easily by setting batch_mode to max and setting a unique +resource: + +.. code-block:: python + + class A(luigi.Task): + date = luigi.DateParameter(batch_mode=max) + + resources = {'overwrite_resource': 1} + +Now if you have multiple tasks such as ``A(date=2016-06-01)``, +``A(date=2016-06-02)``, ``A(date=2016-06-03)``, the scheduler will just +tell you to run the highest available one and mark the lower ones as +batch_running. Using a unique resource will prevent multiple tasks from +writing to the same location at the same time if a new one becomes +available while others are running. + Monitoring task pipelines ~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/luigi/parameter.py b/luigi/parameter.py index 1ef13b68e5..fa60bea8ca 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -117,7 +117,7 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False): + config_path=None, positional=True, always_in_help=False, batch_method=None): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -140,8 +140,13 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip ``positional=False`` for abstract base classes and similar cases. :param bool always_in_help: For the --help option in the command line parsing. Set true to always show in --help. + :param function(iterable[A])->A batch_method: Method to combine an iterable of parsed + parameter values into a single value. Used + when receiving batched parameter lists from + the scheduler. See :ref:`batch_method` """ self._default = default + self._batch_method = batch_method if is_global: warnings.warn("is_global support is removed. Assuming positional=False", DeprecationWarning, @@ -211,6 +216,9 @@ def task_value(self, task_name, param_name): else: return self.normalize(value) + def _is_batchable(self): + return self._batch_method is not None + def parse(self, x): """ Parse an individual value from the input. @@ -223,6 +231,23 @@ def parse(self, x): """ return x # default impl + def _parse_list(self, xs): + """ + Parse a list of values from the scheduler. + + Only possible if this is_batchable() is True. This will combine the list into a single + parameter value using batch method. This should never need to be overridden. + + :param xs: list of values to parse and combine + :return: the combined parsed values + """ + if not self._is_batchable(): + raise NotImplementedError('No batch method found') + elif not xs: + raise ValueError('Empty parameter list passed to parse_list') + else: + return self._batch_method(map(self.parse, xs)) + def serialize(self, x): """ Opposite of :py:meth:`parse`. diff --git a/luigi/scheduler.py b/luigi/scheduler.py index 7755a0c3e8..8938f565df 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -29,6 +29,7 @@ except ImportError: import pickle import functools +import hashlib import itertools import logging import os @@ -41,7 +42,8 @@ from luigi import notifications from luigi import parameter from luigi import task_history as history -from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN +from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \ + BATCH_RUNNING from luigi.task import Config logger = logging.getLogger(__name__) @@ -224,10 +226,20 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.status_message = status_message self.scheduler_disable_time = None self.runnable = False + self.batchable = False + self.batch_id = None def __repr__(self): return "Task(%r)" % vars(self) + # TODO(2017-08-10) replace this function with direct calls to batchable + # this only exists for backward compatibility + def is_batchable(self): + try: + return self.batchable + except AttributeError: + return False + def add_failure(self): self.failures.add_failure() @@ -324,12 +336,15 @@ def __init__(self, state_path): self._tasks = {} # map from id to a Task object self._status_tasks = collections.defaultdict(dict) self._active_workers = {} # map from id to a Worker object + self._task_batchers = {} def get_state(self): - return self._tasks, self._active_workers + return self._tasks, self._active_workers, self._task_batchers def set_state(self, state): - self._tasks, self._active_workers = state + self._tasks, self._active_workers = state[:2] + if len(state) >= 3: + self._task_batchers = state[2] def dump(self): try: @@ -366,6 +381,13 @@ def get_active_tasks(self, status=None): for task in six.itervalues(self._tasks): yield task + def get_batch_running_tasks(self, batch_id): + assert batch_id is not None + return [ + task for task in self.get_active_tasks(BATCH_RUNNING) + if task.batch_id == batch_id + ] + def get_running_tasks(self): return six.itervalues(self._status_tasks[RUNNING]) @@ -373,6 +395,13 @@ def get_pending_tasks(self): return itertools.chain.from_iterable(six.itervalues(self._status_tasks[status]) for status in [PENDING, RUNNING]) + def set_batcher(self, worker_id, family, batcher_args, max_batch_size): + self._task_batchers.setdefault(worker_id, {}) + self._task_batchers[worker_id][family] = (batcher_args, max_batch_size) + + def get_batcher(self, worker_id, family): + return self._task_batchers.get(worker_id, {}).get(family, (None, 1)) + def num_pending_tasks(self): """ Return how many tasks are PENDING + RUNNING. O(1). @@ -397,13 +426,20 @@ def re_enable(self, task, config=None): self.set_status(task, FAILED, config) task.failures.clear() + def set_batch_running(self, task, batch_id, worker_id): + self.set_status(task, BATCH_RUNNING) + task.batch_id = batch_id + task.worker_running = worker_id + def set_status(self, task, new_status, config=None): if new_status == FAILED: assert config is not None - if new_status == DISABLED and task.status == RUNNING: + if new_status == DISABLED and task.status in (RUNNING, BATCH_RUNNING): return + remove_on_failure = task.batch_id is not None and not task.batchable + if task.status == DISABLED: if new_status == DONE: self.re_enable(task) @@ -412,6 +448,12 @@ def set_status(self, task, new_status, config=None): elif task.scheduler_disable_time is not None and new_status != DISABLED: return + if task.status == RUNNING and task.batch_id is not None: + for batch_task in self.get_batch_running_tasks(task.batch_id): + self.set_status(batch_task, new_status, config) + batch_task.batch_id = None + task.batch_id = None + if new_status == FAILED and task.status != DISABLED: task.add_failure() if task.has_excessive_failures(): @@ -435,6 +477,11 @@ def set_status(self, task, new_status, config=None): task.status = new_status task.updated = time.time() + if new_status == FAILED: + task.retry = time.time() + config.retry_delay + if remove_on_failure: + task.remove = time.time() + def fail_dead_worker_task(self, task, config, assistants): # If a running worker disconnects, tag all its jobs as FAILED and subject it to the same retry logic if task.status == RUNNING and task.worker_running and task.worker_running not in task.stakeholders | assistants: @@ -464,7 +511,7 @@ def update_status(self, task, config): self.set_status(task, PENDING, config) def may_prune(self, task): - return task.remove and time.time() > task.remove + return task.remove and time.time() >= task.remove def inactivate_tasks(self, delete_tasks): # The terminology is a bit confusing: we used to "delete" tasks when they became inactive, @@ -595,12 +642,16 @@ def _update_priority(self, task, prio, worker): if t is not None and prio > t.priority: self._update_priority(t, prio, worker) + @rpc_method() + def add_task_batcher(self, worker, task_family, batched_args, max_batch_size=float('inf')): + self._state.set_batcher(worker, task_family, batched_args, max_batch_size) + @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, priority=0, family='', module=None, params=None, - assistant=False, tracking_url=None, worker=None, - retry_policy_dict={}, **kwargs): + assistant=False, tracking_url=None, worker=None, batchable=None, + batch_id=None, retry_policy_dict={}, **kwargs): """ * add task identified by task_id if it doesn't exist * if deps is not None, update dependency list @@ -634,16 +685,30 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if not task.params: task.params = _get_default(params, {}) + if batch_id is not None: + task.batch_id = batch_id + if status == RUNNING and not task.worker_running: + task.worker_running = worker_id + if tracking_url is not None or task.status != RUNNING: task.tracking_url = tracking_url + if task.batch_id is not None: + for batch_task in self._state.get_batch_running_tasks(task.batch_id): + batch_task.tracking_url = tracking_url + + if batchable is not None: + task.batchable = batchable if task.remove is not None: task.remove = None # unmark task for removal so it isn't removed after being added if expl is not None: task.expl = expl + if task.batch_id is not None: + for batch_task in self._state.get_batch_running_tasks(task.batch_id): + batch_task.expl = expl - if not (task.status == RUNNING and status == PENDING) or new_deps: + if not (task.status in (RUNNING, BATCH_RUNNING) and status == PENDING) or new_deps: # don't allow re-scheduling of task while it is running, it must either fail or succeed first if status == PENDING or status != task.status: # Update the DB only if there was a acctual change, to prevent noise. @@ -651,8 +716,6 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, # (so checking for status != task.status woule lie) self._update_task_history(task, status) self._state.set_status(task, PENDING if status == SUSPENDED else status, self._config) - if status == FAILED: - task.retry = self._retry_time(task, self._config) if deps is not None: task.deps = set(deps) @@ -739,8 +802,18 @@ def _schedulable(self, task): return False return True - def _retry_time(self, task, config): - return time.time() + config.retry_delay + def _reset_orphaned_batch_running_tasks(self, worker_id): + running_batch_ids = { + task.batch_id + for task in self._state.get_running_tasks() + if task.worker_running == worker_id + } + orphaned_tasks = [ + task for task in self._state.get_active_tasks(BATCH_RUNNING) + if task.worker_running == worker_id and task.batch_id not in running_batch_ids + ] + for task in orphaned_tasks: + self._state.set_status(task, PENDING) @rpc_method(allow_null=False) def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, **kwargs): @@ -767,6 +840,7 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, if assistant: self.add_worker(worker_id, [('assistant', assistant)]) + batched_params, unbatched_params, batched_tasks, max_batch_size = None, None, [], 1 best_task = None if current_tasks is not None: ct_set = set(current_tasks) @@ -774,6 +848,10 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, if task.worker_running == worker_id and task.id not in ct_set: best_task = task + if current_tasks is not None: + # batch running tasks that weren't claimed since the last get_work go back in the pool + self._reset_orphaned_batch_running_tasks(worker_id) + locally_pending_tasks = 0 running_tasks = [] upstream_table = {} @@ -814,6 +892,12 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, if len(task.workers) == 1 and not assistant: n_unique_pending += 1 + if (best_task and batched_params and task.family == best_task.family and + len(batched_tasks) < max_batch_size and task.is_batchable() and all( + task.params.get(name) == value for name, value in unbatched_params.items())): + for name, params in batched_params.items(): + params.append(task.params.get(name)) + batched_tasks.append(task) if best_task: continue @@ -825,6 +909,20 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, if self._schedulable(task) and self._has_resources(task.resources, greedy_resources): if in_workers and self._has_resources(task.resources, used_resources): best_task = task + batch_param_names, max_batch_size = self._state.get_batcher( + worker_id, task.family) + if batch_param_names and task.is_batchable(): + try: + batched_params = { + name: [task.params[name]] for name in batch_param_names + } + unbatched_params = { + name: value for name, value in task.params.items() + if name not in batched_params + } + batched_tasks.append(task) + except KeyError: + batched_params, unbatched_params = None, None else: workers = itertools.chain(task.workers, [worker_id]) if assistant else task.workers for task_worker in workers: @@ -843,7 +941,23 @@ def get_work(self, host=None, assistant=False, current_tasks=None, worker=None, 'task_id': None, 'n_unique_pending': n_unique_pending} - if best_task: + if len(batched_tasks) > 1: + batch_string = '|'.join(task.id for task in batched_tasks) + batch_id = hashlib.md5(batch_string.encode('utf-8')).hexdigest() + for task in batched_tasks: + self._state.set_batch_running(task, batch_id, worker_id) + + combined_params = best_task.params.copy() + combined_params.update(batched_params) + + reply['task_id'] = None + reply['task_family'] = best_task.family + reply['task_module'] = getattr(best_task, 'module', None) + reply['task_params'] = combined_params + reply['batch_id'] = batch_id + reply['batch_task_ids'] = [task.id for task in batched_tasks] + + elif best_task: self._state.set_status(best_task, RUNNING, self._config) best_task.worker_running = worker_id best_task.time_running = time.time() @@ -1143,6 +1257,9 @@ def set_task_status_message(self, task_id, status_message): if self._state.has_task(task_id): task = self._state.get_task(task_id) task.status_message = status_message + if task.status == RUNNING and task.batch_id is not None: + for batch_task in self._state.get_batch_running_tasks(task.batch_id): + batch_task.status_message = status_message @rpc_method() def get_task_status_message(self, task_id): diff --git a/luigi/task.py b/luigi/task.py index cb085ac398..dca13e6684 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -146,6 +146,17 @@ class MyTask(luigi.Task): #: Only works when using multiple workers. worker_timeout = None + #: Maximum number of tasks to run together as a batch. Infinite by default + max_batch_size = float('inf') + + @property + def batchable(self): + """ + True if this instance can be run as part of a batch. By default, True + if it has any batched parameters + """ + return bool(self.batch_param_names()) + @property def retry_count(self): """ @@ -242,6 +253,10 @@ def get_params(cls): params.sort(key=lambda t: t[1]._counter) return params + @classmethod + def batch_param_names(cls): + return [name for name, p in cls.get_params() if p._is_batchable()] + @classmethod def get_param_names(cls, include_significant=False): return [name for name, p in cls.get_params() if include_significant or p.significant] @@ -332,7 +347,11 @@ def from_str_params(cls, params_str): kwargs = {} for param_name, param in cls.get_params(): if param_name in params_str: - kwargs[param_name] = param.parse(params_str[param_name]) + param_str = params_str[param_name] + if isinstance(param_str, list): + kwargs[param_name] = param._parse_list(param_str) + else: + kwargs[param_name] = param.parse(param_str) return cls(**kwargs) diff --git a/luigi/task_status.py b/luigi/task_status.py index 9aba081856..e0223bd5ae 100644 --- a/luigi/task_status.py +++ b/luigi/task_status.py @@ -22,6 +22,7 @@ FAILED = 'FAILED' DONE = 'DONE' RUNNING = 'RUNNING' +BATCH_RUNNING = 'BATCH_RUNNING' SUSPENDED = 'SUSPENDED' # Only kept for backward compatibility with old clients UNKNOWN = 'UNKNOWN' DISABLED = 'DISABLED' diff --git a/luigi/worker.py b/luigi/worker.py index 8641fdfc80..b0cf6f72be 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -389,11 +389,14 @@ def __init__(self, scheduler=None, worker_id=None, worker_processes=1, assistant self.host = socket.gethostname() self._scheduled_tasks = {} self._suspended_tasks = {} + self._batch_running_tasks = {} + self._batch_families_sent = set() self._first_task = None self.add_succeeded = True self.run_succeeded = True + self.unfulfilled_counts = collections.defaultdict(int) # note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we @@ -428,6 +431,11 @@ def _add_task(self, *args, **kwargs): if task: msg = (task, status, runnable) self._add_task_history.append(msg) + + if task_id in self._batch_running_tasks: + for batch_task in self._batch_running_tasks.pop(task_id): + self._add_task_history.append((batch_task, status, True)) + self._scheduler.add_task(*args, **kwargs) logger.info('Informed scheduler that task %s has status %s', task_id, status) @@ -572,6 +580,20 @@ def add(self, task, multiprocess=False): pool.join() return self.add_succeeded + def _add_task_batcher(self, task): + family = task.task_family + if family not in self._batch_families_sent: + task_class = type(task) + batch_param_names = task_class.batch_param_names() + if batch_param_names: + self._scheduler.add_task_batcher( + worker=self._id, + task_family=family, + batched_args=batch_param_names, + max_batch_size=task.max_batch_size, + ) + self._batch_families_sent.add(family) + def _add(self, task, is_complete): if self._config.task_limit is not None and len(self._scheduled_tasks) >= self._config.task_limit: logger.warning('Will not run %s or any dependencies due to exceeded task-limit of %d', task, self._config.task_limit) @@ -617,6 +639,7 @@ def _add(self, task, is_complete): else: try: deps = task.deps() + self._add_task_batcher(task) except Exception as ex: formatted_traceback = traceback.format_exc() self.add_succeeded = False @@ -642,14 +665,20 @@ def _add(self, task, is_complete): deps = [d.task_id for d in deps] self._scheduled_tasks[task.task_id] = task - self._add_task(worker=self._id, task_id=task.task_id, status=status, - deps=deps, runnable=runnable, priority=task.priority, - resources=task.process_resources(), - params=task.to_str_params(), - family=task.task_family, - module=task.task_module, - retry_policy_dict=_get_retry_policy_dict(task), - ) + self._add_task( + worker=self._id, + task_id=task.task_id, + status=status, + deps=deps, + runnable=runnable, + priority=task.priority, + resources=task.process_resources(), + params=task.to_str_params(), + family=task.task_family, + module=task.task_module, + batchable=task.batchable, + retry_policy_dict=_get_retry_policy_dict(task), + ) def _validate_dependency(self, dependency): if isinstance(dependency, Target): @@ -678,6 +707,26 @@ def _log_remote_tasks(self, running_tasks, n_pending_tasks, n_unique_pending): if n_unique_pending: logger.debug("There are %i pending tasks unique to this worker", n_unique_pending) + def _get_work_task_id(self, get_work_response): + if get_work_response['task_id'] is not None: + return get_work_response['task_id'] + elif 'batch_id' in get_work_response: + task = load_task( + module=get_work_response.get('task_module'), + task_name=get_work_response['task_family'], + params_str=get_work_response['task_params'], + ) + self._scheduler.add_task( + worker=self._id, + task_id=task.task_id, + module=get_work_response.get('task_module'), + family=get_work_response['task_family'], + params=task.to_str_params(), + status=RUNNING, + batch_id=get_work_response['batch_id'], + ) + return task.task_id + def _get_work(self): if self._stop_requesting_work: return None, 0, 0, 0 @@ -689,14 +738,14 @@ def _get_work(self): current_tasks=list(self._running_tasks.keys()), ) n_pending_tasks = r['n_pending_tasks'] - task_id = r['task_id'] running_tasks = r['running_tasks'] n_unique_pending = r['n_unique_pending'] + task_id = self._get_work_task_id(r) - self._get_work_response_history.append(dict( - task_id=task_id, - running_tasks=running_tasks, - )) + self._get_work_response_history.append({ + 'task_id': task_id, + 'running_tasks': running_tasks, + }) if task_id is not None and task_id not in self._scheduled_tasks: logger.info('Did not schedule %s, will load it dynamically', task_id) @@ -718,6 +767,11 @@ def _get_work(self): task_id = None self.run_succeeded = False + if task_id is not None and 'batch_task_ids' in r: + batch_tasks = filter(None, [ + self._scheduled_tasks.get(batch_id) for batch_id in r['batch_task_ids']]) + self._batch_running_tasks[task_id] = batch_tasks + return task_id, running_tasks, n_pending_tasks, n_unique_pending def _run_task(self, task_id): diff --git a/test/execution_summary_test.py b/test/execution_summary_test.py index 47770a2ba8..eb330c47b2 100644 --- a/test/execution_summary_test.py +++ b/test/execution_summary_test.py @@ -95,6 +95,59 @@ def requires(self): for i, line in enumerate(result): self.assertEqual(line, expected[i]) + def test_batch_complete(self): + ran_tasks = set() + + class MaxBatchTask(luigi.Task): + param = luigi.IntParameter(batch_method=max) + + def run(self): + ran_tasks.add(self.param) + + def complete(self): + return any(self.param <= ran_param for ran_param in ran_tasks) + + class MaxBatches(luigi.WrapperTask): + def requires(self): + return map(MaxBatchTask, range(5)) + + self.run_task(MaxBatches()) + d = self.summary_dict() + expected_completed = { + MaxBatchTask(0), + MaxBatchTask(1), + MaxBatchTask(2), + MaxBatchTask(3), + MaxBatchTask(4), + MaxBatches(), + } + self.assertEqual(expected_completed, d['completed']) + + def test_batch_fail(self): + class MaxBatchFailTask(luigi.Task): + param = luigi.IntParameter(batch_method=max) + + def run(self): + assert self.param < 4 + + def complete(self): + return False + + class MaxBatches(luigi.WrapperTask): + def requires(self): + return map(MaxBatchFailTask, range(5)) + + self.run_task(MaxBatches()) + d = self.summary_dict() + expected_failed = { + MaxBatchFailTask(0), + MaxBatchFailTask(1), + MaxBatchFailTask(2), + MaxBatchFailTask(3), + MaxBatchFailTask(4), + } + self.assertEqual(expected_failed, d['failed']) + def test_check_complete_error(self): class Bar(luigi.Task): def run(self): diff --git a/test/parameter_test.py b/test/parameter_test.py index 09dd7e9630..31989fbd11 100644 --- a/test/parameter_test.py +++ b/test/parameter_test.py @@ -285,6 +285,28 @@ def test_tuple_serialize_parse(self): b_tuple = ((1, 2), (3, 4)) self.assertEqual(b_tuple, a.parse(a.serialize(b_tuple))) + def test_parse_list_without_batch_method(self): + param = luigi.Parameter() + for xs in [], ['x'], ['x', 'y']: + self.assertRaises(NotImplementedError, param._parse_list, xs) + + def test_parse_empty_list_raises_value_error(self): + for batch_method in (max, min, tuple, ','.join): + param = luigi.Parameter(batch_method=batch_method) + self.assertRaises(ValueError, param._parse_list, []) + + def test_parse_int_list_max(self): + param = luigi.IntParameter(batch_method=max) + self.assertEqual(17, param._parse_list(['7', '17', '5'])) + + def test_parse_string_list_max(self): + param = luigi.Parameter(batch_method=max) + self.assertEqual('7', param._parse_list(['7', '17', '5'])) + + def test_parse_list_as_tuple(self): + param = luigi.IntParameter(batch_method=tuple) + self.assertEqual((7, 17, 5), param._parse_list(['7', '17', '5'])) + class TestParametersHashability(LuigiTestCase): def test_date(self): diff --git a/test/scheduler_api_test.py b/test/scheduler_api_test.py index 42f7562795..ea1193eef9 100644 --- a/test/scheduler_api_test.py +++ b/test/scheduler_api_test.py @@ -15,12 +15,13 @@ # limitations under the License. # +import itertools import time from helpers import unittest from nose.plugins.attrib import attr import luigi.notifications from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, \ - UNKNOWN, RUNNING, Scheduler + UNKNOWN, RUNNING, BATCH_RUNNING, Scheduler luigi.notifications.DEBUG = True WORKER = 'myworker' @@ -155,6 +156,351 @@ def test_disconnect_running(self): self.assertEqual(self.sch.get_work(worker='Y')['task_id'], 'A') + def test_get_work_single_batch_item(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task( + worker=WORKER, task_id='A_a_1', family='A', params={'a': '1'}, batchable=True) + + response = self.sch.get_work(worker=WORKER) + self.assertEqual('A_a_1', response['task_id']) + + param_values = response['task_params'].values() + self.assertTrue(not any(isinstance(param, list)) for param in param_values) + + def test_get_work_multiple_batch_items(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task( + worker=WORKER, task_id='A_a_1', family='A', params={'a': '1'}, batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_a_2', family='A', params={'a': '2'}, batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_a_3', family='A', params={'a': '3'}, batchable=True) + + response = self.sch.get_work(worker=WORKER) + self.assertIsNone(response['task_id']) + self.assertEqual({'a': ['1', '2', '3']}, response['task_params']) + self.assertEqual('A', response['task_family']) + + def test_get_work_with_batch_items_with_resources(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task( + worker=WORKER, task_id='A_a_1', family='A', params={'a': '1'}, batchable=True, + resources={'r1': 1}) + self.sch.add_task( + worker=WORKER, task_id='A_a_2', family='A', params={'a': '2'}, batchable=True, + resources={'r1': 1}) + self.sch.add_task( + worker=WORKER, task_id='A_a_3', family='A', params={'a': '3'}, batchable=True, + resources={'r1': 1}) + + response = self.sch.get_work(worker=WORKER) + self.assertIsNone(response['task_id']) + self.assertEqual({'a': ['1', '2', '3']}, response['task_params']) + self.assertEqual('A', response['task_family']) + + def test_get_work_limited_batch_size(self): + self.sch.add_task_batcher( + worker=WORKER, task_family='A', batched_args=['a'], max_batch_size=2) + self.sch.add_task( + worker=WORKER, task_id='A_a_1', family='A', params={'a': '1'}, batchable=True, + priority=1) + self.sch.add_task( + worker=WORKER, task_id='A_a_2', family='A', params={'a': '2'}, batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_a_3', family='A', params={'a': '3'}, batchable=True, + priority=2) + + response = self.sch.get_work(worker=WORKER) + self.assertIsNone(response['task_id']) + self.assertEqual({'a': ['3', '1']}, response['task_params']) + self.assertEqual('A', response['task_family']) + + response2 = self.sch.get_work(worker=WORKER) + self.assertEqual('A_a_2', response2['task_id']) + + def test_get_work_do_not_batch_non_batchable_item(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task( + worker=WORKER, task_id='A_a_1', family='A', params={'a': '1'}, batchable=True, + priority=1) + self.sch.add_task( + worker=WORKER, task_id='A_a_2', family='A', params={'a': '2'}, batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_a_3', family='A', params={'a': '3'}, batchable=False, + priority=2) + + response = self.sch.get_work(worker=WORKER) + self.assertEqual('A_a_3', response['task_id']) + + response2 = self.sch.get_work(worker=WORKER) + self.assertIsNone(response2['task_id']) + self.assertEqual({'a': ['1', '2']}, response2['task_params']) + self.assertEqual('A', response2['task_family']) + + def test_get_work_group_on_non_batch_params(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['b']) + for a, b, c in itertools.product((1, 2), repeat=3): + self.sch.add_task( + worker=WORKER, task_id='A_%i_%i_%i' % (a, b, c), family='A', + params={'a': str(a), 'b': str(b), 'c': str(c)}, batchable=True, + priority=9 * a + 3 * c + b) + + for a, c in [('2', '2'), ('2', '1'), ('1', '2'), ('1', '1')]: + response = self.sch.get_work(worker=WORKER) + self.assertIsNone(response['task_id']) + self.assertEqual({'a': a, 'b': ['2', '1'], 'c': c}, response['task_params']) + self.assertEqual('A', response['task_family']) + + def test_get_work_multiple_batched_params(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a', 'b']) + self.sch.add_task( + worker=WORKER, task_id='A_1_1', family='A', params={'a': '1', 'b': '1'}, priority=1, + batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_1_2', family='A', params={'a': '1', 'b': '2'}, priority=2, + batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_2_1', family='A', params={'a': '2', 'b': '1'}, priority=3, + batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_2_2', family='A', params={'a': '2', 'b': '2'}, priority=4, + batchable=True) + + response = self.sch.get_work(worker=WORKER) + self.assertIsNone(response['task_id']) + + expected_params = { + 'a': ['2', '2', '1', '1'], + 'b': ['2', '1', '2', '1'], + } + self.assertEqual(expected_params, response['task_params']) + + def test_get_work_with_unbatched_worker_on_batched_task(self): + self.sch.add_task_batcher(worker='batcher', task_family='A', batched_args=['a']) + for i in range(5): + self.sch.add_task( + worker=WORKER, task_id='A_%i' % i, family='A', params={'a': str(i)}, priority=i, + batchable=False) + self.sch.add_task( + worker='batcher', task_id='A_%i' % i, family='A', params={'a': str(i)}, priority=i, + batchable=True) + self.assertEqual('A_4', self.sch.get_work(worker=WORKER)['task_id']) + batch_response = self.sch.get_work(worker='batcher') + self.assertIsNone(batch_response['task_id']) + self.assertEqual({'a': ['3', '2', '1', '0']}, batch_response['task_params']) + + def test_batched_tasks_become_batch_running(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': 1}, batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': 2}, batchable=True) + self.sch.get_work(worker=WORKER) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list('BATCH_RUNNING', '').keys())) + + def test_set_batch_runner_new_task(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + response = self.sch.get_work(worker=WORKER) + batch_id = response['batch_id'] + self.sch.add_task( + worker=WORKER, task_id='A_1_2', task_family='A', params={'a': '1,2'}, + batch_id=batch_id, status='RUNNING') + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list('BATCH_RUNNING', '').keys())) + self.assertEqual({'A_1_2'}, set(self.sch.task_list('RUNNING', '').keys())) + + self.sch.add_task(worker=WORKER, task_id='A_1_2', status=DONE) + self.assertEqual({'A_1', 'A_2', 'A_1_2'}, set(self.sch.task_list(DONE, '').keys())) + + def test_set_batch_runner_max(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + response = self.sch.get_work(worker=WORKER) + batch_id = response['batch_id'] + self.sch.add_task( + worker=WORKER, task_id='A_2', task_family='A', params={'a': '2'}, + batch_id=batch_id, status='RUNNING') + self.assertEqual({'A_1'}, set(self.sch.task_list('BATCH_RUNNING', '').keys())) + self.assertEqual({'A_2'}, set(self.sch.task_list('RUNNING', '').keys())) + + self.sch.add_task(worker=WORKER, task_id='A_2', status=DONE) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(DONE, '').keys())) + + def _start_simple_batch(self, use_max=False): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + response = self.sch.get_work(worker=WORKER) + batch_id = response['batch_id'] + task_id, params = ('A_2', {'a': '2'}) if use_max else ('A_1_2', {'a': '1,2'}) + self.sch.add_task( + worker=WORKER, task_id=task_id, task_family='A', params=params, batch_id=batch_id, + status='RUNNING') + + def test_batch_fail(self): + self._start_simple_batch() + self.sch.add_task(worker=WORKER, task_id='A_1_2', status=FAILED, expl='bad failure') + + task_ids = {'A_1', 'A_2'} + self.assertEqual(task_ids, set(self.sch.task_list(FAILED, '').keys())) + for task_id in task_ids: + expl = self.sch.fetch_error(task_id)['error'] + self.assertEqual('bad failure', expl) + + def test_batch_fail_max(self): + self._start_simple_batch(use_max=True) + self.sch.add_task(worker=WORKER, task_id='A_2', status=FAILED, expl='bad max failure') + + task_ids = {'A_1', 'A_2'} + self.assertEqual(task_ids, set(self.sch.task_list(FAILED, '').keys())) + for task_id in task_ids: + response = self.sch.fetch_error(task_id) + self.assertEqual('bad max failure', response['error']) + + def test_batch_fail_from_dead_worker(self): + self.setTime(1) + self._start_simple_batch() + self.setTime(10000) + self.sch.prune() + self.setTime(10001) + self.sch.prune() + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(FAILED, '').keys())) + + def test_batch_fail_max_from_dead_worker(self): + self.setTime(1) + self._start_simple_batch(use_max=True) + self.setTime(601) + self.sch.prune() + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(FAILED, '').keys())) + + def test_batch_update_status(self): + self._start_simple_batch() + self.sch.set_task_status_message('A_1_2', 'test message') + for task_id in ('A_1', 'A_2', 'A_1_2'): + self.assertEqual('test message', self.sch.get_task_status_message(task_id)['statusMessage']) + + def test_batch_tracking_url(self): + self._start_simple_batch() + self.sch.add_task(worker=WORKER, task_id='A_1_2', tracking_url='http://test.tracking.url/') + + tasks = self.sch.task_list('', '') + for task_id in ('A_1', 'A_2', 'A_1_2'): + self.assertEqual('http://test.tracking.url/', tasks[task_id]['tracking_url']) + + def test_finish_batch(self): + self._start_simple_batch() + self.sch.add_task(worker=WORKER, task_id='A_1_2', status=DONE) + self.assertEqual({'A_1', 'A_2', 'A_1_2'}, set(self.sch.task_list(DONE, '').keys())) + + def test_reschedule_max_batch(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task( + worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, batchable=True) + self.sch.add_task( + worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, batchable=True) + response = self.sch.get_work(worker=WORKER) + batch_id = response['batch_id'] + self.sch.add_task( + worker=WORKER, task_id='A_2', task_family='A', params={'a': '2'}, batch_id=batch_id, + status='RUNNING') + self.sch.add_task(worker=WORKER, task_id='A_2', status=DONE) + self.sch.add_task( + worker=WORKER, task_id='A_2', task_family='A', params={'a': '2'}, batchable=True) + + self.assertEqual({'A_2'}, set(self.sch.task_list(PENDING, '').keys())) + self.assertEqual({'A_1'}, set(self.sch.task_list(DONE, '').keys())) + + def test_resend_batch_on_get_work_retry(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + response = self.sch.get_work(worker=WORKER) + response2 = self.sch.get_work(worker=WORKER, current_tasks=()) + self.assertEqual(response['task_id'], response2['task_id']) + self.assertEqual(response['task_family'], response2.get('task_family')) + self.assertEqual(response['task_params'], response2.get('task_params')) + + def test_resend_batch_runner_on_get_work_retry(self): + self._start_simple_batch() + get_work = self.sch.get_work(worker=WORKER, current_tasks=()) + self.assertEqual('A_1_2', get_work['task_id']) + + def test_resend_max_batch_runner_on_get_work_retry(self): + self._start_simple_batch(use_max=True) + get_work = self.sch.get_work(worker=WORKER, current_tasks=()) + self.assertEqual('A_2', get_work['task_id']) + + def test_do_not_resend_batch_runner_on_get_work(self): + self._start_simple_batch() + get_work = self.sch.get_work(worker=WORKER, current_tasks=('A_1_2',)) + self.assertIsNone(get_work['task_id']) + + def test_do_not_resend_max_batch_runner_on_get_work(self): + self._start_simple_batch(use_max=True) + get_work = self.sch.get_work(worker=WORKER, current_tasks=('A_2',)) + self.assertIsNone(get_work['task_id']) + + def test_rescheduled_batch_running_tasks_stay_batch_running_before_runner(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + self.sch.get_work(worker=WORKER) + + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(BATCH_RUNNING, '').keys())) + + def test_rescheduled_batch_running_tasks_stay_batch_running_after_runner(self): + self._start_simple_batch() + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(BATCH_RUNNING, '').keys())) + + def test_disabled_batch_running_tasks_stay_batch_running_before_runner(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + self.sch.get_work(worker=WORKER) + + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True, status=DISABLED) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True, status=DISABLED) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(BATCH_RUNNING, '').keys())) + + def test_get_work_returns_batch_task_id_list(self): + self.sch.add_task_batcher(worker=WORKER, task_family='A', batched_args=['a']) + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True) + response = self.sch.get_work(worker=WORKER) + self.assertEqual({'A_1', 'A_2'}, set(response['batch_task_ids'])) + + def test_disabled_batch_running_tasks_stay_batch_running_after_runner(self): + self._start_simple_batch() + self.sch.add_task(worker=WORKER, task_id='A_1', family='A', params={'a': '1'}, + batchable=True, status=DISABLED) + self.sch.add_task(worker=WORKER, task_id='A_2', family='A', params={'a': '2'}, + batchable=True, status=DISABLED) + self.assertEqual({'A_1', 'A_2'}, set(self.sch.task_list(BATCH_RUNNING, '').keys())) + def test_do_not_overwrite_tracking_url_while_running(self): self.sch.add_task(task_id='A', worker='X', status='RUNNING', tracking_url='trackme') self.assertEqual('trackme', self.sch.task_list('RUNNING', '')['A']['tracking_url']) diff --git a/test/worker_test.py b/test/worker_test.py index 7034662d1e..792427c21b 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -689,6 +689,128 @@ def requires(self): self.assertTrue(d.has_run) self.assertFalse(a.has_run) + def test_run_csv_batch_job(self): + completed = set() + + class CsvBatchJob(luigi.Task): + values = luigi.parameter.Parameter(batch_method=','.join) + has_run = False + + def run(self): + completed.update(self.values.split(',')) + self.has_run = True + + def complete(self): + return all(value in completed for value in self.values.split(',')) + + tasks = [CsvBatchJob(str(i)) for i in range(10)] + for task in tasks: + self.assertTrue(self.w.add(task)) + self.assertTrue(self.w.run()) + + for task in tasks: + self.assertTrue(task.complete()) + self.assertFalse(task.has_run) + + def test_run_max_batch_job(self): + completed = set() + + class MaxBatchJob(luigi.Task): + value = luigi.IntParameter(batch_method=max) + has_run = False + + def run(self): + completed.add(self.value) + self.has_run = True + + def complete(self): + return any(self.value <= ran for ran in completed) + + tasks = [MaxBatchJob(i) for i in range(10)] + for task in tasks: + self.assertTrue(self.w.add(task)) + self.assertTrue(self.w.run()) + + for task in tasks: + self.assertTrue(task.complete()) + # only task number 9 should run + self.assertFalse(task.has_run and task.value < 9) + + def test_run_batch_job_unbatched(self): + completed = set() + + class MaxNonBatchJob(luigi.Task): + value = luigi.IntParameter(batch_method=max) + has_run = False + + batchable = False + + def run(self): + completed.add(self.value) + self.has_run = True + + def complete(self): + return self.value in completed + + tasks = [MaxNonBatchJob((i,)) for i in range(10)] + for task in tasks: + self.assertTrue(self.w.add(task)) + self.assertTrue(self.w.run()) + + for task in tasks: + self.assertTrue(task.complete()) + self.assertTrue(task.has_run) + + def test_run_batch_job_limit_batch_size(self): + completed = set() + runs = [] + + class CsvLimitedBatchJob(luigi.Task): + value = luigi.parameter.Parameter(batch_method=','.join) + has_run = False + + max_batch_size = 4 + + def run(self): + completed.update(self.value.split(',')) + runs.append(self) + + def complete(self): + return all(value in completed for value in self.value.split(',')) + + tasks = [CsvLimitedBatchJob(str(i)) for i in range(11)] + for task in tasks: + self.assertTrue(self.w.add(task)) + self.assertTrue(self.w.run()) + + for task in tasks: + self.assertTrue(task.complete()) + + self.assertEqual(3, len(runs)) + + def test_fail_max_batch_job(self): + class MaxBatchFailJob(luigi.Task): + value = luigi.IntParameter(batch_method=max) + has_run = False + + def run(self): + self.has_run = True + assert False + + def complete(self): + return False + + tasks = [MaxBatchFailJob(i) for i in range(10)] + for task in tasks: + self.assertTrue(self.w.add(task)) + self.assertFalse(self.w.run()) + + for task in tasks: + # only task number 9 should run + self.assertFalse(task.has_run and task.value < 9) + + self.assertEqual({task.task_id for task in tasks}, set(self.sch.task_list('FAILED', ''))) + class DynamicDependenciesTest(unittest.TestCase): n_workers = 1