Skip to content

Commit dc816ee

Browse files
ziw-liuedyoshikun
andcommitted
Masked autoencoder pre-training for virtual staining models (#67)
* refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu <ziwen.liu@czbiohub.org> * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * compose normalizations for predict and test stages * black * fix normalization in example config * fix collate when multi-sample transform is not used * ddp caching fixes * fix caching when using combined loader * move log values to GPU before syncing Lightning-AI/pytorch-lightning#18803 * removing normalize_source from configs. * typing fixes * fix test data path * fix test dataset * add docstring for ConcatDataModule * format --------- Co-authored-by: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
1 parent 435f659 commit dc816ee

36 files changed

+1527
-264
lines changed

CONTRIBUTING.md

+13-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies:
1010
pip install -e ".[dev,visual,metrics]"
1111
```
1212

13-
## Testing
13+
## CI requirements
14+
15+
Lint with Ruff:
16+
17+
```sh
18+
ruff check viscy
19+
```
20+
21+
Format the code with Black:
22+
23+
```sh
24+
black viscy
25+
```
1426

1527
Run tests with `pytest`:
1628

examples/configs/fit_example.yml

+14-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ data:
3737
batch_size: 32
3838
num_workers: 16
3939
yx_patch_size: [256, 256]
40+
normalizations:
41+
- class_path: viscy.transforms.NormalizeSampled
42+
init_args:
43+
keys: [source]
44+
level: "fov_statistics"
45+
subtrahend: "mean"
46+
divisor: "std"
47+
- class_path: viscy.transforms.NormalizeSampled
48+
init_args:
49+
keys: [target_1]
50+
level: "fov_statistics"
51+
subtrahend: "median"
52+
divisor: "iqr"
4053
augmentations:
4154
- class_path: viscy.transforms.RandWeightedCropd
4255
init_args:
@@ -74,5 +87,4 @@ data:
7487
sigma_z: [0.25, 1.5]
7588
sigma_y: [0.25, 1.5]
7689
sigma_x: [0.25, 1.5]
77-
caching: false
78-
normalize_source: true
90+
caching: false

examples/configs/predict_example.yml

-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ predict:
6262
- 256
6363
- 256
6464
caching: false
65-
normalize_source: false
6665
predict_scale_source: null
6766
return_predictions: false
6867
ckpt_path: null

examples/configs/test_example.yml

-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ data:
6161
- 256
6262
- 256
6363
caching: false
64-
normalize_source: false
6564
ground_truth_masks: null
6665
ckpt_path: null
6766
verbose: true

examples/demo_dlmbl/debug_log_graph.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
2020

2121
# HCSDataModule makes it easy to load data during training.
22-
from viscy.light.data import HCSDataModule
22+
from viscy.data.hcs import HCSDataModule
2323

2424
# Trainer class and UNet.
2525
from viscy.light.engine import VSUNet

examples/demo_dlmbl/solution.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@
8383
from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard
8484

8585
# HCSDataModule makes it easy to load data during training.
86-
from viscy.light.data import HCSDataModule
86+
from viscy.data.hcs import HCSDataModule
8787

8888
# training augmentations
8989
from viscy.transforms import (

pyproject.toml

+13-8
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ requires-python = ">=3.10"
1010
license = { file = "LICENSE" }
1111
authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }]
1212
dependencies = [
13-
"iohub==0.1.0rc0",
14-
"torch>=2.0.0",
13+
"iohub==0.1.0",
14+
"torch>=2.1.2",
1515
"timm>=0.9.5",
1616
"tensorboard>=2.13.0",
1717
"lightning>=2.0.1",
@@ -30,7 +30,15 @@ metrics = [
3030
"ptflops>=0.7",
3131
]
3232
visual = ["ipykernel", "graphviz", "torchview"]
33-
dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"]
33+
dev = [
34+
"pytest",
35+
"pytest-cov",
36+
"hypothesis",
37+
"ruff",
38+
"black",
39+
"profilehooks",
40+
"onnxruntime",
41+
]
3442

3543
[project.scripts]
3644
viscy = "viscy.cli.cli:main"
@@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main"
3947
write_to = "viscy/_version.py"
4048

4149
[tool.black]
42-
src = ["viscy"]
4350
line-length = 88
4451

4552
[tool.ruff]
4653
src = ["viscy", "tests"]
47-
extend-select = ["I001"]
48-
49-
[tool.ruff.isort]
50-
known-first-party = ["viscy"]
54+
lint.extend-select = ["I001"]
55+
lint.isort.known-first-party = ["viscy"]

tests/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path:
3636
norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names}
3737
with open_ome_zarr(dataset_path, mode="r+") as dataset:
3838
dataset.zattrs["normalization"] = norm_meta
39+
for _, fov in dataset.positions():
40+
fov.zattrs["normalization"] = norm_meta
3941
return dataset_path
4042

4143

tests/data/__init__.py

Whitespace-only changes.

tests/data/test_data.py

+105
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from pathlib import Path
2+
3+
from iohub import open_ome_zarr
4+
from monai.transforms import RandSpatialCropSamplesd
5+
from pytest import mark
6+
7+
from viscy.data.hcs import HCSDataModule
8+
from viscy.light.trainer import VSTrainer
9+
10+
11+
@mark.parametrize("default_channels", [True, False])
12+
def test_preprocess(small_hcs_dataset: Path, default_channels: bool):
13+
data_path = small_hcs_dataset
14+
if default_channels:
15+
channel_names = -1
16+
else:
17+
with open_ome_zarr(data_path) as dataset:
18+
channel_names = dataset.channel_names
19+
trainer = VSTrainer(accelerator="cpu")
20+
trainer.preprocess(data_path, channel_names=channel_names, num_workers=2)
21+
with open_ome_zarr(data_path) as dataset:
22+
channel_names = dataset.channel_names
23+
for channel in channel_names:
24+
assert "dataset_statistics" in dataset.zattrs["normalization"][channel]
25+
for _, fov in dataset.positions():
26+
norm_metadata = fov.zattrs["normalization"]
27+
for channel in channel_names:
28+
assert channel in norm_metadata
29+
assert "dataset_statistics" in norm_metadata[channel]
30+
assert "fov_statistics" in norm_metadata[channel]
31+
32+
33+
@mark.parametrize("multi_sample_augmentation", [True, False])
34+
def test_datamodule_setup_fit(preprocessed_hcs_dataset, multi_sample_augmentation):
35+
data_path = preprocessed_hcs_dataset
36+
z_window_size = 5
37+
channel_split = 2
38+
split_ratio = 0.8
39+
yx_patch_size = [128, 96]
40+
batch_size = 4
41+
with open_ome_zarr(data_path) as dataset:
42+
channel_names = dataset.channel_names
43+
if multi_sample_augmentation:
44+
transforms = [
45+
RandSpatialCropSamplesd(
46+
keys=channel_names,
47+
roi_size=[z_window_size, *yx_patch_size],
48+
num_samples=2,
49+
)
50+
]
51+
else:
52+
transforms = []
53+
dm = HCSDataModule(
54+
data_path=data_path,
55+
source_channel=channel_names[:channel_split],
56+
target_channel=channel_names[channel_split:],
57+
z_window_size=z_window_size,
58+
batch_size=batch_size,
59+
num_workers=0,
60+
augmentations=transforms,
61+
architecture="3D",
62+
split_ratio=split_ratio,
63+
yx_patch_size=yx_patch_size,
64+
)
65+
dm.setup(stage="fit")
66+
for batch in dm.train_dataloader():
67+
assert batch["source"].shape == (
68+
batch_size,
69+
channel_split,
70+
z_window_size,
71+
*yx_patch_size,
72+
)
73+
assert batch["target"].shape == (
74+
batch_size,
75+
len(channel_names) - channel_split,
76+
z_window_size,
77+
*yx_patch_size,
78+
)
79+
80+
81+
def test_datamodule_setup_predict(preprocessed_hcs_dataset):
82+
data_path = preprocessed_hcs_dataset
83+
z_window_size = 5
84+
channel_split = 2
85+
with open_ome_zarr(data_path) as dataset:
86+
channel_names = dataset.channel_names
87+
img = next(dataset.positions())[1][0]
88+
total_p = len(list(dataset.positions()))
89+
dm = HCSDataModule(
90+
data_path=data_path,
91+
source_channel=channel_names[:channel_split],
92+
target_channel=channel_names[channel_split:],
93+
z_window_size=z_window_size,
94+
batch_size=2,
95+
num_workers=0,
96+
)
97+
dm.setup(stage="predict")
98+
dataset = dm.predict_dataset
99+
assert len(dataset) == total_p * 2 * (img.slices - z_window_size + 1)
100+
assert dataset[0]["source"].shape == (
101+
channel_split,
102+
z_window_size,
103+
img.height,
104+
img.width,
105+
)

tests/light/test_data.py

-70
This file was deleted.

tests/light/test_engine.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from viscy.light.engine import FcmaeUNet
2+
3+
4+
def test_fcmae_vsunet() -> None:
5+
model = FcmaeUNet(
6+
model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6
7+
)

tests/unet/__init__.py

Whitespace-only changes.
File renamed without changes.

0 commit comments

Comments
 (0)