Skip to content

Commit

Permalink
Pydantic models for parameter validators.
Browse files Browse the repository at this point in the history
  • Loading branch information
jmchilton committed Oct 8, 2024
1 parent fbdd054 commit b5b504c
Show file tree
Hide file tree
Showing 28 changed files with 1,577 additions and 304 deletions.
11 changes: 11 additions & 0 deletions lib/galaxy/tool_util/parameters/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

from typing import (
Any,
cast,
List,
Optional,
Expand All @@ -15,6 +16,7 @@

# https://stackoverflow.com/questions/56832881/check-if-a-field-is-typing-optional
from typing_extensions import (
Annotated,
get_args,
get_origin,
)
Expand Down Expand Up @@ -46,3 +48,12 @@ def cast_as_type(arg) -> Type:

def is_optional(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)


def expand_annotation(field: Type, new_annotations: List[Any]) -> Type:
is_annotation = get_origin(field) is Annotated
if is_annotation:
args = get_args(field) # noqa: F841
return Annotated[tuple([args[0], *args[1:], *new_annotations])] # type: ignore[return-value]
else:
return Annotated[tuple([field, *new_annotations])] # type: ignore[return-value]
76 changes: 73 additions & 3 deletions lib/galaxy/tool_util/parameters/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,19 @@
PagesSource,
ToolSource,
)
from galaxy.tool_util.parser.util import parse_profile_version
from galaxy.tool_util.parser.parameter_validators import (
EmptyFieldParameterValidatorModel,
ExpressionParameterValidatorModel,
InRangeParameterValidatorModel,
LengthParameterValidatorModel,
NoOptionsParameterValidatorModel,
RegexParameterValidatorModel,
static_validators,
)
from galaxy.tool_util.parser.util import (
parse_profile_version,
text_input_is_optional,
)
from galaxy.util import string_as_bool
from .models import (
BaseUrlParameterModel,
Expand Down Expand Up @@ -42,10 +54,13 @@
HiddenParameterModel,
IntegerParameterModel,
LabelValue,
NumberCompatiableValidators,
RepeatParameterModel,
RulesParameterModel,
SectionParameterModel,
SelectCompatiableValidators,
SelectParameterModel,
TextCompatiableValidators,
TextParameterModel,
ToolParameterBundle,
ToolParameterBundleModel,
Expand Down Expand Up @@ -82,7 +97,23 @@ def _from_input_source_galaxy(input_source: InputSource, profile: float) -> Tool
int_value = None
else:
raise ParameterDefinitionError()
return IntegerParameterModel(name=input_source.parse_name(), optional=optional, value=int_value)
static_validator_models = static_validators(input_source.parse_validators())
int_validators: List[NumberCompatiableValidators] = []
for static_validator in static_validator_models:
if static_validator.type == "in_range":
int_validators.append(cast(InRangeParameterValidatorModel, static_validator))
min_raw = input_source.get("min", None)
max_raw = input_source.get("max", None)
min_int = int(min_raw) if min_raw is not None else None
max_int = int(max_raw) if max_raw is not None else None
return IntegerParameterModel(
name=input_source.parse_name(),
optional=optional,
value=int_value,
min=min_int,
max=max_int,
validators=int_validators,
)
elif param_type == "boolean":
nullable = input_source.parse_optional()
value = input_source.get_bool_or_none("checked", None if nullable else False)
Expand All @@ -92,10 +123,12 @@ def _from_input_source_galaxy(input_source: InputSource, profile: float) -> Tool
value=value,
)
elif param_type == "text":
optional = input_source.parse_optional()
optional, optionality_inferred = text_input_is_optional(input_source)
text_validators: List[TextCompatiableValidators] = _text_validators(input_source)
return TextParameterModel(
name=input_source.parse_name(),
optional=optional,
validators=text_validators,
)
elif param_type == "float":
optional = input_source.parse_optional()
Expand All @@ -107,18 +140,32 @@ def _from_input_source_galaxy(input_source: InputSource, profile: float) -> Tool
float_value = None
else:
raise ParameterDefinitionError()
static_validator_models = static_validators(input_source.parse_validators())
float_validators: List[NumberCompatiableValidators] = []
for static_validator in static_validator_models:
if static_validator.type == "in_range":
float_validators.append(cast(InRangeParameterValidatorModel, static_validator))
min_raw = input_source.get("min", None)
max_raw = input_source.get("max", None)
min_float = float(min_raw) if min_raw is not None else None
max_float = float(max_raw) if max_raw is not None else None
return FloatParameterModel(
name=input_source.parse_name(),
optional=optional,
value=float_value,
min=min_float,
max=max_float,
validators=float_validators,
)
elif param_type == "hidden":
optional = input_source.parse_optional()
value = input_source.get("value")
text_validators: List[TextCompatiableValidators] = _text_validators(input_source)
return HiddenParameterModel(
name=input_source.parse_name(),
optional=optional,
value=value,
validators=text_validators,
)
elif param_type == "color":
optional = input_source.parse_optional()
Expand Down Expand Up @@ -158,11 +205,17 @@ def _from_input_source_galaxy(input_source: InputSource, profile: float) -> Tool
options = []
for option_label, option_value, selected in input_source.parse_static_options():
options.append(LabelValue(label=option_label, value=option_value, selected=selected))
static_validator_models = static_validators(input_source.parse_validators())
select_validators: List[SelectCompatiableValidators] = []
for static_validator in static_validator_models:
if static_validator.type == "no_options":
select_validators.append(cast(NoOptionsParameterValidatorModel, static_validator))
return SelectParameterModel(
name=input_source.parse_name(),
optional=optional,
options=options,
multiple=multiple,
validators=select_validators,
)
elif param_type == "drill_down":
multiple = input_source.get_bool("multiple", False)
Expand Down Expand Up @@ -206,8 +259,10 @@ def _from_input_source_galaxy(input_source: InputSource, profile: float) -> Tool
multiple=multiple,
)
elif param_type == "directory_uri":
text_validators: List[TextCompatiableValidators] = _text_validators(input_source)
return DirectoryUriParameterModel(
name=input_source.parse_name(),
validators=text_validators,
)
else:
raise UnknownParameterTypeError(f"Unknown Galaxy parameter type {param_type}")
Expand Down Expand Up @@ -304,6 +359,21 @@ def _simple_cwl_type_to_model(simple_type: str, input_source: CwlInputSource):
)


def _text_validators(input_source: InputSource) -> List[TextCompatiableValidators]:
static_validator_models = static_validators(input_source.parse_validators())
text_validators: List[TextCompatiableValidators] = []
for static_validator in static_validator_models:
if static_validator.type == "length":
text_validators.append(cast(LengthParameterValidatorModel, static_validator))
elif static_validator.type == "regex":
text_validators.append(cast(RegexParameterValidatorModel, static_validator))
elif static_validator.type == "expression":
text_validators.append(cast(ExpressionParameterValidatorModel, static_validator))
elif static_validator.type == "empty_field":
text_validators.append(cast(EmptyFieldParameterValidatorModel, static_validator))
return text_validators


def _from_input_source_cwl(input_source: CwlInputSource) -> ToolParameterT:
schema_salad_field = input_source.field
if schema_salad_field is None:
Expand Down
Loading

0 comments on commit b5b504c

Please sign in to comment.