Skip to content

Commit

Permalink
Improve optional parameters to allow any parameter type to be optional
Browse files Browse the repository at this point in the history
  • Loading branch information
adrien-berchet committed Oct 21, 2021
1 parent ad5ddc9 commit 3cef59d
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 16 deletions.
10 changes: 8 additions & 2 deletions luigi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
DateIntervalParameter, TimeDeltaParameter,
IntParameter, FloatParameter, BoolParameter,
TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter,
NumericalParameter, ChoiceParameter, OptionalParameter
NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter,
OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter,
OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
OptionalChoiceParameter, OptionalNumericalParameter,
)

from luigi import configuration
Expand All @@ -60,7 +63,10 @@
'FloatParameter', 'BoolParameter', 'TaskParameter',
'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter',
'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event',
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'LuigiStatusCode',
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter',
'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter',
'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode',
'__version__',
]

Expand Down
115 changes: 106 additions & 9 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class DuplicateParameterException(ParameterException):
pass


class OptionalParameterTypeWarning(UserWarning):
"""
Warning class for OptionalParameterMixin with wrong type.
"""
pass


class Parameter:
"""
Parameter whose value is a ``str``, and a base class for other parameter types.
Expand Down Expand Up @@ -323,24 +330,62 @@ def _parser_kwargs(cls, param_name, task_name=None):
}


class OptionalParameter(Parameter):
""" A Parameter that treats empty string as None """
class OptionalParameterMixin:
"""
Mixin to make a parameter class optional and treat empty string as None.
"""

expected_type = str

def serialize(self, x):
"""
Parse the given value if the value is not None else return an empty string.
"""
if x is None:
return ''
else:
return str(x)
return super().serialize(x)

def parse(self, x):
return x or None
"""
Parse the given value if it is a string (empty strings are parsed to None).
"""
if not isinstance(x, str):
return x
elif x:
return super().parse(x)
else:
return None

def normalize(self, x):
"""
Normalize the given value if it is not None.
"""
if x is None:
return None
return super().normalize(x)

def _warn_on_wrong_param_type(self, param_name, param_value):
if self.__class__ != OptionalParameter:
return
if not isinstance(param_value, str) and param_value is not None:
warnings.warn('OptionalParameter "{}" with value "{}" is not of type string or None.'.format(
param_name, param_value))
if not isinstance(param_value, self.expected_type) and param_value is not None:
warnings.warn(
(
f'{self.__class__.__name__} "{param_name}" with value '
f'"{param_value}" is not of type "{self.expected_type.__name__}" or None.'
),
OptionalParameterTypeWarning,
)


class OptionalParameter(OptionalParameterMixin, Parameter):
"""Class to parse optional parameters."""

expected_type = str


class OptionalStrParameter(OptionalParameterMixin, Parameter):
"""Class to parse optional str parameters."""

expected_type = str


_UNIX_EPOCH = datetime.datetime.utcfromtimestamp(0)
Expand Down Expand Up @@ -627,6 +672,12 @@ def next_in_enumeration(self, value):
return value + 1


class OptionalIntParameter(OptionalParameterMixin, IntParameter):
"""Class to parse optional int parameters."""

expected_type = int


class FloatParameter(Parameter):
"""
Parameter whose value is a ``float``.
Expand All @@ -639,6 +690,12 @@ def parse(self, s):
return float(s)


class OptionalFloatParameter(OptionalParameterMixin, FloatParameter):
"""Class to parse optional float parameters."""

expected_type = float


class BoolParameter(Parameter):
"""
A Parameter whose value is a ``bool``. This parameter has an implicit default value of
Expand Down Expand Up @@ -709,6 +766,12 @@ def _parser_kwargs(self, *args, **kwargs):
return parser_kwargs


class OptionalBoolParameter(OptionalParameterMixin, BoolParameter):
"""Class to parse optional bool parameters."""

expected_type = bool


class DateIntervalParameter(Parameter):
"""
A Parameter whose value is a :py:class:`~luigi.date_interval.DateInterval`.
Expand Down Expand Up @@ -1007,6 +1070,12 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalDictParameter(OptionalParameterMixin, DictParameter):
"""Class to parse optional dict parameters."""

