Skip to content

Commit 75db771

Browse files
ziw-liuedyoshikun
andcommitted
2D FCMAE (#71)
* refactor data loading into its own module * update type annotations * move the logging module out * move old logging into utils * rename tests to match module name * bump torch * draft fcmae encoder * add stem to the encoder * wip: masked stem layernorm * wip: patchify masked features for linear * use mlp from timm * hack: POC training script for FCMAE * fix mask for fitting * remove training script * default architecture * fine-tuning options * fix cli for finetuning * draft combined data module * fix import * manual validation loss reduction * update linting new black version has different rules * update development guide * update type hints * bump iohub * draft ctmc v1 dataset * update tests * move test_data * remove path conversion * configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu <ziwen.liu@czbiohub.org> * fix ctmc dataloading * add example ctmc v1 loading script * changing the normalization and augmentations default from None to empty list. * invert intensity transform * concatenated data module * subsample videos * livecell dataset * all sample fields are optional * fix multi-dataloader validation * lint * fixing preprocessing for varying array shapes (i.e aics dataset) * update loading scripts * fix CombineMode * always use untrainable head for FCMAE * move log values to GPU before syncing Lightning-AI/pytorch-lightning#18803 * custom head * ddp caching fixes * fix caching when using combined loader * compose normalizations for predict and test stages * black * fix normalization in example config * fix normalization in example config * prefetch more in validation * fix collate when multi-sample transform is not used * ddp caching fixes * fix caching when using combined loader * typing fixes * fix test dataset * fix invert transform * add ddp prepare flag for combined data module * remove redundant operations * filter empty detections * pass trainer to underlying data modules in concatenated * hack: add test dataloader for LiveCell dataset * test datasets for livecell and ctmc * fix merge error * fix merge error * fix mAP default for over 100 detections * bump torchmetric * fix combined loader training for virtual staining task * fix non-combined data loader training * add fcmae to graph script * fix type hint * format * add back convolutiuon option for fcmae head --------- Co-authored-by: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
1 parent 503a416 commit 75db771

File tree

11 files changed

+265
-58
lines changed

11 files changed

+265
-58
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dynamic = ["version"]
2626
metrics = [
2727
"cellpose==2.1.0",
2828
"scikit-learn>=1.1.3",
29-
"torchmetrics[detection]>=1.0.0",
29+
"torchmetrics[detection]>=1.3.1",
3030
"ptflops>=0.7",
3131
]
3232
visual = ["ipykernel", "graphviz", "torchview"]

tests/unet/test_fcmae.py

+21
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
MaskedConvNeXtV2Block,
77
MaskedConvNeXtV2Stage,
88
MaskedMultiscaleEncoder,
9+
PixelToVoxelShuffleHead,
910
generate_mask,
1011
masked_patchify,
1112
masked_unpatchify,
@@ -104,6 +105,13 @@ def test_masked_multiscale_encoder():
104105
assert afeat.shape[2] == afeat.shape[3] == xy_size // stride
105106

106107

108+
def test_pixel_to_voxel_shuffle_head():
109+
head = PixelToVoxelShuffleHead(240, 3, out_stack_depth=5, xy_scaling=4)
110+
x = torch.rand(2, 240, 16, 16)
111+
y = head(x)
112+
assert y.shape == (2, 3, 5, 64, 64)
113+
114+
107115
def test_fcmae():
108116
x = torch.rand(2, 3, 5, 128, 128)
109117
model = FullyConvolutionalMAE(3, 3)
@@ -113,3 +121,16 @@ def test_fcmae():
113121
y, m = model(x, mask_ratio=0.6)
114122
assert y.shape == x.shape
115123
assert m.shape == (2, 1, 128, 128)
124+
125+
126+
def test_fcmae_head_conv():
127+
x = torch.rand(2, 3, 5, 128, 128)
128+
model = FullyConvolutionalMAE(
129+
3, 3, head_conv=True, head_conv_expansion_ratio=4, head_conv_pool=True
130+
)
131+
y, m = model(x)
132+
assert y.shape == x.shape
133+
assert m is None
134+
y, m = model(x, mask_ratio=0.6)
135+
assert y.shape == x.shape
136+
assert m.shape == (2, 1, 128, 128)

viscy/data/combined.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ class ConcatDataModule(LightningDataModule):
7979
The concatenated data module will have the same
8080
batch size and number of workers as the first data module.
8181
Each element will be sampled uniformly regardless of their original data module.
82-
8382
:param Sequence[LightningDataModule] data_modules: data modules to concatenate
8483
"""
8584

@@ -93,9 +92,11 @@ def __init__(self, data_modules: Sequence[LightningDataModule]):
9392
raise ValueError("Inconsistent number of workers")
9493
if dm.batch_size != self.batch_size:
9594
raise ValueError("Inconsistent batch size")
95+
self.prepare_data_per_node = True
9696

9797
def prepare_data(self):
9898
for dm in self.data_modules:
99+
dm.trainer = self.trainer
99100
dm.prepare_data()
100101

101102
def setup(self, stage: Literal["fit", "validate", "test", "predict"]):

viscy/data/ctmc_v1.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@
1010

1111

1212
class CTMCv1ValidationDataset(SlidingWindowDataset):
13-
subsample_rate: int = 30
1413

15-
def __len__(self) -> int:
14+
def __len__(self, subsample_rate: int = 30) -> int:
1615
# sample every 30th frame in the videos
1716
return super().__len__() // self.subsample_rate
1817

viscy/data/hcs.py

-2
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ def __getitem__(self, index: int) -> Sample:
191191
sample_images["norm_meta"] = norm_meta
192192
if self.transform:
193193
sample_images = self.transform(sample_images)
194-
# if isinstance(sample_images, list):
195-
# sample_images = sample_images[0]
196194
if "weight" in sample_images:
197195
del sample_images["weight"]
198196
sample = {

viscy/data/livecell.py

+90-11
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33

44
import torch
55
from lightning.pytorch import LightningDataModule
6-
from monai.transforms import Compose, Transform
6+
from monai.transforms import Compose, MapTransform
7+
from pycocotools.coco import COCO
78
from tifffile import imread
89
from torch.utils.data import DataLoader, Dataset
10+
from torchvision.ops import box_convert
911

1012
from viscy.data.typing import Sample
1113

@@ -15,10 +17,10 @@ class LiveCellDataset(Dataset):
1517
LiveCell dataset.
1618
1719
:param list[Path] images: List of paths to single-page, single-channel TIFF files.
18-
:param Transform | Compose transform: Transform to apply to the dataset
20+
:param MapTransform | Compose transform: Transform to apply to the dataset
1921
"""
2022

21-
def __init__(self, images: list[Path], transform: Transform | Compose) -> None:
23+
def __init__(self, images: list[Path], transform: MapTransform | Compose) -> None:
2224
self.images = images
2325
self.transform = transform
2426

@@ -32,36 +34,100 @@ def __getitem__(self, idx: int) -> Sample:
3234
return {"source": image, "target": image}
3335

3436

37+
class LiveCellTestDataset(Dataset):
38+
"""
39+
LiveCell dataset.
40+
41+
:param list[Path] images: List of paths to single-page, single-channel TIFF files.
42+
:param MapTransform | Compose transform: Transform to apply to the dataset
43+
"""
44+
45+
def __init__(
46+
self,
47+
image_dir: Path,
48+
transform: MapTransform | Compose,
49+
annotations: Path,
50+
load_target: bool = False,
51+
load_labels: bool = False,
52+
) -> None:
53+
self.image_dir = image_dir
54+
self.transform = transform
55+
self.coco = COCO(str(annotations))
56+
self.image_ids = list(self.coco.imgs.keys())
57+
self.load_target = load_target
58+
self.load_labels = load_labels
59+
60+
def __len__(self) -> int:
61+
return len(self.image_ids)
62+
63+
def __getitem__(self, idx: int) -> Sample:
64+
image_id = self.image_ids[idx]
65+
file_name = self.coco.imgs[image_id]["file_name"]
66+
image_path = self.image_dir / file_name
67+
image = imread(image_path)[None, None]
68+
image = torch.from_numpy(image).to(torch.float32)
69+
sample = Sample(source=image)
70+
if self.load_target:
71+
sample["target"] = image
72+
if self.load_labels:
73+
anns = self.coco.loadAnns(self.coco.getAnnIds(image_id)) or []
74+
boxes = [torch.tensor(ann["bbox"]).to(torch.float32) for ann in anns]
75+
masks = [
76+
torch.from_numpy(self.coco.annToMask(ann)).to(torch.bool)
77+
for ann in anns
78+
]
79+
dets = {
80+
"boxes": box_convert(torch.stack(boxes), in_fmt="xywh", out_fmt="xyxy"),
81+
"labels": torch.zeros(len(anns)).to(torch.uint8),
82+
"masks": torch.stack(masks),
83+
}
84+
sample["detections"] = dets
85+
sample["file_name"] = file_name
86+
self.transform(sample)
87+
return sample
88+
89+
3590
class LiveCellDataModule(LightningDataModule):
3691
def __init__(
3792
self,
38-
train_val_images: Path,
39-
train_annotations: Path,
40-
val_annotations: Path,
41-
train_transforms: list[Transform],
42-
val_transforms: list[Transform],
93+
train_val_images: Path | None = None,
94+
test_images: Path | None = None,
95+
train_annotations: Path | None = None,
96+
val_annotations: Path | None = None,
97+
test_annotations: Path | None = None,
98+
train_transforms: list[MapTransform] = [],
99+
val_transforms: list[MapTransform] = [],
100+
test_transforms: list[MapTransform] = [],
43101
batch_size: int = 16,
44102
num_workers: int = 8,
45103
) -> None:
46104
super().__init__()
47105
self.train_val_images = Path(train_val_images)
48106
if not self.train_val_images.is_dir():
49107
raise NotADirectoryError(str(train_val_images))
108+
self.test_images = Path(test_images)
109+
if not self.test_images.is_dir():
110+
raise NotADirectoryError(str(test_images))
50111
self.train_annotations = Path(train_annotations)
51112
if not self.train_annotations.is_file():
52113
raise FileNotFoundError(str(train_annotations))
53114
self.val_annotations = Path(val_annotations)
54115
if not self.val_annotations.is_file():
55116
raise FileNotFoundError(str(val_annotations))
117+
self.test_annotations = Path(test_annotations)
118+
if not self.test_annotations.is_file():
119+
raise FileNotFoundError(str(test_annotations))
56120
self.train_transforms = Compose(train_transforms)
57121
self.val_transforms = Compose(val_transforms)
122+
self.test_transforms = Compose(test_transforms)
58123
self.batch_size = batch_size
59124
self.num_workers = num_workers
60125

61126
def setup(self, stage: str) -> None:
62-
if stage != "fit":
63-
raise NotImplementedError("Only fit stage is supported")
64-
self._setup_fit()
127+
if stage == "fit":
128+
self._setup_fit()
129+
elif stage == "test":
130+
self._setup_test()
65131

66132
def _parse_image_names(self, annotations: Path) -> list[Path]:
67133
with open(annotations) as f:
@@ -80,6 +146,14 @@ def _setup_fit(self) -> None:
80146
transform=self.val_transforms,
81147
)
82148

