Skip to content

Commit

Permalink
Merge c4a234b into fc4a5e6
Browse files Browse the repository at this point in the history
  • Loading branch information
mzouink authored Nov 15, 2024
2 parents fc4a5e6 + c4a234b commit 97abfa0
Show file tree
Hide file tree
Showing 20 changed files with 664 additions and 50 deletions.
7 changes: 4 additions & 3 deletions dacapo/compute_context/local_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def device(self):
if self._device is None:
if torch.cuda.is_available():
# TODO: make this more sophisticated, for multiple GPUs for instance
free = torch.cuda.mem_get_info()[0] / 1024**3
if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
return torch.device("cpu")
# commented out code below is for checking free memory and falling back on CPU, whhen model in GPU and memory is low model get moved to CPU
# free = torch.cuda.mem_get_info()[0] / 1024**3
# if free < self.oom_limit: # less than 1 GB free, decrease chance of OOM
# return torch.device("cpu")
return torch.device("cuda")
# Multiple MPS ops are not available yet : https://github.com/pytorch/pytorch/issues/77764
# got error aten::max_pool3d_with_indices
Expand Down
23 changes: 12 additions & 11 deletions dacapo/experiments/architectures/cnnectome_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn as nn

from funlib.geometry import Coordinate

import math


Expand Down Expand Up @@ -176,7 +178,7 @@ def __init__(self, architecture_config):
self.unet = self.module()

@property
def eval_shape_increase(self):
def eval_shape_increase(self) -> Coordinate:
"""
The increase in shape due to the U-Net.
Expand All @@ -192,7 +194,7 @@ def eval_shape_increase(self):
"""
if self._eval_shape_increase is None:
return super().eval_shape_increase
return self._eval_shape_increase
return Coordinate(self._eval_shape_increase)

def module(self):
"""
Expand Down Expand Up @@ -235,16 +237,15 @@ def module(self):
"""
fmaps_in = self.fmaps_in
levels = len(self.downsample_factors) + 1
dims = len(self.downsample_factors[0])

if hasattr(self, "kernel_size_down"):
if self.kernel_size_down is not None:
kernel_size_down = self.kernel_size_down
else:
kernel_size_down = [[(3,) * dims, (3,) * dims]] * levels
if hasattr(self, "kernel_size_up"):
kernel_size_down = [[(3,) * self.dims, (3,) * self.dims]] * levels
if self.kernel_size_up is not None:
kernel_size_up = self.kernel_size_up
else:
kernel_size_up = [[(3,) * dims, (3,) * dims]] * (levels - 1)
kernel_size_up = [[(3,) * self.dims, (3,) * self.dims]] * (levels - 1)

