Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade pydantic #349

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions btrack/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import logging
import os
from pathlib import Path
from typing import ClassVar, Optional
from typing import Optional

import numpy as np
from pydantic import BaseModel, conlist, validator
from pydantic import BaseModel, ConfigDict, conlist, field_validator

from btrack import _version

Expand Down Expand Up @@ -97,11 +97,11 @@ class TrackerConfig(BaseModel):
]
enable_optimisation = True

@validator("volume", pre=True, always=True)
@field_validator("volume", mode="before", check_fields=True)
def _parse_volume(cls, v):
return ImagingVolume(*v) if isinstance(v, tuple) else v

@validator("tracking_updates", pre=True, always=True)
@field_validator("tracking_updates", mode="before", check_fields=True)
def _parse_tracking_updates(cls, v):
_tracking_updates = v
if all(isinstance(k, str) for k in _tracking_updates):
Expand All @@ -111,12 +111,13 @@ def _parse_tracking_updates(cls, v):
_tracking_updates = list(set(_tracking_updates))
return _tracking_updates

class Config:
arbitrary_types_allowed = True
validate_assignment = True
json_encoders: ClassVar[dict] = {
model_config = ConfigDict(
arbitrary_types_allowed=True,
validate_assignment=True,
json_encoders={
np.ndarray: lambda x: x.ravel().tolist(),
}
},
)


def load_config(filename: os.PathLike) -> TrackerConfig:
Expand Down Expand Up @@ -177,5 +178,5 @@ def save_config(filename: os.PathLike, cfg: TrackerConfig) -> None:
"""

with open(filename, "w") as json_file:
json_data = json.loads(cfg.json())
json_data = json.loads(cfg.model_dump_json())
json.dump(json_data, json_file, indent=2, separators=(",", ": "))
4 changes: 2 additions & 2 deletions btrack/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def configure(
self._config = configuration

# set all configuration options using setattr
for attr in configuration.__fields__:
for attr in configuration.model_fields:
setattr(self, attr, getattr(configuration, attr))

self._initialised = True
Expand All @@ -197,7 +197,7 @@ def __setattr__(self, attr, value):
if not attr.startswith("_") and self.configuration.verbose:
logger.info(f"Setting {attr} -> {value}")

if attr in config.TrackerConfig.__fields__:
if attr in config.TrackerConfig.model_fields:
setattr(self.configuration, attr, value)
else:
object.__setattr__(self, attr, value)
Expand Down
39 changes: 17 additions & 22 deletions btrack/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

import numpy as np
from pydantic import BaseModel, root_validator, validator
from pydantic import BaseModel, ConfigDict, field_validator, model_validator

from . import constants
from .optimise.hypothesis import H_TYPES, PyHypothesisParams
Expand Down Expand Up @@ -90,54 +90,54 @@ class MotionModel(BaseModel):
prob_not_assign: float = constants.PROB_NOT_ASSIGN
name: str = "Default"

@validator("A", "H", "P", "R", "G", "Q", pre=True)
@field_validator("A", "H", "P", "R", "G", "Q", mode="before")
def parse_arrays(cls, v):
if isinstance(v, dict):
m = v.get("matrix", None)
s = v.get("sigma", 1.0)
return np.asarray(m, dtype=float) * s
return np.asarray(v, dtype=float)

@validator("A")
@field_validator("A")
def reshape_A(cls, a, values):
shape = (values["states"], values["states"])
return np.reshape(a, shape)

@validator("H")
@field_validator("H")
def reshape_H(cls, h, values):
shape = (values["measurements"], values["states"])
return np.reshape(h, shape)

@validator("P")
@field_validator("P")
def reshape_P(cls, p, values):
shape = (values["states"], values["states"])
p = np.reshape(p, shape)
if not _check_symmetric(p):
raise ValueError("Matrix `P` is not symmetric.")
return p

@validator("R")
@field_validator("R")
def reshape_R(cls, r, values):
shape = (values["measurements"], values["measurements"])
r = np.reshape(r, shape)
if not _check_symmetric(r):
raise ValueError("Matrix `R` is not symmetric.")
return r

@validator("G")
@field_validator("G")
def reshape_G(cls, g, values):
shape = (1, values["states"])
return np.reshape(g, shape)

@validator("Q")
@field_validator("Q")
def reshape_Q(cls, q, values):
shape = (values["states"], values["states"])
q = np.reshape(q, shape)
if not _check_symmetric(q):
raise ValueError("Matrix `Q` is not symmetric.")
return q

@root_validator
@model_validator
def validate_motion_model(cls, values):
if values["Q"] is None:
G = values.get("G", None)
Expand All @@ -146,9 +146,7 @@ def validate_motion_model(cls, values):
values["Q"] = G.T @ G
return values

class Config:
arbitrary_types_allowed = True
validate_assignment = True
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)


class ObjectModel(BaseModel):
Expand Down Expand Up @@ -178,23 +176,21 @@ class ObjectModel(BaseModel):
start: np.ndarray
name: str = "Default"

@validator("emission", "transition", "start", pre=True)
@field_validator("emission", "transition", "start", mode="before")
def parse_array(cls, v, values):
return np.asarray(v, dtype=float)

@validator("emission", "transition", "start", pre=True)
@field_validator("emission", "transition", "start", mode="before")
def reshape_emission_transition(cls, v, values):
shape = (values["states"], values["states"])
return np.reshape(v, shape)

@validator("emission", "transition", "start", pre=True)
@field_validator("emission", "transition", "start", mode="before")
def reshape_start(cls, v, values):
shape = (1, values["states"])
return np.reshape(v, shape)

class Config:
arbitrary_types_allowed = True
validate_assignment = True
model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True)


class HypothesisModel(BaseModel):
Expand Down Expand Up @@ -273,7 +269,7 @@ class HypothesisModel(BaseModel):
relax: bool
name: str = "Default"

@validator("hypotheses", pre=True)
@field_validator("hypotheses", mode="before")
def parse_hypotheses(cls, hypotheses):
if not all(h in H_TYPES for h in hypotheses):
raise ValueError("Unknown hypothesis type in `hypotheses`.")
Expand All @@ -289,13 +285,12 @@ def as_ctype(self) -> PyHypothesisParams:
h_params = PyHypothesisParams()
fields = [f[0] for f in h_params._fields_]

for k, v in self.dict().items():
for k, v in self.model_dump().items():
if k in fields:
setattr(h_params, k, v)

# set the hypotheses to generate
h_params.hypotheses_to_generate = self.hypotheses_to_generate()
return h_params

class Config:
validate_assignment = True
model_config = ConfigDict(validate_assignment=True)
15 changes: 7 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@ classifiers = [
]
dependencies = [
"cvxopt>=1.3.1",
"h5py>=2.10.0",
"numpy>=1.17.3",
"h5py>=3.9.0",
"numpy>=1.25.1",
"pandas>=2.0.3",
"pooch>=1.0.0",
"pydantic<2",
"scikit-image>=0.16.2",
"scipy>=1.3.1",
"pooch>=1.7.0",
"pydantic>=2.0.3",
"scikit-image>=0.21.0",
"tqdm>=4.65.0",
]
description = "A framework for Bayesian multi-object tracking"
Expand Down Expand Up @@ -144,8 +143,8 @@ isort.sections.napari = [
mccabe.max-complexity = 18
pep8-naming.classmethod-decorators = [
"classmethod",
"pydantic.root_validator",
"pydantic.validator",
"pydantic.field_validator",
"pydantic.model_validator",
]

[tool.setuptools]
Expand Down
8 changes: 4 additions & 4 deletions tests/napari/test_dock_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def test_config_to_widgets_round_trip(track_widget, config):
config objects and widgets works as expected.
"""

