Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve optional parameters #3079

Merged
merged 3 commits into from
Oct 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions luigi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
DateIntervalParameter, TimeDeltaParameter,
IntParameter, FloatParameter, BoolParameter,
TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter,
NumericalParameter, ChoiceParameter, OptionalParameter
NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter,
OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter,
OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
OptionalChoiceParameter, OptionalNumericalParameter,
)

from luigi import configuration
Expand All @@ -60,7 +63,10 @@
'FloatParameter', 'BoolParameter', 'TaskParameter',
'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter',
'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event',
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'LuigiStatusCode',
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'OptionalStrParameter',
'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter',
'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode',
'__version__',
]

Expand Down
115 changes: 106 additions & 9 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class DuplicateParameterException(ParameterException):
pass


class OptionalParameterTypeWarning(UserWarning):
"""
Warning class for OptionalParameterMixin with wrong type.
"""
pass


class Parameter:
"""
Parameter whose value is a ``str``, and a base class for other parameter types.
Expand Down Expand Up @@ -323,24 +330,62 @@ def _parser_kwargs(cls, param_name, task_name=None):
}


class OptionalParameter(Parameter):
""" A Parameter that treats empty string as None """
class OptionalParameterMixin:
"""
Mixin to make a parameter class optional and treat empty string as None.
"""

expected_type = type(None)

def serialize(self, x):
"""
Parse the given value if the value is not None else return an empty string.
"""
if x is None:
return ''
else:
return str(x)
return super().serialize(x)

def parse(self, x):
return x or None
"""
Parse the given value if it is a string (empty strings are parsed to None).
"""
if not isinstance(x, str):
return x
elif x:
return super().parse(x)
else:
return None

def normalize(self, x):
"""
Normalize the given value if it is not None.
"""
if x is None:
return None
return super().normalize(x)

def _warn_on_wrong_param_type(self, param_name, param_value):
if self.__class__ != OptionalParameter:
return
if not isinstance(param_value, str) and param_value is not None:
warnings.warn('OptionalParameter "{}" with value "{}" is not of type string or None.'.format(
param_name, param_value))
if not isinstance(param_value, self.expected_type) and param_value is not None:
warnings.warn(
(
f'{self.__class__.__name__} "{param_name}" with value '
f'"{param_value}" is not of type "{self.expected_type.__name__}" or None.'
),
OptionalParameterTypeWarning,
)


class OptionalParameter(OptionalParameterMixin, Parameter):
"""Class to parse optional parameters."""

expected_type = str


class OptionalStrParameter(OptionalParameterMixin, Parameter):
"""Class to parse optional str parameters."""

expected_type = str


_UNIX_EPOCH = datetime.datetime.utcfromtimestamp(0)
Expand Down Expand Up @@ -627,6 +672,12 @@ def next_in_enumeration(self, value):
return value + 1


class OptionalIntParameter(OptionalParameterMixin, IntParameter):
"""Class to parse optional int parameters."""

expected_type = int


class FloatParameter(Parameter):
"""
Parameter whose value is a ``float``.
Expand All @@ -639,6 +690,12 @@ def parse(self, s):
return float(s)


class OptionalFloatParameter(OptionalParameterMixin, FloatParameter):
"""Class to parse optional float parameters."""

expected_type = float


class BoolParameter(Parameter):
"""
A Parameter whose value is a ``bool``. This parameter has an implicit default value of
Expand Down Expand Up @@ -709,6 +766,12 @@ def _parser_kwargs(self, *args, **kwargs):
return parser_kwargs


class OptionalBoolParameter(OptionalParameterMixin, BoolParameter):
"""Class to parse optional bool parameters."""

expected_type = bool


class DateIntervalParameter(Parameter):
"""
A Parameter whose value is a :py:class:`~luigi.date_interval.DateInterval`.
Expand Down Expand Up @@ -1007,6 +1070,12 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalDictParameter(OptionalParameterMixin, DictParameter):
"""Class to parse optional dict parameters."""

expected_type = FrozenOrderedDict