expected_type = FrozenOrderedDict


class ListParameter(Parameter):
"""
Parameter whose value is a ``list``.
Expand Down Expand Up @@ -1070,6 +1139,12 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalListParameter(OptionalParameterMixin, ListParameter):
"""Class to parse optional list parameters."""

expected_type = tuple


class TupleParameter(ListParameter):
"""
Parameter whose value is a ``tuple`` or ``tuple`` of tuples.
Expand Down Expand Up @@ -1125,6 +1200,12 @@ def parse(self, x):
return tuple(literal_eval(x)) # if this causes an error, let that error be raised.


class OptionalTupleParameter(OptionalParameterMixin, TupleParameter):
"""Class to parse optional tuple parameters."""

expected_type = tuple


class NumericalParameter(Parameter):
"""
Parameter whose value is a number of the specified type, e.g. ``int`` or
Expand Down Expand Up @@ -1201,6 +1282,14 @@ def parse(self, s):
s=s, permitted_range=self._permitted_range))


class OptionalNumericalParameter(OptionalParameterMixin, NumericalParameter):
"""Class to parse optional numerical parameters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_type = self._var_type


class ChoiceParameter(Parameter):
"""
A parameter which takes two values:
Expand Down Expand Up @@ -1257,3 +1346,11 @@ def normalize(self, var):
else:
raise ValueError("{var} is not a valid choice from {choices}".format(
var=var, choices=self._choices))


class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter):
"""Class to parse optional choice parameters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_type = self._var_type
104 changes: 104 additions & 0 deletions test/optional_parameter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import luigi
import mock

from helpers import LuigiTestCase, with_config


class OptionalParameterTest(LuigiTestCase):

def actual_test(current_test, cls, default, expected_value, expected_type, bad_data, **kwargs):

class TestConfig(luigi.Config):
param = cls(default=default, **kwargs)
empty_param = cls(default=default, **kwargs)

def run(self):
assert self.param == expected_value
assert self.empty_param is None

# Test parsing empty string (should be None)
current_test.assertIsNone(cls(**kwargs).parse(''))

# Test that warning is raised only with bad type
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig()
warnings.warn.assert_not_called()

if cls != luigi.OptionalChoiceParameter:
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=None)
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=bad_data)
if cls == luigi.OptionalBoolParameter:
warnings.warn.assert_not_called()
else:
warnings.warn.assert_called_with(
'{} "param" with value "{}" is not of type "{}" or None.'.format(
cls.__name__,
bad_data,
expected_type
),
luigi.parameter.OptionalParameterTypeWarning
)

# Test with value from config
current_test.assertTrue(luigi.build([TestConfig()], local_scheduler=True))

@with_config({"TestConfig": {"param": "expected value", "empty_param": ""}})
def test_optional_parameter(self):
self.actual_test(luigi.OptionalParameter, None, "expected value", "str", 0)
self.actual_test(luigi.OptionalParameter, "default value", "expected value", "str", 0)

@with_config({"TestConfig": {"param": "10", "empty_param": ""}})
def test_optional_int_parameter(self):
self.actual_test(luigi.OptionalIntParameter, None, 10, "int", "bad data")
self.actual_test(luigi.OptionalIntParameter, 1, 10, "int", "bad data")

@with_config({"TestConfig": {"param": "true", "empty_param": ""}})
def test_optional_bool_parameter(self):
self.actual_test(luigi.OptionalBoolParameter, None, True, "bool", "bad data")
self.actual_test(luigi.OptionalBoolParameter, False, True, "bool", "bad data")

