diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000..f83f67ab75 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,12 @@ +# The following patterns are used to auto-assign review requests +# to specific individuals. Order is important; the last matching +# pattern takes the most precedence. + +# These owners will be the default owners for everything in +# the repo. Unless a later match takes precedence, +* @dlstadther @Tarrasch @ulzha + +# Specific files, directories, paths, or file types can be +# assigned more specificially. +contrib/redshift*.py @dlstadther + diff --git a/README.rst b/README.rst index 3a76eb03a4..f3a02646ab 100644 --- a/README.rst +++ b/README.rst @@ -149,6 +149,7 @@ or held presentations about Luigi: * `Leipzig University Library `_ `(presentation, 2016) `__ / `(project) `__ * `Synetiq `_ `(presentation, 2017) `__ * `Glossier `_ `(blog, 2018) `__ +* `Data Revenue `_ `(blog, 2018) `_ Some more companies are using Luigi but haven't had a chance yet to write about it: diff --git a/doc/configuration.rst b/doc/configuration.rst index 5cf649d8bb..dec55c3679 100644 --- a/doc/configuration.rst +++ b/doc/configuration.rst @@ -1,18 +1,35 @@ Configuration ============= -All configuration can be done by adding configuration files. They are looked for in: +All configuration can be done by adding configuration files. - * ``/etc/luigi/client.cfg`` - * ``luigi.cfg`` (or its legacy name ``client.cfg``) in your current working directory - * ``LUIGI_CONFIG_PATH`` environment variable +Supported config parsers: +* ``cfg`` (default) +* ``toml`` -in increasing order of preference. The order only matters in case of key conflicts (see docs for ConfigParser.read_). These files are meant for both the client and ``luigid``. If you decide to specify your own configuration you should make sure that both the client and ``luigid`` load it properly. +You can choose right parser via ``LUIGI_CONFIG_PARSER`` environment variable. For example, ``LUIGI_CONFIG_PARSER=toml``. + +Default (cfg) parser are looked for in: + +* ``/etc/luigi/client.cfg`` (deprecated) +* ``/etc/luigi/luigi.cfg`` +* ``client.cfg`` (deprecated) +* ``luigi.cfg`` +* ``LUIGI_CONFIG_PATH`` environment variable + +`TOML `_ parser are looked for in: + +* ``/etc/luigi/luigi.toml`` +* ``luigi.toml`` +* ``LUIGI_CONFIG_PATH`` environment variable + +Both config lists increase in priority (from low to high). The order only matters in case of key conflicts (see docs for ConfigParser.read_). These files are meant for both the client and ``luigid``. If you decide to specify your own configuration you should make sure that both the client and ``luigid`` load it properly. .. _ConfigParser.read: https://docs.python.org/3.6/library/configparser.html#configparser.ConfigParser.read -The config file is broken into sections, each controlling a different part of the config. Example configuration file: +The config file is broken into sections, each controlling a different part of the config. +Example cfg config: .. code:: ini @@ -23,6 +40,17 @@ The config file is broken into sections, each controlling a different part of th [core] scheduler_host=luigi-host.mycompany.foo +Example toml config: + +.. code:: python + + [hadoop] + version = "cdh4" + streaming-jar = "/usr/lib/hadoop-xyz/hadoop-streaming-xyz-123.jar" + + [core] + scheduler_host = "luigi-host.mycompany.foo" + .. _ParamConfigIngestion: diff --git a/doc/parameters.rst b/doc/parameters.rst index 1a4a8a721b..6dca716c30 100644 --- a/doc/parameters.rst +++ b/doc/parameters.rst @@ -88,6 +88,25 @@ are not the same instance: >>> hash(c) == hash(d) True +Parameter visibility +^^^^^^^^^^^^^^^^^^^^ + +Using :class:`~luigi.parameter.ParameterVisibility` you can configure parameter visibility. By default, all +parameters are public, but you can also set them hidden or private. + +.. code:: python + + >>> import luigi + >>> from luigi.parameter import ParameterVisibility + + >>> luigi.Parameter(visibility=ParameterVisibility.PRIVATE) + +``ParameterVisibility.PUBLIC`` (default) - visible everywhere + +``ParameterVisibility.HIDDEN`` - ignored in WEB-view, but saved into database if save db_history is true + +``ParameterVisibility.PRIVATE`` - visible only inside task. + Parameter types ^^^^^^^^^^^^^^^ diff --git a/luigi/configuration/__init__.py b/luigi/configuration/__init__.py new file mode 100644 index 0000000000..21ff657fd8 --- /dev/null +++ b/luigi/configuration/__init__.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .cfg_parser import LuigiConfigParser +from .core import get_config, add_config_path +from .toml_parser import LuigiTomlParser + + +__all__ = [ + 'add_config_path', + 'get_config', + 'LuigiConfigParser', + 'LuigiTomlParser', +] diff --git a/luigi/configuration/base_parser.py b/luigi/configuration/base_parser.py new file mode 100644 index 0000000000..9b70a78155 --- /dev/null +++ b/luigi/configuration/base_parser.py @@ -0,0 +1,41 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging + + +# IMPORTANT: don't inherit from `object`! +# ConfigParser have some troubles in this case. +# More info: https://stackoverflow.com/a/19323238 +class BaseParser: + @classmethod + def instance(cls, *args, **kwargs): + """ Singleton getter """ + if cls._instance is None: + cls._instance = cls(*args, **kwargs) + loaded = cls._instance.reload() + logging.getLogger('luigi-interface').info('Loaded %r', loaded) + + return cls._instance + + @classmethod + def add_config_path(cls, path): + cls._config_paths.append(path) + cls.reload() + + @classmethod + def reload(cls): + return cls.instance().read(cls._config_paths) diff --git a/luigi/configuration.py b/luigi/configuration/cfg_parser.py similarity index 80% rename from luigi/configuration.py rename to luigi/configuration/cfg_parser.py index 6d3ddc2a33..e0df87f10a 100644 --- a/luigi/configuration.py +++ b/luigi/configuration/cfg_parser.py @@ -29,7 +29,6 @@ See :doc:`/configuration` for more info. """ -import logging import os import warnings @@ -38,9 +37,12 @@ except ImportError: from configparser import ConfigParser, NoOptionError, NoSectionError +from .base_parser import BaseParser -class LuigiConfigParser(ConfigParser): + +class LuigiConfigParser(BaseParser, ConfigParser): NO_DEFAULT = object() + enabled = True _instance = None _config_paths = [ '/etc/luigi/client.cfg', # Deprecated old-style global luigi config @@ -48,27 +50,6 @@ class LuigiConfigParser(ConfigParser): 'client.cfg', # Deprecated old-style local luigi config 'luigi.cfg', ] - if 'LUIGI_CONFIG_PATH' in os.environ: - config_file = os.environ['LUIGI_CONFIG_PATH'] - if not os.path.isfile(config_file): - warnings.warn("LUIGI_CONFIG_PATH points to a file which does not exist. Invalid file: {path}".format(path=config_file)) - else: - _config_paths.append(config_file) - - @classmethod - def add_config_path(cls, path): - cls._config_paths.append(path) - cls.reload() - - @classmethod - def instance(cls, *args, **kwargs): - """ Singleton getter """ - if cls._instance is None: - cls._instance = cls(*args, **kwargs) - loaded = cls._instance.reload() - logging.getLogger('luigi-interface').info('Loaded %r', loaded) - - return cls._instance @classmethod def reload(cls): @@ -124,10 +105,3 @@ def set(self, section, option, value=None): ConfigParser.add_section(self, section) return ConfigParser.set(self, section, option, value) - - -def get_config(): - """ - Convenience method (for backwards compatibility) for accessing config singleton. - """ - return LuigiConfigParser.instance() diff --git a/luigi/configuration/core.py b/luigi/configuration/core.py new file mode 100644 index 0000000000..7ca0d6e673 --- /dev/null +++ b/luigi/configuration/core.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import logging +import os +import warnings + +from .cfg_parser import LuigiConfigParser +from .toml_parser import LuigiTomlParser + + +logger = logging.getLogger('luigi-interface') + + +PARSERS = { + 'cfg': LuigiConfigParser, + 'conf': LuigiConfigParser, + 'ini': LuigiConfigParser, + 'toml': LuigiTomlParser, +} + +# select parser via env var +DEFAULT_PARSER = 'cfg' +PARSER = os.environ.get('LUIGI_CONFIG_PARSER', DEFAULT_PARSER) +if PARSER not in PARSERS: + warnings.warn("Invalid parser: {parser}".format(parser=PARSER)) + PARSER = DEFAULT_PARSER + + +def get_config(parser=PARSER): + """Get configs singleton for parser + """ + + parser_class = PARSERS[parser] + if not parser_class.enabled: + logger.error(( + "Parser not installed yet. " + "Please, install luigi with required parser:\n" + "pip install luigi[{parser}]" + ).format(parser) + ) + + return parser_class.instance() + + +def add_config_path(path): + """Select config parser by file extension and add path into parser. + """ + if not os.path.isfile(path): + warnings.warn("Config file does not exist: {path}".format(path=path)) + return False + + # select parser by file extension + _base, ext = os.path.splitext(path) + if ext and ext[1:] in PARSERS: + parser_class = PARSERS[ext[1:]] + else: + parser_class = PARSERS[PARSER] + + # add config path to parser + parser_class.add_config_path(path) + return True + + +if 'LUIGI_CONFIG_PATH' in os.environ: + add_config_path(os.environ['LUIGI_CONFIG_PATH']) diff --git a/luigi/configuration/toml_parser.py b/luigi/configuration/toml_parser.py new file mode 100644 index 0000000000..8e6fa3923b --- /dev/null +++ b/luigi/configuration/toml_parser.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Cindicator Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os.path + +try: + import toml +except ImportError: + toml = False + +from .base_parser import BaseParser + + +class LuigiTomlParser(BaseParser): + NO_DEFAULT = object() + enabled = bool(toml) + data = dict() + _instance = None + _config_paths = [ + '/etc/luigi/luigi.toml', + 'luigi.toml', + ] + + @staticmethod + def _update_data(data, new_data): + if not new_data: + return data + if not data: + return new_data + for section, content in new_data.items(): + if section not in data: + data[section] = dict() + data[section].update(content) + return data + + def read(self, config_paths): + self.data = dict() + for path in config_paths: + if os.path.isfile(path): + self.data = self._update_data(self.data, toml.load(path)) + return self.data + + def get(self, section, option, default=NO_DEFAULT, **kwargs): + try: + return self.data[section][option] + except KeyError: + if default is self.NO_DEFAULT: + raise + return default + + def getboolean(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getint(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getfloat(self, section, option, default=NO_DEFAULT): + return self.get(section, option, default) + + def getintdict(self, section): + return self.data.get(section, {}) + + def set(self, section, option, value=None): + if section not in self.data: + self.data[section] = {} + self.data[section][option] = value + + def __getitem__(self, name): + return self.data[name] diff --git a/luigi/contrib/postgres.py b/luigi/contrib/postgres.py index d06b2a1f69..363cde70b0 100644 --- a/luigi/contrib/postgres.py +++ b/luigi/contrib/postgres.py @@ -106,11 +106,14 @@ class PostgresTarget(luigi.Target): """ marker_table = luigi.configuration.get_config().get('postgres', 'marker-table', 'table_updates') + # if not supplied, fall back to default Postgres port + DEFAULT_DB_PORT = 5432 + # Use DB side timestamps or client side timestamps in the marker_table use_db_timestamps = True def __init__( - self, host, database, user, password, table, update_id, port=5432 + self, host, database, user, password, table, update_id, port=None ): """ Args: @@ -126,7 +129,7 @@ def __init__( self.host, self.port = host.split(':') else: self.host = host - self.port = port + self.port = port or self.DEFAULT_DB_PORT self.database = database self.user = user self.password = password @@ -323,6 +326,8 @@ def run(self): self.init_copy(connection) self.copy(cursor, tmp_file) self.post_copy(connection) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) except psycopg2.ProgrammingError as e: if e.pgcode == psycopg2.errorcodes.UNDEFINED_TABLE and attempt == 0: # if first attempt fails with "relation not found", try creating table diff --git a/luigi/contrib/rdbms.py b/luigi/contrib/rdbms.py index d7b82d4c9d..a955c87275 100644 --- a/luigi/contrib/rdbms.py +++ b/luigi/contrib/rdbms.py @@ -27,7 +27,125 @@ logger = logging.getLogger('luigi-interface') -class CopyToTable(luigi.task.MixinNaiveBulkComplete, luigi.Task): +class _MetadataColumnsMixin(object): + """Provide an additional behavior that adds columns and values to tables + + This mixin is used to provide an additional behavior that allow a task to + add generic metadata columns to every table created for both PSQL and + Redshift. + + Example: + + This is a use-case example of how this mixin could come handy and how + to use it. + + .. code:: python + + class CommonMetaColumnsBehavior(object): + def update_report_execution_date_query(self): + query = "UPDATE {0} " \ + "SET date_param = DATE '{1}' " \ + "WHERE date_param IS NULL".format(self.table, self.date) + + return query + + @property + def metadata_columns(self): + if self.date: + cols.append(('date_param', 'VARCHAR')) + + return cols + + @property + def metadata_queries(self): + queries = [self.update_created_tz_query()] + if self.date: + queries.append(self.update_report_execution_date_query()) + + return queries + + + class RedshiftCopyCSVToTableFromS3(CommonMetaColumnsBehavior, redshift.S3CopyToTable): + "We have some business override here that would only add noise to the + example, so let's assume that this is only a shell." + pass + + + class UpdateTableA(RedshiftCopyCSVToTableFromS3): + date = luigi.Parameter() + table = 'tableA' + + def queries(): + return [query_content_for('/queries/deduplicate_dupes.sql')] + + + class UpdateTableB(RedshiftCopyCSVToTableFromS3): + date = luigi.Parameter() + table = 'tableB' + """ + @property + def metadata_columns(self): + """Returns the default metadata columns. + + Those columns are columns that we want each tables to have by default. + """ + return [] + + @property + def metadata_queries(self): + return [] + + @property + def enable_metadata_columns(self): + return False + + def _add_metadata_columns(self, connection): + cursor = connection.cursor() + + for column in self.metadata_columns: + if len(column) == 0: + raise ValueError("_add_metadata_columns is unable to infer column information from column {column} for {table}".format(column=column, + table=self.table)) + + column_name = column[0] + if not self._column_exists(cursor, column_name): + logger.info('Adding missing metadata column {column} to {table}'.format(column=column, table=self.table)) + self._add_column_to_table(cursor, column) + + def _column_exists(self, cursor, column_name): + if '.' in self.table: + schema, table = self.table.split('.') + query = "SELECT 1 AS column_exists " \ + "FROM information_schema.columns " \ + "WHERE table_schema = LOWER('{0}') AND table_name = LOWER('{1}') AND column_name = LOWER('{2}') LIMIT 1;".format(schema, table, column_name) + else: + query = "SELECT 1 AS column_exists " \ + "FROM information_schema.columns " \ + "WHERE table_name = LOWER('{0}') AND column_name = LOWER('{1}') LIMIT 1;".format(self.table, column_name) + + cursor.execute(query) + result = cursor.fetchone() + return bool(result) + + def _add_column_to_table(self, cursor, column): + if len(column) == 1: + raise ValueError("_add_column_to_table() column type not specified for {column}".format(column=column[0])) + elif len(column) == 2: + query = "ALTER TABLE {table} ADD COLUMN {column};".format(table=self.table, column=' '.join(column)) + elif len(column) == 3: + query = "ALTER TABLE {table} ADD COLUMN {column} ENCODE {encoding};".format(table=self.table, column=' '.join(column[0:2]), encoding=column[2]) + else: + raise ValueError("_add_column_to_table() found no matching behavior for {column}".format(column=column)) + + cursor.execute(query) + + def post_copy_metacolumns(self, cursor): + logger.info('Executing post copy metadata queries') + for query in self.metadata_queries: + cursor.execute(query) + + +class CopyToTable(luigi.task.MixinNaiveBulkComplete, _MetadataColumnsMixin, luigi.Task): """ An abstract task for inserting a data set into RDBMS. @@ -120,6 +238,9 @@ def init_copy(self, connection): if hasattr(self, "clear_table"): raise Exception("The clear_table attribute has been removed. Override init_copy instead!") + if self.enable_metadata_columns: + self._add_metadata_columns(connection.cursor()) + def post_copy(self, connection): """ Override to perform custom queries. diff --git a/luigi/contrib/redshift.py b/luigi/contrib/redshift.py index 5302c3d5f0..6792efff09 100644 --- a/luigi/contrib/redshift.py +++ b/luigi/contrib/redshift.py @@ -135,6 +135,9 @@ class RedshiftTarget(postgres.PostgresTarget): 'marker-table', 'table_updates') + # if not supplied, fall back to default Redshift port + DEFAULT_DB_PORT = 5439 + use_db_timestamps = False @@ -370,6 +373,9 @@ def run(self): self.copy(cursor, path) self.post_copy(cursor) + if self.enable_metadata_columns: + self.post_copy_metacolumns(cursor) + # update marker table output.touch(connection) connection.commit() @@ -469,6 +475,9 @@ def init_copy(self, connection): logger.info("Creating table %s", self.table) self.create_table(connection) + if self.enable_metadata_columns: + self._add_metadata_columns(connection) + if self.do_truncate_table: logger.info("Truncating table %s", self.table) self.truncate_table(connection) @@ -485,6 +494,14 @@ def post_copy(self, cursor): for query in self.queries: cursor.execute(query) + def post_copy_metacolums(self, cursor): + """ + Performs post-copy to fill metadata columns. + """ + logger.info('Executing post copy metadata queries') + for query in self.metadata_queries: + cursor.execute(query) + class S3CopyJSONToTable(S3CopyToTable, _CredentialsMixin): """ diff --git a/luigi/contrib/s3.py b/luigi/contrib/s3.py index fb5fbbb83f..56fc655b0d 100644 --- a/luigi/contrib/s3.py +++ b/luigi/contrib/s3.py @@ -226,6 +226,8 @@ def remove(self, path, recursive=True): def move(self, source_path, destination_path, **kwargs): """ Rename/move an object from one S3 location to another. + :param source_path: The `s3://` path of the directory or key to copy from + :param destination_path: The `s3://` path of the directory or key to copy to :param kwargs: Keyword arguments are passed to the boto3 function `copy` """ self.copy(source_path, destination_path, **kwargs) @@ -243,12 +245,11 @@ def get_key(self, path): def put(self, local_path, destination_s3_path, **kwargs): """ Put an object stored locally to an S3 path. - + :param local_path: Path to source local file + :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto function `put_object` """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) # put the file self.put_multipart(local_path, destination_s3_path, **kwargs) @@ -256,11 +257,11 @@ def put(self, local_path, destination_s3_path, **kwargs): def put_string(self, content, destination_s3_path, **kwargs): """ Put a string to an S3 path. + :param content: Data str + :param destination_s3_path: URL for target S3 location :param kwargs: Keyword arguments are passed to the boto3 function `put_object` """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) (bucket, key) = self._path_to_bucket_and_key(destination_s3_path) # validate the bucket @@ -279,9 +280,7 @@ def put_multipart(self, local_path, destination_s3_path, part_size=8388608, **kw :param part_size: Part size in bytes. Default: 8388608 (8MB) :param kwargs: Keyword arguments are passed to the boto function `upload_fileobj` as ExtraArgs """ - if 'encrypt_key' in kwargs: - raise DeprecatedBotoClientException( - 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + self._check_deprecated_argument(**kwargs) from boto3.s3.transfer import TransferConfig # default part size for boto3 is 8Mb, changing it to fit part_size @@ -446,6 +445,7 @@ def listdir(self, path, start_time=None, end_time=None, return_key=False): """ Get an iterable with S3 folder contents. Iterable contains paths relative to queried path. + :param path: URL for target S3 location :param start_time: Optional argument to list files with modified (offset aware) datetime after start_time :param end_time: Optional argument to list files with modified (offset aware) datetime before end_time :param return_key: Optional argument, when set to True will return boto3's ObjectSummary (instead of the filename) @@ -482,11 +482,12 @@ def list(self, path, start_time=None, end_time=None, return_key=False): # backw else: yield item[key_path_len:] - def _get_s3_config(self, key=None): + @staticmethod + def _get_s3_config(key=None): defaults = dict(configuration.get_config().defaults()) try: config = dict(configuration.get_config().items('s3')) - except NoSectionError: + except (NoSectionError, KeyError): return {} # So what ports etc can be read without us having to specify all dtypes for k, v in six.iteritems(config): @@ -500,17 +501,30 @@ def _get_s3_config(self, key=None): return section_only - def _path_to_bucket_and_key(self, path): + @staticmethod + def _path_to_bucket_and_key(path): (scheme, netloc, path, query, fragment) = urlsplit(path) path_without_initial_slash = path[1:] return netloc, path_without_initial_slash - def _is_root(self, key): + @staticmethod + def _is_root(key): return (len(key) == 0) or (key == '/') - def _add_path_delimiter(self, key): + @staticmethod + def _add_path_delimiter(key): return key if key[-1:] == '/' or key == '' else key + '/' + @staticmethod + def _check_deprecated_argument(**kwargs): + """ + If `encrypt_key` is part of the arguments raise an exception + :return: None + """ + if 'encrypt_key' in kwargs: + raise DeprecatedBotoClientException( + 'encrypt_key deprecated in boto3. Please refer to boto3 documentation for encryption details.') + def _validate_bucket(self, bucket_name): exists = True @@ -525,18 +539,15 @@ def _validate_bucket(self, bucket_name): return exists def _exists(self, bucket, key): - s3_key = False try: self.s3.Object(bucket, key).load() except botocore.exceptions.ClientError as e: if e.response['Error']['Code'] in ['NoSuchKey', '404']: - s3_key = False + return False else: raise - else: - s3_key = True - if s3_key: - return True + + return True class AtomicS3File(AtomicLocalFile): diff --git a/luigi/lock.py b/luigi/lock.py index 1b31ed0c90..e1a604f540 100644 --- a/luigi/lock.py +++ b/luigi/lock.py @@ -21,6 +21,7 @@ """ from __future__ import print_function +import errno import hashlib import os import sys @@ -102,10 +103,14 @@ def acquire_for(pid_dir, num_available=1, kill_signal=None): my_pid, my_cmd, pid_file = get_info(pid_dir) - # Check if there is a pid file corresponding to this name - if not os.path.exists(pid_dir): + # Create a pid file if it does not exist + try: os.mkdir(pid_dir) os.chmod(pid_dir, 0o777) + except OSError as exc: + if exc.errno != errno.EEXIST: + raise + pass # Let variable "pids" be all pids who exist in the .pid-file who are still # about running the same command. diff --git a/luigi/parameter.py b/luigi/parameter.py index 7485d09f61..4c4c3853a0 100644 --- a/luigi/parameter.py +++ b/luigi/parameter.py @@ -23,6 +23,7 @@ import abc import datetime import warnings +from enum import IntEnum import json from json import JSONEncoder from collections import OrderedDict, Mapping @@ -40,10 +41,26 @@ from luigi import configuration from luigi.cmdline_parser import CmdlineParser - _no_value = object() +class ParameterVisibility(IntEnum): + """ + Possible values for the parameter visibility option. Public is the default. + See :doc:`/parameters` for more info. + """ + PUBLIC = 0 + HIDDEN = 1 + PRIVATE = 2 + + @classmethod + def has_value(cls, value): + return any(value == item.value for item in cls) + + def serialize(self): + return self.value + + class ParameterException(Exception): """ Base exception. @@ -113,7 +130,8 @@ def run(self): _counter = 0 # non-atomically increasing counter used for ordering parameters. def __init__(self, default=_no_value, is_global=False, significant=True, description=None, - config_path=None, positional=True, always_in_help=False, batch_method=None): + config_path=None, positional=True, always_in_help=False, batch_method=None, + visibility=ParameterVisibility.PUBLIC): """ :param default: the default value for this parameter. This should match the type of the Parameter, i.e. ``datetime.date`` for ``DateParameter`` or ``int`` for @@ -140,6 +158,10 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip parameter values into a single value. Used when receiving batched parameter lists from the scheduler. See :ref:`batch_method` + + :param visibility: A Parameter whose value is a :py:class:`~luigi.parameter.ParameterVisibility`. + Default value is ParameterVisibility.PUBLIC + """ self._default = default self._batch_method = batch_method @@ -150,6 +172,7 @@ def __init__(self, default=_no_value, is_global=False, significant=True, descrip positional = False self.significant = significant # Whether different values for this parameter will differentiate otherwise equal tasks self.positional = positional + self.visibility = visibility if ParameterVisibility.has_value(visibility) else ParameterVisibility.PUBLIC self.description = description self.always_in_help = always_in_help @@ -168,7 +191,7 @@ def _get_value_from_config(self, section, name): try: value = conf.get(section, name) - except (NoSectionError, NoOptionError): + except (NoSectionError, NoOptionError, KeyError): return _no_value return self.parse(value) @@ -195,11 +218,11 @@ def _value_iterator(self, task_name, param_name): yield (self._get_value_from_config(task_name, param_name), None) yield (self._get_value_from_config(task_name, param_name.replace('_', '-')), 'Configuration [{}] {} (with dashes) should be avoided. Please use underscores.'.format( - task_name, param_name)) + task_name, param_name)) if self._config_path: yield (self._get_value_from_config(self._config_path['section'], self._config_path['name']), 'The use of the configuration [{}] {} is deprecated. Please use [{}] {}'.format( - self._config_path['section'], self._config_path['name'], task_name, param_name)) + self._config_path['section'], self._config_path['name'], task_name, param_name)) yield (self._default, None) def has_task_value(self, task_name, param_name): @@ -689,6 +712,7 @@ class DateIntervalParameter(Parameter): (eg. "2015-W35"). In addition, it also supports arbitrary date intervals provided as two dates separated with a dash (eg. "2015-11-04-2015-12-04"). """ + def parse(self, s): """ Parses a :py:class:`~luigi.date_interval.DateInterval` from the input. @@ -740,8 +764,10 @@ def field(key): def optional_field(key): return "(%s)?" % field(key) + # A little loose: ISO 8601 does not allow weeks in combination with other fields, but this regex does (as does python timedelta) - regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) + regex = "P(%s|%s(T%s)?)" % (field("weeks"), optional_field("days"), + "".join([optional_field(key) for key in ["hours", "minutes", "seconds"]])) return self._apply_regex(regex, input) def _parseSimple(self, input): @@ -905,6 +931,7 @@ class _DictParamEncoder(JSONEncoder): """ JSON encoder for :py:class:`~DictParameter`, which makes :py:class:`~_FrozenOrderedDict` JSON serializable. """ + def default(self, obj): if isinstance(obj, _FrozenOrderedDict): return obj.get_wrapped() @@ -943,6 +970,7 @@ def run(self): tags, that are dynamically constructed outside Luigi), or you have a complex parameter containing logically related values (like a database connection config). """ + def normalize(self, value): """ Ensure that dictionary parameter is converted to a _FrozenOrderedDict so it can be hashed. @@ -996,6 +1024,7 @@ def run(self): $ luigi --module my_tasks MyTask --grades '[100,70]' """ + def normalize(self, x): """ Ensure that struct is recursively converted to a tuple so it can be hashed. @@ -1053,6 +1082,7 @@ def run(self): $ luigi --module my_tasks MyTask --book_locations '((12,3),(4,15),(52,1))' """ + def parse(self, x): """ Parse an individual value from the input. @@ -1100,6 +1130,7 @@ class MyTask(luigi.Task): $ luigi --module my_tasks MyTask --my-param-1 -3 --my-param-2 -2 """ + def __init__(self, left_op=operator.le, right_op=operator.lt, *args, **kwargs): """ :param function var_type: The type of the input variable, e.g. int or float. @@ -1178,6 +1209,7 @@ class MyTask(luigi.Task): same type and transparency of parameter value on the command line is desired. """ + def __init__(self, var_type=str, *args, **kwargs): """ :param function var_type: The type of the input variable, e.g. str, int, diff --git a/luigi/rpc.py b/luigi/rpc.py index a18bd58ded..1c4580a46e 100644 --- a/luigi/rpc.py +++ b/luigi/rpc.py @@ -116,6 +116,7 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._rpc_retry_attempts = config.getint('core', 'rpc-retry-attempts', 3) self._rpc_retry_wait = config.getint('core', 'rpc-retry-wait', 30) + self._rpc_log_retries = config.getboolean('core', 'rpc-log-retries', True) if HAS_REQUESTS: self._fetcher = RequestsFetcher(requests.Session()) @@ -123,24 +124,26 @@ def __init__(self, url='http://localhost:8082/', connect_timeout=None): self._fetcher = URLLibFetcher() def _wait(self): - logger.info("Wait for %d seconds" % self._rpc_retry_wait) + if self._rpc_log_retries: + logger.info("Wait for %d seconds" % self._rpc_retry_wait) time.sleep(self._rpc_retry_wait) - def _fetch(self, url_suffix, body, log_exceptions=True): + def _fetch(self, url_suffix, body): full_url = _urljoin(self._url, url_suffix) last_exception = None attempt = 0 while attempt < self._rpc_retry_attempts: attempt += 1 if last_exception: - logger.info("Retrying attempt %r of %r (max)" % (attempt, self._rpc_retry_attempts)) + if self._rpc_log_retries: + logger.info("Retrying attempt %r of %r (max)" % (attempt, self._rpc_retry_attempts)) self._wait() # wait for a bit and retry try: response = self._fetcher.fetch(full_url, body, self._connect_timeout) break except self._fetcher.raises as e: last_exception = e - if log_exceptions: + if self._rpc_log_retries: logger.warning("Failed connecting to remote scheduler %r", self._url, exc_info=True) continue @@ -152,11 +155,11 @@ def _fetch(self, url_suffix, body, log_exceptions=True): ) return response - def _request(self, url, data, log_exceptions=True, attempts=3, allow_null=True): + def _request(self, url, data, attempts=3, allow_null=True): body = {'data': json.dumps(data)} for _ in range(attempts): - page = self._fetch(url, body, log_exceptions) + page = self._fetch(url, body) response = json.loads(page)["response"] if allow_null or response is not None: return response diff --git a/luigi/scheduler.py b/luigi/scheduler.py index b7993c760b..fbc01a838d 100644 --- a/luigi/scheduler.py +++ b/luigi/scheduler.py @@ -49,6 +49,7 @@ from luigi.task_status import DISABLED, DONE, FAILED, PENDING, RUNNING, SUSPENDED, UNKNOWN, \ BATCH_RUNNING from luigi.task import Config +from luigi.parameter import ParameterVisibility logger = logging.getLogger(__name__) @@ -280,7 +281,7 @@ def __eq__(self, other): class Task(object): def __init__(self, task_id, status, deps, resources=None, priority=0, family='', module=None, - params=None, accepts_messages=False, tracking_url=None, status_message=None, + params=None, param_visibilities=None, accepts_messages=False, tracking_url=None, status_message=None, progress_percentage=None, retry_policy='notoptional'): self.id = task_id self.stakeholders = set() # workers ids that are somehow related to this task (i.e. don't prune while any of these workers are still active) @@ -301,8 +302,11 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', self.resources = _get_default(resources, {}) self.family = family self.module = module - self.params = _get_default(params, {}) - + self.param_visibilities = _get_default(param_visibilities, {}) + self.params = {} + self.public_params = {} + self.hidden_params = {} + self.set_params(params) self.accepts_messages = accepts_messages self.retry_policy = retry_policy self.failures = Failures(self.retry_policy.disable_window) @@ -318,6 +322,13 @@ def __init__(self, task_id, status, deps, resources=None, priority=0, family='', def __repr__(self): return "Task(%r)" % vars(self) + def set_params(self, params): + self.params = _get_default(params, {}) + self.public_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.PUBLIC} + self.hidden_params = {key: value for key, value in self.params.items() if + self.param_visibilities.get(key, ParameterVisibility.PUBLIC) == ParameterVisibility.HIDDEN} + # TODO(2017-08-10) replace this function with direct calls to batchable # this only exists for backward compatibility def is_batchable(self): @@ -343,7 +354,7 @@ def has_excessive_failures(self): @property def pretty_id(self): - param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.params.items())) + param_str = ', '.join(u'{}={}'.format(key, value) for key, value in sorted(self.public_params.items())) return u'{}({})'.format(self.family, param_str) @@ -778,7 +789,7 @@ def forgive_failures(self, task_id=None): @rpc_method() def add_task(self, task_id=None, status=PENDING, runnable=True, deps=None, new_deps=None, expl=None, resources=None, - priority=0, family='', module=None, params=None, accepts_messages=False, + priority=0, family='', module=None, params=None, param_visibilities=None, accepts_messages=False, assistant=False, tracking_url=None, worker=None, batchable=None, batch_id=None, retry_policy_dict=None, owners=None, **kwargs): """ @@ -802,7 +813,7 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, if worker.enabled: _default_task = self._make_task( task_id=task_id, status=PENDING, deps=deps, resources=resources, - priority=priority, family=family, module=module, params=params, + priority=priority, family=family, module=module, params=params, param_visibilities=param_visibilities, ) else: _default_task = None @@ -817,8 +828,10 @@ def add_task(self, task_id=None, status=PENDING, runnable=True, task.family = family if not getattr(task, 'module', None): task.module = module + if not task.param_visibilities: + task.param_visibilities = _get_default(param_visibilities, {}) if not task.params: - task.params = _get_default(params, {}) + task.set_params(params) if batch_id is not None: task.batch_id = batch_id @@ -1272,6 +1285,7 @@ def _upstream_status(self, task_id, upstream_status_table): def _serialize_task(self, task_id, include_deps=True, deps=None): task = self._state.get_task(task_id) + ret = { 'display_name': task.pretty_id, 'status': task.status, @@ -1280,7 +1294,7 @@ def _serialize_task(self, task_id, include_deps=True, deps=None): 'time_running': getattr(task, "time_running", None), 'start_time': task.time, 'last_updated': getattr(task, "updated", task.time), - 'params': task.params, + 'params': task.public_params, 'name': task.family, 'priority': task.priority, 'resources': task.resources, diff --git a/luigi/task.py b/luigi/task.py index 4340e513dc..08f27b8179 100644 --- a/luigi/task.py +++ b/luigi/task.py @@ -39,6 +39,7 @@ from luigi import parameter from luigi.task_register import Register +from luigi.parameter import ParameterVisibility Parameter = parameter.Parameter logger = logging.getLogger('luigi-interface') @@ -441,7 +442,7 @@ def __init__(self, *args, **kwargs): self.param_kwargs = dict(param_values) self._warn_on_wrong_param_types() - self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True)) + self.task_id = task_id_str(self.get_task_family(), self.to_str_params(only_significant=True, only_public=True)) self.__hash = hash(self.task_id) self.set_tracking_url = None @@ -482,18 +483,29 @@ def from_str_params(cls, params_str): return cls(**kwargs) - def to_str_params(self, only_significant=False): + def to_str_params(self, only_significant=False, only_public=False): """ Convert all parameters to a str->str hash. """ params_str = {} params = dict(self.get_params()) for param_name, param_value in six.iteritems(self.param_kwargs): - if (not only_significant) or params[param_name].significant: + if (((not only_significant) or params[param_name].significant) + and ((not only_public) or params[param_name].visibility == ParameterVisibility.PUBLIC) + and params[param_name].visibility != ParameterVisibility.PRIVATE): params_str[param_name] = params[param_name].serialize(param_value) return params_str + def _get_param_visibilities(self): + param_visibilities = {} + params = dict(self.get_params()) + for param_name, param_value in six.iteritems(self.param_kwargs): + if params[param_name].visibility != ParameterVisibility.PRIVATE: + param_visibilities[param_name] = params[param_name].visibility.serialize() + + return param_visibilities + def clone(self, cls=None, **kwargs): """ Creates a new instance from an existing instance where some of the args have changed. diff --git a/luigi/util.py b/luigi/util.py index f6be20021b..784a778602 100644 --- a/luigi/util.py +++ b/luigi/util.py @@ -52,7 +52,7 @@ def requires(self): more burdensome than the last. Refactoring becomes more difficult. There are several ways one might try and avoid the problem. -**Approach 1**: Parameters via command line or config instead of ``requires``. +**Approach 1**: Parameters via command line or config instead of :func:`~luigi.task.Task.requires`. .. code-block:: python @@ -132,13 +132,13 @@ def requires(self): specified in the wrong order. This contrived example is easy to fix (by swapping the ordering of the parents of ``TaskA``), but real world cases can be more difficult to both spot and fix. Inheriting from multiple classes -derived from ``luigi.Task`` should be undertaken with caution and avoided +derived from :class:`~luigi.task.Task` should be undertaken with caution and avoided where possible. -**Approach 3**: Use ``inherits`` and ``requires`` +**Approach 3**: Use :class:`~luigi.util.inherits` and :class:`~luigi.util.requires` -The ``inherits`` class decorator in this module copies parameters (and +The :class:`~luigi.util.inherits` class decorator in this module copies parameters (and nothing else) from one task class to another, and avoids direct pythonic inheritance. @@ -185,11 +185,12 @@ def requires(self): issues, and keeps the task command line interface as simple (as it can be, anyway). Refactoring task parameters is also much easier. -The ``requires`` helper function can reduce this pattern even further. It -does everything ``inherits`` does, and also attaches a ``requires`` method +The :class:`~luigi.util.requires` helper function can reduce this pattern even further. It +does everything :class:`~luigi.util.inherits` does, +and also attaches a :class:`~luigi.util.requires` method to your task (still all without pythonic inheritance). -But how does it know how to invoke the upstream task? It uses ``clone`` +But how does it know how to invoke the upstream task? It uses :func:`~luigi.task.Task.clone` behind the scenes! .. code-block:: python @@ -251,59 +252,91 @@ class inherits(object): """ Task inheritance. + *New after Luigi 2.7.6:* multiple arguments support. + Usage: .. code-block:: python class AnotherTask(luigi.Task): + m = luigi.IntParameter() + + class YetAnotherTask(luigi.Task): n = luigi.IntParameter() - # ... @inherits(AnotherTask): - class MyTask(luigi.Task): + class MyFirstTask(luigi.Task): def requires(self): return self.clone_parent() + def run(self): + print self.m # this will be defined + # ... + + @inherits(AnotherTask, YetAnotherTask): + class MySecondTask(luigi.Task): + def requires(self): + return self.clone_parents() + def run(self): print self.n # this will be defined # ... """ - def __init__(self, task_to_inherit): + def __init__(self, *tasks_to_inherit): super(inherits, self).__init__() - self.task_to_inherit = task_to_inherit + if not tasks_to_inherit: + raise TypeError("tasks_to_inherit cannot be empty") + + self.tasks_to_inherit = tasks_to_inherit def __call__(self, task_that_inherits): - # Get all parameter objects from the underlying task - for param_name, param_obj in self.task_to_inherit.get_params(): - # Check if the parameter exists in the inheriting task - if not hasattr(task_that_inherits, param_name): - # If not, add it to the inheriting task - setattr(task_that_inherits, param_name, param_obj) + # Get all parameter objects from each of the underlying tasks + for task_to_inherit in self.tasks_to_inherit: + for param_name, param_obj in task_to_inherit.get_params(): + # Check if the parameter exists in the inheriting task + if not hasattr(task_that_inherits, param_name): + # If not, add it to the inheriting task + setattr(task_that_inherits, param_name, param_obj) # Modify task_that_inherits by adding methods - def clone_parent(_self, **args): - return _self.clone(cls=self.task_to_inherit, **args) + def clone_parent(_self, **kwargs): + return _self.clone(cls=self.tasks_to_inherit[0], **kwargs) task_that_inherits.clone_parent = clone_parent + def clone_parents(_self, **kwargs): + return [ + _self.clone(cls=task_to_inherit, **kwargs) + for task_to_inherit in self.tasks_to_inherit + ] + task_that_inherits.clone_parents = clone_parents + return task_that_inherits class requires(object): """ - Same as @inherits, but also auto-defines the requires method. + Same as :class:`~luigi.util.inherits`, but also auto-defines the requires method. + + *New after Luigi 2.7.6:* multiple arguments support. + """ - def __init__(self, task_to_require): + def __init__(self, *tasks_to_require): super(requires, self).__init__() - self.inherit_decorator = inherits(task_to_require) + if not tasks_to_require: + raise TypeError("tasks_to_require cannot be empty") + + self.tasks_to_require = tasks_to_require def __call__(self, task_that_requires): - task_that_requires = self.inherit_decorator(task_that_requires) + task_that_requires = inherits(*self.tasks_to_require)(task_that_requires) - # Modify task_that_requres by adding methods + # Modify task_that_requires by adding requires method. + # If only one task is required, this single task is returned. + # Otherwise, list of tasks is returned def requires(_self): - return _self.clone_parent() + return _self.clone_parent() if len(self.tasks_to_require) == 1 else _self.clone_parents() task_that_requires.requires = requires return task_that_requires diff --git a/luigi/worker.py b/luigi/worker.py index 5c76bbc3de..6cdaff1884 100644 --- a/luigi/worker.py +++ b/luigi/worker.py @@ -37,6 +37,7 @@ import signal import subprocess import sys +import contextlib try: import Queue @@ -135,16 +136,8 @@ def __init__(self, task, worker_id, result_queue, status_reporter, self.check_unfulfilled_deps = check_unfulfilled_deps def _run_get_new_deps(self): - # forward some attributes before running - for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): - setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) - task_gen = self.task.run() - # reset attributes again - for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): - setattr(self.task, task_attr, None) - if not isinstance(task_gen, types.GeneratorType): return None @@ -202,7 +195,8 @@ def run(self): expl = 'Task is an external data dependency ' \ 'and data does not exist (yet?).' else: - new_deps = self._run_get_new_deps() + with self._forward_attributes(): + new_deps = self._run_get_new_deps() status = DONE if not new_deps else PENDING if new_deps: @@ -258,6 +252,18 @@ def terminate(self): except ImportError: return super(TaskProcess, self).terminate() + @contextlib.contextmanager + def _forward_attributes(self): + # forward configured attributes to the task + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): + setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) + try: + yield self + finally: + # reset attributes again + for reporter_attr, task_attr in six.iteritems(self.forward_reporter_attributes): + setattr(self.task, task_attr, None) + # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 @@ -565,6 +571,9 @@ def _add_task(self, *args, **kwargs): for batch_task in self._batch_running_tasks.pop(task_id): self._add_task_history.append((batch_task, status, True)) + if task and kwargs.get('params'): + kwargs['param_visibilities'] = task._get_param_visibilities() + self._scheduler.add_task(*args, **kwargs) logger.info('Informed scheduler that task %s has status %s', task_id, status) diff --git a/setup.py b/setup.py index 85f7dba8fa..89cffbcb1d 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ # the License. import os +import sys from setuptools import setup @@ -48,6 +49,9 @@ def get_static_files(path): install_requires.remove('python-daemon<3.0') install_requires.append('sphinx>=1.4.4') # Value mirrored in doc/conf.py +if sys.version_info < (3, 4): + install_requires.append('enum34>1.1.0') + setup( name='luigi', version='2.7.6', @@ -58,6 +62,7 @@ def get_static_files(path): license='Apache License 2.0', packages=[ 'luigi', + 'luigi.configuration', 'luigi.contrib', 'luigi.contrib.hdfs', 'luigi.tools' @@ -75,6 +80,9 @@ def get_static_files(path): ] }, install_requires=install_requires, + extras_require={ + 'toml': ['toml<2.0.0'], + }, classifiers=[ 'Development Status :: 5 - Production/Stable', 'Environment :: Console', diff --git a/test/config_toml_test.py b/test/config_toml_test.py new file mode 100644 index 0000000000..e0211c60ff --- /dev/null +++ b/test/config_toml_test.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 Cindicator Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from luigi.configuration import LuigiTomlParser, get_config, add_config_path + + +from helpers import LuigiTestCase + + +class TomlConfigParserTest(LuigiTestCase): + @classmethod + def setUpClass(cls): + add_config_path('test/testconfig/luigi.toml') + add_config_path('test/testconfig/luigi_local.toml') + + def setUp(self): + LuigiTomlParser._instance = None + super(TomlConfigParserTest, self).setUp() + + def test_get_config(self): + config = get_config('toml') + self.assertIsInstance(config, LuigiTomlParser) + + def test_file_reading(self): + config = get_config('toml') + self.assertIn('hdfs', config.data) + + def test_get(self): + config = get_config('toml') + + # test getting + self.assertEqual(config.get('hdfs', 'client'), 'hadoopcli') + self.assertEqual(config.get('hdfs', 'client', 'test'), 'hadoopcli') + + # test default + self.assertEqual(config.get('hdfs', 'test', 'check'), 'check') + with self.assertRaises(KeyError): + config.get('hdfs', 'test') + + # test override + self.assertEqual(config.get('hdfs', 'namenode_host'), 'localhost') + # test non-string values + self.assertEqual(config.get('hdfs', 'namenode_port'), 50030) + + def test_set(self): + config = get_config('toml') + + self.assertEqual(config.get('hdfs', 'client'), 'hadoopcli') + config.set('hdfs', 'client', 'test') + self.assertEqual(config.get('hdfs', 'client'), 'test') + config.set('hdfs', 'check', 'test me') + self.assertEqual(config.get('hdfs', 'check'), 'test me') diff --git a/test/contrib/postgres_test.py b/test/contrib/postgres_test.py index 5df6888343..eadd6d1018 100644 --- a/test/contrib/postgres_test.py +++ b/test/contrib/postgres_test.py @@ -121,3 +121,55 @@ def test_bulk_complete(self, mock_connect): 'DummyPostgresQuery_2015_01_06_f91a47ec40', ]) self.assertFalse(task.complete()) + + +@attr('postgres') +class TestCopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.postgres.PostgresTarget") + @mock.patch('psycopg2.connect') + def test_copy_with_metadata_columns_enabled(self, + mock_connect, + mock_redshift_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockPostgresCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.postgres.CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.postgres.CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.postgres.CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.postgres.CopyToTable.rows", return_value=['row1', 'row2']) + @mock.patch("luigi.contrib.postgres.PostgresTarget") + @mock.patch('psycopg2.connect') + def test_copy_with_metadata_columns_disabled(self, + mock_connect, + mock_redshift_target, + mock_rows, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + + task = DummyPostgresImporter(date=datetime.datetime(1991, 3, 24)) + + mock_cursor = MockPostgresCursor([task.task_id]) + mock_connect.return_value.cursor.return_value = mock_cursor + + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) diff --git a/test/contrib/rdbms_test.py b/test/contrib/rdbms_test.py new file mode 100644 index 0000000000..3127cb2e8d --- /dev/null +++ b/test/contrib/rdbms_test.py @@ -0,0 +1,255 @@ +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +We're using Redshift as the test bed since Redshift implements RDBMS. We could +have opted for PSQL but we're less familiar with that contrib and there are +less examples on how to test it. +""" + +import luigi +import luigi.contrib.redshift +import mock + +import unittest + + +# Fake AWS and S3 credentials taken from `../redshift_test.py`. +AWS_ACCESS_KEY = 'key' +AWS_SECRET_KEY = 'secret' + +AWS_ACCOUNT_ID = '0123456789012' +AWS_ROLE_NAME = 'MyRedshiftRole' + +BUCKET = 'bucket' +KEY = 'key' + + +class DummyS3CopyToTableBase(luigi.contrib.redshift.S3CopyToTable): + # Class attributes taken from `DummyPostgresImporter` in + # `../postgres_test.py`. + host = 'dummy_host' + database = 'dummy_database' + user = 'dummy_user' + password = 'dummy_password' + table = luigi.Parameter(default='dummy_table') + columns = luigi.TupleParameter( + default=( + ('some_text', 'varchar(255)'), + ('some_int', 'int'), + ) + ) + + copy_options = '' + prune_table = '' + prune_column = '' + prune_date = '' + + def s3_load_path(self): + return 's3://%s/%s' % (BUCKET, KEY) + + +class DummyS3CopyToTableKey(DummyS3CopyToTableBase): + aws_access_key_id = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_KEY + + +class TestS3CopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_check_meta_columns_to_table_if_exists(self, + mock_redshift_target, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "SELECT 1 AS column_exists FROM information_schema.columns " \ + "WHERE table_name = LOWER('{table}') " \ + "AND column_name = LOWER('{column}') " \ + "LIMIT 1;".format(table='my_test_table', column='created_tz') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_check_meta_columns_to_schematable_if_exists(self, + mock_redshift_target, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='test.my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[2][0][0] + + expected_output = "SELECT 1 AS column_exists FROM information_schema.columns " \ + "WHERE table_schema = LOWER('{schema}') " \ + "AND table_name = LOWER('{table}') " \ + "AND column_name = LOWER('{column}') " \ + "LIMIT 1;".format(schema='test', table='my_test_table', column='created_tz') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_not_add_if_meta_columns_already_exists(self, + mock_redshift_target, + mock_add_to_table, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertFalse(mock_add_to_table.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_column_to_table") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_if_meta_columns_not_already_exists(self, + mock_redshift_target, + mock_add_to_table, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertTrue(mock_add_to_table.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_regular_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "ALTER TABLE {table} " \ + "ADD COLUMN {column} {type};".format(table='my_test_table', column='created_tz', type='TIMESTAMP') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz', 'TIMESTAMP', 'bytedict')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_add_encoded_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey(table='my_test_table') + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[1][0][0] + + expected_output = "ALTER TABLE {table} " \ + "ADD COLUMN {column} {type} ENCODE {encoding};".format(table='my_test_table', column='created_tz', + type='TIMESTAMP', + encoding='bytedict') + + self.assertEqual(executed_query, expected_output) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, return_value=[('created_tz')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_raise_error_on_no_column_type(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + + with self.assertRaises(ValueError): + task.run() + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_columns", new_callable=mock.PropertyMock, + return_value=[('created_tz', 'TIMESTAMP', 'bytedict', '42')]) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._column_exists", return_value=False) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_raise_error_on_invalid_column(self, + mock_redshift_target, + mock_columns_exists, + mock_metadata_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + + with self.assertRaises(ValueError): + task.run() + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable.metadata_queries", new_callable=mock.PropertyMock, return_value=['SELECT 1 FROM X', 'SELECT 2 FROM Y']) + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_post_copy_metacolumns(self, + mock_redshift_target, + mock_metadata_queries, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + mock_cursor = (mock_redshift_target.return_value + .connect + .return_value + .cursor + .return_value) + + executed_query = mock_cursor.execute.call_args_list[2][0][0] + expected_output = "SELECT 1 FROM X" + self.assertEqual(executed_query, expected_output) + + executed_query = mock_cursor.execute.call_args_list[3][0][0] + expected_output = "SELECT 2 FROM Y" + self.assertEqual(executed_query, expected_output) diff --git a/test/contrib/redshift_test.py b/test/contrib/redshift_test.py index c6b23bf2b1..5433c6d186 100644 --- a/test/contrib/redshift_test.py +++ b/test/contrib/redshift_test.py @@ -80,6 +80,36 @@ def s3_load_path(self): return 's3://%s/%s' % (BUCKET, KEY) +class DummyS3CopyJSONToTableBase(luigi.contrib.redshift.S3CopyJSONToTable): + # Class attributes taken from `DummyPostgresImporter` in + # `../postgres_test.py`. + aws_access_key_id = AWS_ACCESS_KEY + aws_secret_access_key = AWS_SECRET_KEY + + host = 'dummy_host' + database = 'dummy_database' + user = 'dummy_user' + password = 'dummy_password' + table = luigi.Parameter(default='dummy_table') + columns = luigi.TupleParameter( + default=( + ('some_text', 'varchar(255)'), + ('some_int', 'int'), + ) + ) + + copy_options = '' + prune_table = '' + prune_column = '' + prune_date = '' + + jsonpath = '' + copy_json_options = '' + + def s3_load_path(self): + return 's3://%s/%s' % (BUCKET, KEY) + + class DummyS3CopyToTableKey(DummyS3CopyToTableBase): aws_access_key_id = AWS_ACCESS_KEY aws_secret_access_key = AWS_SECRET_KEY @@ -130,6 +160,68 @@ def test_from_config(self): self.assertEqual(self.aws_secret_access_key, "config_secret") +class TestS3CopyToTableWithMetaColumns(unittest.TestCase): + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_with_metadata_columns_enabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_copy_with_metadata_columns_disabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyToTableKey() + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=True) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_json_copy_with_metadata_columns_enabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyJSONToTableBase() + task.run() + + self.assertTrue(mock_add_columns.called) + self.assertTrue(mock_update_columns.called) + + @mock.patch("luigi.contrib.redshift.S3CopyToTable.enable_metadata_columns", new_callable=mock.PropertyMock, return_value=False) + @mock.patch("luigi.contrib.redshift.S3CopyToTable._add_metadata_columns") + @mock.patch("luigi.contrib.redshift.S3CopyToTable.post_copy_metacolumns") + @mock.patch("luigi.contrib.redshift.RedshiftTarget") + def test_json_copy_with_metadata_columns_disabled(self, + mock_redshift_target, + mock_add_columns, + mock_update_columns, + mock_metadata_columns_enabled): + task = DummyS3CopyJSONToTableBase() + task.run() + + self.assertFalse(mock_add_columns.called) + self.assertFalse(mock_update_columns.called) + + class TestS3CopyToTable(unittest.TestCase): @mock.patch("luigi.contrib.redshift.RedshiftTarget") def test_copy_missing_creds(self, mock_redshift_target): diff --git a/test/contrib/s3_test.py b/test/contrib/s3_test.py index 93b24c4c69..97ea5cfc2e 100644 --- a/test/contrib/s3_test.py +++ b/test/contrib/s3_test.py @@ -41,6 +41,13 @@ AWS_SECRET_KEY = "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" +def create_bucket(): + conn = boto3.resource('s3', region_name='us-east-1') + # We need to create the bucket since this is all in Moto's 'virtual' AWS account + conn.create_bucket(Bucket='mybucket') + return conn + + class TestS3Target(unittest.TestCase, FileSystemTargetTestMixin): def setUp(self): @@ -57,20 +64,14 @@ def setUp(self): self.mock_s3.start() self.addCleanup(self.mock_s3.stop) - def create_bucket(self): - conn = boto3.resource('s3', region_name='us-east-1') - # We need to create the bucket since this is all in Moto's 'virtual' AWS account - conn.create_bucket(Bucket='mybucket') - return conn - def create_target(self, format=None, **kwargs): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() return S3Target('s3://mybucket/test_file', client=client, format=format, **kwargs) def test_read(self): client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() client.put(self.tempFilePath, 's3://mybucket/tempfile') t = S3Target('s3://mybucket/tempfile', client=client) read_file = t.open() @@ -99,7 +100,7 @@ def test_read_iterator_long(self): tempf.close() client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) - self.create_bucket() + create_bucket() client.put(temppath, 's3://mybucket/largetempfile') t = S3Target('s3://mybucket/largetempfile', client=client) with t.open() as read_file: @@ -167,33 +168,27 @@ def test_init_with_config_and_roles(self, sts_mock, s3_mock): sts_mock.client.assume_role.called_with( RoleArn='role', RoleSessionName='name') - def create_bucket(self): - conn = boto3.resource('s3', region_name='us-east-1') - # We need to create the bucket since this is all in Moto's 'virtual' AWS account - conn.create_bucket(Bucket='mybucket') - return conn - def test_put(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') self.assertTrue(s3_client.exists('s3://mybucket/putMe')) def test_put_sse_deprecated(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put(self.tempFilePath, 's3://mybucket/putMe', encrypt_key=True) def test_put_string(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("SOMESTRING", 's3://mybucket/putString') self.assertTrue(s3_client.exists('s3://mybucket/putString')) def test_put_string_sse_deprecated(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) with self.assertRaises(DeprecatedBotoClientException): s3_client.put('SOMESTRING', @@ -243,7 +238,7 @@ def test_put_multipart_less_than_split_size(self): self._run_multipart_test(part_size, file_size) def test_exists(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.exists('s3://mybucket/')) @@ -266,7 +261,7 @@ def test_exists(self): self.assertFalse(s3_client.exists('s3://mybucket/tempdir')) def test_get(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') @@ -280,7 +275,7 @@ def test_get(self): tmp_file.close() def test_get_as_string(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/putMe') @@ -289,14 +284,14 @@ def test_get_as_string(self): self.assertEquals(contents, self.tempFileContents.decode("utf-8")) def test_get_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put(self.tempFilePath, 's3://mybucket/key_to_find') self.assertTrue(s3_client.get_key('s3://mybucket/key_to_find').key) self.assertFalse(s3_client.get_key('s3://mybucket/does_not_exist')) def test_isdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir('s3://mybucket')) @@ -310,7 +305,7 @@ def test_isdir(self): self.assertFalse(s3_client.isdir('s3://mybucket/key')) def test_mkdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertTrue(s3_client.isdir('s3://mybucket')) s3_client.mkdir('s3://mybucket') @@ -324,7 +319,7 @@ def test_mkdir(self): self.assertFalse(s3_client.isdir('s3://mybucket/dir/foo/bar')) def test_listdir(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -334,7 +329,7 @@ def test_listdir(self): list(s3_client.listdir('s3://mybucket/hello'))) def test_list(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -344,7 +339,7 @@ def test_list(self): list(s3_client.list('s3://mybucket/hello'))) def test_listdir_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -354,7 +349,7 @@ def test_listdir_key(self): [s3_client.exists('s3://' + x.bucket_name + '/' + x.key) for x in s3_client.listdir('s3://mybucket/hello', return_key=True)]) def test_list_key(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) s3_client.put_string("", 's3://mybucket/hello/frank') @@ -364,7 +359,7 @@ def test_list_key(self): [s3_client.exists('s3://' + x.bucket_name + '/' + x.key) for x in s3_client.listdir('s3://mybucket/hello', return_key=True)]) def test_remove(self): - self.create_bucket() + create_bucket() s3_client = S3Client(AWS_ACCESS_KEY, AWS_SECRET_KEY) self.assertRaises( @@ -436,7 +431,7 @@ def test_copy_dir(self): """ Test copying 20 files from one folder to another """ - self.create_bucket() + create_bucket() n = 20 copy_part_size = (1024 ** 2) * 5 @@ -468,7 +463,7 @@ def test_copy_dir(self): @mock_s3 def _run_multipart_copy_test(self, put_method): - self.create_bucket() + create_bucket() # Run the method to put the file into s3 into the first place put_method() @@ -493,7 +488,7 @@ def _run_multipart_copy_test(self, put_method): @mock_s3 def _run_copy_test(self, put_method): - self.create_bucket() + create_bucket() # Run the method to put the file into s3 into the first place put_method() @@ -514,7 +509,7 @@ def _run_copy_test(self, put_method): @mock_s3 def _run_multipart_test(self, part_size, file_size, **kwargs): - self.create_bucket() + create_bucket() file_contents = b"a" * file_size s3_path = 's3://mybucket/putMe' diff --git a/test/db_task_history_test.py b/test/db_task_history_test.py index 8b162d282e..d302bed292 100644 --- a/test/db_task_history_test.py +++ b/test/db_task_history_test.py @@ -24,6 +24,7 @@ from luigi.db_task_history import DbTaskHistory from luigi.task_status import DONE, PENDING, RUNNING import luigi.scheduler +from luigi.parameter import ParameterVisibility class DummyTask(luigi.Task): @@ -32,7 +33,8 @@ class DummyTask(luigi.Task): class ParamTask(luigi.Task): param1 = luigi.Parameter() - param2 = luigi.IntParameter() + param2 = luigi.IntParameter(visibility=ParameterVisibility.HIDDEN) + param3 = luigi.Parameter(default="empty", visibility=ParameterVisibility.PRIVATE) class DbTaskHistoryTest(unittest.TestCase): diff --git a/test/decorator_test.py b/test/decorator_test.py index 0e113caaf6..e9851a269e 100644 --- a/test/decorator_test.py +++ b/test/decorator_test.py @@ -53,9 +53,14 @@ class D_null(luigi.Task): param1 = None +@inherits(A, B) +class E(luigi.Task): + param4 = luigi.Parameter("class E-specific default") + + @inherits(A) @inherits(B) -class E(luigi.Task): +class E_stacked(luigi.Task): param4 = luigi.Parameter("class E-specific default") @@ -69,6 +74,7 @@ def setUp(self): self.d = D() self.d_null = D_null() self.e = E() + self.e_stacked = E_stacked() def test_has_param(self): b_params = dict(self.b.get_params()).keys() @@ -91,11 +97,22 @@ def test_overwriting_defaults(self): self.assertNotEqual(self.d.param1, self.a.param1) self.assertEqual(self.d.param1, "class D overwriting class A's default") - def test_stacked_inheritance(self): + def test_multiple_inheritance(self): self.assertEqual(self.e.param1, self.a.param1) self.assertEqual(self.e.param1, self.b.param1) self.assertEqual(self.e.param2, self.b.param2) + def test_stacked_inheritance(self): + self.assertEqual(self.e_stacked.param1, self.a.param1) + self.assertEqual(self.e_stacked.param1, self.b.param1) + self.assertEqual(self.e_stacked.param2, self.b.param2) + + def test_empty_inheritance(self): + with self.assertRaises(TypeError): + @inherits() + class shouldfail(luigi.Task): + pass + def test_removing_parameter(self): self.assertFalse("param1" in dict(self.d_null.get_params()).keys()) @@ -226,53 +243,75 @@ def test_wrong_common_params_order(self): self.assertRaises(TypeError, self.k_wrongparamsorder.requires) -class X(luigi.Task): +class V(luigi.Task): n = luigi.IntParameter(default=42) -@inherits(X) -class Y(luigi.Task): +@inherits(V) +class W(luigi.Task): def requires(self): return self.clone_parent() -@requires(X) -class Y2(luigi.Task): +@requires(V) +class W2(luigi.Task): pass -@requires(X) -class Y3(luigi.Task): +@requires(V) +class W3(luigi.Task): n = luigi.IntParameter(default=43) +class X(luigi.Task): + m = luigi.IntParameter(default=56) + + +@requires(V, X) +class Y(luigi.Task): + pass + + class CloneParentTest(unittest.TestCase): def test_clone_parent(self): - y = Y() - x = X() - self.assertEqual(y.requires(), x) - self.assertEqual(y.n, 42) + w = W() + v = V() + self.assertEqual(w.requires(), v) + self.assertEqual(w.n, 42) def test_requires(self): - y2 = Y2() - x = X() - self.assertEqual(y2.requires(), x) - self.assertEqual(y2.n, 42) + w2 = W2() + v = V() + self.assertEqual(w2.requires(), v) + self.assertEqual(w2.n, 42) def test_requires_override_default(self): - y3 = Y3() + w3 = W3() + v = V() + self.assertNotEqual(w3.requires(), v) + self.assertEqual(w3.n, 43) + self.assertEqual(w3.requires().n, 43) + + def test_multiple_requires(self): + y = Y() + v = V() x = X() - self.assertNotEqual(y3.requires(), x) - self.assertEqual(y3.n, 43) - self.assertEqual(y3.requires().n, 43) + self.assertEqual(y.requires()[0], v) + self.assertEqual(y.requires()[1], x) + + def test_empty_requires(self): + with self.assertRaises(TypeError): + @requires() + class shouldfail(luigi.Task): + pass def test_names(self): # Just make sure the decorators retain the original class names - x = X() - self.assertEqual(str(x), 'X(n=42)') - self.assertEqual(x.__class__.__name__, 'X') + v = V() + self.assertEqual(str(v), 'V(n=42)') + self.assertEqual(v.__class__.__name__, 'V') class P(luigi.Task): diff --git a/test/rpc_test.py b/test/rpc_test.py index 044e1c14fe..cfb55a1ca1 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from helpers import unittest +from helpers import unittest, with_config try: from unittest import mock except ImportError: @@ -52,7 +52,7 @@ def _wait(self): scheduler = ShorterWaitRemoteScheduler('http://zorg.com', 42) with mock.patch.object(scheduler, '_fetcher') as fetcher: - fetcher.raises = socket.timeout + fetcher.raises = socket.timeout, socket.gaierror fetcher.fetch.side_effect = fetcher_side_effect return scheduler.get_work("fake_worker") @@ -72,6 +72,36 @@ def test_retry_rpc_limited(self): fetch_results = [socket.timeout, socket.timeout, socket.timeout] self.assertRaises(luigi.rpc.RPCError, self.get_work, fetch_results) + @mock.patch('luigi.rpc.logger') + def test_log_rpc_retries_enabled(self, mock_logger): + """ + Tests that each retry of an RPC method is logged + """ + + fetch_results = [socket.timeout, socket.timeout, '{"response":{}}'] + self.get_work(fetch_results) + self.assertEqual([ + mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True), + mock.call.info('Retrying attempt 2 of 3 (max)'), + mock.call.warning('Failed connecting to remote scheduler %r', 'http://zorg.com', exc_info=True), + mock.call.info('Retrying attempt 3 of 3 (max)'), + ], mock_logger.mock_calls) + + @with_config({'core': {'rpc-log-retries': 'false'}}) + @mock.patch('luigi.rpc.logger') + def test_log_rpc_retries_disabled(self, mock_logger): + """ + Tests that retries of an RPC method are not logged + """ + + fetch_results = [socket.timeout, socket.timeout, socket.gaierror] + try: + self.get_work(fetch_results) + self.fail("get_work should have thrown RPCError") + except luigi.rpc.RPCError as e: + self.assertTrue(isinstance(e.sub_exception, socket.gaierror)) + self.assertEqual([], mock_logger.mock_calls) + def test_get_work_retries_on_null(self): """ Tests that get_work will retry if the response is null diff --git a/test/scheduler_parameter_visibilities_test.py b/test/scheduler_parameter_visibilities_test.py new file mode 100644 index 0000000000..b3cae1f579 --- /dev/null +++ b/test/scheduler_parameter_visibilities_test.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, RunOnceTask +import server_test + +import luigi +import luigi.scheduler +import luigi.worker +from luigi.parameter import ParameterVisibility +import json +import time + + +class SchedulerParameterVisibilitiesTest(LuigiTestCase): + def test_task_with_deps(self): + s = luigi.scheduler.Scheduler(send_messages=True) + with luigi.worker.Worker(scheduler=s) as w: + class DynamicTask(RunOnceTask): + dynamic_public = luigi.Parameter(default="dynamic_public") + dynamic_hidden = luigi.Parameter(default="dynamic_hidden", visibility=ParameterVisibility.HIDDEN) + dynamic_private = luigi.Parameter(default="dynamic_private", visibility=ParameterVisibility.PRIVATE) + + class RequiredTask(RunOnceTask): + required_public = luigi.Parameter(default="required_param") + required_hidden = luigi.Parameter(default="required_hidden", visibility=ParameterVisibility.HIDDEN) + required_private = luigi.Parameter(default="required_private", visibility=ParameterVisibility.PRIVATE) + + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + def requires(self): + return required_task + + def run(self): + yield dynamic_task + + dynamic_task = DynamicTask() + required_task = RequiredTask() + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + task_deps = s.dep_graph(task_id=task.task_id) + required_task_deps = s.dep_graph(task_id=required_task.task_id) + dynamic_task_deps = s.dep_graph(task_id=dynamic_task.task_id) + + self.assertEqual('Task(a=a, d=d)', task_deps[task.task_id]['display_name']) + self.assertEqual('RequiredTask(required_public=required_param)', + required_task_deps[required_task.task_id]['display_name']) + self.assertEqual('DynamicTask(dynamic_public=dynamic_public)', + dynamic_task_deps[dynamic_task.task_id]['display_name']) + + self.assertEqual({'a': 'a', 'd': 'd'}, task_deps[task.task_id]['params']) + self.assertEqual({'required_public': 'required_param'}, + required_task_deps[required_task.task_id]['params']) + self.assertEqual({'dynamic_public': 'dynamic_public'}, + dynamic_task_deps[dynamic_task.task_id]['params']) + + def test_public_and_hidden_params(self): + s = luigi.scheduler.Scheduler(send_messages=True) + with luigi.worker.Worker(scheduler=s) as w: + class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + task = Task() + + w.add(task) + w.run() + + time.sleep(1) + t = s._state.get_task(task.task_id) + self.assertEqual({'b': 'b'}, t.hidden_params) + self.assertEqual({'a': 'a', 'd': 'd'}, t.public_params) + self.assertEqual({'a': 0, 'b': 1, 'd': 0}, t.param_visibilities) + + +class Task(RunOnceTask): + a = luigi.Parameter(default="a") + b = luigi.Parameter(default="b", visibility=ParameterVisibility.HIDDEN) + c = luigi.Parameter(default="c", visibility=ParameterVisibility.PRIVATE) + d = luigi.Parameter(default="d", visibility=ParameterVisibility.PUBLIC) + + +class RemoteSchedulerParameterVisibilitiesTest(server_test.ServerTestBase): + def test_public_params(self): + task = Task() + luigi.build(tasks=[task], workers=2, scheduler_port=self.get_http_port()) + + time.sleep(1) + + response = self.fetch('/api/graph') + + body = response.body + decoded = body.decode('utf8').replace("'", '"') + data = json.loads(decoded) + + self.assertEqual({'a': 'a', 'd': 'd'}, data['response'][task.task_id]['params']) diff --git a/test/scheduler_visualisation_test.py b/test/scheduler_visualisation_test.py index 1d35f69ffc..4edb668e51 100644 --- a/test/scheduler_visualisation_test.py +++ b/test/scheduler_visualisation_test.py @@ -19,7 +19,7 @@ import os import tempfile import time -from helpers import unittest +from helpers import unittest, RunOnceTask import luigi import luigi.notifications @@ -33,7 +33,7 @@ class DummyTask(luigi.Task): - task_id = luigi.Parameter() + task_id = luigi.IntParameter() def run(self): f = self.output().open('w') @@ -44,7 +44,7 @@ def output(self): class FactorTask(luigi.Task): - product = luigi.Parameter() + product = luigi.IntParameter() def requires(self): for factor in range(2, self.product): @@ -77,7 +77,10 @@ def complete(self): class FailingTask(luigi.Task): task_namespace = __name__ - task_id = luigi.Parameter() + task_id = luigi.IntParameter() + + def complete(self): + return False def run(self): raise Exception("Error Message") @@ -100,7 +103,6 @@ def run(self): class SchedulerVisualisationTest(unittest.TestCase): - def setUp(self): self.scheduler = luigi.scheduler.Scheduler() @@ -190,7 +192,7 @@ def complete(self): six.assertCountEqual(self, expected_nodes, graph) def test_truncate_graph_with_full_levels(self): - class BinaryTreeTask(luigi.Task): + class BinaryTreeTask(RunOnceTask): idx = luigi.IntParameter() def requires(self): @@ -226,7 +228,7 @@ def complete(self): graph = self.scheduler.dep_graph(root_task.task_id) self.assertEqual(10, len(graph)) - expected_nodes = [LinearTask(i).task_id for i in range(100, 91, -1)] +\ + expected_nodes = [LinearTask(i).task_id for i in range(100, 91, -1)] + \ [LinearTask(0).task_id] self.maxDiff = None six.assertCountEqual(self, expected_nodes, graph) @@ -387,30 +389,29 @@ def test_task_list_failed(self): def test_task_list_upstream_status(self): class A(luigi.ExternalTask): - pass + def complete(self): + return False class B(luigi.ExternalTask): - def complete(self): return True - class C(luigi.Task): - + class C(RunOnceTask): def requires(self): return [A(), B()] class F(luigi.Task): + def complete(self): + return False def run(self): raise Exception() - class D(luigi.Task): - + class D(RunOnceTask): def requires(self): return [F()] - class E(luigi.Task): - + class E(RunOnceTask): def requires(self): return [C(), D()] @@ -478,22 +479,20 @@ def test_fetch_error(self): self.assertTrue("Traceback" in error["error"]) def test_inverse_deps(self): - class X(luigi.Task): + class X(RunOnceTask): pass - class Y(luigi.Task): - + class Y(RunOnceTask): def requires(self): return [X()] - class Z(luigi.Task): - id = luigi.Parameter() + class Z(RunOnceTask): + id = luigi.IntParameter() def requires(self): return [Y()] - class ZZ(luigi.Task): - + class ZZ(RunOnceTask): def requires(self): return [Z(1), Z(2)] @@ -513,7 +512,6 @@ def assert_has_deps(task_id, deps): def test_simple_worker_list(self): class X(luigi.Task): - def run(self): self._complete = True @@ -536,12 +534,10 @@ def complete(self): def test_worker_list_pending_uniques(self): class X(luigi.Task): - def complete(self): return False class Y(X): - def requires(self): return X() @@ -562,7 +558,7 @@ class Z(Y): self.assertEqual(0, worker['num_running']) def test_worker_list_running(self): - class X(luigi.Task): + class X(RunOnceTask): n = luigi.IntParameter() w = luigi.worker.Worker(worker_id='w', scheduler=self.scheduler, worker_processes=3) @@ -584,7 +580,7 @@ class X(luigi.Task): self.assertEqual(1, worker['num_uniques']) def test_worker_list_disabled_worker(self): - class X(luigi.Task): + class X(RunOnceTask): pass with luigi.worker.Worker(worker_id='w', scheduler=self.scheduler) as w: diff --git a/test/task_forwarded_attributes_test.py b/test/task_forwarded_attributes_test.py new file mode 100644 index 0000000000..48ef319136 --- /dev/null +++ b/test/task_forwarded_attributes_test.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2012-2015 Spotify AB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from helpers import LuigiTestCase, RunOnceTask + +import luigi +import luigi.scheduler +import luigi.worker + + +FORWARDED_ATTRIBUTES = set(luigi.worker.TaskProcess.forward_reporter_attributes.values()) + + +class NonYieldingTask(RunOnceTask): + + # need to accept messages in order for the "scheduler_message" attribute to be not None + accepts_messages = True + + def gather_forwarded_attributes(self): + """ + Returns a set of names of attributes that are forwarded by the TaskProcess and that are not + *None*. The tests in this file check if and which attributes are present at different times, + e.g. while running, or before and after a dynamic dependency was yielded. + """ + attrs = set() + for attr in FORWARDED_ATTRIBUTES: + if getattr(self, attr, None) is not None: + attrs.add(attr) + return attrs + + def run(self): + # store names of forwarded attributes which are only available within the run method + self.attributes_while_running = self.gather_forwarded_attributes() + + # invoke the run method of the RunOnceTask which marks this task as complete + RunOnceTask.run(self) + + +class YieldingTask(NonYieldingTask): + + def run(self): + # as TaskProcess._run_get_new_deps handles generators in a specific way, store names of + # forwarded attributes before and after yielding a dynamic dependency, so we can explicitely + # validate the attribute forwarding implementation + self.attributes_before_yield = self.gather_forwarded_attributes() + yield RunOnceTask() + self.attributes_after_yield = self.gather_forwarded_attributes() + + # invoke the run method of the RunOnceTask which marks this task as complete + RunOnceTask.run(self) + + +class TaskForwardedAttributesTest(LuigiTestCase): + + def run_task(self, task): + sch = luigi.scheduler.Scheduler() + with luigi.worker.Worker(scheduler=sch) as w: + w.add(task) + w.run() + return task + + def test_non_yielding_task(self): + task = self.run_task(NonYieldingTask()) + + self.assertEqual(task.attributes_while_running, FORWARDED_ATTRIBUTES) + + def test_yielding_task(self): + task = self.run_task(YieldingTask()) + + self.assertEqual(task.attributes_before_yield, FORWARDED_ATTRIBUTES) + self.assertEqual(task.attributes_after_yield, FORWARDED_ATTRIBUTES) diff --git a/test/testconfig/luigi.toml b/test/testconfig/luigi.toml new file mode 100644 index 0000000000..6c8e3409a3 --- /dev/null +++ b/test/testconfig/luigi.toml @@ -0,0 +1,7 @@ +[core] +logging_conf_file = "test/testconfig/logging.cfg" + +[hdfs] +client = "hadoopcli" +snakebite_autoconfig = false +namenode_host = "must be overridden in local config" diff --git a/test/testconfig/luigi_local.toml b/test/testconfig/luigi_local.toml new file mode 100644 index 0000000000..21330c1040 --- /dev/null +++ b/test/testconfig/luigi_local.toml @@ -0,0 +1,3 @@ +[hdfs] +namenode_host = "localhost" +namenode_port = 50030 diff --git a/test/visible_parameters_test.py b/test/visible_parameters_test.py new file mode 100644 index 0000000000..e644aa7cb0 --- /dev/null +++ b/test/visible_parameters_test.py @@ -0,0 +1,95 @@ +import luigi +from luigi.parameter import ParameterVisibility +from helpers import unittest +import json + + +class TestTask1(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', significant=True) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE, significant=True) + + +class TestTask2(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PRIVATE) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PRIVATE) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PRIVATE) + + +class TestTask3(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.HIDDEN, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.HIDDEN, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.HIDDEN, significant=True) + + +class TestTask4(luigi.Task): + param_one = luigi.Parameter(default='1', visibility=ParameterVisibility.PUBLIC, significant=True) + param_two = luigi.Parameter(default='2', visibility=ParameterVisibility.PUBLIC, significant=False) + param_three = luigi.Parameter(default='3', visibility=ParameterVisibility.PUBLIC, significant=True) + + +class Test(unittest.TestCase): + def test_to_str_params(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2'}) + + task = TestTask2() + + self.assertEqual(task.to_str_params(), {}) + + task = TestTask3() + + self.assertEqual(task.to_str_params(), {'param_one': '1', 'param_two': '2', 'param_three': '3'}) + + def test_all_public_equals_all_hidden(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(), hidden.to_str_params()) + + def test_all_public_equals_all_hidden_using_significant(self): + hidden = TestTask3() + public = TestTask4() + + self.assertEqual(public.to_str_params(only_significant=True), hidden.to_str_params(only_significant=True)) + + def test_private_params_and_significant(self): + task = TestTask1() + + self.assertEqual(task.to_str_params(), task.to_str_params(only_significant=True)) + + def test_param_visibilities(self): + task = TestTask1() + + self.assertEqual(task._get_param_visibilities(), {'param_one': 1, 'param_two': 0}) + + def test_incorrect_visibility_value(self): + class Task(luigi.Task): + a = luigi.Parameter(default='val', visibility=5) + + task = Task() + + self.assertEqual(task._get_param_visibilities(), {'a': 0}) + + def test_task_id_exclude_hidden_and_private_params(self): + task = TestTask1() + + self.assertEqual({'param_two': '2'}, task.to_str_params(only_public=True)) + + def test_json_dumps(self): + public = json.dumps(ParameterVisibility.PUBLIC.serialize()) + hidden = json.dumps(ParameterVisibility.HIDDEN.serialize()) + private = json.dumps(ParameterVisibility.PRIVATE.serialize()) + + self.assertEqual('0', public) + self.assertEqual('1', hidden) + self.assertEqual('2', private) + + public = json.loads(public) + hidden = json.loads(hidden) + private = json.loads(private) + + self.assertEqual(0, public) + self.assertEqual(1, hidden) + self.assertEqual(2, private) diff --git a/tox.ini b/tox.ini index 48f3d6bb0c..7ee65ee9df 100644 --- a/tox.ini +++ b/tox.ini @@ -35,6 +35,7 @@ deps= hypothesis[datetime] selenium==3.0.2 pymongo==3.4.0 + toml<2.0.0 passenv = USER JAVA_HOME POSTGRES_USER DATAPROC_TEST_PROJECT_ID GCS_TEST_PROJECT_ID GCS_TEST_BUCKET GOOGLE_APPLICATION_CREDENTIALS TRAVIS_BUILD_ID TRAVIS TRAVIS_BRANCH TRAVIS_JOB_NUMBER TRAVIS_PULL_REQUEST TRAVIS_JOB_ID TRAVIS_REPO_SLUG TRAVIS_COMMIT CI setenv = @@ -109,6 +110,7 @@ deps = boto3 Sphinx>=1.4.4,<1.5 sphinx_rtd_theme + enum34>1.1.0 commands = # build API docs sphinx-apidoc -o doc/api -T luigi --separate