Skip to content

Commit b6d22b6

Browse files
committed
Enable to easily load pretrained SAM weights
1 parent f98b321 commit b6d22b6

File tree

3 files changed

+42
-36
lines changed

3 files changed

+42
-36
lines changed

otx/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/segment_anything.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,15 @@ def replace_state_dict_keys(state_dict, revise_keys):
162162
# state_dict from args.load_from
163163
state_dict = replace_state_dict_keys(state_dict, revise_keys)
164164
self.load_state_dict(state_dict)
165-
elif self.config.model.checkpoint:
165+
elif self.config.model.checkpoint == "pretrained" or self.config.model.checkpoint is None:
166+
# load SAM pretrained weights
167+
state_dict = torch.hub.load_state_dict_from_url(CKPT_PATHS[self.config.model.backbone])
168+
state_dict = replace_state_dict_keys(state_dict, revise_keys)
169+
self.load_state_dict(state_dict)
170+
else:
171+
# load custom weights
166172
try:
173+
# load checkpoint trained by pytorch lightning
167174
self.load_from_checkpoint(self.config.model.checkpoint)
168175
except Exception:
169176
if str(self.config.model.checkpoint).startswith("http"):

otx/algorithms/visual_prompting/configs/sam_vit_b/config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model:
1717
freeze_image_encoder: true
1818
freeze_prompt_encoder: true
1919
freeze_mask_decoder: false
20-
checkpoint: https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
20+
checkpoint: pretrained
2121

2222
optimizer:
2323
name: Adam

tests/unit/algorithms/visual_prompting/adapters/pytorch_lightning/models/visual_prompters/test_segment_anything.py

+33-34
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,14 @@ def forward(self, *args, **kwargs):
5151

5252

5353
class TestSegmentAnything:
54+
@pytest.fixture
55+
def mocker_load_state_dict(self, mocker) -> None:
56+
return mocker.patch(
57+
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict"
58+
)
59+
5460
@pytest.fixture(autouse=True)
55-
def setup(self, monkeypatch) -> None:
61+
def setup(self, mocker, monkeypatch) -> None:
5662
monkeypatch.setattr(
5763
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMImageEncoder",
5864
MockImageEncoder,
@@ -65,6 +71,12 @@ def setup(self, monkeypatch) -> None:
6571
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMMaskDecoder",
6672
MockMaskDecoder,
6773
)
74+
monkeypatch.setattr("torch.hub.load_state_dict_from_url", lambda *args, **kwargs: OrderedDict())
75+
monkeypatch.setattr("torch.load", lambda *args, **kwargs: None)
76+
77+
self.mocker_load_from_checkpoint = mocker.patch(
78+
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint"
79+
)
6880

