From 45c29b34101f6fb40b5cc93dfe4f0e1017aa2d91 Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 3 Dec 2024 21:26:31 +0100 Subject: [PATCH] Store problem configuration in Problem Introduces Problem.config which contains the info from the PEtab yaml file. Sometimes it is convenient to have the original filenames around. Closes #324. --- petab/v1/problem.py | 119 ++++++++++++++++++++++++++++++----------- pyproject.toml | 1 + tests/v1/test_petab.py | 13 +++-- 3 files changed, 98 insertions(+), 35 deletions(-) diff --git a/petab/v1/problem.py b/petab/v1/problem.py index 4a5577eb..2caaf23e 100644 --- a/petab/v1/problem.py +++ b/petab/v1/problem.py @@ -10,6 +10,7 @@ from warnings import warn import pandas as pd +from pydantic import AnyUrl, BaseModel, Field, RootModel from . import ( conditions, @@ -78,6 +79,7 @@ def __init__( observable_df: pd.DataFrame = None, mapping_df: pd.DataFrame = None, extensions_config: dict = None, + config: ProblemConfig = None, ): self.condition_df: pd.DataFrame | None = condition_df self.measurement_df: pd.DataFrame | None = measurement_df @@ -112,6 +114,7 @@ def __init__( self.model: Model | None = model self.extensions_config = extensions_config or {} + self.config = config def __getattr__(self, name): # For backward-compatibility, allow access to SBML model related @@ -251,21 +254,32 @@ def from_files( ) @staticmethod - def from_yaml(yaml_config: dict | Path | str) -> Problem: + def from_yaml( + yaml_config: dict | Path | str, base_path: str | Path = None + ) -> Problem: """ Factory method to load model and tables as specified by YAML file. Arguments: yaml_config: PEtab configuration as dictionary or YAML file name + base_path: Base directory or URL to resolve relative paths """ + # path to the yaml file + filepath = None + if isinstance(yaml_config, Path): yaml_config = str(yaml_config) - get_path = lambda filename: filename # noqa: E731 if isinstance(yaml_config, str): - path_prefix = get_path_prefix(yaml_config) + filepath = yaml_config + if base_path is None: + base_path = get_path_prefix(yaml_config) yaml_config = yaml.load_yaml(yaml_config) - get_path = lambda filename: f"{path_prefix}/{filename}" # noqa: E731 + + def get_path(filename): + if base_path is None: + return filename + return f"{base_path}/{filename}" if yaml.is_composite_problem(yaml_config): raise ValueError( @@ -289,24 +303,25 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: DeprecationWarning, stacklevel=2, ) + config = ProblemConfig( + **yaml_config, base_path=base_path, filepath=filepath + ) + problem0 = config.problems[0] + # currently required for handling PEtab v2 in here + problem0_ = yaml_config["problems"][0] - problem0 = yaml_config["problems"][0] - - if isinstance(yaml_config[PARAMETER_FILE], list): + if isinstance(config.parameter_file, list): parameter_df = parameters.get_parameter_df( - [get_path(f) for f in yaml_config[PARAMETER_FILE]] + [get_path(f) for f in config.parameter_file] ) else: parameter_df = ( - parameters.get_parameter_df( - get_path(yaml_config[PARAMETER_FILE]) - ) - if yaml_config[PARAMETER_FILE] + parameters.get_parameter_df(get_path(config.parameter_file)) + if config.parameter_file else None ) - - if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]: - if len(problem0[SBML_FILES]) > 1: + if config.format_version.root in [1, "1", "1.0.0"]: + if len(problem0.sbml_files) > 1: # TODO https://github.com/PEtab-dev/libpetab-python/issues/6 raise NotImplementedError( "Support for multiple models is not yet implemented." @@ -314,24 +329,24 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: model = ( model_factory( - get_path(problem0[SBML_FILES][0]), + get_path(problem0.sbml_files[0]), MODEL_TYPE_SBML, model_id=None, ) - if problem0[SBML_FILES] + if problem0.sbml_files else None ) else: - if len(problem0[MODEL_FILES]) > 1: + if len(problem0_[MODEL_FILES]) > 1: # TODO https://github.com/PEtab-dev/libpetab-python/issues/6 raise NotImplementedError( "Support for multiple models is not yet implemented." ) - if not problem0[MODEL_FILES]: + if not problem0_[MODEL_FILES]: model = None else: model_id, model_info = next( - iter(problem0[MODEL_FILES].items()) + iter(problem0_[MODEL_FILES].items()) ) model = model_factory( get_path(model_info[MODEL_LOCATION]), @@ -339,9 +354,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: model_id=model_id, ) - measurement_files = [ - get_path(f) for f in problem0.get(MEASUREMENT_FILES, []) - ] + measurement_files = [get_path(f) for f in problem0.measurement_files] # If there are multiple tables, we will merge them measurement_df = ( core.concat_tables( @@ -351,9 +364,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: else None ) - condition_files = [ - get_path(f) for f in problem0.get(CONDITION_FILES, []) - ] + condition_files = [get_path(f) for f in problem0.condition_files] # If there are multiple tables, we will merge them condition_df = ( core.concat_tables(condition_files, conditions.get_condition_df) @@ -362,7 +373,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: ) visualization_files = [ - get_path(f) for f in problem0.get(VISUALIZATION_FILES, []) + get_path(f) for f in problem0.visualization_files ] # If there are multiple tables, we will merge them visualization_df = ( @@ -371,9 +382,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: else None ) - observable_files = [ - get_path(f) for f in problem0.get(OBSERVABLE_FILES, []) - ] + observable_files = [get_path(f) for f in problem0.observable_files] # If there are multiple tables, we will merge them observable_df = ( core.concat_tables(observable_files, observables.get_observable_df) @@ -381,7 +390,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: else None ) - mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])] + mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])] # If there are multiple tables, we will merge them mapping_df = ( core.concat_tables(mapping_files, mapping.get_mapping_df) @@ -398,6 +407,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem: visualization_df=visualization_df, mapping_df=mapping_df, extensions_config=yaml_config.get(EXTENSIONS, {}), + config=config, ) @staticmethod @@ -998,3 +1008,50 @@ def n_priors(self) -> int: return 0 return self.parameter_df[OBJECTIVE_PRIOR_PARAMETERS].notna().sum() + + +class VersionNumber(RootModel): + root: str | int + + +class ListOfFiles(RootModel): + """List of files.""" + + root: list[str | AnyUrl] = Field(..., description="List of files.") + + def __iter__(self): + return iter(self.root) + + def __len__(self): + return len(self.root) + + def __getitem__(self, index): + return self.root[index] + + +class SubProblem(BaseModel): + """A `problems` object in the PEtab problem configuration.""" + + sbml_files: ListOfFiles = [] + measurement_files: ListOfFiles = [] + condition_files: ListOfFiles = [] + observable_files: ListOfFiles = [] + visualization_files: ListOfFiles = [] + + +class ProblemConfig(BaseModel): + """The PEtab problem configuration.""" + + filepath: str | AnyUrl | None = Field( + None, + description="The path to the PEtab problem configuration.", + exclude=True, + ) + base_path: str | AnyUrl | None = Field( + None, + description="The base path to resolve relative paths.", + exclude=True, + ) + format_version: VersionNumber = 1 + parameter_file: str | AnyUrl | None = None + problems: list[SubProblem] = [] diff --git a/pyproject.toml b/pyproject.toml index 1758476a..2c8a1757 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "pyyaml", "jsonschema", "antlr4-python3-runtime==4.13.1", + "pydantic>=2.10", ] license = {text = "MIT License"} authors = [ diff --git a/tests/v1/test_petab.py b/tests/v1/test_petab.py index 65700af5..1a3f3344 100644 --- a/tests/v1/test_petab.py +++ b/tests/v1/test_petab.py @@ -862,11 +862,16 @@ def test_problem_from_yaml_v1_multiple_files(): observables_df, Path(tmpdir, f"observables{i}.tsv") ) - petab_problem = petab.Problem.from_yaml(yaml_path) + petab_problem1 = petab.Problem.from_yaml(yaml_path) - assert petab_problem.measurement_df.shape[0] == 2 - assert petab_problem.observable_df.shape[0] == 2 - assert petab_problem.condition_df.shape[0] == 2 + # test that we can load the problem from a dict with a custom base path + yaml_config = petab.v1.load_yaml(yaml_path) + petab_problem2 = petab.Problem.from_yaml(yaml_config, base_path=tmpdir) + + for petab_problem in (petab_problem1, petab_problem2): + assert petab_problem.measurement_df.shape[0] == 2 + assert petab_problem.observable_df.shape[0] == 2 + assert petab_problem.condition_df.shape[0] == 2 def test_get_required_parameters_for_parameter_table(petab_problem):