Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to parameter models for test case inputs #18743

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def parse_tool(tool_source: ToolSource) -> ParsedTool:
version = tool_source.parse_version()
name = tool_source.parse_name()
description = tool_source.parse_description()
inputs = input_models_for_tool_source(tool_source).input_models
inputs = input_models_for_tool_source(tool_source).parameters
outputs = from_tool_source(tool_source)
citations = tool_source.parse_citations()
license = tool_source.parse_license()
Expand Down
2 changes: 2 additions & 0 deletions lib/galaxy/tool_util/parameters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
repeat_inputs_to_array,
validate_explicit_conditional_test_value,
visit_input_values,
VISITOR_NO_REPLACEMENT,
)

__all__ = (
Expand Down Expand Up @@ -116,6 +117,7 @@
"keys_starting_with",
"visit_input_values",
"repeat_inputs_to_array",
"VISITOR_NO_REPLACEMENT",
"decode",
"encode",
"WorkflowStepToolState",
Expand Down
26 changes: 24 additions & 2 deletions lib/galaxy/tool_util/parameters/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
from packaging.version import Version

from galaxy.tool_util.parser.interface import (
TestCollectionDef,
ToolSource,
ToolSourceTest,
ToolSourceTestInput,
ToolSourceTestInputs,
xml_data_input_to_json,
XmlTestCollectionDefDict,
)
from galaxy.util import asbool
from .factory import input_models_for_tool_source
Expand All @@ -25,6 +28,7 @@
ConditionalWhen,
DataCollectionParameterModel,
DataColumnParameterModel,
DataParameterModel,
FloatParameterModel,
IntegerParameterModel,
RepeatParameterModel,
Expand Down Expand Up @@ -249,8 +253,26 @@ def _merge_into_state(
else:
test_input = _input_for(state_path, inputs)
if test_input is not None:
input_value: Any
if isinstance(tool_input, (DataCollectionParameterModel,)):
input_value = test_input.get("attributes", {}).get("collection")
input_value = TestCollectionDef.from_dict(
cast(XmlTestCollectionDefDict, test_input.get("attributes", {}).get("collection"))
).test_format_to_dict()
elif isinstance(tool_input, (DataParameterModel,)):
data_tool_input = cast(DataParameterModel, tool_input)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this cast necessary ? mypy has no problems with this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not know that - good to know. I think I've gotten in the habit of doing that with parameters because I so often am dispatching on the parameter_type string maybe. I think it kind of reads better this way but I will admit it is more code. I will defer to you and remove the casts in my working branch if you'd like.

if data_tool_input.multiple:
value = test_input["value"]
input_value_list = []
if value:
test_input_values = cast(str, value).split(",")
for test_input_value in test_input_values:
instance_test_input = test_input.copy()
instance_test_input["value"] = test_input_value
input_value = xml_data_input_to_json(test_input)
input_value_list.append(input_value)
input_value = input_value_list
else:
input_value = xml_data_input_to_json(test_input)
else:
input_value = test_input["value"]
input_value = legacy_from_string(tool_input, input_value, warnings, profile)
Expand Down Expand Up @@ -299,6 +321,6 @@ def validate_test_cases_for_tool_source(
test_cases: List[ToolSourceTest] = tool_source.parse_tests_to_dict()["tests"]
results_by_test: List[TestCaseStateValidationResult] = []
for test_case in test_cases:
validation_result = test_case_validation(test_case, tool_parameter_bundle.input_models, profile)
validation_result = test_case_validation(test_case, tool_parameter_bundle.parameters, profile)
results_by_test.append(validation_result)
return results_by_test
4 changes: 2 additions & 2 deletions lib/galaxy/tool_util/parameters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _from_input_source_cwl(input_source: CwlInputSource) -> ToolParameterT:


def input_models_from_json(json: List[Dict[str, Any]]) -> ToolParameterBundle:
return ToolParameterBundleModel(input_models=json)
return ToolParameterBundleModel(parameters=json)


def tool_parameter_bundle_from_json(json: Dict[str, Any]) -> ToolParameterBundleModel:
Expand All @@ -328,7 +328,7 @@ def tool_parameter_bundle_from_json(json: Dict[str, Any]) -> ToolParameterBundle

def input_models_for_tool_source(tool_source: ToolSource) -> ToolParameterBundleModel:
pages = tool_source.parse_input_pages()
return ToolParameterBundleModel(input_models=input_models_for_pages(pages))
return ToolParameterBundleModel(parameters=input_models_for_pages(pages))


def input_models_for_pages(pages: PagesSource) -> List[ToolParameterT]:
Expand Down
32 changes: 15 additions & 17 deletions lib/galaxy/tool_util/parameters/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from galaxy.exceptions import RequestParameterInvalidException
from galaxy.tool_util.parser.interface import (
DrillDownOptionsDict,
TestCollectionDefDict,
JsonTestCollectionDefDict,
JsonTestDatasetDefDict,
)
from ._types import (
cast_as_type,
Expand Down Expand Up @@ -312,9 +313,9 @@ def py_type_internal(self) -> Type:
def py_type_test_case(self) -> Type:
base_model: Type
if self.multiple:
base_model = str
base_model = list_type(JsonTestDatasetDefDict)
else:
base_model = str
base_model = JsonTestDatasetDefDict
return optional_if_needed(base_model, self.optional)

def pydantic_template(self, state_representation: StateRepresentationT) -> DynamicModelInformation:
Expand Down Expand Up @@ -372,7 +373,7 @@ def pydantic_template(self, state_representation: StateRepresentationT) -> Dynam
elif state_representation == "workflow_step_linked":
return dynamic_model_information_from_py_type(self, ConnectedValue)
elif state_representation == "test_case_xml":
return dynamic_model_information_from_py_type(self, TestCollectionDefDict)
return dynamic_model_information_from_py_type(self, JsonTestCollectionDefDict)
else:
raise NotImplementedError(
f"Have not implemented data collection parameter models for state representation {state_representation}"
Expand Down Expand Up @@ -1164,12 +1165,11 @@ class ToolParameterModel(RootModel):
class ToolParameterBundle(Protocol):
"""An object having a dictionary of input models (i.e. a 'Tool')"""

# TODO: rename to parameters to align with ConditionalWhen and Repeat.
input_models: List[ToolParameterT]
parameters: List[ToolParameterT]


class ToolParameterBundleModel(BaseModel):
input_models: List[ToolParameterT]
parameters: List[ToolParameterT]


def to_simple_model(input_parameter: Union[ToolParameterModel, ToolParameterT]) -> ToolParameterT:
Expand All @@ -1180,10 +1180,8 @@ def to_simple_model(input_parameter: Union[ToolParameterModel, ToolParameterT])
return cast(ToolParameterT, input_parameter)


def simple_input_models(
input_models: Union[List[ToolParameterModel], List[ToolParameterT]]
) -> Iterable[ToolParameterT]:
return [to_simple_model(m) for m in input_models]
def simple_input_models(parameters: Union[List[ToolParameterModel], List[ToolParameterT]]) -> Iterable[ToolParameterT]:
return [to_simple_model(m) for m in parameters]


def create_model_strict(*args, **kwd) -> Type[BaseModel]:
Expand All @@ -1194,27 +1192,27 @@ def create_model_strict(*args, **kwd) -> Type[BaseModel]:


def create_request_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "request")
return create_field_model(tool.parameters, name, "request")


def create_request_internal_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "request_internal")
return create_field_model(tool.parameters, name, "request_internal")


def create_job_internal_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "job_internal")
return create_field_model(tool.parameters, name, "job_internal")


def create_test_case_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "test_case_xml")
return create_field_model(tool.parameters, name, "test_case_xml")


def create_workflow_step_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step")
return create_field_model(tool.parameters, name, "workflow_step")


def create_workflow_step_linked_model(tool: ToolParameterBundle, name: str = "DynamicModelForTool") -> Type[BaseModel]:
return create_field_model(tool.input_models, name, "workflow_step_linked")
return create_field_model(tool.parameters, name, "workflow_step_linked")


def create_field_model(
Expand Down
38 changes: 19 additions & 19 deletions lib/galaxy/tool_util/parameters/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def __init__(self, input_state: Dict[str, Any]):
def _validate(self, pydantic_model: Type[BaseModel]) -> None:
validate_against_model(pydantic_model, self.input_state)

def validate(self, input_models: HasToolParameters) -> None:
base_model = self.parameter_model_for(input_models)
def validate(self, parameters: HasToolParameters) -> None:
base_model = self.parameter_model_for(parameters)
if base_model is None:
raise NotImplementedError(
f"Validating tool state against state representation {self.state_representation} is not implemented."
Expand All @@ -53,64 +53,64 @@ def state_representation(self) -> StateRepresentationT:
"""Get state representation of the inputs."""

@classmethod
def parameter_model_for(cls, input_models: HasToolParameters) -> Type[BaseModel]:
def parameter_model_for(cls, parameters: HasToolParameters) -> Type[BaseModel]:
bundle: ToolParameterBundle
if isinstance(input_models, list):
bundle = ToolParameterBundleModel(input_models=input_models)
if isinstance(parameters, list):
bundle = ToolParameterBundleModel(parameters=parameters)
else:
bundle = input_models
bundle = parameters
return cls._parameter_model_for(bundle)

@classmethod
@abstractmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
"""Return a model type for this tool state kind."""


class RequestToolState(ToolState):
state_representation: Literal["request"] = "request"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_model(parameters)


class RequestInternalToolState(ToolState):
state_representation: Literal["request_internal"] = "request_internal"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_request_internal_model(parameters)


class JobInternalToolState(ToolState):
state_representation: Literal["job_internal"] = "job_internal"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_job_internal_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_job_internal_model(parameters)


class TestCaseToolState(ToolState):
state_representation: Literal["test_case_xml"] = "test_case_xml"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
# implement a test case model...
return create_test_case_model(input_models)
return create_test_case_model(parameters)


class WorkflowStepToolState(ToolState):
state_representation: Literal["workflow_step"] = "workflow_step"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_model(parameters)


class WorkflowStepLinkedToolState(ToolState):
state_representation: Literal["workflow_step_linked"] = "workflow_step_linked"

@classmethod
def _parameter_model_for(cls, input_models: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_linked_model(input_models)
def _parameter_model_for(cls, parameters: ToolParameterBundle) -> Type[BaseModel]:
return create_workflow_step_linked_model(parameters)
2 changes: 1 addition & 1 deletion lib/galaxy/tool_util/parameters/visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def visit_input_values(
no_replacement_value=VISITOR_NO_REPLACEMENT,
) -> Dict[str, Any]:
return _visit_input_values(
simple_input_models(input_models.input_models),
simple_input_models(input_models.parameters),
tool_state.input_state,
callback=callback,
no_replacement_value=no_replacement_value,
Expand Down
Loading
Loading