-
Notifications
You must be signed in to change notification settings - Fork 517
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into feature/SG-1189-update-notebooks
# Conflicts: # tests/deci_core_unit_test_suite_runner.py
- Loading branch information
Showing
4 changed files
with
342 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
""" | ||
Entry point for converting recipe file to self-contained train.py file. | ||
Convert a recipe YAML file to a self-contained <train.py> file that can be run with python <train.py>. | ||
Generated file will contain all training hyperparameters from input recipe file but will be self-contained (no dependencies on original recipe). | ||
Limitations: Converting a recipe with command-line overrides of some parameters in this recipe is not supported. | ||
General use: python -m super_gradients.convert_recipe_to_code DESIRED_RECIPE OUTPUT_SCRIPT | ||
Example: python -m super_gradients.convert_recipe_to_code coco2017_yolo_nas_s train_coco2017_yolo_nas_s.py | ||
For recipe's specific instructions and details refer to the recipe's configuration file in the recipes' directory. | ||
""" | ||
import argparse | ||
import collections | ||
import os.path | ||
import pathlib | ||
from typing import Tuple, Mapping, Dict, Union, Optional | ||
|
||
import hydra | ||
import pkg_resources | ||
from hydra.core.global_hydra import GlobalHydra | ||
from omegaconf import DictConfig, OmegaConf, ListConfig | ||
|
||
from super_gradients import Trainer | ||
from super_gradients.common import MultiGPUMode | ||
from super_gradients.common.abstractions.abstract_logger import get_logger | ||
from super_gradients.common.environment.omegaconf_utils import register_hydra_resolvers | ||
from super_gradients.common.environment.path_utils import normalize_path | ||
from super_gradients.training.utils import get_param | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
def try_import_black(): | ||
""" | ||
Attempts to import black code formatter. | ||
If black is not installed, it will attempt to install it with pip. | ||
If installation fails, it will return None | ||
""" | ||
try: | ||
import black | ||
|
||
return black | ||
except ImportError: | ||
logger.info("Trying to install black using pip to enable formatting of the generated script.") | ||
try: | ||
import pip | ||
|
||
pip.main(["install", "black==22.10.0"]) | ||
import black | ||
|
||
logger.info("Black installed via pip. ") | ||
return black | ||
except Exception: | ||
logger.info("Black installation failed. Formatting of the generated script will be disabled.") | ||
return None | ||
|
||
|
||
def recursively_walk_and_extract_hydra_targets( | ||
cfg: DictConfig, objects: Optional[Mapping] = None, prefix: Optional[str] = None | ||
) -> Tuple[DictConfig, Dict[str, Mapping]]: | ||
""" | ||
Iterates over the input config, extracts all hydra targets present in it and replace them with variable references. | ||
Extracted hydra targets are stored in the objects dictionary (Used to generated instantiations of the objects in the generated script). | ||
:param cfg: Input config | ||
:param objects: Dictionary of extracted hydra targets | ||
:param prefix: A prefix variable to track the path to the current config (Used to give variables meaningful name) | ||
:return: A new config and the dictionary of objects that must be created in the generated script | ||
""" | ||
if objects is None: | ||
objects = collections.OrderedDict() | ||
if prefix is None: | ||
prefix = "" | ||
|
||
if isinstance(cfg, DictConfig): | ||
for key, value in cfg.items(): | ||
value, objects = recursively_walk_and_extract_hydra_targets(value, objects, prefix=f"{prefix}_{key}") | ||
cfg[key] = value | ||
|
||
if "_target_" in cfg: | ||
target_class = cfg["_target_"] | ||
target_params = dict([(k, v) for k, v in cfg.items() if k != "_target_"]) | ||
object_name = f"{prefix}".replace(".", "_").lower() | ||
objects[object_name] = (target_class, target_params) | ||
cfg = object_name | ||
|
||
elif isinstance(cfg, ListConfig): | ||
for index, item in enumerate(cfg): | ||
item, objects = recursively_walk_and_extract_hydra_targets(item, objects, prefix=f"{prefix}_{index}") | ||
cfg[index] = item | ||
else: | ||
pass | ||
return cfg, objects | ||
|
||
|
||
def convert_recipe_to_code(config_name: Union[str, pathlib.Path], config_dir: Union[str, pathlib.Path], output_script_path: Union[str, pathlib.Path]) -> None: | ||
""" | ||
Convert a recipe YAML file to a self-contained <train.py> file that can be run with python <train.py>. | ||
Generated file will contain all training hyperparameters from input recipe file but will be self-contained (no dependencies on original recipe). | ||
Limitations: Converting a recipe with command-line overrides of some paramters in this recipe is not supported. | ||
:param config_name: Name of the recipe file (can be with or without .yaml extension) | ||
:param config_dir: Directory where the recipe file is located | ||
:param output_script_path: Path to the output .py file | ||
:return: None | ||
""" | ||
config_name = str(config_name) | ||
config_dir = str(config_dir) | ||
output_script_path = str(output_script_path) | ||
|
||
register_hydra_resolvers() | ||
GlobalHydra.instance().clear() | ||
with hydra.initialize_config_dir(config_dir=normalize_path(config_dir), version_base="1.2"): | ||
cfg = hydra.compose(config_name=config_name) | ||
|
||
cfg = Trainer._trigger_cfg_modifying_callbacks(cfg) | ||
OmegaConf.resolve(cfg) | ||
|
||
device = get_param(cfg, "device") | ||
multi_gpu = get_param(cfg, "multi_gpu") | ||
|
||
if multi_gpu is False: | ||
multi_gpu = MultiGPUMode.OFF | ||
num_gpus = get_param(cfg, "num_gpus") | ||
|
||
train_dataloader = get_param(cfg, "train_dataloader") | ||
train_dataset_params = OmegaConf.to_container(cfg.dataset_params.train_dataset_params, resolve=True) | ||
train_dataloader_params = OmegaConf.to_container(cfg.dataset_params.train_dataloader_params, resolve=True) | ||
|
||
val_dataloader = get_param(cfg, "val_dataloader") | ||
val_dataset_params = OmegaConf.to_container(cfg.dataset_params.val_dataset_params, resolve=True) | ||
val_dataloader_params = OmegaConf.to_container(cfg.dataset_params.val_dataloader_params, resolve=True) | ||
|
||
num_classes = cfg.arch_params.num_classes | ||
arch_params = OmegaConf.to_container(cfg.arch_params, resolve=True) | ||
|
||
strict_load = cfg.checkpoint_params.strict_load | ||
if isinstance(strict_load, Mapping) and "_target_" in strict_load: | ||
strict_load = hydra.utils.instantiate(strict_load) | ||
|
||
training_hyperparams, hydra_instantiated_objects = recursively_walk_and_extract_hydra_targets(cfg.training_hyperparams) | ||
|
||
checkpoint_num_classes = get_param(cfg.checkpoint_params, "checkpoint_num_classes") | ||
content = f""" | ||
import super_gradients | ||
from super_gradients import init_trainer, Trainer | ||
from super_gradients.training.utils.distributed_training_utils import setup_device | ||
from super_gradients.training import models, dataloaders | ||
from super_gradients.common.data_types.enum import MultiGPUMode, StrictLoad | ||
import numpy as np | ||
def main(): | ||
init_trainer() | ||
setup_device(device={device}, multi_gpu="{multi_gpu}", num_gpus={num_gpus}) | ||
trainer = Trainer(experiment_name="{cfg.experiment_name}", ckpt_root_dir="{cfg.ckpt_root_dir}") | ||
num_classes = {num_classes} | ||
arch_params = {arch_params} | ||
model = models.get( | ||
model_name="{cfg.architecture}", | ||
num_classes=num_classes, | ||
arch_params=arch_params, | ||
strict_load={strict_load}, | ||
pretrained_weights={cfg.checkpoint_params.pretrained_weights}, | ||
checkpoint_path={cfg.checkpoint_params.checkpoint_path}, | ||
load_backbone={cfg.checkpoint_params.load_backbone}, | ||
checkpoint_num_classes={checkpoint_num_classes}, | ||
) | ||
train_dataloader = dataloaders.get( | ||
name={train_dataloader}, | ||
dataset_params={train_dataset_params}, | ||
dataloader_params={train_dataloader_params}, | ||
) | ||
val_dataloader = dataloaders.get( | ||
name={val_dataloader}, | ||
dataset_params={val_dataset_params}, | ||
dataloader_params={val_dataloader_params}, | ||
) | ||
""" | ||
for name, (class_name, class_params) in hydra_instantiated_objects.items(): | ||
class_params_str = [] | ||
for k, v in class_params.items(): | ||
class_params_str.append(f"{k}={v}") | ||
class_params_str = ",".join(class_params_str) | ||
content += f" {name} = {class_name}({class_params_str})\n\n" | ||
|
||
content += f""" | ||
training_hyperparams = {training_hyperparams} | ||
# TRAIN | ||
result = trainer.train( | ||
model=model, | ||
train_loader=train_dataloader, | ||
valid_loader=val_dataloader, | ||
training_params=training_hyperparams, | ||
) | ||
print(result) | ||
if __name__ == "__main__": | ||
main() | ||
""" | ||
# Remove quotes from dict values to reference them as variables | ||
for key in hydra_instantiated_objects.keys(): | ||
key_to_search = f"'{key}'" | ||
key_to_replace_with = f"{key}" | ||
content = content.replace(key_to_search, key_to_replace_with) | ||
|
||
with open(output_script_path, "w") as f: | ||
black = try_import_black() | ||
if black is not None: | ||
content = black.format_str(content, mode=black.FileMode(line_length=160)) | ||
f.write(content) | ||
|
||
|
||
def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("config_name", type=str, help=".yaml filename") | ||
parser.add_argument("save_path", type=str, default=None, help="Destination path to the output .py file") | ||
parser.add_argument("--config_dir", type=str, default=pkg_resources.resource_filename("super_gradients.recipes", ""), help="The config directory path") | ||
args = parser.parse_args() | ||
|
||
save_path = args.save_path or os.path.splitext(os.path.basename(args.config_name))[0] + ".py" | ||
logger.info(f"Saving recipe script to {save_path}") | ||
|
||
convert_recipe_to_code(args.config_name, args.config_dir, save_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import ast | ||
import tempfile | ||
|
||
import pkg_resources | ||
import unittest | ||
|
||
from super_gradients.convert_recipe_to_code import convert_recipe_to_code | ||
from pathlib import Path | ||
|
||
|
||
class TestConvertRecipeToCode(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.recipes_dir: Path = Path(pkg_resources.resource_filename("super_gradients.recipes", "")) | ||
self.recipes_that_should_work = [ | ||
"cifar10_resnet.yaml", | ||
"cityscapes_al_ddrnet.yaml", | ||
"cityscapes_ddrnet.yaml", | ||
"cityscapes_pplite_seg50.yaml", | ||
"cityscapes_pplite_seg75.yaml", | ||
"cityscapes_regseg48.yaml", | ||
"cityscapes_segformer_b0.yaml", | ||
"cityscapes_segformer_b1.yaml", | ||
"cityscapes_segformer_b2.yaml", | ||
"cityscapes_segformer_b3.yaml", | ||
"cityscapes_segformer_b4.yaml", | ||
"cityscapes_segformer_b5.yaml", | ||
"cityscapes_stdc_base.yaml", | ||
"cityscapes_stdc_seg50.yaml", | ||
"cityscapes_stdc_seg75.yaml", | ||
"coco2017_pose_dekr_rescoring.yaml", | ||
"coco2017_pose_dekr_w32_no_dc.yaml", | ||
"coco2017_ppyoloe_l.yaml", | ||
"coco2017_ppyoloe_m.yaml", | ||
"coco2017_ppyoloe_s.yaml", | ||
"coco2017_ppyoloe_x.yaml", | ||
"coco2017_ssd_lite_mobilenet_v2.yaml", | ||
"coco2017_yolo_nas_s.yaml", | ||
"coco2017_yolox.yaml", | ||
"coco_segmentation_shelfnet_lw.yaml", | ||
"imagenet_efficientnet.yaml", | ||
"imagenet_mobilenetv2.yaml", | ||
"imagenet_mobilenetv3_large.yaml", | ||
"imagenet_mobilenetv3_small.yaml", | ||
"imagenet_regnetY.yaml", | ||
"imagenet_repvgg.yaml", | ||
"imagenet_resnet50.yaml", | ||
"imagenet_vit_base.yaml", | ||
"imagenet_vit_large.yaml", | ||
"supervisely_unet.yaml", | ||
"user_recipe_mnist_as_external_dataset_example.yaml", | ||
"user_recipe_mnist_example.yaml", | ||
] | ||
|
||
self.recipes_that_does_not_work = [ | ||
"cityscapes_kd_base.yaml", # KD recipe not supported | ||
"imagenet_resnet50_kd.yaml", # KD recipe not supported | ||
"imagenet_mobilenetv3_base.yaml", # Base recipe (not complete) for other MobileNetV3 recipes | ||
"cityscapes_segformer.yaml", # Base recipe (not complete) for other SegFormer recipes | ||
"roboflow_ppyoloe.yaml", # Require explicit command line arguments | ||
"roboflow_yolo_nas_m.yaml", # Require explicit command line arguments | ||
"roboflow_yolo_nas_s.yaml", # Require explicit command line arguments | ||
"roboflow_yolo_nas_s_qat.yaml", # Require explicit command line arguments | ||
"roboflow_yolox.yaml", # Require explicit command line arguments | ||
"variable_setup.yaml", # Not a recipe | ||
"script_generate_rescoring_data_dekr_coco2017.yaml", # Not a recipe | ||
] | ||
|
||
def test_all_recipes_are_tested(self): | ||
present_recipes = set(recipe.name for recipe in self.recipes_dir.glob("*.yaml")) | ||
known_recipes = set(self.recipes_that_should_work + self.recipes_that_does_not_work) | ||
new_recipes = present_recipes - known_recipes | ||
removed_recipes = known_recipes - present_recipes | ||
if len(new_recipes): | ||
self.fail(f"New recipes found: {new_recipes}. Please add them to the list of recipes to test.") | ||
if len(removed_recipes): | ||
self.fail(f"Removed recipes found: {removed_recipes}. Please remove them from the list of recipes to test.") | ||
|
||
def test_convert_recipes_that_should_work(self): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
for recipe in self.recipes_that_should_work: | ||
with self.subTest(recipe=recipe): | ||
output_script_path = Path(temp_dir) / Path(recipe).name | ||
convert_recipe_to_code(recipe, self.recipes_dir, output_script_path) | ||
src = output_script_path.read_text() | ||
try: | ||
ast.parse(src, feature_version=(3, 9)) | ||
except SyntaxError as e: | ||
self.fail(f"Recipe {recipe} failed to convert to python script: {e}") | ||
|
||
def test_convert_recipes_that_are_expected_to_fail(self): | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
for recipe in self.recipes_that_does_not_work: | ||
with self.subTest(recipe=recipe): | ||
output_script_path = Path(temp_dir) / Path(recipe).name | ||
with self.assertRaises(Exception): | ||
convert_recipe_to_code(recipe, self.recipes_dir, output_script_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |