Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Anomaly Task] 🐞 Fix inference when model backbone changes #1242

Merged
merged 5 commits into from
Sep 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions external/anomaly/configs/base/padim/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ class LearningParameters(BaseAnomalyConfig.LearningParameters):
header = string_attribute("Learning Parameters")
description = header

# Editable is set to false as WideResNet50 is very large for
# onnx's protobuf (2gb) limit. This ends up crashing the export.
backbone = selectable(
default_value=ModelBackbone.RESNET18,
header="Model Backbone",
description="Pre-trained backbone used for feature extraction",
editable=False,
visible_in_ui=False,
)

learning_parameters = add_parameter_group(LearningParameters)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
4 changes: 2 additions & 2 deletions external/anomaly/configs/detection/padim/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ learning_parameters:
auto_hpo_value: null
default_value: resnet18
description: Pre-trained backbone used for feature extraction
editable: true
editable: false
enum_name: ModelBackbone
header: Model Backbone
options:
Expand All @@ -48,7 +48,7 @@ learning_parameters:
rules: []
type: UI_RULES
value: resnet18
visible_in_ui: true
visible_in_ui: false
warning: null
description: Learning Parameters
header: Learning Parameters
Expand Down
10 changes: 8 additions & 2 deletions external/anomaly/tasks/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if ote_model is None:
model = get_model(config=self.config)
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
Expand All @@ -130,10 +130,16 @@ def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

if model_data["config"]["model"]["backbone"] != self.config["model"]["backbone"]:
logger.warning(
"Backbone of the model in the Task Environment is different from the one in the template. "
f"creating model with backbone={model_data['config']['model']['backbone']}"
)
self.config["model"]["backbone"] = model_data["config"]["model"]["backbone"]
try:
model = get_model(config=self.config)
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")

except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

Expand Down
42 changes: 42 additions & 0 deletions external/anomaly/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# See the License for the specific language governing permissions
# and limitations under the License.

import io
from typing import Optional

import torch
from adapters.anomalib.callbacks import ProgressCallback
from adapters.anomalib.data import OTEAnomalyDataModule
from adapters.anomalib.logger import get_logger
from anomalib.models import AnomalyModule, get_model
from anomalib.utils.callbacks import (
MetricsConfigurationCallback,
MinMaxNormalizationCallback,
Expand Down Expand Up @@ -83,3 +86,42 @@ def train(
self.save_model(output_model)

logger.info("Training completed.")

def load_model(self, ote_model: Optional[ModelEntity]) -> AnomalyModule:
"""Create and Load Anomalib Module from OTE Model.

This method checks if the task environment has a saved OTE Model,
and creates one. If the OTE model already exists, it returns the
the model with the saved weights.

Args:
ote_model (Optional[ModelEntity]): OTE Model from the
task environment.

Returns:
AnomalyModule: Anomalib
classification or segmentation model with/without weights.
"""
model = get_model(config=self.config)
if ote_model is None:
logger.info(
"No trained model in project yet. Created new model with '%s'",
self.model_name,
)
else:
buffer = io.BytesIO(ote_model.get_data("weights.pth"))
model_data = torch.load(buffer, map_location=torch.device("cpu"))

try:
if model_data["config"]["model"]["backbone"] == self.config["model"]["backbone"]:
model.load_state_dict(model_data["model"])
logger.info("Loaded model weights from Task Environment")
else:
logger.info(
"Model backbone does not match. Created new model with '%s'",
self.model_name,
)
except BaseException as exception:
raise ValueError("Could not load the saved model. The model file structure is invalid.") from exception

return model