Skip to content

Commit d9ad2d7

Browse files
[Anomaly Task] 🐞 Fix inference when model backbone changes (#1242)
* Add check for change in model backbone * Add load model to train * Limit padim to only resnet18 * Fix comment
1 parent 5f622b4 commit d9ad2d7

File tree

6 files changed

+60
-8
lines changed

6 files changed

+60
-8
lines changed

external/anomaly/configs/base/padim/configuration.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,14 @@ class LearningParameters(BaseAnomalyConfig.LearningParameters):
3838
header = string_attribute("Learning Parameters")
3939
description = header
4040

41+
# Editable is set to false as WideResNet50 is very large for
42+
# onnx's protobuf (2gb) limit. This ends up crashing the export.
4143
backbone = selectable(
4244
default_value=ModelBackbone.RESNET18,
4345
header="Model Backbone",
4446
description="Pre-trained backbone used for feature extraction",
47+
editable=False,
48+
visible_in_ui=False,
4549
)
4650

4751
learning_parameters = add_parameter_group(LearningParameters)

external/anomaly/configs/classification/padim/configuration.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ learning_parameters:
3535
auto_hpo_value: null
3636
default_value: resnet18
3737
description: Pre-trained backbone used for feature extraction
38-
editable: true
38+
editable: false
3939
enum_name: ModelBackbone
4040
header: Model Backbone
4141
options:
@@ -48,7 +48,7 @@ learning_parameters:
4848
rules: []
4949
type: UI_RULES
5050
value: resnet18
51-
visible_in_ui: true
51+
visible_in_ui: false
5252
warning: null
5353
description: Learning Parameters
5454
header: Learning Parameters

external/anomaly/configs/detection/padim/configuration.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ learning_parameters:
3535
auto_hpo_value: null
3636
default_value: resnet18
3737
description: Pre-trained backbone used for feature extraction
38-
editable: true
38+
editable: false
3939
enum_name: ModelBackbone
4040
header: Model Backbone
4141
options:
@@ -48,7 +48,7 @@ learning_parameters:
4848
rules: []
4949
type: UI_RULES
5050
value: resnet18
51-
visible_in_ui: true
51+
visible_in_ui: false
5252
warning: null
5353
description: Learning Parameters
5454
header: Learning Parameters

external/anomaly/configs/segmentation/padim/configuration.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ learning_parameters:
3535
auto_hpo_value: null
3636
default_value: resnet18
3737
description: Pre-trained backbone used for feature extraction
38-
editable: true
38+
editable: false
3939
enum_name: ModelBackbone
4040
header: Model Backbone
4141
options:
@@ -48,7 +48,7 @@ learning_parameters:
4848
rules: []
4949
type: UI_RULES
5050
value: resnet18
51-
visible_in_ui: true
51+
visible_in_ui: false
5252
warning: null
5353
description: Learning Parameters
5454
header: Learning Parameters

external/anomaly/tasks/inference.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
120120
AnomalyModule: Anomalib
121121
classification or segmentation model with/without weights.
122122
"""
123-
model = get_model(config=self.config)
124123
if ote_model is None:
124+
model = get_model(config=self.config)
125125
logger.info(
126126
"No trained model in project yet. Created new model with '%s'",
127127
self.model_name,
@@ -130,10 +130,16 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
130130
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
131131
model_data = torch.load(buffer, map_location=torch.device("cpu"))
132132

133+
if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]:
134+
logger.warning(
135+
"Backbone of the model in the Task Environment is different from the one in the template. "
136+
f"creating model with backbone={model_data['config']['model']['backbone']}"
137+
)
138+
self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"]
133139
try:
140+
model = get_model(config=self.config)
134141
model.load_state_dict(model_data["model"])
135142
logger.info("Loaded model weights from Task Environment")
136-
137143
except BaseException as exception:
138144
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception
139145

external/anomaly/tasks/train.py

+42
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17+
import io
1718
from typing import Optional
1819

20+
import torch
1921
from adapters.anomalib.callbacks import ProgressCallback
2022
from adapters.anomalib.data import OTEAnomalyDataModule
2123
from adapters.anomalib.logger import get_logger
24+
from anomalib.models import AnomalyModule, get_model
2225
from anomalib.utils.callbacks import (
2326
MetricsConfigurationCallback,
2427
MinMaxNormalizationCallback,
@@ -83,3 +86,42 @@ def train(
8386
self.save_model(output_model)
8487

8588
logger.info("Training completed.")
89+
90+
def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
91+
"""Create and Load Anomalib Module from OTE Model.
92+
93+
This method checks if the task environment has a saved OTE Model,
94+
and creates one. If the OTE model already exists, it returns the
95+
the model with the saved weights.
96+
97+
Args:
98+
ote_model (Optional[ModelEntity]): OTE Model from the
99+
task environment.
100+
101+
Returns:
102+
AnomalyModule: Anomalib
103+
classification or segmentation model with/without weights.
104+
"""
105+
model = get_model(config=self.config)
106+
if ote_model is None:
107+
logger.info(
108+
"No trained model in project yet. Created new model with '%s'",
109+
self.model_name,
110+
)
111+
else:
112+
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
113+
model_data = torch.load(buffer, map_location=torch.device("cpu"))
114+
115+
try:
116+
if model_data["config"]["model"]["backbone"] == self.config["model"]["backbone"]:
117+
model.load_state_dict(model_data["model"])
118+
logger.info("Loaded model weights from Task Environment")
119+
else:
120+
logger.info(
121+
"Model backbone does not match. Created new model with '%s'",
122+
self.model_name,
123+
)
124+
except BaseException as exception:
125+
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception
126+
127+
return model

0 commit comments

Comments
 (0)