From 54e5a23bebfad8bc7b6290b2f9b6e5fb08193137 Mon Sep 17 00:00:00 2001 From: Louis-Dupont <35190946+Louis-Dupont@users.noreply.github.com> Date: Wed, 23 Aug 2023 08:53:04 +0300 Subject: [PATCH] Feature/sg 1039 add factory doc (#1395) * wip * clean version + update register in init Signed-off-by: Louis Dupont * minor change * utilizing using * fix typo * fix * fix type, add ',' and explain diff between register and resolve_params --------- Signed-off-by: Louis Dupont Co-authored-by: Ofri Masad --- documentation/source/configuration_files.md | 88 +------- documentation/source/factories.md | 211 ++++++++++++++++++ mkdocs.yml | 1 + .../common/registry/__init__.py | 47 +++- 4 files changed, 266 insertions(+), 81 deletions(-) create mode 100644 documentation/source/factories.md diff --git a/documentation/source/configuration_files.md b/documentation/source/configuration_files.md index a3264f1a39..c85fe06f09 100644 --- a/documentation/source/configuration_files.md +++ b/documentation/source/configuration_files.md @@ -53,8 +53,8 @@ train_dataset_params: - RandomHorizontalFlip - ToTensor - Normalize: - mean: ${dataset_params.img_mean} - std: ${dataset_params.img_std} + mean: [0.485, 0.456, 0.406] # mean for normalization + std: [0.229, 0.224, 0.225] # std for normalization val_dataset_params: root: /data/Imagenet/val @@ -65,8 +65,8 @@ val_dataset_params: size: 224 - ToTensor - Normalize: - mean: ${dataset_params.img_mean} - std: ${dataset_params.img_std} + mean: [0.485, 0.456, 0.406] # mean for normalization + std: [0.229, 0.224, 0.225] # std for normalization ``` Configuration file can also help you track the exact settings used for each one of your experiments, tweak and tune these settings, and share them with others. @@ -107,7 +107,6 @@ python -m super_gradients.evaluate_from_recipe --config-name=cifar10_resnet that will run only the evaluation part of the recipe (without any training iterations) - ## Hydra Hydra is an open-source Python framework that provides us with many useful functionalities for YAML management. You can learn about Hydra [here](https://hydra.cc/docs/intro). We use Hydra to load YAML files and convert them into dictionaries, while @@ -130,7 +129,6 @@ in the first arg of the command line. In the experiment directory a `.hydra` subdirectory will be created. The configuration files related to this run will be saved by hydra to that subdirectory. --------- Two Hydra features worth mentioning are _YAML Composition_ and _Command-Line Overrides_. #### YAML Composition @@ -163,6 +161,7 @@ initial learning-rate. This feature is extremely usefully when experimenting wit Note that the arguments are referenced without the `--` prefix and that each parameter is referenced with its full path in the configuration tree, concatenated with a `.`. + ## Resolvers Resolvers are converting the strings from the YAML file into Python objects or values. The most basic resolvers are the Hydra native resolvers. Here are a few simple examples: @@ -178,79 +177,12 @@ third_of_list: "${getitem: ${my_list}, 2}" first_of_list: "${first: ${my_list}}" last_of_list: "${last: ${my_list}}" ``` +You can register any additional resolver you want by simply following the official [documentation](https://omegaconf.readthedocs.io/en/latest/usage.html#resolvers). -The more advanced resolvers will instantiate objects. In the following example we define a few transforms that -will be used to augment a dataset. -```yaml -train_dataset_params: - transforms: - # for more options see common.factories.transforms_factory.py - - SegColorJitter: - brightness: 0.1 - contrast: 0.1 - saturation: 0.1 - - - SegRandomFlip: - prob: 0.5 - - - SegRandomRescale: - scales: [ 0.4, 1.6 ] -``` -Each one of the keys (`SegColorJitter`, `SegRandomFlip`, `SegRandomRescale`) is mapped to a type, and the configuration parameters under that key will be passed -to the type constructor by name (as key word arguments). - -If you want to see where this magic is happening, you can look for the `@resolve_param` decorator in the code - -```python -class ImageNetDataset(torch_datasets.ImageFolder): - - @resolve_param("transforms", factory=TransformsFactory()) - def __init__(self, root: str, transforms: Union[list, dict] = [], *args, **kwargs): - ... - ... -``` - -The `@resolve_param` wraps functions and resolves a string or a dictionary argument (in the example above "transforms") to an object. -To do so, it uses a factory object that maps a string or a dictionary to a type. when `__init__(..)` will be called, the function will receive -an object, and not a dictionary. The parameters under "transforms" in the YAML will be passed as -arguments for instantiation the objects. We will learn how to add a new type of object into these mappings in the next sections. - -## Registering a new object -To use a new object from your configuration file, you need to define the mapping of the string to a type. -This is done using one of the many registration function supported by SG. -```python -register_model -register_detection_module -register_metric -register_loss -register_dataloader -register_callback -register_transform -register_dataset -``` - -These decorator functions can be imported and used as follows: - -```python -from super_gradients.common.registry import register_model - -@register_model(name="MyNet") -class MyExampleNet(nn.Module): - def __init__(self, num_classes: int): - .... -``` - -This simple decorator, maps the name "MyNet" to the type `MyExampleNet`. Note that if your constructor -include required arguments, you will be expected to provide them when using this string - -```yaml -... -architecture: - MyNet: - num_classes: 8 -... - -``` +## Factories +Factories are similar to resolvers but were built specifically to instantiate SuperGradients objects within a recipe. +This is a key feature of SuperGradient which is being used in all of our recipes, and we recommend you to +go over this [introduction to Factories](factories.md). ## Required Hyper-Parameters Most parameters can be defined by default when including `default_train_params` in you `defaults`. diff --git a/documentation/source/factories.md b/documentation/source/factories.md new file mode 100644 index 0000000000..f7d1036330 --- /dev/null +++ b/documentation/source/factories.md @@ -0,0 +1,211 @@ +# Working with Factories + +Factories in SuperGradients provide a powerful and concise way to instantiate objects in your configuration files. + +Prerequisites: +- [Training with Configuration Files](configuration_files.md) + +In this tutorial, we'll cover how to use existing factories, register new ones, and briefly explore the implementation details. + +## Using Existing Factories + +If you had a look at the [recipes](https://github.com/Deci-AI/super-gradients/tree/master/src/super_gradients/recipes), you may have noticed that many objects are defined directly in the recipes. + +In the [Supervisely dataset recipe](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/recipes/dataset_params/supervisely_persons_dataset_params.yaml) you can see the following + +```yaml +train_dataset_params: + transforms: + - SegColorJitter: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + - SegRandomFlip: + prob: 0.5 + - SegRandomRescale: + scales: [0.4, 1.6] +``` +If you load the `.yaml` recipe as is into a python dictionary, you would get the following +```python +{ + "train_dataset_params": { + "transforms": [ + { + "SegColorJitter": { + "brightness": 0.1, + "contrast": 0.1, + "saturation": 0.1 + } + }, + { + "SegRandomFlip": { + "prob": 0.5 + } + }, + { + "SegRandomRescale": { + "scales": [0.4, 1.6] + } + } + ] + } +} +``` + +This configuration alone is not very useful, as we need instances of the classes, not just their configurations. +So we would like to somehow instantiate these classes `SegColorJitter`, `SegRandomFlip` and `SegRandomRescale`. + +Factories in SuperGradients come into play here! All these objects were registered beforehand in SuperGradients, +so that when you write these names in the recipe, SuperGradients will detect and instantiate them for you. + +## Registering a Class + +As explained above, only registered objects can be instantiated. +This registration consists of mapping the object name to the corresponding class type. + +In the example above, the string `"SegColorJitter"` was mapped to the class `SegColorJitter`, and this is how SuperGradients knows how to convert the string defined in the recipe, into an object. + +You can register the class using a name different from the actual class name. +However, it's generally recommended to use the same name for consistency and clarity. + +### Example + +```python +from super_gradients.common.registry import register_transform + +@register_transform(name="MyTransformName") +class MyTransform: + def __init__(self, prob: float): + ... +``` +In this simple example, we register a new transform. +Note that here we registered (for the sake of the example) the class `MyTransform` to the name `MyTransformName` which is different. +We strongly recommend to not do it, and to instead register a class with its own name. + +Once you registered a class, you can use it in your recipe. Here, we will add this transform to the original recipe +```yaml +train_dataset_params: + transforms: + - SegColorJitter: + brightness: 0.1 + contrast: 0.1 + saturation: 0.1 + - SegRandomFlip: + prob: 0.5 + - SegRandomRescale: + scales: [0.4, 1.6] + - MyTransformName: # We use the name used to register, which may be different from the name of the class + prob: 0.7 +``` + +Final Step: Ensure that you import the module containing `MyTransformName` into your script. +Doing so will trigger the registration function, allowing SuperGradients to recognize it. + +Here is an example (adapted from the [train_from_recipe script](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/train_from_recipe.py)). + +```python +from .my_module import MyTransform # Importing the module is enough as it will trigger the register_transform function + +# The code below is the same as the basic `train_from_recipe.py` script +# See: https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/train_from_recipe.py +from omegaconf import DictConfig +import hydra + +from super_gradients import Trainer, init_trainer + + +@hydra.main(config_path="recipes", version_base="1.2") +def _main(cfg: DictConfig) -> None: + Trainer.train_from_config(cfg) + + +def main() -> None: + init_trainer() # `init_trainer` needs to be called before `@hydra.main` + _main() + + +if __name__ == "__main__": + main() + +``` + +## Under the Hood + +Until now, we saw how to use existing Factories, and how to register new ones. +In some cases, you may want to create objects that would benefit from using the factories. + +### Basic +The basic way to use factories as below. +``` +from super_gradients.common.factories import TransformsFactory +factory = TransformsFactory() +my_transform = factory.get({'MyTransformName': {'prob': 0.7}}) +``` +You may recognize that the input passed to `factory.get` is actually the dictionary that we get after loading the recipe +(See [Using Existing Factories](#using-existing-factories)) + +### Recommended +Factories become even more powerful when used with the `@resolve_param` decorator. +This feature allows functions to accept both instantiated objects and their dictionary representations. +It means you can pass either the actual python object or a dictionary that describes it straight from the recipe. + +```python +class ImageNetDataset(torch_datasets.ImageFolder): + + @resolve_param("transform", factory=TransformsFactory()) + def __init__(self, root: str, transform: Transform): + ... +``` + +Now, `ImageNetDataset` can be passed both an instance of `MyTransform` + +```python +my_transform = MyTransform(prob=0.7) +ImageNetDataset(root=..., transform=my_transform) +``` + +And a dictionary representing the same object +```python +my_transform = {'MyTransformName': {'prob': 0.7}} +ImageNetDataset(root=..., transform=my_transform) +``` + +This second way of instantiating the dataset combines perfectly with the concept `.yaml` recipes. + +**Difference with `register_transform`** +- `register_transform` is responsible to map a string to a class type. +- `@resolve_param("transform", factory=TransformsFactory())` is responsible to convert a config into an object, using the mapping created with `register_transform`. + +## Supported Factory Types +Until here, we focused on a single type of factory, `TransformsFactory`, +associated with the registration decorator `register_transform`. + +SuperGradients supports a wide range of factories, used throughout the training process, +each with its own registering decorator. + +SuperGradients offers various types of factories, and each is associated with a specific registration decorator. + +``` python +from super_gradients.common.factories import ( + register_model, + register_kd_model, + register_detection_module, + register_metric, + register_loss, + register_dataloader, + register_callback, + register_transform, + register_dataset, + register_pre_launch_callback, + register_unet_backbone_stage, + register_unet_up_block, + register_target_generator, + register_lr_scheduler, + register_lr_warmup, + register_sg_logger, + register_collate_function, + register_sampler, + register_optimizer, + register_processing, +) +``` diff --git a/mkdocs.yml b/mkdocs.yml index 35c09e445d..79c640924f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -27,6 +27,7 @@ nav: - Phase Callbacks: ./documentation/source/PhaseCallbacks.md - YAMLs and Recipes: - Configurations: ./documentation/source/configuration_files.md + - Factories: ./documentation/source/Factories.md - Recipes: ./src/super_gradients/recipes/Training_Recipes.md - Checkpoints: ./documentation/source/Checkpoints.md - Docker: ./documentation/source/SGDocker.md diff --git a/src/super_gradients/common/registry/__init__.py b/src/super_gradients/common/registry/__init__.py index a72c6d3465..291afde3bf 100644 --- a/src/super_gradients/common/registry/__init__.py +++ b/src/super_gradients/common/registry/__init__.py @@ -1,4 +1,45 @@ -from super_gradients.common.registry.registry import register_model, register_metric, register_loss, register_detection_module, register_lr_scheduler +from super_gradients.common.registry.registry import ( + register_model, + register_kd_model, + register_detection_module, + register_metric, + register_loss, + register_dataloader, + register_callback, + register_transform, + register_dataset, + register_pre_launch_callback, + register_unet_backbone_stage, + register_unet_up_block, + register_target_generator, + register_lr_scheduler, + register_lr_warmup, + register_sg_logger, + register_collate_function, + register_sampler, + register_optimizer, + register_processing, +) - -__all__ = ["register_model", "register_detection_module", "register_metric", "register_loss", "register_lr_scheduler"] +__all__ = [ + "register_model", + "register_kd_model", + "register_detection_module", + "register_metric", + "register_loss", + "register_dataloader", + "register_callback", + "register_transform", + "register_dataset", + "register_pre_launch_callback", + "register_unet_backbone_stage", + "register_unet_up_block", + "register_target_generator", + "register_lr_scheduler", + "register_lr_warmup", + "register_sg_logger", + "register_collate_function", + "register_sampler", + "register_optimizer", + "register_processing", +]