@@ -51,8 +51,14 @@ def forward(self, *args, **kwargs):
51
51
52
52
53
53
class TestSegmentAnything :
54
+ @pytest .fixture
55
+ def mocker_load_state_dict (self , mocker ) -> None :
56
+ return mocker .patch (
57
+ "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict"
58
+ )
59
+
54
60
@pytest .fixture (autouse = True )
55
- def setup (self , monkeypatch ) -> None :
61
+ def setup (self , mocker , monkeypatch ) -> None :
56
62
monkeypatch .setattr (
57
63
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMImageEncoder" ,
58
64
MockImageEncoder ,
@@ -65,6 +71,12 @@ def setup(self, monkeypatch) -> None:
65
71
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SAMMaskDecoder" ,
66
72
MockMaskDecoder ,
67
73
)
74
+ monkeypatch .setattr ("torch.hub.load_state_dict_from_url" , lambda * args , ** kwargs : OrderedDict ())
75
+ monkeypatch .setattr ("torch.load" , lambda * args , ** kwargs : None )
76
+
77
+ self .mocker_load_from_checkpoint = mocker .patch (
78
+ "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint"
79
+ )
68
80
69
81
self .base_config = DictConfig (
70
82
dict (
@@ -84,7 +96,7 @@ def setup(self, monkeypatch) -> None:
84
96
85
97
@e2e_pytest_unit
86
98
@pytest .mark .parametrize ("backbone" , ["vit_b" , "resnet" ])
87
- def test_set_models (self , mocker , backbone : str ) -> None :
99
+ def test_set_models (self , mocker , mocker_load_state_dict , backbone : str ) -> None :
88
100
"""Test set_models."""
89
101
mocker .patch (
90
102
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks"
@@ -114,7 +126,7 @@ def test_set_models(self, mocker, backbone: str) -> None:
114
126
@pytest .mark .parametrize ("freeze_prompt_encoder" , [True , False ])
115
127
@pytest .mark .parametrize ("freeze_mask_decoder" , [True , False ])
116
128
def test_freeze_networks (
117
- self , mocker , freeze_image_encoder : bool , freeze_prompt_encoder : bool , freeze_mask_decoder : bool
129
+ self , mocker , mocker_load_state_dict , freeze_image_encoder : bool , freeze_prompt_encoder : bool , freeze_mask_decoder : bool
118
130
):
119
131
"""Test freeze_networks."""
120
132
mocker .patch (
@@ -154,7 +166,7 @@ def test_freeze_networks(
154
166
155
167
@e2e_pytest_unit
156
168
@pytest .mark .parametrize ("loss_type" , ["sam" , "medsam" ])
157
- def test_set_metrics (self , mocker , loss_type : str ):
169
+ def test_set_metrics (self , mocker , mocker_load_state_dict , loss_type : str ):
158
170
"""Test set_metrics."""
159
171
mocker .patch (
160
172
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_models"
@@ -231,8 +243,8 @@ def test_load_checkpoint_with_state_dict(self, mocker, is_backbone_arg: bool, st
231
243
assert v == sam_state_dict [k ]
232
244
233
245
@e2e_pytest_unit
234
- @pytest .mark .parametrize ("checkpoint" , [None , "checkpoint" , "http://checkpoint" ])
235
- def test_load_checkpoint (self , mocker , monkeypatch , checkpoint : str ):
246
+ @pytest .mark .parametrize ("checkpoint" , [None , "pretrained" , " checkpoint.pth " , "http://checkpoint" ])
247
+ def test_load_checkpoint (self , mocker , mocker_load_state_dict , checkpoint : str ):
236
248
"""Test load_checkpoint."""
237
249
mocker .patch (
238
250
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.freeze_networks"
@@ -241,21 +253,8 @@ def test_load_checkpoint(self, mocker, monkeypatch, checkpoint: str):
241
253
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.set_metrics"
242
254
)
243
255
if checkpoint is not None :
244
- monkeypatch .setattr ("torch.hub.load_state_dict_from_url" , lambda * args , ** kwargs : OrderedDict ())
245
- monkeypatch .setattr ("torch.load" , lambda * args , ** kwargs : None )
246
-
247
- mocker_load_state_dict = mocker .patch (
248
- "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_state_dict"
249
- )
250
256
if checkpoint .startswith ("http" ):
251
- mocker_load_from_checkpoint = mocker .patch (
252
- "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint" ,
253
- side_effect = ValueError (),
254
- )
255
- else :
256
- mocker_load_from_checkpoint = mocker .patch (
257
- "otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.load_from_checkpoint"
258
- )
257
+ self .mocker_load_from_checkpoint .side_effect = ValueError ()
259
258
260
259
config = self .base_config .copy ()
261
260
config .model .update (dict (checkpoint = checkpoint ))
@@ -264,19 +263,19 @@ def test_load_checkpoint(self, mocker, monkeypatch, checkpoint: str):
264
263
265
264
if checkpoint is None :
266
265
assert True
267
- elif checkpoint .startswith ("http" ):
266
+ elif checkpoint .startswith ("http" ) or checkpoint == "pretrained" :
268
267
mocker_load_state_dict .assert_called_once ()
269
268
else :
270
- mocker_load_from_checkpoint .assert_called_once ()
269
+ self . mocker_load_from_checkpoint .assert_called_once ()
271
270
272
271
@e2e_pytest_unit
273
- def test_forward (self ) -> None :
274
- """Test forward ."""
272
+ def test_forward_train (self , mocker_load_state_dict ) -> None :
273
+ """Test forward_train ."""
275
274
sam = SegmentAnything (config = self .base_config )
276
275
images = torch .zeros ((1 ))
277
276
bboxes = torch .zeros ((1 ))
278
277
279
- results = sam .forward (images = images , bboxes = bboxes , points = None )
278
+ results = sam .forward_train (images = images , bboxes = bboxes , points = None )
280
279
pred_masks , ious = results
281
280
282
281
assert len (bboxes ) == len (pred_masks ) == len (ious )
@@ -285,7 +284,7 @@ def test_forward(self) -> None:
285
284
@pytest .mark .parametrize (
286
285
"loss_type,expected" , [("sam" , torch .tensor (2.4290099144 )), ("medsam" , torch .tensor (0.9650863409 ))]
287
286
)
288
- def test_training_step (self , mocker , loss_type : str , expected : Tensor ) -> None :
287
+ def test_training_step (self , mocker , mocker_load_state_dict , loss_type : str , expected : Tensor ) -> None :
289
288
"""Test training_step."""
290
289
mocker .patch (
291
290
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward" ,
@@ -311,7 +310,7 @@ def test_training_step(self, mocker, loss_type: str, expected: Tensor) -> None:
311
310
assert torch .equal (results , expected )
312
311
313
312
@e2e_pytest_unit
314
- def test_training_epoch_end (self ) -> None :
313
+ def test_training_epoch_end (self , mocker_load_state_dict ) -> None :
315
314
"""Test training_epoch_end."""
316
315
sam = SegmentAnything (config = self .base_config )
317
316
for k , v in sam .train_metrics .items ():
@@ -328,7 +327,7 @@ def test_training_epoch_end(self) -> None:
328
327
assert sam .train_metrics ["train_loss_iou" ].compute ().isnan ()
329
328
330
329
@e2e_pytest_unit
331
- def test_validation_step (self , mocker ) -> None :
330
+ def test_validation_step (self , mocker , mocker_load_state_dict ) -> None :
332
331
"""Test validation_step."""
333
332
mocker .patch (
334
333
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward" ,
@@ -357,7 +356,7 @@ def test_validation_step(self, mocker) -> None:
357
356
assert torch .equal (results ["val_IoU" ].compute (), torch .tensor (1.0 ))
358
357
359
358
@e2e_pytest_unit
360
- def test_validation_epoch_end (self ) -> None :
359
+ def test_validation_epoch_end (self , mocker_load_state_dict ) -> None :
361
360
"""Test validation_epoch_end."""
362
361
sam = SegmentAnything (config = self .base_config )
363
362
for k , v in sam .val_metrics .items ():
@@ -377,7 +376,7 @@ def test_validation_epoch_end(self) -> None:
377
376
(False , torch .Tensor ([[False for _ in range (4 )] for _ in range (4 )])),
378
377
],
379
378
)
380
- def test_predict_step (self , mocker , return_logits : bool , expected : Tensor ) -> None :
379
+ def test_predict_step (self , mocker , mocker_load_state_dict , return_logits : bool , expected : Tensor ) -> None :
381
380
"""Test predict_step."""
382
381
mocker .patch (
383
382
"otx.algorithms.visual_prompting.adapters.pytorch_lightning.models.visual_prompters.segment_anything.SegmentAnything.forward" ,
@@ -414,7 +413,7 @@ def test_predict_step(self, mocker, return_logits: bool, expected: Tensor) -> No
414
413
],
415
414
)
416
415
def test_postprocess_masks (
417
- self , input_size : Tuple [int ], original_size : Tuple [int ], is_predict : bool , expected : Tuple [int ]
416
+ self , mocker_load_state_dict , input_size : Tuple [int ], original_size : Tuple [int ], is_predict : bool , expected : Tuple [int ]
418
417
) -> None :
419
418
"""Test postprocess_masks."""
420
419
sam = SegmentAnything (config = self .base_config )
@@ -433,7 +432,7 @@ def test_postprocess_masks(
433
432
(Tensor ([[0 , 0 , 0.3 , 0.3 , 0 , 0 ]]), Tensor ([[0 , 0 , 1 , 1 , 0 , 0 ]]), Tensor ([0.3888888359 ])),
434
433
],
435
434
)
436
- def test_calculate_dice_loss (self , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
435
+ def test_calculate_dice_loss (self , mocker_load_state_dict , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
437
436
"""Test calculate_dice_loss."""
438
437
sam = SegmentAnything (config = self .base_config )
439
438
@@ -450,7 +449,7 @@ def test_calculate_dice_loss(self, inputs: Tensor, targets: Tensor, expected: Te
450
449
(Tensor ([[0 , 0 , 0.3 , 0.3 , 0 , 0 ]]), Tensor ([[0 , 0 , 1 , 1 , 0 , 0 ]]), Tensor ([0.0226361733 ])),
451
450
],
452
451
)
453
- def test_calculate_sigmoid_ce_focal_loss (self , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
452
+ def test_calculate_sigmoid_ce_focal_loss (self , mocker_load_state_dict , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
454
453
"""Test calculate_sigmoid_ce_focal_loss."""
455
454
sam = SegmentAnything (config = self .base_config )
456
455
@@ -467,7 +466,7 @@ def test_calculate_sigmoid_ce_focal_loss(self, inputs: Tensor, targets: Tensor,
467
466
(Tensor ([[0 , 0 , 0.3 , 0.3 , 0 , 0 ]]), Tensor ([[0 , 0 , 1 , 1 , 0 , 0 ]]), Tensor ([0.0 ])),
468
467
],
469
468
)
470
- def test_calculate_iou (self , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
469
+ def test_calculate_iou (self , mocker_load_state_dict , inputs : Tensor , targets : Tensor , expected : Tensor ) -> None :
471
470
"""Test calculate_iou."""
472
471
sam = SegmentAnything (config = self .base_config )
473
472
0 commit comments