Skip to content

Commit 56e6624

Browse files
author
Songki Choi
authored
Merge pull request #1118 from openvinotoolkit/ashwin/fix_non_deterministic
[Anomaly Task] Fix non deterministic + sample.py
2 parents 2e18117 + 9ef77d0 commit 56e6624

File tree

5 files changed

+53
-10
lines changed

5 files changed

+53
-10
lines changed

external/anomaly/ote_anomalib/data/create_mvtec_ad_json_annotations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def create_task_annotations(task: str, data_path: str, annotation_path: str) ->
184184
Raises:
185185
ValueError: When task is not classification, detection or segmentation.
186186
"""
187-
annotation_path = os.path.join(data_path, task)
187+
annotation_path = os.path.join(annotation_path, task)
188188
os.makedirs(annotation_path, exist_ok=True)
189189

190190
for split in ["train", "val", "test"]:

external/anomaly/ote_anomalib/train_task.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# See the License for the specific language governing permissions
1515
# and limitations under the License.
1616

17+
from typing import Optional
18+
1719
from anomalib.utils.callbacks import MinMaxNormalizationCallback
1820
from ote_anomalib import AnomalyInferenceTask
1921
from ote_anomalib.callbacks import ProgressCallback
@@ -23,7 +25,7 @@
2325
from ote_sdk.entities.model import ModelEntity
2426
from ote_sdk.entities.train_parameters import TrainParameters
2527
from ote_sdk.usecases.tasks.interfaces.training_interface import ITrainingTask
26-
from pytorch_lightning import Trainer
28+
from pytorch_lightning import Trainer, seed_everything
2729

2830
logger = get_logger(__name__)
2931

@@ -36,17 +38,26 @@ def train(
3638
dataset: DatasetEntity,
3739
output_model: ModelEntity,
3840
train_parameters: TrainParameters,
41+
seed: Optional[int] = None,
3942
) -> None:
4043
"""Train the anomaly classification model.
4144
4245
Args:
4346
dataset (DatasetEntity): Input dataset.
4447
output_model (ModelEntity): Output model to save the model weights.
4548
train_parameters (TrainParameters): Training parameters
49+
seed: (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
50+
deterministic flag to True.
4651
"""
4752
logger.info("Training the model.")
4853

4954
config = self.get_config()
55+
56+
if seed:
57+
logger.info(f"Setting seed to {seed}")
58+
seed_everything(seed, workers=True)
59+
config.trainer.deterministic = True
60+
5061
logger.info("Training Configs '%s'", config)
5162

5263
datamodule = OTEAnomalyDataModule(config=config, dataset=dataset, task_type=self.task_type)

external/anomaly/tests/test_ote_training.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,8 @@ def _run_ote_training(self, data_collector):
238238
self.copy_hyperparams = deepcopy(self.task.task_environment.get_hyper_parameters())
239239

240240
try:
241-
self.task.train(self.dataset, self.output_model, TrainParameters)
241+
# fix seed so that result is repeatable
242+
self.task.train(self.dataset, self.output_model, TrainParameters, seed=42)
242243
except Exception as ex:
243244
raise RuntimeError("Training failed") from ex
244245

external/anomaly/tools/README.md

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
OpenVINO Training Extension interacts with the anomaly detection library ([Anomalib](https://github.com/openvinotoolkit/anomalib)) by providing interfaces in the `external/anomaly` of this repository. The `sample.py` file contained in this folder serves as an end-to-end example of how these interfaces are used. To begin using this script, first ensure that `ote_cli`, `ote_sdk` and `external/anomaly` dependencies are installed.
2+
3+
To get started, we provide a handy script in `ote_anomalib/data/create_mvtec_ad_json_annotations.py` to help generate annotation json files for MVTec dataset. Assuming that you have placed the MVTec dataset in a directory your home folder (`~/dataset/MVTec`), you can run the following command to generate the annotations.
4+
5+
```bash
6+
python create_mvtec_ad_json_annotations.py --data_path ~/datasets/MVTec --annotation_path ~/training_extensions/data/MVtec/
7+
```
8+
9+
This will generate three folders in `~/training_extensions/data/MVtec/` for classification, segmentation and detection task.
10+
11+
Then, to run sample.py you can use the following command.
12+
13+
```bash
14+
python tools/sample.py \
15+
--dataset_path ~/datasets/MVTec \
16+
--category bottle \
17+
--train-ann-files ../../data/MVtec/bottle/segmentation/train.json \
18+
--val-ann-files ../../data/MVtec/bottle/segmentation/val.json \
19+
--test-ann-files ../../data/MVtec/bottle/segmentation/test.json \
20+
--model_template_path ./configs/anomaly_segmentation/padim/template.yaml
21+
```
22+
23+
Optionally, you can also optimize to `nncf` or `pot` by using the `--optimization` flag

external/anomaly/tools/sample.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import os
2323
import shutil
2424
from argparse import Namespace
25-
from typing import Any, Dict, Type, Union
25+
from typing import Any, Dict, Optional, Type, Union
2626

2727
from ote_anomalib import AnomalyNNCFTask, OpenVINOAnomalyTask
2828
from ote_anomalib.data.dataset import (
@@ -61,13 +61,18 @@ def __init__(
6161
val_subset: Dict[str, str],
6262
test_subset: Dict[str, str],
6363
model_template_path: str,
64+
seed: Optional[int] = None,
6465
) -> None:
6566
"""Initialize OteAnomalyTask.
6667
6768
Args:
6869
dataset_path (str): Path to the MVTec dataset.
69-
seed (int): Seed to split the dataset into train/val/test splits.
70+
train_subset (Dict[str, str]): Dictionary containing path to train annotation file and path to dataset.
71+
val_subset (Dict[str, str]): Dictionary containing path to validation annotation file and path to dataset.
72+
test_subset (Dict[str, str]): Dictionary containing path to test annotation file and path to dataset.
7073
model_template_path (str): Path to model template.
74+
seed (Optional[int]): Setting seed to a value other than 0 also marks PytorchLightning trainer's
75+
deterministic flag to True.
7176
7277
Example:
7378
>>> import os
@@ -78,9 +83,12 @@ def __init__(
7883
7984
>>> model_template_path = "./configs/anomaly_classification/padim/template.yaml"
8085
>>> dataset_path = "./datasets/MVTec"
81-
>>> seed = 0
8286
>>> task = OteAnomalyTask(
83-
... dataset_path=dataset_path, seed=seed, model_template_path=model_template_path
87+
... dataset_path=dataset_path,
88+
... train_subset={"ann_file": train.json, "data_root": dataset_path},
89+
... val_subset={"ann_file": val.json, "data_root": dataset_path},
90+
... test_subset={"ann_file": test.json, "data_root": dataset_path},
91+
... model_template_path=model_template_path
8492
... )
8593
8694
>>> task.train()
@@ -110,6 +118,7 @@ def __init__(
110118
self.openvino_task: OpenVINOAnomalyTask
111119
self.nncf_task: AnomalyNNCFTask
112120
self.results = {"category": dataset_path}
121+
self.seed = seed
113122

114123
def get_dataclass(
115124
self,
@@ -176,9 +185,7 @@ def train(self) -> ModelEntity:
176185
configuration=self.task_environment.get_model_configuration(),
177186
)
178187
self.torch_task.train(
179-
dataset=self.dataset,
180-
output_model=output_model,
181-
train_parameters=TrainParameters(),
188+
dataset=self.dataset, output_model=output_model, train_parameters=TrainParameters(), seed=self.seed
182189
)
183190

184191
logger.info("Inferring the base torch model on the validation set.")
@@ -364,6 +371,7 @@ def main() -> None:
364371
val_subset=val_subset,
365372
test_subset=test_subset,
366373
model_template_path=args.model_template_path,
374+
seed=args.seed,
367375
)
368376

369377
task.train()

0 commit comments

Comments
 (0)