Skip to content

Commit

Permalink
A commit with all the changes now as signed commit. Also, removed the…
Browse files Browse the repository at this point in the history
… unit tests for segformer.
  • Loading branch information
Yael-Baron committed Aug 10, 2023
1 parent fb5b620 commit 27198e6
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 36 deletions.
5 changes: 1 addition & 4 deletions src/super_gradients/training/utils/segmentation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ def to_one_hot(target: torch.Tensor, num_classes: int, ignore_index: int = None)
:param target: Class labels long tensor, with shape [N, H, W]
:param num_classes: num of classes in datasets excluding ignore label, this is the output channels of the one hot
result.
:param ignore_index: the index of the class in the dataset to ignore
:return: one hot tensor with shape [N, num_classes, H, W]
Parameters
----------
ignore_index
"""
num_classes = num_classes if ignore_index is None else num_classes + 1

Expand Down
34 changes: 2 additions & 32 deletions tests/unit_tests/pretrained_models_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training import Trainer
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader, segmentation_test_dataloader
from super_gradients.training.metrics import Accuracy, IoU
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
from super_gradients.training.metrics import Accuracy
import os
import shutil

Expand All @@ -29,36 +29,6 @@ def test_pretrained_repvgg_a0_imagenet(self):
model = models.get(Models.REPVGG_A0, pretrained_weights="imagenet", arch_params={"build_residual_branches": True})
trainer.test(model=model, test_loader=classification_test_dataloader(), test_metrics_list=[Accuracy()], metrics_progress_verbose=True)

def test_pretrained_segformer_b0_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b0_unit_test")
model = models.get(Models.SEGFORMER_B0, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def test_pretrained_segformer_b1_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b1_unit_test")
model = models.get(Models.SEGFORMER_B1, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def test_pretrained_segformer_b2_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b2_unit_test")
model = models.get(Models.SEGFORMER_B2, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def test_pretrained_segformer_b3_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b3_unit_test")
model = models.get(Models.SEGFORMER_B3, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def test_pretrained_segformer_b4_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b4_unit_test")
model = models.get(Models.SEGFORMER_B4, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def test_pretrained_segformer_b5_cityscapes(self):
trainer = Trainer("cityscapes_pretrained_segformer_b5_unit_test")
model = models.get(Models.SEGFORMER_B5, pretrained_weights="cityscapes")
trainer.test(model=model, test_loader=segmentation_test_dataloader(), test_metrics_list=[IoU(num_classes=20)], metrics_progress_verbose=True)

def tearDown(self) -> None:
if os.path.exists("~/.cache/torch/hub/"):
shutil.rmtree("~/.cache/torch/hub/")
Expand Down

0 comments on commit 27198e6

Please sign in to comment.