Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a network CellSamWrapper #7981

Merged
merged 13 commits into from
Aug 10, 2024
92 changes: 92 additions & 0 deletions monai/networks/nets/cell_sam_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
from torch import nn
from torch.nn import functional as F

from monai.utils import optional_import

build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

_all__ = ["CellSamWrapper"]


class CellSamWrapper(torch.nn.Module):
"""
CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
with an image only decoder, that can be used for segmentation tasks.


Args:
auto_resize_inputs: whether to resize inputs before passing to the network.
(usually they need be resized, unless they are already at the expected size)
network_resize_roi: expected input size for the network.
(currently SAM expects 1024x1024)
checkpoint: checkpoint file to load the SAM weights from.
(this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
return_features: whether to return features from SAM encoder
(without using decoder/upsampling to the original input size)

"""

def __init__(
self,
auto_resize_inputs=True,
network_resize_roi=(1024, 1024),
checkpoint="sam_vit_b_01ec64.pth",
myron marked this conversation as resolved.
Show resolved Hide resolved
return_features=False,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)

self.network_resize_roi = network_resize_roi
self.auto_resize_inputs = auto_resize_inputs
self.return_features = return_features

if not has_sam:
raise ValueError(
"SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
)

model = build_sam_vit_b(checkpoint=checkpoint)

model.prompt_encoder = None
model.mask_decoder = None

model.mask_decoder = nn.Sequential(
nn.BatchNorm2d(num_features=256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
myron marked this conversation as resolved.
Show resolved Hide resolved
nn.BatchNorm2d(num_features=128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
)

self.model = model

def forward(self, x):
sh = x.shape[2:]

if self.auto_resize_inputs:
x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")
myron marked this conversation as resolved.
Show resolved Hide resolved

x = self.model.image_encoder(x)

if not self.return_features:
myron marked this conversation as resolved.
Show resolved Hide resolved
x = self.model.mask_decoder(x)
if self.auto_resize_inputs:
x = F.interpolate(x, size=sh, mode="bilinear")

return x
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ nvidia-ml-py
huggingface_hub
pyamg>=5.0.0
git+https://github.com/Project-MONAI/GenerativeModels.git@7428fce193771e9564f29b91d29e523dd1b6b4cd
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
58 changes: 58 additions & 0 deletions tests/test_cell_sam_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets.cell_sam_wrapper import CellSamWrapper
from monai.utils import optional_import

build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")

device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_CASE_CELLSEGWRAPPER = []
for dims in [128, 256, 512, 1024]:
test_case = [
{"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None},
(1, 3, *([dims] * 2)),
(1, 3, *([dims] * 2)),
]
TEST_CASE_CELLSEGWRAPPER.append(test_case)


@unittest.skipUnless(has_sam, "Requires SAM installation")
class TestResNetDS(unittest.TestCase):

@parameterized.expand(TEST_CASE_CELLSEGWRAPPER)
def test_shape(self, input_param, input_shape, expected_shape):
net = CellSamWrapper(**input_param).to(device)
with eval_mode(net):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape, msg=str(input_param))

def test_ill_arg0(self):
with self.assertRaises(RuntimeError):
net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device)
net(torch.randn([1, 3, 256, 256]).to(device))

def test_ill_arg1(self):
with self.assertRaises(RuntimeError):
net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device)
net(torch.randn([1, 3, 1024, 1024]).to(device))


if __name__ == "__main__":
unittest.main()
Loading