expected_config = btrack.config.load_config(config).json()
expected_config = btrack.config.load_config(config).model_dump_json()

unscaled_config = btrack.napari.config.UnscaledTrackerConfig(config)
btrack.napari.sync.update_widgets_from_config(unscaled_config, track_widget)
btrack.napari.sync.update_config_from_widgets(unscaled_config, track_widget)

actual_config = unscaled_config.scale_config().json()
actual_config = unscaled_config.scale_config().model_dump_json()

# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(actual_config) == json.loads(expected_config)
Expand All @@ -60,15 +60,15 @@ def test_save_button(track_widget):
)
# this is done in in the gui too
unscaled_config.tracker_config.name = "cell"
expected_config = unscaled_config.scale_config().json()
expected_config = unscaled_config.scale_config().model_dump_json()

with patch(
"btrack.napari.widgets.save_path_dialogue_box"
) as save_path_dialogue_box:
save_path_dialogue_box.return_value = "user_config.json"
track_widget.save_config_button.click()

actual_config = btrack.config.load_config("user_config.json").json()
actual_config = btrack.config.load_config("user_config.json").model_dump_json()

# use json.loads to avoid failure in string comparison because e.g "100.0" != "100"
assert json.loads(expected_config) == json.loads(actual_config)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_import_config():
def test_config_to_json():
"""Test that a config can be converted to json format without raising an error"""
cfg = btrack.config.load_config(CONFIG_FILE)
cfg.json()
cfg.model_dump_json()


def test_config_tracker_setters():
Expand All @@ -82,7 +82,7 @@ def test_config_tracker_setters():
def _cfg_dict() -> tuple[dict, dict]:
cfg_raw = btrack.config.load_config(CONFIG_FILE)
cfg = _random_config()
cfg.update(cfg_raw.dict())
cfg.update(cfg_raw.model_dump())
assert isinstance(cfg, dict)
return cfg, cfg

Expand All @@ -91,7 +91,7 @@ def _cfg_file() -> tuple[str, dict]:
filename = CONFIG_FILE
assert isinstance(filename, str)
cfg = btrack.config.load_config(filename)
return filename, cfg.dict()
return filename, cfg.model_dump()


def _cfg_pydantic() -> tuple[btrack.config.TrackerConfig, dict]:
Expand All @@ -100,7 +100,7 @@ def _cfg_pydantic() -> tuple[btrack.config.TrackerConfig, dict]:
for key, value in options.items():
setattr(cfg, key, value)
assert isinstance(cfg, btrack.config.TrackerConfig)
return cfg, cfg.dict()
return cfg, cfg.model_dump()


@pytest.mark.parametrize("get_cfg", [_cfg_file, _cfg_dict, _cfg_pydantic])
Expand Down