Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update script to handle Optional and Union input parameters #1160

Merged
merged 14 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,4 @@ stop-argo: ## Stop the argo server
.PHONY: test-on-cluster
test-on-cluster: ## Run workflow tests (requires local argo cluster)
@(kubectl -n argo port-forward deployment/argo-server 2746:2746 &)
@poetry run python -m pytest tests/test_submission.py -m on_cluster
@poetry run python -m pytest tests/submissions -m on_cluster
16 changes: 12 additions & 4 deletions src/hera/workflows/_runner/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import json
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, cast
from typing import Any, Callable, Dict, List, Optional, Union, cast

from hera.shared._pydantic import _PYDANTIC_VERSION
from hera.shared.serialization import serialize
Expand Down Expand Up @@ -86,7 +86,7 @@ def _parse(value: str, key: str, f: Callable) -> Any:
The parsed value.

"""
if _is_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f):
if _can_be_str_kwarg_of(key, f) or _is_artifact_loaded(key, f) or _is_output_kwarg(key, f):
jeongukjae marked this conversation as resolved.
Show resolved Hide resolved
return value
try:
if os.environ.get("hera__script_annotations", None) is None:
Expand Down Expand Up @@ -132,13 +132,21 @@ def _get_unannotated_type(key: str, f: Callable) -> Optional[type]:
return type_


def _is_str_kwarg_of(key: str, f: Callable) -> bool:
"""Check if param `key` of function `f` has a type annotation of a subclass of str."""
def _can_be_str_kwarg_of(key: str, f: Callable) -> bool:
jeongukjae marked this conversation as resolved.
Show resolved Hide resolved
"""Check if param `key` of function `f` has a type annotation that can be interpreted as a subclass of str."""
func_param_annotation = inspect.signature(f).parameters[key].annotation
if func_param_annotation is inspect.Parameter.empty:
return False

type_ = _get_type(func_param_annotation)
if type_ is Union:
# Checking only Union[X, None] or Union[None, X] for given X which is subclass of str.
# Note that Optional[X] is alias of Union[X, None], so Optional is also handled in here.
args = get_args(func_param_annotation)
return len(args) == 2 and (
(args[0] is type(None) and issubclass(args[1], str))
or (issubclass(args[0], str) and args[1] is type(None))
)
return issubclass(type_, str)


Expand Down
9 changes: 9 additions & 0 deletions src/hera/workflows/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,15 @@ class will be used as inputs, rather than the class itself.
else:
default = MISSING

type_ = get_origin(func_param.annotation)
args = get_args(func_param.annotation)
if type_ is Annotated:
type_ = get_origin(args[0])
args = get_args(args[0])

if (type_ is Union and len(args) == 2 and args[1] is type(None)) and default is MISSING:
jeongukjae marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Optional parameter '{func_param.name}' doesn't have default value.")
jeongukjae marked this conversation as resolved.
Show resolved Hide resolved

parameters.append(Parameter(name=func_param.name, default=default))
else:
annotation = get_args(func_param.annotation)[1]
Expand Down
34 changes: 34 additions & 0 deletions tests/script_runner/parameter_with_complex_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
from typing import Optional, Union

from hera.shared import global_config
from hera.workflows import script

global_config.experimental_features["script_annotations"] = True


@script(constructor="runner")
def optional_str_parameter(my_string: Optional[str] = None) -> Optional[str]:
return my_string


@script(constructor="runner")
def optional_str_parameter_using_union(my_string: Union[None, str] = None) -> Union[None, str]:
return my_string


if sys.version_info[0] >= 3 and sys.version_info[1] >= 10:
# Union types using OR operator are allowed since python 3.10.
@script(constructor="runner")
def optional_str_parameter_using_or(my_string: str | None = None) -> str | None:
return my_string


@script(constructor="runner")
def optional_int_parameter(my_int: Optional[int] = None) -> Optional[int]:
return my_int


@script(constructor="runner")
def union_parameter(my_param: Union[str, int] = None) -> Union[str, int]:
return my_param
Empty file added tests/submissions/__init__.py
Empty file.
File renamed without changes.
51 changes: 51 additions & 0 deletions tests/submissions/test_optional_input_parameter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

import pytest

from hera.workflows import Parameter, Steps, Workflow, WorkflowsService, script
from hera.workflows.models import (
NodeStatus,
Parameter as ModelParameter,
)


@script(outputs=Parameter(name="message-out", value_from={"path": "/tmp/message-out"}))
def print_msg(message: Optional[str] = None):
with open("/tmp/message-out", "w") as f:
f.write("Got: {}".format(message))


def get_workflow() -> Workflow:
with Workflow(
generate_name="optional-param-",
entrypoint="steps",
namespace="argo",
workflows_service=WorkflowsService(
host="https://localhost:2746",
namespace="argo",
verify_ssl=False,
),
) as w:
with Steps(name="steps"):
print_msg(name="step-1", arguments={"message": "Hello world!"})
print_msg(name="step-2", arguments={})
print_msg(name="step-3")

return w


@pytest.mark.on_cluster
def test_create_workflow_with_optional_input_parameter():
model_workflow = get_workflow().create(wait=True)
assert model_workflow.status and model_workflow.status.phase == "Succeeded"

step_and_expected_output = {
"step-1": "Got: Hello world!",
"step-2": "Got: None",
"step-3": "Got: None",
}

for step, expected_output in step_and_expected_output.items():
node: NodeStatus = next(filter(lambda n: n.display_name == step, model_workflow.status.nodes.values()))
message_out: ModelParameter = next(filter(lambda n: n.name == "message-out", node.outputs.parameters))
assert message_out.value == expected_output
84 changes: 84 additions & 0 deletions tests/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import importlib
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Literal
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -978,3 +979,86 @@ def test_runner_pydantic_output_with_result(
for file in expected_files:
assert Path(tmp_path / file["subpath"]).is_file()
assert Path(tmp_path / file["subpath"]).read_text() == file["value"]


@pytest.mark.parametrize(
"entrypoint",
[
"tests.script_runner.parameter_with_complex_types:optional_str_parameter",
"tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_union",
]
+ (
# Union types using OR operator are allowed since python 3.10.
["tests.script_runner.parameter_with_complex_types:optional_str_parameter_using_or"]
if sys.version_info[0] >= 3 and sys.version_info[1] >= 10
else []
),
)
@pytest.mark.parametrize(
"kwargs_list,expected_output",
[
pytest.param(
[{"name": "my_string", "value": "a string"}],
"a string",
),
pytest.param(
[{"name": "my_string", "value": None}],
"null",
),
],
)
def test_script_optional_parameter(
monkeypatch: pytest.MonkeyPatch,
entrypoint,
kwargs_list,
expected_output,
):
# GIVEN
monkeypatch.setenv("hera__script_annotations", "")

# WHEN
output = _runner(entrypoint, kwargs_list)

# THEN
assert serialize(output) == expected_output


@pytest.mark.parametrize(
"entrypoint,kwargs_list,expected_output",
[
[
"tests.script_runner.parameter_with_complex_types:optional_int_parameter",
[{"name": "my_int", "value": 123}],
"123",
],
[
"tests.script_runner.parameter_with_complex_types:optional_int_parameter",
[{"name": "my_int", "value": None}],
"null",
],
[
"tests.script_runner.parameter_with_complex_types:union_parameter",
[{"name": "my_param", "value": "a string"}],
"a string",
],
[
"tests.script_runner.parameter_with_complex_types:union_parameter",
[{"name": "my_param", "value": 123}],
"123",
],
],
)
def test_script_with_complex_types(
monkeypatch: pytest.MonkeyPatch,
entrypoint,
kwargs_list,
expected_output,
):
# GIVEN
monkeypatch.setenv("hera__script_annotations", "")

# WHEN
output = _runner(entrypoint, kwargs_list)

# THEN
assert serialize(output) == expected_output
41 changes: 41 additions & 0 deletions tests/test_unit/test_script.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import Optional

import pytest

try:
from typing import Annotated # type: ignore
except ImportError:
Expand Down Expand Up @@ -117,3 +121,40 @@ def unknown_annotations_ignored(my_string: Annotated[str, "some metadata"]) -> s

assert parameter.name == "my_string"
assert parameter.default is None


def test_script_optional_parameter():
# GIVEN
@script()
def unknown_annotations_ignored(my_optional_string: Optional[str] = None) -> str:
return "Got: {}".format(my_optional_string)

# WHEN
params, artifacts = _get_inputs_from_callable(unknown_annotations_ignored)

# THEN
assert artifacts == []
assert isinstance(params, list)
assert len(params) == 1
parameter = params[0]

assert parameter.name == "my_optional_string"
assert parameter.default == "null"


def test_invalid_script_when_optional_parameter_does_not_have_default_value():
@script()
def unknown_annotations_ignored(my_optional_string: Optional[str]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' doesn't have default value."):
_get_inputs_from_callable(unknown_annotations_ignored)


def test_invalid_script_when_optional_parameter_does_not_have_default_value_2():
@script()
def unknown_annotations_ignored(my_optional_string: Annotated[Optional[str], "123"]) -> str:
return "Got: {}".format(my_optional_string)

with pytest.raises(ValueError, match="Optional parameter 'my_optional_string' doesn't have default value."):
_get_inputs_from_callable(unknown_annotations_ignored)