class ListParameter(Parameter):
"""
Parameter whose value is a ``list``.
Expand Down Expand Up @@ -1070,6 +1139,12 @@ def serialize(self, x):
return json.dumps(x, cls=_DictParamEncoder)


class OptionalListParameter(OptionalParameterMixin, ListParameter):
"""Class to parse optional list parameters."""

expected_type = tuple
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is OptionalListParameter intentionally a tuple? As opposed to a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is intentional because the _warn_on_wrong_param_type is called after normalize which calls recursively_freeze on the value, thus the list is transformed into a tuple before the type is checked. That's why we expect a tuple here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right. I vaguely recall that confusion. I believe I added List Parameter and Tuple Parameter separately and later wished I hadn't added both. But alas, we have them both and treat them nearly the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it would be simpler to have only one of these 2.



class TupleParameter(ListParameter):
"""
Parameter whose value is a ``tuple`` or ``tuple`` of tuples.
Expand Down Expand Up @@ -1125,6 +1200,12 @@ def parse(self, x):
return tuple(literal_eval(x)) # if this causes an error, let that error be raised.


class OptionalTupleParameter(OptionalParameterMixin, TupleParameter):
"""Class to parse optional tuple parameters."""

expected_type = tuple


class NumericalParameter(Parameter):
"""
Parameter whose value is a number of the specified type, e.g. ``int`` or
Expand Down Expand Up @@ -1201,6 +1282,14 @@ def parse(self, s):
s=s, permitted_range=self._permitted_range))


class OptionalNumericalParameter(OptionalParameterMixin, NumericalParameter):
"""Class to parse optional numerical parameters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_type = self._var_type


class ChoiceParameter(Parameter):
"""
A parameter which takes two values:
Expand Down Expand Up @@ -1257,3 +1346,11 @@ def normalize(self, var):
else:
raise ValueError("{var} is not a valid choice from {choices}".format(
var=var, choices=self._choices))


class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter):
"""Class to parse optional choice parameters."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.expected_type = self._var_type
104 changes: 104 additions & 0 deletions test/optional_parameter_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import luigi
import mock

from helpers import LuigiTestCase, with_config


class OptionalParameterTest(LuigiTestCase):

def actual_test(self, cls, default, expected_value, expected_type, bad_data, **kwargs):

class TestConfig(luigi.Config):
param = cls(default=default, **kwargs)
empty_param = cls(default=default, **kwargs)

def run(self):
assert self.param == expected_value
assert self.empty_param is None

# Test parsing empty string (should be None)
self.assertIsNone(cls(**kwargs).parse(''))

# Test that warning is raised only with bad type
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig()
warnings.warn.assert_not_called()

if cls != luigi.OptionalChoiceParameter:
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=None)
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=bad_data)
if cls == luigi.OptionalBoolParameter:
warnings.warn.assert_not_called()
else:
warnings.warn.assert_called_with(
'{} "param" with value "{}" is not of type "{}" or None.'.format(
cls.__name__,
bad_data,
expected_type
),
luigi.parameter.OptionalParameterTypeWarning
)

# Test with value from config
self.assertTrue(luigi.build([TestConfig()], local_scheduler=True))

@with_config({"TestConfig": {"param": "expected value", "empty_param": ""}})
def test_optional_parameter(self):
self.actual_test(luigi.OptionalParameter, None, "expected value", "str", 0)
self.actual_test(luigi.OptionalParameter, "default value", "expected value", "str", 0)

@with_config({"TestConfig": {"param": "10", "empty_param": ""}})
def test_optional_int_parameter(self):
self.actual_test(luigi.OptionalIntParameter, None, 10, "int", "bad data")
self.actual_test(luigi.OptionalIntParameter, 1, 10, "int", "bad data")

@with_config({"TestConfig": {"param": "true", "empty_param": ""}})
def test_optional_bool_parameter(self):
self.actual_test(luigi.OptionalBoolParameter, None, True, "bool", "bad data")
self.actual_test(luigi.OptionalBoolParameter, False, True, "bool", "bad data")

