From 44a6e94c05e11930863f1908298e657e4041b594 Mon Sep 17 00:00:00 2001 From: Hironori Yamamoto Date: Wed, 27 Nov 2024 19:13:39 +0900 Subject: [PATCH] feat: add serializable parameter --- gokart/__init__.py | 2 +- gokart/parameter.py | 26 +++++++ test/test_serializable_object_parameter.py | 83 ++++++++++++++++++++++ 3 files changed, 110 insertions(+), 1 deletion(-) create mode 100644 test/test_serializable_object_parameter.py diff --git a/gokart/__init__.py b/gokart/__init__.py index 25e54f41..e2ad78e8 100644 --- a/gokart/__init__.py +++ b/gokart/__init__.py @@ -1,7 +1,7 @@ from gokart.build import WorkerSchedulerFactory, build # noqa:F401 from gokart.info import make_tree_info, tree_info # noqa:F401 from gokart.pandas_type_config import PandasTypeConfig # noqa:F401 -from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401 +from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, SerializableObjectParameter, TaskInstanceParameter # noqa:F401 from gokart.run import run # noqa:F401 from gokart.task import TaskOnKart # noqa:F401 from gokart.testing import test_run # noqa:F401 diff --git a/gokart/parameter.py b/gokart/parameter.py index de8c9556..1b87b300 100644 --- a/gokart/parameter.py +++ b/gokart/parameter.py @@ -1,6 +1,7 @@ import bz2 import json from logging import getLogger +from typing import Generic, Protocol, TypeVar import luigi from luigi import task_register @@ -87,3 +88,28 @@ def __init__(self, *args, **kwargs): def _parser_kwargs(self, *args, **kwargs): # type: ignore return luigi.Parameter._parser_kwargs(*args, *kwargs) + + +T = TypeVar('T') + + +class Serializable(Protocol): + def serialize(self) -> str: ... + + @classmethod + def deserialize(cls: type[T], s: str) -> T: ... + + +S = TypeVar('S', bound=Serializable) + + +class SerializableObjectParameter(luigi.Parameter, Generic[S]): + def __init__(self, object_type: type[S], *args, **kwargs): + self._object_type = object_type + luigi.Parameter.__init__(self, *args, **kwargs) + + def parse(self, s: str) -> S: + return self._object_type.deserialize(s) + + def serialize(self, x: S) -> str: + return x.serialize() diff --git a/test/test_serializable_object_parameter.py b/test/test_serializable_object_parameter.py new file mode 100644 index 00000000..c9c397f2 --- /dev/null +++ b/test/test_serializable_object_parameter.py @@ -0,0 +1,83 @@ +import json +import tempfile +from dataclasses import asdict, dataclass + +import luigi +import pytest +from luigi.cmdline_parser import CmdlineParser +from mypy import api + +from gokart import SerializableObjectParameter, TaskOnKart +from test.config import PYPROJECT_TOML + + +@dataclass(frozen=True) +class Config: + foo: int + bar: str + + def serialize(self) -> str: + # dict is ordered in Python 3.7+ + return json.dumps(asdict(self)) + + @classmethod + def deserialize(cls, s: str) -> 'Config': + return cls(**json.loads(s)) + + +class SerializableObjectParameterWithOutDefault(TaskOnKart): + task_namespace = __name__ + config: Config = SerializableObjectParameter(object_type=Config) + + def run(self): + self.dump(self.config) + + +class SerializableObjectParameterWithDefault(TaskOnKart): + task_namespace = __name__ + config: Config = SerializableObjectParameter(object_type=Config, default=Config(foo=1, bar='bar')) + + def run(self): + self.dump(self.config) + + +class TestSerializableObjectParameter: + def test_default(self): + with CmdlineParser.global_instance([f'{__name__}.SerializableObjectParameterWithDefault']) as cp: + assert cp.get_task_obj().config == Config(foo=1, bar='bar') + + def test_parse_param(self): + with CmdlineParser.global_instance([f'{__name__}.SerializableObjectParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp: + assert cp.get_task_obj().config == Config(foo=100, bar='val') + + def test_missing_parameter(self): + with pytest.raises(luigi.parameter.MissingParameterException): + with CmdlineParser.global_instance([f'{__name__}.SerializableObjectParameterWithOutDefault']) as cp: + cp.get_task_obj() + + def test_value_error(self): + with pytest.raises(ValueError): + with CmdlineParser.global_instance([f'{__name__}.SerializableObjectParameterWithOutDefault', '--config', 'Foo']) as cp: + cp.get_task_obj() + + def test_expected_one_argument_error(self): + with pytest.raises(SystemExit): + with CmdlineParser.global_instance([f'{__name__}.SerializableObjectParameterWithOutDefault', '--config']) as cp: + cp.get_task_obj() + + def test_mypy(self): + """check invalid object cannot used for SerializableObjectParameter""" + + test_code = """ +import gokart + +class InvalidClass: + ... + +gokart.SerializableObjectParameter(object_type=InvalidClass) + """ + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + assert 'Value of type variable "S" of "SerializableObjectParameter" cannot be "InvalidClass" [type-var]' in result[0]