# downsample factors has to be a list of tuples
downsample_factors = [tuple(x) for x in self.downsample_factors]
Expand Down Expand Up @@ -280,7 +281,7 @@ def module(self):
conv = ConvPass(
self.fmaps_out,
self.fmaps_out,
[(3,) * len(upsample_factor)] * 2,
kernel_size_up[-1],
activation="ReLU",
batch_norm=self.batch_norm,
)
Expand All @@ -306,11 +307,11 @@ def scale(self, voxel_size):
The voxel size should be given as a tuple ``(z, y, x)``.
"""
for upsample_factor in self.upsample_factors:
voxel_size = voxel_size / upsample_factor
voxel_size = voxel_size / Coordinate(upsample_factor)
return voxel_size

@property
def input_shape(self):
def input_shape(self) -> Coordinate:
"""
Return the input shape of the U-Net.
Expand All @@ -324,7 +325,7 @@ def input_shape(self):
Note:
The input shape should be given as a tuple ``(batch, channels, [length,] depth, height, width)``.
"""
return self._input_shape
return Coordinate(self._input_shape)

@property
def num_in_channels(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion dacapo/experiments/architectures/cnnectome_unet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,6 @@ class CNNectomeUNetConfig(ArchitectureConfig):
},
)
batch_norm: bool = attr.ib(
default=True,
default=False,
metadata={"help_text": "Whether to use batch normalization."},
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from funlib.persistence import Array

from typing import Iterable
import logging

logger = logging.getLogger(__name__)


class ThresholdPostProcessor(PostProcessor):
Expand Down Expand Up @@ -135,7 +138,7 @@ def process_block(block):
data = input_array[write_roi] > parameters.threshold
data = data.astype(np.uint8)
if int(data.max()) == 0:
print("No data in block", write_roi)
logger.debug("No data in block", write_roi)
return
output_array[write_roi] = data

Expand Down
14 changes: 9 additions & 5 deletions dacapo/experiments/tasks/predictors/distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ def create_distance_mask(
)
slices = tmp.ndim * (slice(1, -1),)
tmp[slices] = channel_mask
sampling = tuple(float(v) / 2 for v in voxel_size)
sampling = sampling[-len(tmp.shape) :]
boundary_distance = distance_transform_edt(
tmp,
sampling=voxel_size,
sampling=sampling,
)
if self.epsilon is None:
add = 0
Expand Down Expand Up @@ -315,13 +317,15 @@ def process(
distances = np.ones(channel.shape, dtype=np.float32) * max_distance
else:
# get distances (voxel_size/2 because image is doubled)
distances = distance_transform_edt(
boundaries, sampling=tuple(float(v) / 2 for v in voxel_size)
)
sampling = tuple(float(v) / 2 for v in voxel_size)
# fixing the sampling for 2D images
if len(boundaries.shape) < len(sampling):
sampling = sampling[-len(boundaries.shape) :]
distances = distance_transform_edt(boundaries, sampling=sampling)
distances = distances.astype(np.float32)

# restore original shape
downsample = (slice(None, None, 2),) * len(voxel_size)
downsample = (slice(None, None, 2),) * distances.ndim
distances = distances[downsample]

# todo: inverted distance
Expand Down
6 changes: 2 additions & 4 deletions dacapo/experiments/tasks/predictors/dummy_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ def create_target(self, gt):
# zeros
return np_to_funlib_array(
np.zeros((self.embedding_dims,) + gt.data.shape[-gt.dims :]),
gt.roi,
gt.roi.offset,
gt.voxel_size,
["c^"] + gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand All @@ -96,9 +95,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
np.ones(target.data.shape),
target.roi,
target.roi.offset,
target.voxel_size,
target.axis_names,
),
None,
)
Expand Down
8 changes: 3 additions & 5 deletions dacapo/experiments/tasks/predictors/hot_distance_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,11 @@ def create_target(self, gt):
Examples:
>>> target = predictor.create_target(gt)
"""
target = self.process(gt.data, gt.voxel_size, self.norm, self.dt_scale_factor)
target = self.process(gt[:], gt.voxel_size, self.norm, self.dt_scale_factor)
return np_to_funlib_array(
target,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand Down Expand Up @@ -209,9 +208,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
weights,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
),
moving_class_counts,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,8 @@ def create_target(self, gt):
)
return np_to_funlib_array(
distances,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
)

def create_weight(self, gt, target, mask, moving_class_counts=None):
Expand Down Expand Up @@ -155,9 +154,8 @@ def create_weight(self, gt, target, mask, moving_class_counts=None):
return (
np_to_funlib_array(
weights,
gt.roi,
gt.roi.offset,
gt.voxel_size,
gt.axis_names,
),
moving_class_counts,
)
Expand Down
2 changes: 2 additions & 0 deletions dacapo/experiments/trainers/gunpowder_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def build_batch_provider(self, datasets, model, task, snapshot_container=None):
assert isinstance(dataset.weight, int), dataset