@with_config({"TestConfig": {"param": "10.5", "empty_param": ""}})
def test_optional_float_parameter(self):
self.actual_test(luigi.OptionalFloatParameter, None, 10.5, "float", "bad data")
self.actual_test(luigi.OptionalFloatParameter, 1.5, 10.5, "float", "bad data")

@with_config({"TestConfig": {"param": '{"a": 10}', "empty_param": ""}})
def test_optional_dict_parameter(self):
self.actual_test(luigi.OptionalDictParameter, None, {"a": 10}, "FrozenOrderedDict", "bad data")
self.actual_test(luigi.OptionalDictParameter, {"a": 1}, {"a": 10}, "FrozenOrderedDict", "bad data")

@with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}})
def test_optional_list_parameter(self):
self.actual_test(luigi.OptionalListParameter, None, (10.5, ), "tuple", "bad data")
self.actual_test(luigi.OptionalListParameter, (1.5, ), (10.5, ), "tuple", "bad data")

@with_config({"TestConfig": {"param": "[10.5]", "empty_param": ""}})
def test_optional_tuple_parameter(self):
self.actual_test(luigi.OptionalTupleParameter, None, (10.5, ), "tuple", "bad data")
self.actual_test(luigi.OptionalTupleParameter, (1.5, ), (10.5, ), "tuple", "bad data")

@with_config({"TestConfig": {"param": "10.5", "empty_param": ""}})
def test_optional_numerical_parameter_float(self):
self.actual_test(luigi.OptionalNumericalParameter, None, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100)
self.actual_test(luigi.OptionalNumericalParameter, 1.5, 10.5, "float", "bad data", var_type=float, min_value=0, max_value=100)

@with_config({"TestConfig": {"param": "10", "empty_param": ""}})
def test_optional_numerical_parameter_int(self):
self.actual_test(luigi.OptionalNumericalParameter, None, 10, "int", "bad data", var_type=int, min_value=0, max_value=100)
self.actual_test(luigi.OptionalNumericalParameter, 1, 10, "int", "bad data", var_type=int, min_value=0, max_value=100)

@with_config({"TestConfig": {"param": "expected value", "empty_param": ""}})
def test_optional_choice_parameter(self):
choices = ["default value", "expected value"]
self.actual_test(luigi.OptionalChoiceParameter, None, "expected value", "str", "bad data", choices=choices)
self.actual_test(luigi.OptionalChoiceParameter, "default value", "expected value", "str", "bad data", choices=choices)

@with_config({"TestConfig": {"param": "1", "empty_param": ""}})
def test_optional_choice_parameter_int(self):
choices = [0, 1, 2]
self.actual_test(luigi.OptionalChoiceParameter, None, 1, "int", "bad data", var_type=int, choices=choices)
self.actual_test(luigi.OptionalChoiceParameter, "default value", 1, "int", "bad data", var_type=int, choices=choices)
21 changes: 16 additions & 5 deletions test/parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,13 +357,21 @@ class TestConfig(luigi.Config):
TestConfig(param="str")
warnings.warn.assert_not_called()

@mock.patch('luigi.parameter.warnings')
def test_no_warn_on_none_in_optional(self, warnings):
def test_no_warn_on_none_in_optional(self):
class TestConfig(luigi.Config):
param = luigi.OptionalParameter(default=None)

TestConfig()
warnings.warn.assert_not_called()
with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig()
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param=None)
warnings.warn.assert_not_called()

with mock.patch('luigi.parameter.warnings') as warnings:
TestConfig(param="")
warnings.warn.assert_not_called()

@mock.patch('luigi.parameter.warnings')
def test_no_warn_on_string_in_optional(self, warnings):
Expand All @@ -379,7 +387,10 @@ class TestConfig(luigi.Config):
param = luigi.OptionalParameter()

TestConfig(param=1)
warnings.warn.assert_called_once_with('OptionalParameter "param" with value "1" is not of type string or None.')
warnings.warn.assert_called_once_with(
'OptionalParameter "param" with value "1" is not of type "str" or None.',
luigi.parameter.OptionalParameterTypeWarning
)

def test_optional_parameter_parse_none(self):
self.assertIsNone(luigi.OptionalParameter().parse(''))
Expand Down