@with_config({"TestConfig": {"param": "10.5", "empty_param": ""}})
def test_optional_float_parameter(self):
self.actual_test(luigi.OptionalFloatParameter, None, 10.5, "float", "bad data")
self.actual_test(luigi.OptionalFloatParameter, 1.5, 10.5, "float", "bad data")

@with_config({"TestConfig": {"param": '{"a": 10}', "empty_param": ""}})
def test_optional_dict_parameter(self):
self.actual_test(luigi.OptionalDictParameter, None, {"a": 10}, "FrozenOrderedDict", "bad data")
self.actual_test(luigi.OptionalDictParameter, {"a": 1}, {"a": 10}, "FrozenOrderedDict", "bad data")

@with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}})
def test_optional_list_parameter(self):
self.actual_test(luigi.OptionalListParameter, None, (10.5, ), "tuple", "bad data")
self.actual_test(luigi.OptionalListParameter, (1.5, ), (10.5, ), "tuple", "bad data")

@with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}})
def test_optional_tuple_parameter(self):
self.actual_test(luigi.OptionalTupleParameter, None, (10.5, ), "tuple", "bad data")
self.actual_test(luigi.OptionalTupleParameter, (1.5, ), (10.5, ), "tuple", "bad data")

@with_config({"TestConfig": {"param": "10.5", "empty_param": ""}})
def test_optional_numerical_parameter_float(self):
self.actual_test(luigi.OptionalNumericalParameter, None, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100)
self.actual_test(luigi.OptionalNumericalParameter, 1.5, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100)

@with_config({"TestConfig": {"param": "10", "empty_param": ""}})
def test_optional_numerical_parameter_int(self):
self.actual_test(luigi.OptionalNumericalParameter, None, 10, "int", "bad data", var_type=int, min_value=0, max_value=100)
self.actual_test(luigi.OptionalNumericalParameter, 1, 10, "int", "bad data", var_type=int, min_value=0, max_value=100)

@with_config({"TestConfig": {"param": "expected value", "empty_param": ""}})
def test_optional_choice_parameter(self):
choices = ["default value", "expected value"]
self.actual_test(luigi.OptionalChoiceParameter, None, "expected value", "str", "bad data", choices=choices)
self.actual_test(luigi.OptionalChoiceParameter, "default value", "expected value", "str", "bad data", choices=choices)

@with_config({"TestConfig": {"param": "1", "empty_param": ""}})
def test_optional_choice_parameter_int(self):
choices = [0, 1, 2]
self.actual_test(luigi.OptionalChoiceParameter, None, 1, "int", "bad data", var_type=int, choices=choices)
self.actual_test(luigi.OptionalChoiceParameter, "default value", 1, "int", "bad data", var_type=int, choices=choices)
18 changes: 13 additions & 5 deletions test/parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,21 @@ class TestConfig(luigi.Config):
TestConfig(param="str")
warnings.warn.assert_not_called()

@mock.patch('luigi.parameter.warnings')
def test_no_warn_on_none_in_optional(self, warnings):
def test_no_warn_on_none_in_optional(self):
class TestConfig(luigi.Config):
param = luigi.OptionalParameter(default=None)

TestConfig()
warnings.warn.assert_not_called()
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig()
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=None)
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param="")
warnings.warn.assert_not_called()

@mock.patch('luigi.parameter.warnings')
def test_no_warn_on_string_in_optional(self, warnings):
Expand All @@ -379,7 +387,7 @@ class TestConfig(luigi.Config):
param = luigi.OptionalParameter()

TestConfig(param=1)
warnings.warn.assert_called_once_with('OptionalParameter "param" with value "1" is not of type string or None.')
warnings.warn.assert_called_once_with('OptionalParameter "param" with value "1" is not of type "str" or None.', luigi.parameter.OptionalParameterTypeWarning)

def test_optional_parameter_parse_none(self):
self.assertIsNone(luigi.OptionalParameter().parse(''))
Expand Down

0 comments on commit 3cef59d

Please sign in to comment.