diff --git a/doc/parameters.rst b/doc/parameters.rst index 1e9f774416..a83d6d23da 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -88,6 +88,26 @@ are not the same instance: >>> hash(c) == hash(d) True ++Parameter visibility +^^^^^^^^^^^^^^^^^^^^ + +Using :class:`~luigi.parameter.ParameterVisibility` you can configure parameter visibility. By default, all +parameters are public, but you can also set them hidden or private. + +.. code:: python + + >>> import luigi + >>> from luigi.parameter import ParameterVisibility + + >>> luigi.Parameter(visibility=ParameterVisibility.PRIVATE) + +``ParameterVisibility.PUBLIC`` (default) - visible everywhere + +``ParameterVisibility.HIDDEN`` - ignored in WEB-view, but saved into database if save db_history is true + +``ParameterVisibility.PRIVATE`` - visible only inside task. + + Parameter types ^^^^^^^^^^^^^^^ diff --git a/luigi/__init__.py b/luigi/__init__.py index 2e858b18e4..4dd0878158 100644 --- a/luigi/__init__.py +++ b/luigi/__init__.py @@ -43,6 +43,7 @@ from luigi import interface from luigi.interface import run, build +from luigi.execution_summary import LuigiStatusCode from luigi import event from luigi.event import Event @@ -59,5 +60,5 @@ 'FloatParameter', 'BoolParameter', 'TaskParameter', 'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event', - 'NumericalParameter', 'ChoiceParameter', 'OptionalParameter' + 'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'LuigiStatusCode' ] diff --git a/luigi/configuration.py b/luigi/configuration.py index 95e8a4c7c4..1e12d3a4a5 100644 --- a/luigi/configuration.py +++ b/luigi/configuration.py @@ -35,11 +35,17 @@ try: from ConfigParser import ConfigParser, NoOptionError, NoSectionError + Interpolation = object except ImportError: from configparser import ConfigParser, NoOptionError, NoSectionError + from configparser import Interpolation class LuigiConfigParser(ConfigParser): + + # for python2/3 compatibility + _DEFAULT_INTERPOLATION = Interpolation() + NO_DEFAULT = object() _instance = None _config_paths = [ diff --git a/luigi/execution_summary.py b/luigi/execution_summary.py index bc5bb19728..b8a621e811 100644 --- a/luigi/execution_summary.py +++ b/luigi/execution_summary.py @@ -24,6 +24,7 @@ import textwrap import collections import functools +import enum import luigi @@ -32,6 +33,63 @@ class execution_summary(luigi.Config): summary_length = luigi.IntParameter(default=5) +class LuigiStatusCode(enum.Enum): + """ + All possible status codes for the attribute ``status`` in :class:`~luigi.execution_summary.LuigiRunResult` when + the argument ``detailed_summary=True`` in *luigi.run() / luigi.build*. + Here are the codes and what they mean: + + ============================= ========================================================== + Status Code Name Meaning + ============================= ========================================================== + SUCCESS There were no failed tasks or missing dependencies + SUCCESS_WITH_RETRY There were failed tasks but they all succeeded in a retry + FAILED There were failed tasks + FAILED_AND_SCHEDULING_FAILED There were failed tasks and tasks whose scheduling failed + SCHEDULING_FAILED There were tasks whose scheduling failed + NOT_RUN There were tasks that were not granted run permission by the scheduler + MISSING_EXT There were missing external dependencies + ============================= ========================================================== + + """ + SUCCESS = (":)", "there were no failed tasks or missing dependencies") + SUCCESS_WITH_RETRY = (":)", "there were failed tasks but they all succeeded in a retry") + FAILED = (":(", "there were failed tasks") + FAILED_AND_SCHEDULING_FAILED = (":(", "there were failed tasks and tasks whose scheduling failed") + SCHEDULING_FAILED = (":(", "there were tasks whose scheduling failed") + NOT_RUN = (":|", "there were tasks that were not granted run permission by the scheduler") + MISSING_EXT = (":|", "there were missing external dependencies") + + +class LuigiRunResult(object): + """ + The result of a call to build/run when passing the detailed_summary=True argument. + + Attributes: + - one_line_summary (str): One line summary of the progress. + - summary_text (str): Detailed summary of the progress. + - status (LuigiStatusCode): Luigi Status Code. See :class:`~luigi.execution_summary.LuigiStatusCode` for what these codes mean. + - status_code_num (int): Numeric representation for status (LuigiStatusCode) + - worker (luigi.worker.worker): Worker object. See :class:`~luigi.worker.worker`. + - scheduling_succeeded (bool): Boolean which is *True* if all the tasks were scheduled without errors. + + """ + def __init__(self, worker, worker_add_run_status=True): + self.worker = worker + summary_dict = _summary_dict(worker) + self.summary_text = _summary_wrap(_summary_format(summary_dict, worker)) + self.status = _tasks_status(summary_dict) + self.status_code_num = _status_to_code_num(self.status) + self.one_line_summary = _create_one_line_summary(self.status) + self.scheduling_succeeded = worker_add_run_status + + def __str__(self): + return "LuigiRunResult with status {0}".format(self.status) + + def __repr__(self): + return "LuigiRunResult(status={0!r},worker={1!r},scheduling_succeeded={2!r})".format(self.status, self.worker, self.scheduling_succeeded) + + def _partition_tasks(worker): """ Takes a worker and sorts out tasks based on their status. @@ -377,33 +435,60 @@ def _summary_format(set_tasks, worker): if len(ext_workers) == 0: str_output += '\n' str_output += 'Did not run any tasks' - smiley = "" - reason = "" + one_line_summary = _create_one_line_summary(_tasks_status(set_tasks)) + str_output += "\n{0}".format(one_line_summary) + if num_all_tasks == 0: + str_output = 'Did not schedule any tasks' + return str_output + + +def _create_one_line_summary(status_code): + """ + Given a status_code of type LuigiStatusCode which has a tuple value, returns a one line summary + """ + return "This progress looks {0} because {1}".format(*status_code.value) + + +def _tasks_status(set_tasks): + """ + Given a grouped set of tasks, returns a LuigiStatusCode + """ if set_tasks["ever_failed"]: if not set_tasks["failed"]: - smiley = ":)" - reason = "there were failed tasks but they all succeeded in a retry" + return LuigiStatusCode.SUCCESS_WITH_RETRY else: - smiley = ":(" - reason = "there were failed tasks" if set_tasks["scheduling_error"]: - reason += " and tasks whose scheduling failed" + return LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED + return LuigiStatusCode.FAILED elif set_tasks["scheduling_error"]: - smiley = ":(" - reason = "there were tasks whose scheduling failed" + return LuigiStatusCode.SCHEDULING_FAILED elif set_tasks["not_run"]: - smiley = ":|" - reason = "there were tasks that were not granted run permission by the scheduler" + return LuigiStatusCode.NOT_RUN elif set_tasks["still_pending_ext"]: - smiley = ":|" - reason = "there were missing external dependencies" + return LuigiStatusCode.MISSING_EXT else: - smiley = ":)" - reason = "there were no failed tasks or missing external dependencies" - str_output += "\nThis progress looks {0} because {1}".format(smiley, reason) - if num_all_tasks == 0: - str_output = 'Did not schedule any tasks' - return str_output + return LuigiStatusCode.SUCCESS + + +def _status_to_code_num(status_code): + """ + Given a status_code of type LuigiStatusCode, returns a numeric value representing it + POSIX assigns special meanings to 1 and 2 so start from 3 + """ + if status_code == LuigiStatusCode.SUCCESS: + return 0 + elif status_code == LuigiStatusCode.SUCCESS_WITH_RETRY: + return 3 + elif status_code == LuigiStatusCode.FAILED: + return 4 + elif status_code == LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED: + return 5 + elif status_code == LuigiStatusCode.SCHEDULING_FAILED: + return 6 + elif status_code == LuigiStatusCode.NOT_RUN: + return 7 + elif status_code == LuigiStatusCode.MISSING_EXT: + return 8 def _summary_wrap(str_output): diff --git a/luigi/interface.py b/luigi/interface.py index 0b1c118c92..10a42bf043 100644 --- a/luigi/interface.py +++ b/luigi/interface.py @@ -36,7 +36,7 @@ from luigi import scheduler from luigi import task from luigi import worker -from luigi import execution_summary +from luigi.execution_summary import LuigiRunResult from luigi.cmdline_parser import CmdlineParser @@ -205,8 +205,9 @@ def _schedule_and_run(tasks, worker_scheduler_factory=None, override_defaults=No success &= worker.add(t, env_params.parallel_scheduling, env_params.parallel_scheduling_processes) logger.info('Done scheduling tasks') success &= worker.run() - logger.info(execution_summary.summary(worker)) - return dict(success=success, worker=worker) + luigi_run_result = LuigiRunResult(worker, success) + logger.info(luigi_run_result.summary_text) + return luigi_run_result class PidLockAlreadyTakenExit(SystemExit): @@ -217,11 +218,16 @@ class PidLockAlreadyTakenExit(SystemExit): def run(*args, **kwargs): - return _run(*args, **kwargs)['success'] + luigi_run_result = _run(*args, **kwargs) + if kwargs.get('detailed_summary'): + # return status code instead of entire class to keep interface similar (used to return bool) + return luigi_run_result.status_code_num + else: + return luigi_run_result.scheduling_succeeded def _run(cmdline_args=None, main_task_cls=None, - worker_scheduler_factory=None, use_dynamic_argparse=None, local_scheduler=False): + worker_scheduler_factory=None, use_dynamic_argparse=None, local_scheduler=False, detailed_summary=False): """ Please dont use. Instead use `luigi` binary. @@ -248,7 +254,7 @@ def _run(cmdline_args=None, main_task_cls=None, return _schedule_and_run([cp.get_task_obj()], worker_scheduler_factory) -def build(tasks, worker_scheduler_factory=None, **env_params): +def build(tasks, worker_scheduler_factory=None, detailed_summary=False, **env_params): """ Run internally, bypassing the cmdline parsing. @@ -271,4 +277,10 @@ def build(tasks, worker_scheduler_factory=None, **env_params): if "no_lock" not in env_params: env_params["no_lock"] = True - return _schedule_and_run(tasks, worker_scheduler_factory, override_defaults=env_params)['success'] + luigi_run_result = _schedule_and_run(tasks, worker_scheduler_factory, + override_defaults=env_params) + if detailed_summary: + # return status code instead of entire class to keep interface similar (used to return bool) + return luigi_run_result.status_code_num + else: + return luigi_run_result.scheduling_succeeded diff --git a/luigi/notifications.py b/luigi/notifications.py index 824429eae0..41c27e93d5 100644 --- a/luigi/notifications.py +++ b/luigi/notifications.py @@ -88,6 +88,10 @@ class email(luigi.Config): default=DEFAULT_CLIENT_EMAIL, config_path=dict(section='core', name='email-sender'), description='Address to send e-mails from') + region = luigi.parameter.Parameter( + default='', + config_path=dict(section='email', name='region'), + description='AWS region for SES if you want to override the default AWS region for boto3') class smtp(luigi.Config): @@ -219,7 +223,8 @@ def send_email_ses(sender, subject, message, recipients, image_png): """ from boto3 import client as boto3_client - client = boto3_client('ses') + region = email().region or None + client = boto3_client('ses', region_name=region) msg_root = generate_email(sender, subject, message, recipients, image_png) response = client.send_raw_email(Source=sender, diff --git a/luigi/parameter.py b/luigi/parameter.py index 326e850510..d97ff7459e 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -23,6 +23,7 @@ import abc import datetime import warnings +from enum import IntEnum import json from json import JSONEncoder from collections import OrderedDict, Mapping @@ -44,6 +45,23 @@ _no_value = object() +class ParameterVisibility(IntEnum): + """ + Possible values for the parameter visibility option. Public is the default. + See :doc:`/parameters` for more info. + """ + PUBLIC = 0 + HIDDEN = 1 + PRIVATE = 2 + + @classmethod + def has_value(cls, value): + return any(value == item.value for item in cls) + + def serialize(self): + return self.value + + class ParameterException(Exception): """ Base exception. @@ -113,7 +131,8 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False, batch_method=None): + config_path=None, positional=True, always_in_help=False, batch_method=None, + visibility=ParameterVisibility.PUBLIC): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -140,6 +159,10 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip parameter values into a single value. Used when receiving batched parameter lists from the scheduler. See :ref:`batch_method` + + :param visibility: A Parameter whose value is a :py:class:`~luigi.parameter.ParameterVisibility`. + Default value is ParameterVisibility.PUBLIC + """ self._default = default self._batch_method = batch_method @@ -150,6 +173,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional + self.visibility = visibility if ParameterVisibility.has_value(visibility) else ParameterVisibility.PUBLIC self.description = description self.always_in_help = always_in_help @@ -195,11 +219,11 @@ def _value_iterator(self, task_name, param_name): yield (self._get_value_from_config(task_name, param_name), None) yield (self._get_value_from_config(task_name, param_name.replace('_', '-')), 'Configuration [{}] {} (with dashes) should be avoided. Please use underscores.'.format( - task_name, param_name)) + task_name, param_name)) if self._config_path: yield (self._get_value_from_config(self._config_path['section'], self._config_path['name']), 'The use of the configuration [{}] {} is deprecated. Please use [{}] {}'.format( - self._config_path['section'], self._config_path['name'], task_name, param_name)) + self._config_path['section'], self._config_path['name'], task_name, param_name)) yield (self._default, None) def has_task_value(self, task_name, param_name): @@ -694,7 +718,8 @@ def field(key): def optional_field(key): return "(%s)?" % field(key) # A little loose: ISO 8601 does not allow weeks in combination with other fields, but this regex does (as does python timedelta) - regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) + regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), + "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) return self._apply_regex(regex, input) def _parseSimple(self, input): diff --git a/luigi/rpc.py b/luigi/rpc.py index 3211152476..f02f34b239 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -76,6 +76,26 @@ def fetch(self, full_url, body, timeout): return urlopen(full_url, body, timeout).read().decode('utf-8') +class BobaPKIHTTPAdapter(requests.adapters.HTTPAdapter): + """HTTP adapter which disables hostname validation on HTTPS connections. + + Copied from $ATT/rpc2/affirm/rpc2/transport/http.py + """ + + def init_poolmanager(self, *args, **kwargs): + super(BobaPKIHTTPAdapter, self).init_poolmanager(assert_hostname=False, *args, **kwargs) + + +def get_requests_session(): + session = requests.Session() + cert_verify = os.environ.get('BOBAPKI_CACERT_VERIFY', None) + if cert_verify is not None: + session.mount('https://', BobaPKIHTTPAdapter()) + session.verify = cert_verify + + return session + + class RequestsFetcher(object): def __init__(self, session): from requests import exceptions as requests_exceptions @@ -87,7 +107,7 @@ def check_pid(self): # if the process id change changed from when the session was created # a new session needs to be setup since requests isn't multiprocessing safe. if os.getpid() != self.process_id: - self.session = requests.Session() + self.session = get_requests_session() self.process_id = os.getpid() def fetch(self, full_url, body, timeout): @@ -118,7 +138,7 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_retry_wait = config.getint('core', 'rpc-retry-wait', 30) if HAS_REQUESTS: - self._fetcher = RequestsFetcher(requests.Session()) + self._fetcher = RequestsFetcher(get_requests_session()) else: self._fetcher = URLLibFetcher() diff --git a/luigi/scheduler.py b/luigi/scheduler.py index f83893e3c0..2023a4cebd 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -48,6 +48,7 @@ from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \ BATCH_RUNNING from luigi.task import Config +from luigi.parameter import ParameterVisibility logger = logging.getLogger(__name__) @@ -275,7 +276,8 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): + params=None, param_visibilities=None, tracking_url=None, status_message=None, + progress_percentage=None, retry_policy='notoptional'): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) self.workers = OrderedSet() # workers ids that can perform task - task is 'BROKEN' if none of these workers are active @@ -295,8 +297,11 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.resources = _get_default(resources, {}) self.family = family self.module = module - self.params = _get_default(params, {}) - + self.param_visibilities = _get_default(param_visibilities, {}) + self.params = {} + self.public_params = {} + self.hidden_params = {} + self.set_params(params) self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) self.tracking_url = tracking_url @@ -310,6 +315,13 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', def __repr__(self): return "Task(%r)" % vars(self) + def set_params(self, params): + self.params = _get_default(params, {}) + self.public_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC} + self.hidden_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN} + # TODO(2017-08-10) replace this function with direct calls to batchable # this only exists for backward compatibility def is_batchable(self): @@ -335,7 +347,7 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.params.items())) + param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.public_params.items())) return u'{}({})'.format(self.family, param_str) @@ -770,7 +782,7 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, + priority=0, family='', module=None, params=None, param_visibilities=None, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict=None, owners=None, **kwargs): """ @@ -794,7 +806,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, params=params, + priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities ) else: _default_task = None @@ -809,8 +821,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, task.family = family if not getattr(task, 'module', None): task.module = module + if not task.param_visibilities: + task.param_visibilities = _get_default(param_visibilities, {}) if not task.params: - task.params = _get_default(params, {}) + task.set_params(params) if batch_id is not None: task.batch_id = batch_id @@ -1033,9 +1047,19 @@ def count_pending(self, worker): for task in worker.get_tasks(self._state, PENDING, FAILED): if self._upstream_status(task.id, upstream_status_table) == UPSTREAM_DISABLED: continue - num_pending += 1 - num_unique_pending += int(len(task.workers) == 1) - num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id) + has_failed_dependency = False + for dep in task.deps: + dep_task = self._state.get_task(dep, default=None) + if dep_task.status == UNKNOWN: + # consider this task as not pending since these dependencies have broken + # requires. this means that they won't ever be retried and can't succeed at all + has_failed_dependency = True + break + + if not has_failed_dependency: + num_pending += 1 + num_unique_pending += int(len(task.workers) == 1) + num_pending_last_scheduled += int(task.workers.peek(last=True) == worker_id) return { 'n_pending_tasks': num_pending, @@ -1238,7 +1262,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'time_running': getattr(task, "time_running", None), 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), - 'params': task.params, + 'params': task.public_params, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index 08e40ae214..b76e8d8545 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -39,6 +39,7 @@ from luigi import parameter from luigi.task_register import Register +from luigi.parameter import ParameterVisibility Parameter = parameter.Parameter logger = logging.getLogger('luigi-interface') @@ -49,6 +50,7 @@ TASK_ID_TRUNCATE_HASH = 10 TASK_ID_INVALID_CHAR_REGEX = re.compile(r'[^A-Za-z0-9_]') _SAME_AS_PYTHON_MODULE = '_same_as_python_module' +TASK_BATCHED_PARAMS_VAR = '_batched_params' def namespace(namespace=None, scope=''): @@ -364,6 +366,24 @@ def get_params(cls): @classmethod def batch_param_names(cls): return [name for name, p in cls.get_params() if p._is_batchable()] + + @property + def batched_params(self): + """ + Get the batched over values for the parameters with a defined batching_method + + :returns a dict of (name, value) where name is the original param_name and the value is the batched over list + """ + return getattr(self, TASK_BATCHED_PARAMS_VAR) + + @property + def batched_params(self): + """ + Get the batched over values for the parameters with a defined batching_method + + :returns a dict of (name, value) where name is the original param_name and the value is the batched over list + """ + return getattr(self, TASK_BATCHED_PARAMS_VAR) @classmethod def get_param_names(cls, include_significant=False): @@ -432,8 +452,20 @@ def __init__(self, *args, **kwargs): # Register kwargs as an attribute on the class. Might be useful self.param_kwargs = dict(param_values) + # Register default batched_params consisting of just single item lists for batchable params + # if they are found in param_kwargs, this will be overwritten in actual batched calls by + # from_str_params + batched_params = {} + for name in self.batch_param_names(): + if name in self.param_kwargs: + batched_params[name] = [self.param_kwargs[name]] + else: + batched_params[name] = [] + + setattr(self, TASK_BATCHED_PARAMS_VAR, batched_params) + self._warn_on_wrong_param_types() - self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True)) + self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True, only_public=True)) self.__hash = hash(self.task_id) self.set_tracking_url = None @@ -464,28 +496,50 @@ def from_str_params(cls, params_str): :param params_str: dict of param name -> value as string. """ kwargs = {} + batched_params = {} for param_name, param in cls.get_params(): if param_name in params_str: param_str = params_str[param_name] if isinstance(param_str, list): kwargs[param_name] = param._parse_list(param_str) + if param._is_batchable(): + batched_params[param_name] = [param.parse(x) for x in param_str] else: kwargs[param_name] = param.parse(param_str) + if param._is_batchable(): + batched_params[param_name] = [param.parse(param_str)] + + + # Append the attribute after initialization so as to reuse the registry's instance_cache + ret = cls(**kwargs) - return cls(**kwargs) + # TODO(EJS) evaluate if doing an .update is better? + setattr(ret, TASK_BATCHED_PARAMS_VAR, batched_params) + return ret - def to_str_params(self, only_significant=False): + def to_str_params(self, only_significant=False, only_public=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): - if (not only_significant) or params[param_name].significant: + if (((not only_significant) or params[param_name].significant) + and ((not only_public) or params[param_name].visibility == ParameterVisibility.PUBLIC) + and params[param_name].visibility != ParameterVisibility.PRIVATE): params_str[param_name] = params[param_name].serialize(param_value) return params_str + def _get_param_visibilities(self): + param_visibilities = {} + params = dict(self.get_params()) + for param_name, param_value in six.iteritems(self.param_kwargs): + if params[param_name].visibility != ParameterVisibility.PRIVATE: + param_visibilities[param_name] = params[param_name].visibility.serialize() + + return param_visibilities + def clone(self, cls=None, **kwargs): """ Creates a new instance from an existing instance where some of the args have changed. diff --git a/luigi/tools/range.py b/luigi/tools/range.py index bd70f2b97b..996f57315b 100755 --- a/luigi/tools/range.py +++ b/luigi/tools/range.py @@ -659,8 +659,7 @@ class RangeDaily(RangeDailyBase): def missing_datetimes(self, finite_datetimes): try: - cls_with_params = functools.partial(self.of, **self.of_params) - complete_parameters = self.of.bulk_complete.__func__(cls_with_params, map(self.datetime_to_parameter, finite_datetimes)) + complete_parameters = self.of.bulk_complete.__func__(self._instantiate_task_cls, map(self.datetime_to_parameter, finite_datetimes)) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( @@ -688,8 +687,7 @@ class RangeHourly(RangeHourlyBase): def missing_datetimes(self, finite_datetimes): try: # TODO: Why is there a list() here but not for the RangeDaily?? - cls_with_params = functools.partial(self.of, **self.of_params) - complete_parameters = self.of.bulk_complete.__func__(cls_with_params, list(map(self.datetime_to_parameter, finite_datetimes))) + complete_parameters = self.of.bulk_complete.__func__(self._instantiate_task_cls, list(map(self.datetime_to_parameter, finite_datetimes))) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( @@ -716,8 +714,7 @@ class RangeByMinutes(RangeByMinutesBase): def missing_datetimes(self, finite_datetimes): try: - cls_with_params = functools.partial(self.of, **self.of_params) - complete_parameters = self.of.bulk_complete.__func__(cls_with_params, map(self.datetime_to_parameter, finite_datetimes)) + complete_parameters = self.of.bulk_complete.__func__(self._instantiate_task_cls, map(self.datetime_to_parameter, finite_datetimes)) return set(finite_datetimes) - set(map(self.parameter_to_datetime, complete_parameters)) except NotImplementedError: return infer_bulk_complete_from_fs( diff --git a/luigi/worker.py b/luigi/worker.py index aec97d32b9..4f35edadf4 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -507,6 +507,9 @@ def _add_task(self, *args, **kwargs): for batch_task in self._batch_running_tasks.pop(task_id): self._add_task_history.append((batch_task, status, True)) + if task and kwargs.get('params'): + kwargs['param_visibilities'] = task._get_param_visibilities() + self._scheduler.add_task(*args, **kwargs) logger.info('Informed scheduler that task %s has status %s', task_id, status) @@ -656,42 +659,44 @@ def add(self, task, multiprocess=False, processes=0): if self._first_task is None and hasattr(task, 'task_id'): self._first_task = task.task_id self.add_succeeded = True - if multiprocess: - queue = multiprocessing.Manager().Queue() - pool = multiprocessing.Pool(processes=processes if processes > 0 else None) - else: - queue = DequeQueue() - pool = SingleProcessPool() - self._validate_task(task) - pool.apply_async(check_complete, [task, queue]) - # we track queue size ourselves because len(queue) won't work for multiprocessing - queue_size = 1 - try: - seen = {task.task_id} - while queue_size: - current = queue.get() - queue_size -= 1 - item, is_complete = current - for next in self._add(item, is_complete): - if next.task_id not in seen: - self._validate_task(next) - seen.add(next.task_id) - pool.apply_async(check_complete, [next, queue]) - queue_size += 1 - except (KeyboardInterrupt, TaskException): - raise - except Exception as ex: - self.add_succeeded = False - formatted_traceback = traceback.format_exc() - self._log_unexpected_error(task) - task.trigger_event(Event.BROKEN_TASK, task, ex) - self._email_unexpected_error(task, formatted_traceback) - raise - finally: - pool.close() - pool.join() - return self.add_succeeded + with fork_lock: + if multiprocess: + queue = multiprocessing.Manager().Queue() + pool = multiprocessing.Pool(processes=processes if processes > 0 else None) + else: + queue = DequeQueue() + pool = SingleProcessPool() + self._validate_task(task) + pool.apply_async(check_complete, [task, queue]) + + # we track queue size ourselves because len(queue) won't work for multiprocessing + queue_size = 1 + try: + seen = {task.task_id} + while queue_size: + current = queue.get() + queue_size -= 1 + item, is_complete = current + for next in self._add(item, is_complete): + if next.task_id not in seen: + self._validate_task(next) + seen.add(next.task_id) + pool.apply_async(check_complete, [next, queue]) + queue_size += 1 + except (KeyboardInterrupt, TaskException): + raise + except Exception as ex: + self.add_succeeded = False + formatted_traceback = traceback.format_exc() + self._log_unexpected_error(task) + task.trigger_event(Event.BROKEN_TASK, task, ex) + self._email_unexpected_error(task, formatted_traceback) + raise + finally: + pool.close() + pool.join() + return self.add_succeeded def _add_task_batcher(self, task): family = task.task_family @@ -898,6 +903,10 @@ def _get_work(self): batch_tasks = filter(None, [ self._scheduled_tasks.get(batch_id) for batch_id in r['batch_task_ids']]) self._batch_running_tasks[task_id] = batch_tasks + self._scheduled_tasks[task_id] = \ + load_task(module=r.get('task_module'), + task_name=r['task_family'], + params_str=r['task_params']) return GetWorkResponse( task_id=task_id, diff --git a/setup.py b/setup.py index 14dfe0c2eb..d76d473f3d 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ # the License. import os +import sys from setuptools import setup @@ -39,6 +40,7 @@ def get_static_files(path): install_requires = [ 'tornado>=4.0,<5', 'python-daemon<3.0', + 'enum34>1.1.0 ; python_version < "3.4"' ] if os.environ.get('READTHEDOCS', None) == 'True': @@ -50,7 +52,7 @@ def get_static_files(path): setup( name='luigi', - version='2.7.5', + version='2.7.5+affirm.1.4.3', description='Workflow mgmgt + task scheduling + dependency resolution', long_description=long_description, author='The Luigi Authors', diff --git a/test/db_task_history_test.py b/test/db_task_history_test.py index 8b162d282e..d302bed292 100644 --- a/test/db_task_history_test.py +++ b/test/db_task_history_test.py @@ -24,6 +24,7 @@ from luigi.db_task_history import DbTaskHistory from luigi.task_status import DONE, PENDING, RUNNING import luigi.scheduler +from luigi.parameter import ParameterVisibility class DummyTask(luigi.Task): @@ -32,7 +33,8 @@ class DummyTask(luigi.Task): class ParamTask(luigi.Task): param1 = luigi.Parameter() - param2 = luigi.IntParameter() + param2 = luigi.IntParameter(visibility=ParameterVisibility.HIDDEN) + param3 = luigi.Parameter(default="empty", visibility=ParameterVisibility.PRIVATE) class DbTaskHistoryTest(unittest.TestCase): diff --git a/test/interface_test.py b/test/interface_test.py index 217c716b05..a0c6ceb0d5 100644 --- a/test/interface_test.py +++ b/test/interface_test.py @@ -23,6 +23,7 @@ from luigi.interface import _WorkerSchedulerFactory from luigi.worker import Worker from luigi.interface import core +from luigi.execution_summary import LuigiStatusCode from mock import Mock, patch, MagicMock from helpers import LuigiTestCase, with_config @@ -46,12 +47,39 @@ class NoOpTask(luigi.Task): self.task_a = NoOpTask("a") self.task_b = NoOpTask("b") + def _create_summary_dict_with(self, updates={}): + summary_dict = { + 'completed': set(), + 'already_done': set(), + 'ever_failed': set(), + 'failed': set(), + 'scheduling_error': set(), + 'still_pending_ext': set(), + 'still_pending_not_ext': set(), + 'run_by_other_worker': set(), + 'upstream_failure': set(), + 'upstream_missing_dependency': set(), + 'upstream_run_by_other_worker': set(), + 'upstream_scheduling_error': set(), + 'not_run': set() + } + summary_dict.update(updates) + return summary_dict + + def _summary_dict_module_path(): + return 'luigi.execution_summary._summary_dict' + def test_interface_run_positive_path(self): self.worker.add = Mock(side_effect=[True, True]) self.worker.run = Mock(return_value=True) self.assertTrue(self._run_interface()) + def test_interface_run_positive_path_with_detailed_summary_enabled(self): + self.worker.add = Mock(side_effect=[True, True]) + self.worker.run = Mock(return_value=True) + self.assertEqual(self._run_interface(detailed_summary=True), 0) + def test_interface_run_with_add_failure(self): self.worker.add = Mock(side_effect=[True, False]) self.worker.run = Mock(return_value=True) @@ -64,6 +92,70 @@ def test_interface_run_with_run_failure(self): self.assertFalse(self._run_interface()) + @patch(_summary_dict_module_path()) + def test_that_status_is_success(self, fake_summary_dict): + # Nothing in failed tasks so, should succeed + fake_summary_dict.return_value = self._create_summary_dict_with() + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 0) + + @patch(_summary_dict_module_path()) + def test_that_status_is_success_with_retry(self, fake_summary_dict): + # Nothing in failed tasks (only an entry in ever_failed) so, should succeed with retry + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'ever_failed': [self.task_a] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 3) + + @patch(_summary_dict_module_path()) + def test_that_status_is_failed_when_there_is_one_failed_task(self, fake_summary_dict): + # Should fail because a task failed + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'ever_failed': [self.task_a], + 'failed': [self.task_a] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 4) + + @patch(_summary_dict_module_path()) + def test_that_status_is_failed_with_scheduling_failure(self, fake_summary_dict): + # Failed task and also a scheduling error + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'ever_failed': [self.task_a], + 'failed': [self.task_a], + 'scheduling_error': [self.task_b] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 5) + + @patch(_summary_dict_module_path()) + def test_that_status_is_scheduling_failed_with_one_scheduling_error(self, fake_summary_dict): + # Scheduling error for at least one task + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'scheduling_error': [self.task_b] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 6) + + @patch(_summary_dict_module_path()) + def test_that_status_is_not_run_with_one_task_not_run(self, fake_summary_dict): + # At least one of the tasks was not run + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'not_run': [self.task_a] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 7) + + @patch(_summary_dict_module_path()) + def test_that_status_is_missing_ext_with_one_task_with_missing_external_dependency(self, fake_summary_dict): + # Missing external dependency for at least one task + fake_summary_dict.return_value = self._create_summary_dict_with({ + 'still_pending_ext': [self.task_a] + }) + luigi_run_result = self._run_interface(detailed_summary=True) + self.assertEqual(luigi_run_result, 8) + def test_stops_worker_on_add_exception(self): worker = MagicMock() self.worker_scheduler_factory.create_worker = Mock(return_value=worker) @@ -94,8 +186,10 @@ class MyOtherTestTask(luigi.Task): with patch.object(sys, 'argv', ['my_module.py', '--no-lock', '--my-param', 'my_value', '--local-scheduler']): luigi.run(main_task_cls=MyOtherTestTask) - def _run_interface(self): - return luigi.interface.build([self.task_a, self.task_b], worker_scheduler_factory=self.worker_scheduler_factory) + def _run_interface(self, **env_params): + return luigi.interface.build([self.task_a, self.task_b], + worker_scheduler_factory=self.worker_scheduler_factory, + **env_params) class CoreConfigTest(LuigiTestCase): diff --git a/test/notifications_test.py b/test/notifications_test.py index 3c8aaaad53..bba5636693 100644 --- a/test/notifications_test.py +++ b/test/notifications_test.py @@ -359,6 +359,27 @@ def test_sends_ses_email(self): Destinations=self.recipients, RawMessage={'Data': self.mocked_email_msg}) + @with_config({'email': {'region': 'whatever'}}) + def test_sends_ses_email_with_reguon(self): + """ + Call notifications.send_email_ses with fixture parameters + and check that boto is properly called. + """ + + with mock.patch('boto3.client') as boto_client: + with mock.patch('luigi.notifications.generate_email') as generate_email: + generate_email.return_value\ + .as_string.return_value = self.mocked_email_msg + + notifications.send_email_ses(*self.notification_args) + + boto_client.assert_called_once_with('ses', region_name='whatever') + SES = boto_client.return_value + SES.send_raw_email.assert_called_once_with( + Source=self.sender, + Destinations=self.recipients, + RawMessage={'Data': self.mocked_email_msg}) + class TestSNSNotification(unittest.TestCase, NotificationFixture): """ diff --git a/test/range_test.py b/test/range_test.py index 9d56a24dfa..8c0c1f97d1 100755 --- a/test/range_test.py +++ b/test/range_test.py @@ -22,6 +22,7 @@ import luigi import mock from luigi.mock import MockTarget, MockFileSystem +from luigi.task import MixinNaiveBulkComplete from luigi.tools.range import (RangeDaily, RangeDailyBase, RangeEvent, RangeHourly, RangeHourlyBase, RangeByMinutes, RangeByMinutesBase, @@ -1185,6 +1186,22 @@ def complete(self): expected_task = MyTask('woo', datetime.date(2015, 12, 1)) self.assertEqual(expected_task, list(range_task._requires())[0]) + def test_param_name_with_mixinnaivebulkcomplete(self): + class MyTask(MixinNaiveBulkComplete, luigi.Task): + some_non_range_param = luigi.Parameter(default='woo') + date_param = luigi.DateParameter() + + def complete(self): + return False + + range_task = RangeDaily(now=datetime_to_epoch(datetime.datetime(2015, 12, 2)), + of=MyTask, + start=datetime.date(2015, 12, 1), + stop=datetime.date(2015, 12, 2), + param_name='date_param') + expected_task = MyTask('woo', datetime.date(2015, 12, 1)) + self.assertEqual(expected_task, list(range_task._requires())[0]) + def test_param_name_with_inferred_fs(self): class MyTask(luigi.Task): some_non_range_param = luigi.Parameter(default='woo') diff --git a/test/scheduler_parameter_visibilities_test.py b/test/scheduler_parameter_visibilities_test.py new file mode 100644 index 0000000000..cab2bb7364 --- /dev/null +++ b/test/scheduler_parameter_visibilities_test.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, RunOnceTask +import server_test + +import luigi +import luigi.scheduler +import luigi.worker +from luigi.parameter import ParameterVisibility +import json +import time + + +class SchedulerParameterVisibilitiesTest(LuigiTestCase): + def test_task_with_deps(self): + s = luigi.scheduler.Scheduler() + with luigi.worker.Worker(scheduler=s) as w: + class DynamicTask(RunOnceTask): + dynamic_public = luigi.Parameter(default="dynamic_public") + dynamic_hidden = luigi.Parameter(default="dynamic_hidden", visibility=ParameterVisibility.HIDDEN) + dynamic_private = luigi.Parameter(default="dynamic_private", visibility=ParameterVisibility.PRIVATE) + + class RequiredTask(RunOnceTask): + required_public = luigi.Parameter(default="required_param") + required_hidden = luigi.Parameter(default="required_hidden", visibility=ParameterVisibility.HIDDEN) + required_private = luigi.Parameter(default="required_private", visibility=ParameterVisibility.PRIVATE) + + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + def requires(self): + return required_task + + def run(self): + yield dynamic_task + + dynamic_task = DynamicTask() + required_task = RequiredTask() + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + task_deps = s.dep_graph(task_id=task.task_id) + required_task_deps = s.dep_graph(task_id=required_task.task_id) + dynamic_task_deps = s.dep_graph(task_id=dynamic_task.task_id) + + self.assertEqual('Task(a=a, d=d)', task_deps[task.task_id]['display_name']) + self.assertEqual('RequiredTask(required_public=required_param)', + required_task_deps[required_task.task_id]['display_name']) + self.assertEqual('DynamicTask(dynamic_public=dynamic_public)', + dynamic_task_deps[dynamic_task.task_id]['display_name']) + + self.assertEqual({'a': 'a', 'd': 'd'}, task_deps[task.task_id]['params']) + self.assertEqual({'required_public': 'required_param'}, + required_task_deps[required_task.task_id]['params']) + self.assertEqual({'dynamic_public': 'dynamic_public'}, + dynamic_task_deps[dynamic_task.task_id]['params']) + + def test_public_and_hidden_params(self): + s = luigi.scheduler.Scheduler() + with luigi.worker.Worker(scheduler=s) as w: + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + t = s._state.get_task(task.task_id) + self.assertEqual({'b': 'b'}, t.hidden_params) + self.assertEqual({'a': 'a', 'd': 'd'}, t.public_params) + self.assertEqual({'a': 0, 'b': 1, 'd': 0}, t.param_visibilities) + + +class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + +class RemoteSchedulerParameterVisibilitiesTest(server_test.ServerTestBase): + def test_public_params(self): + task = Task() + luigi.build(tasks=[task], workers=2, scheduler_port=self.get_http_port()) + + time.sleep(1) + + response = self.fetch('/api/graph') + + body = response.body + decoded = body.decode('utf8').replace("'", '"') + data = json.loads(decoded) + + self.assertEqual({'a': 'a', 'd': 'd'}, data['response'][task.task_id]['params']) diff --git a/test/scheduler_test.py b/test/scheduler_test.py index b4f4f6f548..6c83eb03e6 100644 --- a/test/scheduler_test.py +++ b/test/scheduler_test.py @@ -19,10 +19,16 @@ import pickle import tempfile import time +import os +import shutil +from multiprocessing import Process from helpers import unittest import luigi.scheduler +import luigi.server +import luigi.configuration from helpers import with_config +from luigi.target import FileAlreadyExists class SchedulerIoTest(unittest.TestCase): @@ -247,3 +253,69 @@ def test_get_pending_tasks_with_many_done_tasks(self): non_trivial_worker = scheduler_state.get_worker('NON_TRIVIAL') self.assertEqual({'A'}, self.get_pending_ids(non_trivial_worker, scheduler_state)) + + +class FailingOnDoubleRunTask(luigi.Task): + time_to_check_secs = 1 + time_to_run_secs = 2 + output_dir = luigi.Parameter(default="") + + def __init__(self, *args, **kwargs): + super(FailingOnDoubleRunTask, self).__init__(*args, **kwargs) + self.file_name = os.path.join(self.output_dir, "AnyTask") + + def complete(self): + time.sleep(self.time_to_check_secs) # e.g., establish connection + exists = os.path.exists(self.file_name) + time.sleep(self.time_to_check_secs) # e.g., close connection + return exists + + def run(self): + time.sleep(self.time_to_run_secs) + if os.path.exists(self.file_name): + raise FileAlreadyExists(self.file_name) + open(self.file_name, 'w').close() + + +class StableDoneCooldownSecsTest(unittest.TestCase): + + def setUp(self): + self.p = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.p) + + def run_task(self): + return luigi.build([FailingOnDoubleRunTask(output_dir=self.p)], + detailed_summary=True, + parallel_scheduling=True, + parallel_scheduling_processes=2) + + @with_config({'worker': {'keep_alive': 'false'}}) + def get_second_run_result_on_double_run(self): + server_process = Process(target=luigi.server.run) + process = Process(target=self.run_task) + try: + # scheduler is started + server_process.start() + # first run is started + process.start() + time.sleep(FailingOnDoubleRunTask.time_to_run_secs + FailingOnDoubleRunTask.time_to_check_secs) + # second run of the same task is started + second_run_result = self.run_task() + return second_run_result + finally: + process.join(1) + server_process.terminate() + server_process.join(1) + + @with_config({'scheduler': {'stable_done_cooldown_secs': '5'}}) + def test_sending_same_task_twice_with_cooldown_does_not_lead_to_double_run(self): + second_run_result = self.get_second_run_result_on_double_run() + self.assertEqual(second_run_result.scheduling_succeeded, True) + + @with_config({'scheduler': {'stable_done_cooldown_secs': '0'}}) + def test_sending_same_task_twice_without_cooldown_leads_to_double_run(self): + second_run_result = self.get_second_run_result_on_double_run() + self.assertEqual(second_run_result.scheduling_succeeded, False) + diff --git a/test/unknown_state_handling_test.py b/test/unknown_state_handling_test.py new file mode 100644 index 0000000000..54e7a2f842 --- /dev/null +++ b/test/unknown_state_handling_test.py @@ -0,0 +1,88 @@ +from helpers import LuigiTestCase + + +import luigi +import luigi.worker +import luigi.execution_summary + + +class DummyRequires(luigi.Task): + def run(self): + print('just a dummy task') + + +class FailInRun(luigi.Task): + def run(self): + print('failing in run') + raise Exception + + +class FailInRequires(luigi.Task): + def requires(self): + print('failing') + raise Exception + + def run(self): + print('running') + + +class FailInDepRequires(luigi.Task): + def requires(self): + return [FailInRequires()] + + def run(self): + print('doing a thing') + + +class FailInDepRun(luigi.Task): + def requires(self): + return [FailInRun()] + + def run(self): + print('doing a thing') + + +class UnknownStateTest(LuigiTestCase): + def setUp(self): + super(UnknownStateTest, self).setUp() + self.scheduler = luigi.scheduler.Scheduler( + prune_on_get_work=False, + retry_count=1 + ) + self.worker = luigi.worker.Worker( + scheduler=self.scheduler, + keep_alive=True + ) + + def run_task(self, task): + self.worker.add(task) # schedule + self.worker.run() # run + + def summary_dict(self): + return luigi.execution_summary._summary_dict(self.worker) + + def test_fail_in_run(self): + self.run_task(FailInRun()) + summary_dict = self.summary_dict() + + self.assertEqual({FailInRun()}, summary_dict['failed']) + + def test_fail_in_requires(self): + self.run_task(FailInRequires()) + summary_dict = self.summary_dict() + + self.assertEqual({FailInRequires()}, summary_dict['scheduling_error']) + + def test_fail_in_dep_run(self): + self.run_task(FailInDepRun()) + summary_dict = self.summary_dict() + + self.assertEqual({FailInRun()}, summary_dict['failed']) + self.assertEqual({FailInDepRun()}, summary_dict['still_pending_not_ext']) + + def test_fail_in_dep_requires(self): + self.run_task(FailInDepRequires()) + summary_dict = self.summary_dict() + + self.assertEqual({FailInRequires()}, summary_dict['scheduling_error']) + self.assertEqual({FailInDepRequires()}, summary_dict['still_pending_not_ext']) diff --git a/test/visible_parameters_test.py b/test/visible_parameters_test.py new file mode 100644 index 0000000000..e644aa7cb0 --- /dev/null +++ b/test/visible_parameters_test.py @@ -0,0 +1,95 @@ +import luigi +from luigi.parameter import ParameterVisibility +from helpers import unittest +import json + + +class TestTask1(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', significant=True) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE, significant=True) + + +class TestTask2(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PRIVATE) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PRIVATE) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE) + + +class TestTask3(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.HIDDEN, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.HIDDEN, significant=True) + + +class TestTask4(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PUBLIC, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PUBLIC, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PUBLIC, significant=True) + + +class Test(unittest.TestCase): + def test_to_str_params(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2'}) + + task = TestTask2() + + self.assertEqual(task.to_str_params(), {}) + + task = TestTask3() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2', 'param_three': '3'}) + + def test_all_public_equals_all_hidden(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(), hidden.to_str_params()) + + def test_all_public_equals_all_hidden_using_significant(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(only_significant=True), hidden.to_str_params(only_significant=True)) + + def test_private_params_and_significant(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), task.to_str_params(only_significant=True)) + + def test_param_visibilities(self): + task = TestTask1() + + self.assertEqual(task._get_param_visibilities(), {'param_one': 1, 'param_two': 0}) + + def test_incorrect_visibility_value(self): + class Task(luigi.Task): + a = luigi.Parameter(default='val', visibility=5) + + task = Task() + + self.assertEqual(task._get_param_visibilities(), {'a': 0}) + + def test_task_id_exclude_hidden_and_private_params(self): + task = TestTask1() + + self.assertEqual({'param_two': '2'}, task.to_str_params(only_public=True)) + + def test_json_dumps(self): + public = json.dumps(ParameterVisibility.PUBLIC.serialize()) + hidden = json.dumps(ParameterVisibility.HIDDEN.serialize()) + private = json.dumps(ParameterVisibility.PRIVATE.serialize()) + + self.assertEqual('0', public) + self.assertEqual('1', hidden) + self.assertEqual('2', private) + + public = json.loads(public) + hidden = json.loads(hidden) + private = json.loads(private) + + self.assertEqual(0, public) + self.assertEqual(1, hidden) + self.assertEqual(2, private) diff --git a/test/worker_test.py b/test/worker_test.py index 5658a0531c..48edecb0c0 100644 --- a/test/worker_test.py +++ b/test/worker_test.py @@ -736,10 +736,94 @@ def complete(self): self.assertTrue(self.w.add(task)) self.assertTrue(self.w.run()) - for task in tasks: + for i, task in enumerate(tasks): self.assertTrue(task.complete()) # only task number 9 should run self.assertFalse(task.has_run and task.value < 9) + # only task number 9 should have more than default batched_params + self.assertFalse(task.batched_params != {'value': [i]} and task.value < 9) + + # Task number 9 should have batched_params of all tasks values + self.assertEquals(tasks[-1].batched_params, {'value': list(range(10))}) + + def test_run_batch_jobs_which_overlap_subset_batch(self): + completed = set() + + class MaxBatchJob(luigi.Task): + value = luigi.IntParameter(batch_method=max) + has_run = False + + def run(self): + completed.add(self.value) + self.has_run = True + + def complete(self): + return any(self.value <= ran for ran in completed) + + tasks = [MaxBatchJob(i) for i in range(5)] + tasks_batch_2 = [MaxBatchJob(i) for i in range(5, 10)] + tasks_which_overlap = tasks + tasks_batch_2 + for task in tasks: + self.assertTrue(self.w.add(task)) + # "Duplilcate" tasks added w2 + for task in tasks_which_overlap: + self.assertTrue(self.w2.add(task)) + + #Run tasks on w "scheduled" first + self.assertTrue(self.w.run()) + + #Run tasks on w2 + self.assertTrue(self.w2.run()) + + for i, task in enumerate(tasks_which_overlap): + #Only 4 and 9 run + self.assertFalse(task.has_run and task.value not in (4,9)) + #Only 4 and 9 have more than default batched_params (content tested below) + self.assertFalse(task.batched_params != {'value': [i]} and task.value not in (4,9)) + + #Task number 4 should have batched_params of the first batch + self.assertEquals(tasks[-1].batched_params, {'value' : list(range(5))}) + + #Task number 9 should have batched_params of all remaining tasks + self.assertEquals(tasks_batch_2[-1].batched_params, {'value' : list(range(5, 10))}) + + def test_run_batch_jobs_which_overlap_superset_batch(self): + completed = set() + + class MaxBatchJob(luigi.Task): + value = luigi.IntParameter(batch_method=max) + has_run = False + + def run(self): + completed.add(self.value) + self.has_run = True + + def complete(self): + return any(self.value <= ran for ran in completed) + + tasks = [MaxBatchJob(i) for i in range(5)] + tasks_batch_2 = [MaxBatchJob(i) for i in range(5, 10)] + tasks_which_overlap = tasks + tasks_batch_2 + for task in tasks: + self.assertTrue(self.w.add(task)) + # "Duplilcate" tasks added w2 + for task in tasks_which_overlap: + self.assertTrue(self.w2.add(task)) + + #Run all tasks one batch + self.assertTrue(self.w2.run()) + + #Run tasks on w (should be a no op) + self.assertTrue(self.w.run()) + + for i, task in enumerate(tasks_which_overlap): + #Only 9 ran + self.assertFalse(task.has_run and task.value != 9) + #Only 4 and 9 have batched_params (content tested below) + self.assertFalse(task.batched_params != {'value': [i]} and task.value != 9) + + #Task number 9 should have batched_params of all tasks + self.assertEquals(tasks_batch_2[-1].batched_params, {'value' : list(range(10))}) def test_run_batch_job_unbatched(self): completed = set() diff --git a/tox.ini b/tox.ini index 2b5de252e2..22ae1a2c9f 100644 --- a/tox.ini +++ b/tox.ini @@ -104,6 +104,7 @@ deps = sqlalchemy Sphinx>=1.4.4,<1.5 sphinx_rtd_theme + enum34>1.1.0 commands = # build API docs sphinx-apidoc -o doc/api -T luigi --separate