Skip to content

Commit

Permalink
Gem Mobilenet-v2 pretrained backbone (#226)
Browse files Browse the repository at this point in the history
* Add pretrained mobilenetv2 backbone

* Update gem.md

Co-authored-by: Jelle Luijkx <j.d.luijkx@tudelft.nl>
  • Loading branch information
jelledouwe and jelledouwe authored Feb 21, 2022
1 parent 822dbc7 commit fe91adf
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 84 deletions.
2 changes: 1 addition & 1 deletion docs/reference/gem.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ Parameters:
Valid values are: "weights_detr", "pretrained_detr", "pretrained_gem", "test_data_l515" and "test_data_sample_images".
In case of "weights_detr", the weigths for single modal DETR with *resnet50* backbone are downloaded.
In case of "pretrained_detr", the weigths for single modal pretrained DETR with *resnet50* backbone are downloaded.
In case of "pretrained_gem", the weights from *'gem_scavg_e294_mAP0983_rn50_l515_7cls.pth'* (backbone: *'resnet50'*, fusion_method: *'scalar averaged'*, trained on *RGB-Infrared l515_dataset* are downloaded.
In case of "pretrained_gem", the weights (backbone: *'resnet50' or 'mobilenetv2'*, fusion_method: *'scalar averaged'*, trained on *RGB-Infrared l515_dataset*) are downloaded.
In case of "test_data_l515", the *RGB-Infrared l515* dataset is downloaded from the OpenDR server.
In case of "test_data_sample images", two sample images for testing the *infer* function are downloaded.
- **verbose** : *bool, default=False*
Expand Down
168 changes: 85 additions & 83 deletions tests/sources/tools/perception/object_detection_2d/gem/test_gem.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,43 +44,42 @@ def rmdir(_dir):


class TestGemLearner(unittest.TestCase):
temp_dir = os.path.join("tests",
"sources",
"tools",
"perception",
"object_detection_2d",
"gem",
"gem_temp",
)
dataset_location = os.path.join(temp_dir, 'sample_dataset')
learners = {}
model_backbones = ["resnet50", "mobilenetv2"]

@classmethod
def setUpClass(cls):
print("\n\n*********************************\nTEST Object Detection GEM Learner\n"
"*********************************")
cls.temp_dir = os.path.join("tests", "sources", "tools",
"perception", "object_detection_2d",
"gem", "gem_temp")

cls.model_backbone = "resnet50"
for backbone in cls.model_backbones:
cls.learners[backbone] = GemLearner(iters=1,
temp_path=cls.temp_dir,
backbone=backbone,
num_classes=7,
device=DEVICE,
)

cls.learner = GemLearner(iters=1,
temp_path=cls.temp_dir,
backbone=cls.model_backbone,
num_classes=7,
device=DEVICE,
)

cls.learner.download(mode='pretrained_gem')
for learner in cls.learners.values():
learner.download(mode='pretrained_gem')

print("Model downloaded", file=sys.stderr)

cls.learner.download(mode='test_data_sample_dataset')
cls.learners['resnet50'].download(mode='test_data_sample_dataset')

cls.learner.download(mode='test_data_sample_images')
cls.learners['resnet50'].download(mode='test_data_sample_images')

print("Data downloaded", file=sys.stderr)
cls.dataset_location = os.path.join(cls.temp_dir,
'sample_dataset',
)
cls.m1_dataset = ExternalDataset(
cls.dataset_location,
"coco",
)
cls.m2_dataset = ExternalDataset(
cls.dataset_location,
"coco",
)
cls.m1_dataset = ExternalDataset(cls.dataset_location, "coco")
cls.m2_dataset = ExternalDataset(cls.dataset_location, "coco")

@classmethod
def tearDownClass(cls):
Expand All @@ -99,35 +98,36 @@ def test_fit(self):
# version)
warnings.simplefilter("ignore", ResourceWarning)
warnings.simplefilter("ignore", DeprecationWarning)
self.learner.model = None
self.learner.ort_session = None

self.learner.download(mode='pretrained_gem')

m = list(self.learner.model.parameters())[0].clone()

self.learner.fit(
m1_train_edataset=self.m1_dataset,
m2_train_edataset=self.m2_dataset,
annotations_folder='annotations',
m1_train_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_train_annotations_file='Thermal_26May2021_14h19m_coco.json',
m1_train_images_folder='train/m1',
m2_train_images_folder='train/m2',
out_dir=os.path.join(self.temp_dir, "outputs"),
trial_dir=os.path.join(self.temp_dir, "trial"),
logging_path='',
verbose=False,
m1_val_edataset=self.m1_dataset,
m2_val_edataset=self.m2_dataset,
m1_val_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_val_annotations_file='Thermal_26May2021_14h19m_coco.json',
m1_val_images_folder='val/m1',
m2_val_images_folder='val/m2',
)

self.assertFalse(torch.equal(m, list(self.learner.model.parameters())[0]),
msg="Model parameters did not change after running fit.")
for backbone in self.model_backbones:
self.learners[backbone].model = None
self.learners[backbone].ort_session = None

self.learners[backbone].download(mode='pretrained_gem')

m = list(self.learners[backbone].model.parameters())[0].clone()

self.learners[backbone].fit(m1_train_edataset=self.m1_dataset,
m2_train_edataset=self.m2_dataset,
annotations_folder='annotations',
m1_train_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_train_annotations_file='Thermal_26May2021_14h19m_coco.json',
m1_train_images_folder='train/m1',
m2_train_images_folder='train/m2',
out_dir=os.path.join(self.temp_dir, "outputs"),
trial_dir=os.path.join(self.temp_dir, "trial"),
logging_path='',
verbose=False,
m1_val_edataset=self.m1_dataset,
m2_val_edataset=self.m2_dataset,
m1_val_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_val_annotations_file='Thermal_26May2021_14h19m_coco.json',
m1_val_images_folder='val/m1',
m2_val_images_folder='val/m2',
)

self.assertFalse(torch.equal(m, list(self.learners[backbone].model.parameters())[0]),
msg="Model parameters did not change after running fit.")

# Cleanup
warnings.simplefilter("default", ResourceWarning)
Expand All @@ -139,58 +139,60 @@ def test_eval(self):
# version)
warnings.simplefilter("ignore", ResourceWarning)
warnings.simplefilter("ignore", DeprecationWarning)
self.learner.model = None
self.learner.ort_session = None

self.learner.download(mode='pretrained_gem')

result = self.learner.eval(
m1_edataset=self.m1_dataset,
m2_edataset=self.m2_dataset,
m1_images_folder='val/m1',
m2_images_folder='val/m2',
annotations_folder='annotations',
m1_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_annotations_file='Thermal_26May2021_14h19m_coco.json',
verbose=False,
)

self.assertGreater(len(result), 0)
for backbone in self.model_backbones:
self.learners[backbone].model = None
self.learners[backbone].ort_session = None

self.learners[backbone].download(mode='pretrained_gem')

result = self.learners[backbone].eval(
m1_edataset=self.m1_dataset,
m2_edataset=self.m2_dataset,
m1_images_folder='val/m1',
m2_images_folder='val/m2',
annotations_folder='annotations',
m1_annotations_file='RGB_26May2021_14h19m_coco.json',
m2_annotations_file='Thermal_26May2021_14h19m_coco.json',
verbose=False,
)

self.assertGreater(len(result), 0)

# Cleanup
warnings.simplefilter("default", ResourceWarning)
warnings.simplefilter("default", DeprecationWarning)

def test_infer(self):
self.learner.model = None
self.learner.ort_session = None

self.learner.download(mode='pretrained_gem')

m1_image = Image.open(os.path.join(self.temp_dir, "sample_images/rgb/2021_04_22_21_35_47_852516.jpg"))
m2_image = Image.open(os.path.join(self.temp_dir, 'sample_images/aligned_infra/2021_04_22_21_35_47_852516.jpg'))

result, _, _ = self.learner.infer(m1_image, m2_image)

self.assertGreater(len(result), 0)
for backbone in self.model_backbones:
self.learners[backbone].model = None
self.learners[backbone].ort_session = None
self.learners[backbone].download(mode='pretrained_gem')
result, _, _ = self.learners[backbone].infer(m1_image, m2_image)
self.assertGreater(len(result), 0)

def test_save(self):
self.learner.model = None
self.learner.ort_session = None
backbone = 'resnet50'
self.learners[backbone].model = None
self.learners[backbone].ort_session = None

model_dir = os.path.join(self.temp_dir, "test_model")

self.learner.download(mode='pretrained_detr')
self.learners[backbone].download(mode='pretrained_detr')

self.learner.save(model_dir)
self.learners[backbone].save(model_dir)

starting_param_1 = list(self.learner.model.parameters())[0].clone()
starting_param_1 = list(self.learners[backbone].model.parameters())[0].clone()

learner2 = GemLearner(
iters=1,
temp_path=self.temp_dir,
device=DEVICE,
num_classes=7,
backbone=backbone,
)
learner2.load(model_dir)

Expand Down

0 comments on commit fe91adf

Please sign in to comment.