Skip to content

Commit fb175fe

Browse files
author
Ashwin Vaidya
committed
Refactor configurable parameters
1 parent 14e35c0 commit fb175fe

File tree

15 files changed

+238
-356
lines changed

15 files changed

+238
-356
lines changed

external/anomaly/anomaly_classification/configs/padim/configuration.py

+2-29
Original file line numberDiff line numberDiff line change
@@ -17,38 +17,11 @@
1717
# and limitations under the License.
1818

1919
from attr import attrs
20-
from ote_anomalib.configs.configuration import BaseAnomalyConfig
21-
from ote_anomalib.configs.configuration_enums import ModelBackbone
22-
from ote_sdk.configuration.elements import (
23-
ParameterGroup,
24-
add_parameter_group,
25-
selectable,
26-
string_attribute,
27-
)
20+
from ote_anomalib.configs.padim import PadimAnomalyBaseConfig
2821

2922

3023
@attrs
31-
class PadimAnomalyClassificationConfig(BaseAnomalyConfig):
24+
class PadimAnomalyClassificationConfig(PadimAnomalyBaseConfig):
3225
"""
3326
Configurable parameters for PADIM anomaly classification task.
3427
"""
35-
36-
header = string_attribute("Configuration for Padim")
37-
description = header
38-
39-
@attrs
40-
class ModelParameters(ParameterGroup):
41-
"""
42-
Parameter Group for tuning the model
43-
"""
44-
45-
header = string_attribute("Model Parameters")
46-
description = header
47-
48-
backbone = selectable(
49-
default_value=ModelBackbone.RESNET18,
50-
header="Model Backbone",
51-
description="Pre-trained backbone used for feature extraction",
52-
)
53-
54-
model = add_parameter_group(ModelParameters)

external/anomaly/anomaly_classification/configs/stfpm/configuration.py

+2-86
Original file line numberDiff line numberDiff line change
@@ -17,95 +17,11 @@
1717
# and limitations under the License.
1818

1919
from attr import attrs
20-
from ote_anomalib.configs.configuration import BaseAnomalyConfig
21-
from ote_anomalib.configs.configuration_enums import EarlyStoppingMetrics, ModelBackbone
22-
from ote_sdk.configuration.elements import (
23-
ParameterGroup,
24-
add_parameter_group,
25-
configurable_float,
26-
configurable_integer,
27-
selectable,
28-
string_attribute,
29-
)
30-
from ote_sdk.configuration.model_lifecycle import ModelLifecycle
20+
from ote_anomalib.configs.stfpm import STFPMAnomalyBaseConfig
3121

3222

3323
@attrs
34-
class STFPMAnomalyClassificationConfig(BaseAnomalyConfig):
24+
class STFPMAnomalyClassificationConfig(STFPMAnomalyBaseConfig):
3525
"""
3626
Configurable parameters for STFPM anomaly classification task.
3727
"""
38-
39-
header = string_attribute("Configuration for STFPM")
40-
description = header
41-
42-
@attrs
43-
class ModelParameters(ParameterGroup):
44-
"""
45-
Parameter Group for training model
46-
"""
47-
48-
header = string_attribute("Model Parameters")
49-
description = header
50-
51-
backbone = selectable(
52-
default_value=ModelBackbone.RESNET18,
53-
header="Model Backbone",
54-
description="Pre-trained backbone used for teacher and student network",
55-
)
56-
57-
lr = configurable_float(
58-
default_value=0.4,
59-
header="Learning Rate",
60-
min_value=1e-3,
61-
max_value=1,
62-
description="Learning rate used for optimizing the Student network.",
63-
)
64-
65-
momentum = configurable_float(
66-
default_value=0.9,
67-
header="Momentum",
68-
min_value=0.1,
69-
max_value=1.0,
70-
description="Momentum used for SGD optimizer",
71-
)
72-
73-
weight_decay = configurable_float(
74-
default_value=0.0001,
75-
header="Weight Decay",
76-
min_value=1e-5,
77-
max_value=1,
78-
description="Decay for SGD optimizer",
79-
)
80-
81-
@attrs
82-
class EarlyStoppingParameters(ParameterGroup):
83-
"""
84-
Early stopping parameters
85-
"""
86-
87-
header = string_attribute("Early Stopping Parameters")
88-
description = header
89-
90-
metric = selectable(
91-
default_value=EarlyStoppingMetrics.IMAGE_F1,
92-
header="Early Stopping Metric",
93-
description="The metric used to determine if the model should stop training",
94-
)
95-
96-
patience = configurable_integer(
97-
default_value=10,
98-
min_value=1,
99-
max_value=100,
100-
header="Early Stopping Patience",
101-
description="Number of epochs to wait for an improvement in the monitored metric. If the metric has "
102-
"not improved for this many epochs, the training will stop and the best model will be "
103-
"returned.",
104-
warning="Setting this value too low might lead to underfitting. Setting the value too high will "
105-
"increase the training time and might lead to overfitting.",
106-
affects_outcome_of=ModelLifecycle.TRAINING,
107-
)
108-
109-
early_stopping = add_parameter_group(EarlyStoppingParameters)
110-
111-
model = add_parameter_group(ModelParameters)

