Skip to content

Commit

Permalink
Store problem configuration in Problem
Browse files Browse the repository at this point in the history
Introduces Problem.config which contains the info from the PEtab yaml file.

Sometimes it is convenient to have the original filenames around.

Closes #324.
  • Loading branch information
dweindl committed Dec 3, 2024
1 parent 0b77d7f commit 45c29b3
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 35 deletions.
119 changes: 88 additions & 31 deletions petab/v1/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from warnings import warn

import pandas as pd
from pydantic import AnyUrl, BaseModel, Field, RootModel

from . import (
conditions,
Expand Down Expand Up @@ -78,6 +79,7 @@ def __init__(
observable_df: pd.DataFrame = None,
mapping_df: pd.DataFrame = None,
extensions_config: dict = None,
config: ProblemConfig = None,
):
self.condition_df: pd.DataFrame | None = condition_df
self.measurement_df: pd.DataFrame | None = measurement_df
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(

self.model: Model | None = model
self.extensions_config = extensions_config or {}
self.config = config

def __getattr__(self, name):
# For backward-compatibility, allow access to SBML model related
Expand Down Expand Up @@ -251,21 +254,32 @@ def from_files(
)

@staticmethod
def from_yaml(yaml_config: dict | Path | str) -> Problem:
def from_yaml(
yaml_config: dict | Path | str, base_path: str | Path = None
) -> Problem:
"""
Factory method to load model and tables as specified by YAML file.
Arguments:
yaml_config: PEtab configuration as dictionary or YAML file name
base_path: Base directory or URL to resolve relative paths
"""
# path to the yaml file
filepath = None

if isinstance(yaml_config, Path):
yaml_config = str(yaml_config)

get_path = lambda filename: filename # noqa: E731
if isinstance(yaml_config, str):
path_prefix = get_path_prefix(yaml_config)
filepath = yaml_config
if base_path is None:
base_path = get_path_prefix(yaml_config)
yaml_config = yaml.load_yaml(yaml_config)
get_path = lambda filename: f"{path_prefix}/{filename}" # noqa: E731

def get_path(filename):
if base_path is None:
return filename
return f"{base_path}/{filename}"

if yaml.is_composite_problem(yaml_config):
raise ValueError(
Expand All @@ -289,59 +303,58 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
DeprecationWarning,
stacklevel=2,
)
config = ProblemConfig(
**yaml_config, base_path=base_path, filepath=filepath
)
problem0 = config.problems[0]
# currently required for handling PEtab v2 in here
problem0_ = yaml_config["problems"][0]

problem0 = yaml_config["problems"][0]

if isinstance(yaml_config[PARAMETER_FILE], list):
if isinstance(config.parameter_file, list):
parameter_df = parameters.get_parameter_df(
[get_path(f) for f in yaml_config[PARAMETER_FILE]]
[get_path(f) for f in config.parameter_file]
)
else:
parameter_df = (
parameters.get_parameter_df(
get_path(yaml_config[PARAMETER_FILE])
)
if yaml_config[PARAMETER_FILE]
parameters.get_parameter_df(get_path(config.parameter_file))
if config.parameter_file
else None
)

if yaml_config[FORMAT_VERSION] in [1, "1", "1.0.0"]:
if len(problem0[SBML_FILES]) > 1:
if config.format_version.root in [1, "1", "1.0.0"]:
if len(problem0.sbml_files) > 1:
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
raise NotImplementedError(
"Support for multiple models is not yet implemented."
)

model = (
model_factory(
get_path(problem0[SBML_FILES][0]),
get_path(problem0.sbml_files[0]),
MODEL_TYPE_SBML,
model_id=None,
)
if problem0[SBML_FILES]
if problem0.sbml_files
else None
)
else:
if len(problem0[MODEL_FILES]) > 1:
if len(problem0_[MODEL_FILES]) > 1:
# TODO https://github.com/PEtab-dev/libpetab-python/issues/6
raise NotImplementedError(
"Support for multiple models is not yet implemented."
)
if not problem0[MODEL_FILES]:
if not problem0_[MODEL_FILES]:
model = None
else:
model_id, model_info = next(
iter(problem0[MODEL_FILES].items())
iter(problem0_[MODEL_FILES].items())
)
model = model_factory(
get_path(model_info[MODEL_LOCATION]),
model_info[MODEL_LANGUAGE],
model_id=model_id,
)

measurement_files = [
get_path(f) for f in problem0.get(MEASUREMENT_FILES, [])
]
measurement_files = [get_path(f) for f in problem0.measurement_files]
# If there are multiple tables, we will merge them
measurement_df = (
core.concat_tables(
Expand All @@ -351,9 +364,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
else None
)

condition_files = [
get_path(f) for f in problem0.get(CONDITION_FILES, [])
]
condition_files = [get_path(f) for f in problem0.condition_files]
# If there are multiple tables, we will merge them
condition_df = (
core.concat_tables(condition_files, conditions.get_condition_df)
Expand All @@ -362,7 +373,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
)

visualization_files = [
get_path(f) for f in problem0.get(VISUALIZATION_FILES, [])
get_path(f) for f in problem0.visualization_files
]
# If there are multiple tables, we will merge them
visualization_df = (
Expand All @@ -371,17 +382,15 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
else None
)

observable_files = [
get_path(f) for f in problem0.get(OBSERVABLE_FILES, [])
]
observable_files = [get_path(f) for f in problem0.observable_files]
# If there are multiple tables, we will merge them
observable_df = (
core.concat_tables(observable_files, observables.get_observable_df)
if observable_files
else None
)

mapping_files = [get_path(f) for f in problem0.get(MAPPING_FILES, [])]
mapping_files = [get_path(f) for f in problem0_.get(MAPPING_FILES, [])]
# If there are multiple tables, we will merge them
mapping_df = (
core.concat_tables(mapping_files, mapping.get_mapping_df)
Expand All @@ -398,6 +407,7 @@ def from_yaml(yaml_config: dict | Path | str) -> Problem:
visualization_df=visualization_df,
mapping_df=mapping_df,
extensions_config=yaml_config.get(EXTENSIONS, {}),
config=config,
)

@staticmethod
Expand Down Expand Up @@ -998,3 +1008,50 @@ def n_priors(self) -> int:
return 0

return self.parameter_df[OBJECTIVE_PRIOR_PARAMETERS].notna().sum()


class VersionNumber(RootModel):
root: str | int


class ListOfFiles(RootModel):
"""List of files."""

root: list[str | AnyUrl] = Field(..., description="List of files.")

def __iter__(self):
return iter(self.root)

def __len__(self):
return len(self.root)

def __getitem__(self, index):
return self.root[index]


class SubProblem(BaseModel):
"""A `problems` object in the PEtab problem configuration."""

sbml_files: ListOfFiles = []
measurement_files: ListOfFiles = []
condition_files: ListOfFiles = []
observable_files: ListOfFiles = []
visualization_files: ListOfFiles = []


class ProblemConfig(BaseModel):
"""The PEtab problem configuration."""

filepath: str | AnyUrl | None = Field(
None,
description="The path to the PEtab problem configuration.",
exclude=True,
)
base_path: str | AnyUrl | None = Field(
None,
description="The base path to resolve relative paths.",
exclude=True,
)
format_version: VersionNumber = 1
parameter_file: str | AnyUrl | None = None
problems: list[SubProblem] = []
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"pyyaml",
"jsonschema",
"antlr4-python3-runtime==4.13.1",
"pydantic>=2.10",
]
license = {text = "MIT License"}
authors = [
Expand Down
13 changes: 9 additions & 4 deletions tests/v1/test_petab.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,11 +862,16 @@ def test_problem_from_yaml_v1_multiple_files():
observables_df, Path(tmpdir, f"observables{i}.tsv")
)

petab_problem = petab.Problem.from_yaml(yaml_path)
petab_problem1 = petab.Problem.from_yaml(yaml_path)

assert petab_problem.measurement_df.shape[0] == 2
assert petab_problem.observable_df.shape[0] == 2
assert petab_problem.condition_df.shape[0] == 2
# test that we can load the problem from a dict with a custom base path
yaml_config = petab.v1.load_yaml(yaml_path)
petab_problem2 = petab.Problem.from_yaml(yaml_config, base_path=tmpdir)

for petab_problem in (petab_problem1, petab_problem2):
assert petab_problem.measurement_df.shape[0] == 2
assert petab_problem.observable_df.shape[0] == 2
assert petab_problem.condition_df.shape[0] == 2


def test_get_required_parameters_for_parameter_table(petab_problem):
Expand Down

0 comments on commit 45c29b3

Please sign in to comment.