Skip to content

Commit 30ea730

Browse files
committed
Modify tiliing unit test
1 parent d0640a8 commit 30ea730

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

tests/unit/algorithms/detection/tiling/test_tiling_detection.py

+1
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def test_load_tiling_parameters(self, tmp_dir_path):
225225
output_model = ModelEntity(self.otx_dataset, task_env.get_model_configuration())
226226
task = MMDetectionTask(task_env, output_path=str(tmp_dir_path))
227227
model_ckpt = os.path.join(tmp_dir_path, "maskrcnn.pth")
228+
task._init_task()
228229
torch.save(detector.state_dict(), model_ckpt)
229230
task._model_ckpt = model_ckpt
230231
task.save_model(output_model)

tests/unit/algorithms/detection/tiling/test_tiling_tile_classifier.py

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def test_load_tile_classifier_parameters(self, tmp_dir_path):
129129
output_model = ModelEntity(self.dataset, task_env.get_model_configuration())
130130
task = MMDetectionTask(task_env, output_path=str(tmp_dir_path))
131131
task._model_ckpt = model_ckpt
132+
task._init_task()
132133
task.save_model(output_model)
133134
for filename, model_adapter in output_model.model_adapters.items():
134135
with open(os.path.join(tmp_dir_path, filename), "wb") as write_file:
@@ -160,6 +161,7 @@ def test_load_tile_classifier_parameters(self, tmp_dir_path):
160161
output_model = ModelEntity(self.dataset, task_env.get_model_configuration())
161162
task = MMDetectionTask(task_env, output_path=str(tmp_dir_path))
162163
task._model_ckpt = tile_classifier_ckpt
164+
task._init_task()
163165
task.save_model(output_model)
164166
for filename, model_adapter in output_model.model_adapters.items():
165167
with open(os.path.join(tmp_dir_path, filename), "wb") as write_file:

0 commit comments

Comments
 (0)