external/anomaly/anomaly_classification/configs/stfpm/configuration.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ model:
5858
auto_hpo_state: not_possible
5959
auto_hpo_value: null
6060
default_value: resnet18
61-
description: Pre-trained backbone used for teacher and student network
61+
description: Pre-trained backbone used for feature extraction
6262
editable: true
6363
enum_name: ModelBackbone
6464
header: Model Backbone
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Configurable parameters for Padim anomaly detection task
2+
Configurable parameters for Padim anomaly segmentation task
33
"""
44

55
# Copyright (C) 2021 Intel Corporation
@@ -17,38 +17,11 @@
1717
# and limitations under the License.
1818

1919
from attr import attrs
20-
from ote_anomalib.configs.configuration import BaseAnomalyConfig
21-
from ote_anomalib.configs.configuration_enums import ModelBackbone
22-
from ote_sdk.configuration.elements import (
23-
ParameterGroup,
24-
add_parameter_group,
25-
selectable,
26-
string_attribute,
27-
)
20+
from ote_anomalib.configs.padim import PadimAnomalyBaseConfig
2821

2922

3023
@attrs
31-
class PadimAnomalyDetectionConfig(BaseAnomalyConfig):
24+
class PadimAnomalyDetectionConfig(PadimAnomalyBaseConfig):
3225
"""
33-
Configurable parameters for PADIM anomaly classification task.
26+
Configurable parameters for PADIM anomaly segmentation task.
3427
"""
35-
36-
header = string_attribute("Configuration for Padim")
37-
description = header
38-
39-
@attrs
40-
class ModelParameters(ParameterGroup):
41-
"""
42-
Parameter Group for tuning the model
43-
"""
44-
45-
header = string_attribute("Model Parameters")
46-
description = header
47-
48-
backbone = selectable(
49-
default_value=ModelBackbone.RESNET18,
50-
header="Model Backbone",
51-
description="Pre-trained backbone used for feature extraction",
52-
)
53-
54-
model = add_parameter_group(ModelParameters)

external/anomaly/anomaly_detection/configs/stfpm/configuration.py

+3-87
Original file line numberDiff line numberDiff line change
@@ -17,95 +17,11 @@
1717
# and limitations under the License.
1818

1919
from attr import attrs
20-
from ote_anomalib.configs.configuration import BaseAnomalyConfig
21-
from ote_anomalib.configs.configuration_enums import EarlyStoppingMetrics, ModelBackbone
22-
from ote_sdk.configuration.elements import (
23-
ParameterGroup,
24-
add_parameter_group,
25-
configurable_float,
26-
configurable_integer,
27-
selectable,
28-
string_attribute,
29-
)
30-
from ote_sdk.configuration.model_lifecycle import ModelLifecycle
20+
from ote_anomalib.configs.stfpm import STFPMAnomalyBaseConfig
3121

3222

3323
@attrs
34-
class STFPMAnomalyDetectionConfig(BaseAnomalyConfig):
24+
class STFPMAnomalyDetectionConfig(STFPMAnomalyBaseConfig):
3525
"""
36-
Configurable parameters for STFPM anomaly classification task.
26+
Configurable parameters for STFPM anomaly detection task.
3727
"""
38-
39-
header = string_attribute("Configuration for STFPM")
40-
description = header
41-
42-
@attrs
43-
class ModelParameters(ParameterGroup):
44-
"""
45-
Parameter Group for training model
46-
"""
47-
48-
header = string_attribute("Model Parameters")
49-
description = header
50-
51-
backbone = selectable(
52-
default_value=ModelBackbone.RESNET18,
53-
header="Model Backbone",
54-
description="Pre-trained backbone used for teacher and student network",
55-
)
56-
57-
lr = configurable_float(
58-
default_value=0.4,
59-
header="Learning Rate",
60-
min_value=1e-3,
61-
max_value=1,
62-
description="Learning rate used for optimizing the Student network.",
63-
)
64-
65-
momentum = configurable_float(
66-
default_value=0.9,
67-
header="Momentum",
68-
min_value=0.1,
69-
max_value=1.0,
70-
description="Momentum used for SGD optimizer",
71-
)
72-
73-
weight_decay = configurable_float(
74-
default_value=0.0001,
75-
header="Weight Decay",
76-
min_value=1e-5,
77-
max_value=1,
78-
description="Decay for SGD optimizer",
79-
)
80-
81-
@attrs
82-
class EarlyStoppingParameters(ParameterGroup):
83-
"""
84-
Early stopping parameters
85-
"""
86-
87-
header = string_attribute("Early Stopping Parameters")
88-
description = header
89-
90-
metric = selectable(
91-
default_value=EarlyStoppingMetrics.IMAGE_F1,
92-
header="Early Stopping Metric",
93-
description="The metric used to determine if the model should stop training",
94-
)
95-
96-
patience = configurable_integer(
97-
default_value=10,
98-
min_value=1,
99-
max_value=100,
100-
header="Early Stopping Patience",
101-
description="Number of epochs to wait for an improvement in the monitored metric. If the metric has "
102-
"not improved for this many epochs, the training will stop and the best model will be "
103-
"returned.",
104-
warning="Setting this value too low might lead to underfitting. Setting the value too high will "
105-
"increase the training time and might lead to overfitting.",
106-
affects_outcome_of=ModelLifecycle.TRAINING,
107-
)
108-
109-
early_stopping = add_parameter_group(EarlyStoppingParameters)
110-
111-
model = add_parameter_group(ModelParameters)

external/anomaly/anomaly_detection/configs/stfpm/configuration.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ model:
5858
auto_hpo_state: not_possible
5959
auto_hpo_value: null
6060
default_value: resnet18
61-
description: Pre-trained backbone used for teacher and student network
61+
description: Pre-trained backbone used for feature extraction
6262
editable: true
6363
enum_name: ModelBackbone
6464
header: Model Backbone
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Configurable parameters for Padim anomaly classification task
2+
Configurable parameters for Padim anomaly segmentation task
33
"""
44