149+
def _setup_test(self) -> None:
150+
self.test_dataset = LiveCellTestDataset(
151+
self.test_images,
152+
transform=self.test_transforms,
153+
annotations=self.test_annotations,
154+
load_labels=True,
155+
)
156+
83157
def train_dataloader(self) -> DataLoader:
84158
return DataLoader(
85159
self.train_dataset,
@@ -96,3 +170,8 @@ def val_dataloader(self) -> DataLoader:
96170
num_workers=self.num_workers,
97171
persistent_workers=bool(self.num_workers),
98172
)
173+
174+
def test_dataloader(self) -> DataLoader:
175+
return DataLoader(
176+
self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers
177+
)

viscy/evaluation/evaluation_metrics.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from monai.metrics.regression import compute_ssim_and_cs
1010
from scipy.optimize import linear_sum_assignment
1111
from skimage.measure import label, regionprops
12-
from torchmetrics.detection import MeanAveragePrecision
12+
from torchmetrics.detection.mean_ap import MeanAveragePrecision
1313
from torchvision.ops import masks_to_boxes
1414

1515

@@ -172,7 +172,12 @@ def mean_average_precision(
172172
:py:class:`torchmetrics.detection.MeanAveragePrecision`
173173
:return dict[str, torch.Tensor]: COCO-style metrics
174174
"""
175-
map_metric = MeanAveragePrecision(box_format="xyxy", iou_type="segm", **kwargs)
175+
defaults = dict(
176+
iou_type="segm", box_format="xyxy", max_detection_thresholds=[1, 100, 10000]
177+
)
178+
if not kwargs:
179+
kwargs = {}
180+
map_metric = MeanAveragePrecision(**(defaults | kwargs))
176181
map_metric.update(
177182
[labels_to_detection(pred_labels)], [labels_to_detection(target_labels)]
178183
)

viscy/light/engine.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def __init__(
153153
self.log_batches_per_epoch = log_batches_per_epoch
154154
self.log_samples_per_batch = log_samples_per_batch
155155
self.training_step_outputs = []
156+
self.validation_losses = []
156157
self.validation_step_outputs = []
157158
# required to log the graph
158159
if architecture == "2D":
@@ -179,32 +180,49 @@ def __init__(
179180
def forward(self, x: Tensor) -> Tensor:
180181
return self.model(x)
181182

182-
def training_step(self, batch: Sample, batch_idx: int):
183-
source = batch["source"]
184-
target = batch["target"]
185-
pred = self.forward(source)
186-
loss = self.loss_function(pred, target)
183+
def training_step(self, batch: Sample | Sequence[Sample], batch_idx: int):
184+
losses = []
185+
batch_size = 0
186+
if not isinstance(batch, Sequence):
187+
batch = [batch]
188+
for b in batch:
189+
source = b["source"]
190+
target = b["target"]
191+
pred = self.forward(source)
192+
loss = self.loss_function(pred, target)
193+
losses.append(loss)
194+
batch_size += source.shape[0]
195+
if batch_idx < self.log_batches_per_epoch:
196+
self.training_step_outputs.extend(
197+
self._detach_sample((source, target, pred))
198+
)
199+
loss_step = torch.stack(losses).mean()
187200
self.log(
188201
"loss/train",
189-
loss,
202+
loss_step.to(self.device),
190203
on_step=True,
191204
on_epoch=True,
192205
prog_bar=True,
193206
logger=True,
194207
sync_dist=True,
208+
batch_size=batch_size,
195209
)
196-
if batch_idx < self.log_batches_per_epoch:
197-
self.training_step_outputs.extend(
198-
self._detach_sample((source, target, pred))
199-
)
200-
return loss
210+
return loss_step
201211

202212
def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
203-
source = batch["source"]
204-
target = batch["target"]
213+
source: Tensor = batch["source"]
214+
target: Tensor = batch["target"]
205215
pred = self.forward(source)
206216
loss = self.loss_function(pred, target)
207-
self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False)
217+
if dataloader_idx + 1 > len(self.validation_losses):
218+
self.validation_losses.append([])
219+
self.validation_losses[dataloader_idx].append(loss.detach())
220+
self.log(
221+
f"loss/val/{dataloader_idx}",
222+
loss.to(self.device),
223+
sync_dist=True,
224+
batch_size=source.shape[0],
225+
)
208226
if batch_idx < self.log_batches_per_epoch:
209227
self.validation_step_outputs.extend(
210228
self._detach_sample((source, target, pred))
@@ -364,8 +382,16 @@ def on_train_epoch_end(self):
364382
self.training_step_outputs = []
365383

366384
def on_validation_epoch_end(self):
385+
super().on_validation_epoch_end()
367386
self._log_samples("val_samples", self.validation_step_outputs)
368387
self.validation_step_outputs = []
388+
# average within each dataloader
389+
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
390+
self.log(
391+
"loss/validate",
392+
torch.tensor(loss_means).mean().to(self.device),
393+
sync_dist=True,
394+
)
369395

370396
def on_test_start(self):
371397
"""Load CellPose model for segmentation."""
@@ -477,7 +503,6 @@ class FcmaeUNet(VSUNet):
477503
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):
478504
super().__init__(architecture="fcmae", **kwargs)
479505
self.fit_mask_ratio = fit_mask_ratio
480-
self.validation_losses = []
481506

482507
def forward(self, x: Tensor, mask_ratio: float = 0.0):
483508
return self.model(x, mask_ratio)
@@ -529,13 +554,3 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0
529554
self.validation_step_outputs.extend(
530555
self._detach_sample((source, target * mask.unsqueeze(2), pred))
531556
)
532-
533-
def on_validation_epoch_end(self):
534-
super().on_validation_epoch_end()
535-
# average within each dataloader
536-
loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses]
537-
self.log(
538-
"loss/validate",
539-
torch.tensor(loss_means).mean().to(self.device),
540-
sync_dist=True,
541-
)

0 commit comments

Comments
 (0)