Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add serializable parameter #411

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gokart/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from gokart.build import WorkerSchedulerFactory, build # noqa:F401
from gokart.info import make_tree_info, tree_info # noqa:F401
from gokart.pandas_type_config import PandasTypeConfig # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter # noqa:F401
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, SerializableParameter, TaskInstanceParameter # noqa:F401
from gokart.run import run # noqa:F401
from gokart.task import TaskOnKart # noqa:F401
from gokart.testing import test_run # noqa:F401
Expand Down
32 changes: 32 additions & 0 deletions gokart/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import bz2
import json
from logging import getLogger
from typing import Generic, Protocol, TypeVar

import luigi
from luigi import task_register
Expand Down Expand Up @@ -87,3 +88,34 @@ def __init__(self, *args, **kwargs):

def _parser_kwargs(self, *args, **kwargs): # type: ignore
return luigi.Parameter._parser_kwargs(*args, *kwargs)


T = TypeVar('T')


class Serializable(Protocol):
def gokart_serialize(self) -> str:
"""Implement this method to serialize the object as an parameter
You can omit some fields from results of serialization if you want to ignore changes of them
"""
...

@classmethod
def gokart_deserialize(cls: type[T], s: str) -> T:
"""Implement this method to deserialize the object from a string"""
...


S = TypeVar('S', bound=Serializable)


class SerializableParameter(luigi.Parameter, Generic[S]):
def __init__(self, object_type: type[S], *args, **kwargs):
self._object_type = object_type
super().__init__(*args, **kwargs)

def parse(self, s: str) -> S:
return self._object_type.gokart_deserialize(s)

def serialize(self, x: S) -> str:
return x.gokart_serialize()
83 changes: 83 additions & 0 deletions test/test_serializable_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
import tempfile
from dataclasses import asdict, dataclass

import luigi
import pytest
from luigi.cmdline_parser import CmdlineParser
from mypy import api

from gokart import SerializableParameter, TaskOnKart
from test.config import PYPROJECT_TOML


@dataclass(frozen=True)
class Config:
foo: int
bar: str

def gokart_serialize(self) -> str:
# dict is ordered in Python 3.7+
return json.dumps(asdict(self))

@classmethod
def gokart_deserialize(cls, s: str) -> 'Config':
return cls(**json.loads(s))


class SerializableParameterWithOutDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableParameter(object_type=Config)

def run(self):
self.dump(self.config)


class SerializableParameterWithDefault(TaskOnKart):
task_namespace = __name__
config: Config = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar'))

def run(self):
self.dump(self.config)


class TestSerializableParameter:
def test_default(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp:
assert cp.get_task_obj().config == Config(foo=1, bar='bar')

def test_parse_param(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp:
assert cp.get_task_obj().config == Config(foo=100, bar='val')

def test_missing_parameter(self):
with pytest.raises(luigi.parameter.MissingParameterException):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp:
cp.get_task_obj()

def test_value_error(self):
with pytest.raises(ValueError):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp:
cp.get_task_obj()

def test_expected_one_argument_error(self):
with pytest.raises(SystemExit):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp:
cp.get_task_obj()

def test_mypy(self):
"""check invalid object cannot used for SerializableParameter"""

test_code = """
import gokart

class InvalidClass:
...

gokart.SerializableParameter(object_type=InvalidClass)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0]