Skip to content

Commit

Permalink
feat: add ChoiceListParameter
Browse files Browse the repository at this point in the history
  • Loading branch information
kitagry committed Sep 16, 2024
1 parent b5d1b96 commit 2fe6a09
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 6 deletions.
12 changes: 6 additions & 6 deletions luigi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
DateIntervalParameter, TimeDeltaParameter,
IntParameter, FloatParameter, BoolParameter, PathParameter,
TaskParameter, EnumParameter, DictParameter, ListParameter, TupleParameter, EnumListParameter,
NumericalParameter, ChoiceParameter, OptionalParameter, OptionalStrParameter,
OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter, OptionalPathParameter,
OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
NumericalParameter, ChoiceParameter, ChoiceListParameter, OptionalParameter,
OptionalStrParameter, OptionalIntParameter, OptionalFloatParameter, OptionalBoolParameter,
OptionalPathParameter, OptionalDictParameter, OptionalListParameter, OptionalTupleParameter,
OptionalChoiceParameter, OptionalNumericalParameter,
)

Expand All @@ -66,9 +66,9 @@
'FloatParameter', 'BoolParameter', 'PathParameter', 'TaskParameter',
'ListParameter', 'TupleParameter', 'EnumParameter', 'DictParameter', 'EnumListParameter',
'configuration', 'interface', 'local_target', 'run', 'build', 'event', 'Event',
'NumericalParameter', 'ChoiceParameter', 'OptionalParameter', 'OptionalStrParameter',
'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter', 'OptionalPathParameter',
'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
'NumericalParameter', 'ChoiceParameter', 'ChoiceListParameter', 'OptionalParameter',
'OptionalStrParameter', 'OptionalIntParameter', 'OptionalFloatParameter', 'OptionalBoolParameter',
'OptionalPathParameter', 'OptionalDictParameter', 'OptionalListParameter', 'OptionalTupleParameter',
'OptionalChoiceParameter', 'OptionalNumericalParameter', 'LuigiStatusCode',
'__version__',
]
Expand Down
47 changes: 47 additions & 0 deletions luigi/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,53 @@ def normalize(self, var):
var=var, choices=self._choices))


class ChoiceListParameter(ChoiceParameter):
"""
A parameter which takes two values:
1. an instance of :class:`~collections.Iterable` and
2. the class of the variables to convert to.
Values are taken to be a list, i.e. order is preserved, duplicates may occur, and empty list is possible.
In the task definition, use
.. code-block:: python
class MyTask(luigi.Task):
my_param = luigi.ChoiceListParameter(choices=['foo', 'bar', 'baz'], var_type=str)
At the command line, use
.. code-block:: console
$ luigi --module my_tasks MyTask --my-param foo,bar
Consider using :class:`~luigi.EnumListParameter` for a typed, structured
alternative. This class can perform the same role when all choices are the
same type and transparency of parameter value on the command line is
desired.
"""

_sep = ','

def __init__(self, *args, **kwargs):
super(ChoiceListParameter, self).__init__(*args, **kwargs)

def parse(self, s):
values = [] if s == '' else s.split(self._sep)
return self.normalize(map(self._var_type, values))

def normalize(self, var):
values = []
for v in var:
values.append(super().normalize(v))
return tuple(values)


def serialize(self, values):
return self._sep.join(values)


class OptionalChoiceParameter(OptionalParameterMixin, ChoiceParameter):
"""Class to parse optional choice parameters."""

Expand Down
26 changes: 26 additions & 0 deletions test/parameter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,25 @@ def test_enum_list_param_invalid(self):
def test_enum_list_param_missing(self):
self.assertRaises(ParameterException, lambda: luigi.parameter.EnumListParameter())

def test_choice_list_param_valid(self):
p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"])
self.assertEqual((), p.parse(''))
self.assertEqual(("1",), p.parse('1'))
self.assertEqual(("1", "3"), p.parse('1,3'))

def test_choice_list_param_invalid(self):
p = luigi.parameter.ChoiceListParameter(choices=["1", "2", "3"])
self.assertRaises(ValueError, lambda: p.parse('1,4'))

def test_invalid_choice_type(self):
self.assertRaises(
AssertionError,
lambda: luigi.ChoiceListParameter(var_type=int, choices=[1, 2, "3"]),
)

def test_choice_list_param_missing(self):
self.assertRaises(ParameterException, lambda: luigi.parameter.ChoiceListParameter())

def test_tuple_serialize_parse(self):
a = luigi.TupleParameter()
b_tuple = ((1, 2), (3, 4))
Expand Down Expand Up @@ -469,6 +488,13 @@ class FooWithDefault(luigi.Task):

self.assertEqual(FooWithDefault().args, p.parse('C'))

def test_choice_list(self):
class Foo(luigi.Task):
args = luigi.ChoiceListParameter(var_type=str, choices=["1", "2", "3"])

p = luigi.ChoiceListParameter(var_type=str, choices=["3", "2", "1"])
self.assertEqual(hash(Foo(args=("3",)).args), hash(p.parse("3")))

def test_dict(self):
class Foo(luigi.Task):
args = luigi.parameter.DictParameter()
Expand Down

0 comments on commit 2fe6a09

Please sign in to comment.