|
17 | 17 | # and limitations under the License.
|
18 | 18 |
|
19 | 19 | from attr import attrs
|
20 |
| -from ote_anomalib.configs.configuration import BaseAnomalyConfig |
21 |
| -from ote_anomalib.configs.configuration_enums import EarlyStoppingMetrics, ModelBackbone |
22 |
| -from ote_sdk.configuration.elements import ( |
23 |
| - ParameterGroup, |
24 |
| - add_parameter_group, |
25 |
| - configurable_float, |
26 |
| - configurable_integer, |
27 |
| - selectable, |
28 |
| - string_attribute, |
29 |
| -) |
30 |
| -from ote_sdk.configuration.model_lifecycle import ModelLifecycle |
| 20 | +from ote_anomalib.configs.stfpm import STFPMAnomalyBaseConfig |
31 | 21 |
|
32 | 22 |
|
33 | 23 | @attrs
|
34 |
| -class STFPMAnomalyDetectionConfig(BaseAnomalyConfig): |
| 24 | +class STFPMAnomalyDetectionConfig(STFPMAnomalyBaseConfig): |
35 | 25 | """
|
36 |
| - Configurable parameters for STFPM anomaly classification task. |
| 26 | + Configurable parameters for STFPM anomaly detection task. |
37 | 27 | """
|
38 |
| - |
39 |
| - header = string_attribute("Configuration for STFPM") |
40 |
| - description = header |
41 |
| - |
42 |
| - @attrs |
43 |
| - class ModelParameters(ParameterGroup): |
44 |
| - """ |
45 |
| - Parameter Group for training model |
46 |
| - """ |
47 |
| - |
48 |
| - header = string_attribute("Model Parameters") |
49 |
| - description = header |
50 |
| - |
51 |
| - backbone = selectable( |
52 |
| - default_value=ModelBackbone.RESNET18, |
53 |
| - header="Model Backbone", |
54 |
| - description="Pre-trained backbone used for teacher and student network", |
55 |
| - ) |
56 |
| - |
57 |
| - lr = configurable_float( |
58 |
| - default_value=0.4, |
59 |
| - header="Learning Rate", |
60 |
| - min_value=1e-3, |
61 |
| - max_value=1, |
62 |
| - description="Learning rate used for optimizing the Student network.", |
63 |
| - ) |
64 |
| - |
65 |
| - momentum = configurable_float( |
66 |
| - default_value=0.9, |
67 |
| - header="Momentum", |
68 |
| - min_value=0.1, |
69 |
| - max_value=1.0, |
70 |
| - description="Momentum used for SGD optimizer", |
71 |
| - ) |
72 |
| - |
73 |
| - weight_decay = configurable_float( |
74 |
| - default_value=0.0001, |
75 |
| - header="Weight Decay", |
76 |
| - min_value=1e-5, |
77 |
| - max_value=1, |
78 |
| - description="Decay for SGD optimizer", |
79 |
| - ) |
80 |
| - |
81 |
| - @attrs |
82 |
| - class EarlyStoppingParameters(ParameterGroup): |
83 |
| - """ |
84 |
| - Early stopping parameters |
85 |
| - """ |
86 |
| - |
87 |
| - header = string_attribute("Early Stopping Parameters") |
88 |
| - description = header |
89 |
| - |
90 |
| - metric = selectable( |
91 |
| - default_value=EarlyStoppingMetrics.IMAGE_F1, |
92 |
| - header="Early Stopping Metric", |
93 |
| - description="The metric used to determine if the model should stop training", |
94 |
| - ) |
95 |
| - |
96 |
| - patience = configurable_integer( |
97 |
| - default_value=10, |
98 |
| - min_value=1, |
99 |
| - max_value=100, |
100 |
| - header="Early Stopping Patience", |
101 |
| - description="Number of epochs to wait for an improvement in the monitored metric. If the metric has " |
102 |
| - "not improved for this many epochs, the training will stop and the best model will be " |
103 |
| - "returned.", |
104 |
| - warning="Setting this value too low might lead to underfitting. Setting the value too high will " |
105 |
| - "increase the training time and might lead to overfitting.", |
106 |
| - affects_outcome_of=ModelLifecycle.TRAINING, |
107 |
| - ) |
108 |
| - |
109 |
| - early_stopping = add_parameter_group(EarlyStoppingParameters) |
110 |
| - |
111 |
| - model = add_parameter_group(ModelParameters) |
0 commit comments