6981
self.base_config = DictConfig(
7082
dict(
@@ -84,7 +96,7 @@ def setup(self, monkeypatch) -> None:
8496

8597
@e2e_pytest_unit
8698
@pytest.mark.parametrize("backbone", ["vit_b", "resnet"])
87-
def test_set_models(self, mocker, backbone: str) -> None:
99+
def test_set_models(self, mocker, mocker_load_state_dict, backbone: str) -> None:
88100
"""Test set_models."""
89101
mocker.patch(
90102
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks"
@@ -114,7 +126,7 @@ def test_set_models(self, mocker, backbone: str) -> None:
114126
@pytest.mark.parametrize("freeze_prompt_encoder", [True, False])
115127
@pytest.mark.parametrize("freeze_mask_decoder", [True, False])
116128
def test_freeze_networks(
117-
self, mocker, freeze_image_encoder: bool, freeze_prompt_encoder: bool, freeze_mask_decoder: bool
129+
self, mocker, mocker_load_state_dict, freeze_image_encoder: bool, freeze_prompt_encoder: bool, freeze_mask_decoder: bool
118130
):
119131
"""Test freeze_networks."""
120132
mocker.patch(
@@ -154,7 +166,7 @@ def test_freeze_networks(
154166

155167
@e2e_pytest_unit
156168
@pytest.mark.parametrize("loss_type", ["sam", "medsam"])
157-
def test_set_metrics(self, mocker, loss_type: str):
169+
def test_set_metrics(self, mocker, mocker_load_state_dict, loss_type: str):
158170
"""Test set_metrics."""
159171
mocker.patch(
160172
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_models"
@@ -231,8 +243,8 @@ def test_load_checkpoint_with_state_dict(self, mocker, is_backbone_arg: bool, st
231243
assert v == sam_state_dict[k]
232244

233245
@e2e_pytest_unit
234-
@pytest.mark.parametrize("checkpoint", [None, "checkpoint", "http://checkpoint"])
235-
def test_load_checkpoint(self, mocker, monkeypatch, checkpoint: str):
246+
@pytest.mark.parametrize("checkpoint", [None, "pretrained", "checkpoint.pth", "http://checkpoint"])
247+
def test_load_checkpoint(self, mocker, mocker_load_state_dict, checkpoint: str):
236248
"""Test load_checkpoint."""
237249
mocker.patch(
238250
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks"
@@ -241,21 +253,8 @@ def test_load_checkpoint(self, mocker, monkeypatch, checkpoint: str):
241253
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_metrics"
242254
)
243255
if checkpoint is not None:
244-
monkeypatch.setattr("torch.hub.load_state_dict_from_url", lambda *args, **kwargs: OrderedDict())
245-
monkeypatch.setattr("torch.load", lambda *args, **kwargs: None)
246-
247-
mocker_load_state_dict = mocker.patch(
248-
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict"
249-
)
250256
if checkpoint.startswith("http"):
251-
mocker_load_from_checkpoint = mocker.patch(
252-
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint",
253-
side_effect=ValueError(),
254-
)
255-
else:
256-
mocker_load_from_checkpoint = mocker.patch(
257-
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint"
258-
)
257+
self.mocker_load_from_checkpoint.side_effect = ValueError()
259258

260259
config = self.base_config.copy()
261260
config.model.update(dict(checkpoint=checkpoint))
@@ -264,19 +263,19 @@ def test_load_checkpoint(self, mocker, monkeypatch, checkpoint: str):
264263

265264
if checkpoint is None:
266265
assert True
267-
elif checkpoint.startswith("http"):
266+
elif checkpoint.startswith("http") or checkpoint == "pretrained":
268267
mocker_load_state_dict.assert_called_once()
269268
else:
270-
mocker_load_from_checkpoint.assert_called_once()
269+
self.mocker_load_from_checkpoint.assert_called_once()
271270

272271
@e2e_pytest_unit
273-
def test_forward(self) -> None:
274-
"""Test forward."""
272+
def test_forward_train(self, mocker_load_state_dict) -> None:
273+
"""Test forward_train."""
275274
sam = SegmentAnything(config=self.base_config)
276275
images = torch.zeros((1))
277276
bboxes = torch.zeros((1))
278277

279-
results = sam.forward(images=images, bboxes=bboxes, points=None)
278+
results = sam.forward_train(images=images, bboxes=bboxes, points=None)
280279
pred_masks, ious = results
281280

282281
assert len(bboxes) == len(pred_masks) == len(ious)
@@ -285,7 +284,7 @@ def test_forward(self) -> None:
285284
@pytest.mark.parametrize(
286285
"loss_type,expected", [("sam", torch.tensor(2.4290099144)), ("medsam", torch.tensor(0.9650863409))]
287286
)
288-
def test_training_step(self, mocker, loss_type: str, expected: Tensor) -> None:
287+
def test_training_step(self, mocker, mocker_load_state_dict, loss_type: str, expected: Tensor) -> None:
289288
"""Test training_step."""
290289
mocker.patch(
291290
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward",
@@ -311,7 +310,7 @@ def test_training_step(self, mocker, loss_type: str, expected: Tensor) -> None:
311310
assert torch.equal(results, expected)
312311

313312
@e2e_pytest_unit
314-
def test_training_epoch_end(self) -> None:
313+
def test_training_epoch_end(self, mocker_load_state_dict) -> None:
315314
"""Test training_epoch_end."""
316315
sam = SegmentAnything(config=self.base_config)
317316
for k, v in sam.train_metrics.items():
@@ -328,7 +327,7 @@ def test_training_epoch_end(self) -> None:
328327
assert sam.train_metrics["train_loss_iou"].compute().isnan()
329328

330329
@e2e_pytest_unit
331-
def test_validation_step(self, mocker) -> None:
330+
def test_validation_step(self, mocker, mocker_load_state_dict) -> None:
332331
"""Test validation_step."""
333332
mocker.patch(
334333
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward",
@@ -357,7 +356,7 @@ def test_validation_step(self, mocker) -> None:
357356
assert torch.equal(results["val_IoU"].compute(), torch.tensor(1.0))
358357

359358
@e2e_pytest_unit
360-
def test_validation_epoch_end(self) -> None:
359+
def test_validation_epoch_end(self, mocker_load_state_dict) -> None:
361360
"""Test validation_epoch_end."""
362361
sam = SegmentAnything(config=self.base_config)
363362
for k, v in sam.val_metrics.items():
@@ -377,7 +376,7 @@ def test_validation_epoch_end(self) -> None:
377376
(False, torch.Tensor([[False for _ in range(4)] for _ in range(4)])),
378377
],
379378
)
380-
def test_predict_step(self, mocker, return_logits: bool, expected: Tensor) -> None:
379+
def test_predict_step(self, mocker, mocker_load_state_dict, return_logits: bool, expected: Tensor) -> None:
381380
"""Test predict_step."""
382381
mocker.patch(
383382
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward",
@@ -414,7 +413,7 @@ def test_predict_step(self, mocker, return_logits: bool, expected: Tensor) -> No
414413
],
415414
)
416415
def test_postprocess_masks(
417-
self, input_size: Tuple[int], original_size: Tuple[int], is_predict: bool, expected: Tuple[int]
416+
self, mocker_load_state_dict, input_size: Tuple[int], original_size: Tuple[int], is_predict: bool, expected: Tuple[int]
418417
) -> None:
419418
"""Test postprocess_masks."""
420419
sam = SegmentAnything(config=self.base_config)
@@ -433,7 +432,7 @@ def test_postprocess_masks(
433432
(Tensor([[0, 0, 0.3, 0.3, 0, 0]]), Tensor([[0, 0, 1, 1, 0, 0]]), Tensor([0.3888888359])),
434433
],
435434
)
436-
def test_calculate_dice_loss(self, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
435+
def test_calculate_dice_loss(self, mocker_load_state_dict, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
437436
"""Test calculate_dice_loss."""
438437
sam = SegmentAnything(config=self.base_config)
439438

@@ -450,7 +449,7 @@ def test_calculate_dice_loss(self, inputs: Tensor, targets: Tensor, expected: Te
450449
(Tensor([[0, 0, 0.3, 0.3, 0, 0]]), Tensor([[0, 0, 1, 1, 0, 0]]), Tensor([0.0226361733])),
451450
],
452451
)
453-
def test_calculate_sigmoid_ce_focal_loss(self, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
452+
def test_calculate_sigmoid_ce_focal_loss(self, mocker_load_state_dict, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
454453
"""Test calculate_sigmoid_ce_focal_loss."""
455454
sam = SegmentAnything(config=self.base_config)
456455

@@ -467,7 +466,7 @@ def test_calculate_sigmoid_ce_focal_loss(self, inputs: Tensor, targets: Tensor,
467466
(Tensor([[0, 0, 0.3, 0.3, 0, 0]]), Tensor([[0, 0, 1, 1, 0, 0]]), Tensor([0.0])),
468467
],
469468
)
470-
def test_calculate_iou(self, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
469+
def test_calculate_iou(self, mocker_load_state_dict, inputs: Tensor, targets: Tensor, expected: Tensor) -> None:
471470
"""Test calculate_iou."""
472471
sam = SegmentAnything(config=self.base_config)
473472

0 commit comments

Comments
 (0)