raw_source = gp.ArraySource(raw_key, dataset.raw)
if dataset.raw.channel_dims == 0:
raw_source += gp.Unsqueeze([raw_key], axis=0)
if self.clip_raw:
raw_source += gp.Crop(
raw_key, dataset.gt.roi.snap_to_grid(dataset.raw.voxel_size)
Expand Down
6 changes: 6 additions & 0 deletions dacapo/predict_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def predict(
compute_context = create_compute_context()
device = compute_context.device

model_device = str(next(model.parameters()).device).split(":")[0]

assert model_device == str(
device
), f"Model is not on the right device, Model: {model_device}, Compute device: {device}"

def predict_fn(block):
raw_input = raw_array.to_ndarray(block.read_roi)

Expand Down
7 changes: 3 additions & 4 deletions dacapo/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# validation_dataset.name,
# criterion,
# )
dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
)
except:
logger.error(
f"Could not evaluate run {run.name} on dataset {validation_dataset.name} with parameters {parameters}.",
Expand All @@ -257,10 +260,6 @@ def validate_run(run: Run, iteration: int, datasets_config=None):
# the evaluator
# array_store.remove(output_array_identifier)

dataset_iteration_scores.append(
[getattr(scores, criterion) for criterion in scores.criteria]
)

iteration_scores.append(dataset_iteration_scores)
# array_store.remove(prediction_array_identifier)

Expand Down
25 changes: 21 additions & 4 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,28 @@
from .db import options
from .architectures import dummy_architecture
from .architectures import (
dummy_architecture,
unet_architecture,
unet_3d_architecture,
unet_architecture_builder,
)
from .arrays import dummy_array, zarr_array, cellmap_array
from .datasplits import dummy_datasplit, twelve_class_datasplit, six_class_datasplit
from .datasplits import (
dummy_datasplit,
twelve_class_datasplit,
six_class_datasplit,
upsample_six_class_datasplit,
)
from .evaluators import binary_3_channel_evaluator
from .losses import dummy_loss
from .post_processors import argmax, threshold
from .predictors import distance_predictor, onehot_predictor
from .runs import dummy_run, distance_run, onehot_run
from .tasks import dummy_task, distance_task, onehot_task
from .runs import (
dummy_run,
distance_run,
onehot_run,
unet_2d_distance_run,
unet_3d_distance_run,
hot_distance_run,
)
from .tasks import dummy_task, distance_task, onehot_task, six_onehot_task, hot_distance_task
from .trainers import dummy_trainer, gunpowder_trainer
79 changes: 78 additions & 1 deletion tests/fixtures/architectures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dacapo.experiments.architectures import DummyArchitectureConfig
from dacapo.experiments.architectures import (
DummyArchitectureConfig,
CNNectomeUNetConfig,
)

import pytest

Expand All @@ -8,3 +11,77 @@ def dummy_architecture():
yield DummyArchitectureConfig(
name="dummy_architecture", num_in_channels=1, num_out_channels=12
)


@pytest.fixture()
def unet_architecture():
yield CNNectomeUNetConfig(
name="tmp_unet_architecture",
input_shape=(132, 132),
eval_shape_increase=(32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(4, 4), (4, 4)],
kernel_size_down=[[(3, 3)] * 2] * 3,
kernel_size_up=[[(3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
)


@pytest.fixture()
def unet_3d_architecture():
yield CNNectomeUNetConfig(
name="tmp_unet_3d_architecture",
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
)


def unet_architecture_builder(batch_norm, upsample, use_attention, three_d):
name = "3d_unet" if three_d else "2d_unet"
name = f"{name}_bn" if batch_norm else name
name = f"{name}_up" if upsample else name
name = f"{name}_att" if use_attention else name

if three_d:
return CNNectomeUNetConfig(
name=name,
input_shape=(188, 188, 188),
eval_shape_increase=(72, 72, 72),
fmaps_in=1,
num_fmaps=6,
fmaps_out=6,
fmap_inc_factor=2,
downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
constant_upsample=True,
upsample_factors=[(2, 2, 2)] if upsample else [],
batch_norm=batch_norm,
use_attention=use_attention,
)
else:
return CNNectomeUNetConfig(
name=name,
input_shape=(132, 132),
eval_shape_increase=(32, 32),
fmaps_in=1,
num_fmaps=8,
fmaps_out=8,
fmap_inc_factor=2,
downsample_factors=[(4, 4), (4, 4)],
kernel_size_down=[[(3, 3)] * 2] * 3,
kernel_size_up=[[(3, 3)] * 2] * 2,
constant_upsample=True,
padding="valid",
batch_norm=batch_norm,
use_attention=use_attention,
upsample_factors=[(2, 2)] if upsample else [],
)
Loading

0 comments on commit 97abfa0

Please sign in to comment.