55
# Copyright (C) 2021 Intel Corporation
@@ -17,38 +17,11 @@
1717
# and limitations under the License.
1818

1919
from attr import attrs
20-
from ote_anomalib.configs.configuration import BaseAnomalyConfig
21-
from ote_anomalib.configs.configuration_enums import ModelBackbone
22-
from ote_sdk.configuration.elements import (
23-
ParameterGroup,
24-
add_parameter_group,
25-
selectable,
26-
string_attribute,
27-
)
20+
from ote_anomalib.configs.padim import PadimAnomalyBaseConfig
2821

2922

3023
@attrs
31-
class PadimAnomalySegmentationConfig(BaseAnomalyConfig):
24+
class PadimAnomalySegmentationConfig(PadimAnomalyBaseConfig):
3225
"""
33-
Configurable parameters for PADIM anomaly classification task.
26+
Configurable parameters for PADIM anomaly segmentation task.
3427
"""
35-
36-
header = string_attribute("Configuration for Padim")
37-
description = header
38-
39-
@attrs
40-
class ModelParameters(ParameterGroup):
41-
"""
42-
Parameter Group for tuning the model
43-
"""
44-
45-
header = string_attribute("Model Parameters")
46-
description = header
47-
48-
backbone = selectable(
49-
default_value=ModelBackbone.RESNET18,
50-
header="Model Backbone",
51-
description="Pre-trained backbone used for feature extraction",
52-
)
53-
54-
model = add_parameter_group(ModelParameters)

0 commit comments

Comments
 (0)