Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 541 auto batch selection #628

Merged
merged 11 commits into from
Jan 22, 2023
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from super_gradients.common.factories.base_factory import BaseFactory
from super_gradients.training import pre_launch_callbacks


class PreLaunchCallbacksFactory(BaseFactory):
def __init__(self):
super().__init__(pre_launch_callbacks.ALL_PRE_LAUNCH_CALLBACKS)
2 changes: 2 additions & 0 deletions src/super_gradients/common/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from super_gradients.training.utils.callbacks.all_callbacks import CALLBACKS
from super_gradients.training.transforms.all_transforms import TRANSFORMS
from super_gradients.training.datasets.all_datasets import ALL_DATASETS
from super_gradients.training.pre_launch_callbacks import ALL_PRE_LAUNCH_CALLBACKS


def create_register_decorator(registry: Dict[str, Callable]) -> Callable:
Expand Down Expand Up @@ -51,3 +52,4 @@ def decorator(cls: Callable) -> Callable:
register_callback = create_register_decorator(registry=CALLBACKS)
register_transform = create_register_decorator(registry=TRANSFORMS)
register_dataset = create_register_decorator(registry=ALL_DATASETS)
register_pre_launch_callback = create_register_decorator(registry=ALL_PRE_LAUNCH_CALLBACKS)
1 change: 1 addition & 0 deletions src/super_gradients/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
"ckpt_name": "ckpt_latest.pth",
"resume_strict_load": False,
"sync_bn": False,
"kill_ddp_pgroup_on_end": True, # Whether to kill the DDP process group in the end of training.
"max_train_batches": None, # For debug- when not None- will break out of inner train loop
# (i.e iterating over train_loader) when reaching this number of batches.
"max_valid_batches": None, # For debug- when not None- will break out of inner valid loop
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from super_gradients.training.pre_launch_callbacks.pre_launch_callbacks import PreLaunchCallback, AutoTrainBatchSizeSelectionCallback

ALL_PRE_LAUNCH_CALLBACKS = {"AutoTrainBatchSizeSelectionCallback": AutoTrainBatchSizeSelectionCallback}

__all__ = ["PreLaunchCallback", "AutoTrainBatchSizeSelectionCallback", "ALL_PRE_LAUNCH_CALLBACKS"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from copy import deepcopy
from typing import Union

from omegaconf import DictConfig
import torch

from super_gradients import is_distributed
from super_gradients.common.abstractions.abstract_logger import get_logger
from super_gradients.training import models
from torch.distributed import barrier

logger = get_logger(__name__)


class PreLaunchCallback:
"""
PreLaunchCallback

Base class for callbacks to be triggered, manipulating the config (cfg) prior to launching training,
when calling Trainer.train_from_config(cfg).

"""

def __call__(self, cfg: Union[dict, DictConfig]) -> Union[dict, DictConfig]:
raise NotImplementedError


class AutoTrainBatchSizeSelectionCallback(PreLaunchCallback):
"""
AutoTrainBatchSizeSelectionCallback

Modifies cfg.dataset_params.train_dataloader_params.batch_size by searching for the maximal batch size that fits
gpu memory. Works out of the box for DDP.

The search is done by running a few forward passes for increasing batch sizes, until CUDA OUT OF MEMORY is raised:

For batch_size in range(min_batch_size:max_batch_size:size_step):
if batch_size raises CUDA OUT OF MEMORY ERROR:
return batch_size-size_step
return batch_size

Example usage: Inside the main recipe .YAML file (for example super_gradients/recipes/cifar10_resnet.yaml),
add the following:

pre_launch_callbacks_list:
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
- AutoTrainBatchSizeSelectionCallback:
min_batch_size: 128
size_step: 64
num_forward_passes: 10

Then, when running super_gradients/examples/train_from_recipe_example/train_from_recipe.py --config-name=...
this pre_launch_callback will modify cfg.dataset_params.train_dataloader_params.batch_size then pass cfg to
Trainer.train_from_config(cfg) and training will continue with the selected batch size.


:param min_batch_size: int, the first batch size to try running forward passes. Should fit memory.

:param size_step: int, the difference between 2 consecutive batch_size trials.

:param num_forward_passes: int, number of forward passes (i.e train_loader data iterations inside an epoch).
Note that the more forward passes being done, the less the selected batch size is prawn to fail. This is because
other then gradients, model computations, data and other fixed gpu memory that is being used- some more gpu memory
might be used by the metric objects and PhaseCallbacks.

:param max_batch_size: int, optional, upper limit of the batch sizes to try. When None, the search will continue until
the maximal batch size that does not raise CUDA OUT OF MEMORY is found (deafult=None).

:param scale_lr: bool, whether to linearly scale cfg.training_hyperparams.initial_lr, i.e multiply by
FOUND_BATCH_SIZE/cfg.dataset_params.train_datalaoder_params.batch_size (default=True)
"""

def __init__(self, min_batch_size: int, size_step: int, num_forward_passes: int = 3, max_batch_size=None, scale_lr: bool = True):
self.scale_lr = scale_lr
self.min_batch_size = min_batch_size
self.size_step = size_step
self.max_batch_size = max_batch_size
self.num_forward_passes = num_forward_passes

def __call__(self, cfg: DictConfig) -> DictConfig:

# IMPORT IS HERE DUE TO CIRCULAR IMPORT PROBLEM
from super_gradients.training.sg_trainer import Trainer

curr_batch_size = self.min_batch_size
# BUILD NETWORK
model = models.get(
model_name=cfg.architecture,
num_classes=cfg.arch_params.num_classes,
arch_params=cfg.arch_params,
strict_load=cfg.checkpoint_params.strict_load,
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=cfg.checkpoint_params.load_backbone,
)
tmp_cfg = deepcopy(cfg)
tmp_cfg.training_hyperparams.batch_accumulate = 1
tmp_cfg.training_hyperparams.max_train_batches = self.num_forward_passes
tmp_cfg.training_hyperparams.run_validation_freq = 2
tmp_cfg.training_hyperparams.silent_mode = True
tmp_cfg.training_hyperparams.save_model = False
tmp_cfg.training_hyperparams.max_epochs = 1
tmp_cfg.training_hyperparams.average_best_models = False
tmp_cfg.training_hyperparams.kill_ddp_pgroup_on_end = False
tmp_cfg.pre_launch_callbacks_list = []

while True:
ofrimasad marked this conversation as resolved.
Show resolved Hide resolved
tmp_cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size

try:
Trainer.train_from_config(tmp_cfg)

except RuntimeError as e:
if "out of memory" in str(e):
if curr_batch_size == self.min_batch_size:
logger.error("Ran out of memory for the smallest batch, try setting smaller min_batch_size.")
raise e
else:
logger.info(f"Ran out of memory for {curr_batch_size}, setting batch size to {curr_batch_size - self.size_step}.")
self._adapt_lr_if_needed(cfg, found_batch_size=curr_batch_size - self.size_step)
cfg.dataset_params.train_dataloader_params.batch_size = curr_batch_size - self.size_step
self._clear_model_gpu_mem(model)
return cfg
else:
raise e

else:
if self.max_batch_size is not None and curr_batch_size >= self.max_batch_size:
logger.info(
f"Did not run out of memory for {curr_batch_size} >= max_batch_size={self.max_batch_size}, " f"setting batch to {self.max_batch_size}."
)
self._adapt_lr_if_needed(cfg, found_batch_size=self.max_batch_size)
cfg.dataset_params.train_dataloader_params.batch_size = self.max_batch_size
self._clear_model_gpu_mem(model)
return cfg
logger.info(f"Did not run out of memory for {curr_batch_size}, retrying batch {curr_batch_size + self.size_step}.")
curr_batch_size += self.size_step
self._clear_model_gpu_mem(model)

def _adapt_lr_if_needed(self, cfg: DictConfig, found_batch_size: int) -> DictConfig:
if self.scale_lr:
scale_factor = found_batch_size / cfg.dataset_params.train_dataloader_params.batch_size
cfg.training_hyperparams.initial_lr = cfg.training_hyperparams.initial_lr * scale_factor
return cfg

@classmethod
def _clear_model_gpu_mem(cls, model):
for p in model.parameters():
if p.grad is not None:
del p.grad # free some memory
torch.cuda.empty_cache()
# WAIT FOR ALL PROCESSES TO CLEAR THEIR MEMORY BEFORE MOVING ON
if is_distributed():
barrier()
35 changes: 24 additions & 11 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from super_gradients.training.utils import HpmStruct
from super_gradients.training.utils.hydra_utils import load_experiment_cfg, add_params_to_cfg
from omegaconf import OmegaConf
from super_gradients.common.factories.pre_launch_callbacks_factory import PreLaunchCallbacksFactory

logger = get_logger(__name__)

Expand Down Expand Up @@ -219,8 +220,22 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
# INSTANTIATE ALL OBJECTS IN CFG
cfg = hydra.utils.instantiate(cfg)

# TRIGGER CFG MODIFYING CALLBACKS
cfg = cls._trigger_cfg_modifying_callbacks(cfg)

trainer = Trainer(experiment_name=cfg.experiment_name, ckpt_root_dir=cfg.ckpt_root_dir)

# BUILD NETWORK
model = models.get(
model_name=cfg.architecture,
num_classes=cfg.arch_params.num_classes,
arch_params=cfg.arch_params,
strict_load=cfg.checkpoint_params.strict_load,
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=cfg.checkpoint_params.load_backbone,
)

# INSTANTIATE DATA LOADERS

train_dataloader = dataloaders.get(
Expand All @@ -235,16 +250,6 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup
dataloader_params=cfg.dataset_params.val_dataloader_params,
)

# BUILD NETWORK
model = models.get(
model_name=cfg.architecture,
num_classes=cfg.arch_params.num_classes,
arch_params=cfg.arch_params,
strict_load=cfg.checkpoint_params.strict_load,
pretrained_weights=cfg.checkpoint_params.pretrained_weights,
checkpoint_path=cfg.checkpoint_params.checkpoint_path,
load_backbone=cfg.checkpoint_params.load_backbone,
)
recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
# TRAIN
res = trainer.train(
Expand All @@ -257,6 +262,14 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> Tuple[nn.Module, Tup

return model, res

@classmethod
def _trigger_cfg_modifying_callbacks(cls, cfg):
pre_launch_cbs = get_param(cfg, "pre_launch_callbacks_list", list())
pre_launch_cbs = ListFactory(PreLaunchCallbacksFactory()).get(pre_launch_cbs)
for plcb in pre_launch_cbs:
cfg = plcb(cfg)
return cfg

@classmethod
def resume_experiment(cls, experiment_name: str, ckpt_root_dir: str = None) -> Tuple[nn.Module, Tuple]:
"""
Expand Down Expand Up @@ -1318,7 +1331,7 @@ def forward(self, inputs, targets):
finally:
if device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL:
# CLEAN UP THE MULTI-GPU PROCESS GROUP WHEN DONE
if torch.distributed.is_initialized():
if torch.distributed.is_initialized() and self.training_params.kill_ddp_pgroup_on_end:
torch.distributed.destroy_process_group()

# PHASE.TRAIN_END
Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_recipe_test_suite_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import unittest

from tests.recipe_training_tests.automatic_batch_selection_single_gpu_test import TestAutoBatchSelectionSingleGPU
from tests.recipe_training_tests.shortened_recipes_accuracy_test import ShortenedRecipesAccuracyTests


Expand All @@ -17,6 +18,7 @@ def _add_modules_to_unit_tests_suite(self):
:return:
"""
self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(ShortenedRecipesAccuracyTests))
self.recipe_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestAutoBatchSelectionSingleGPU))


if __name__ == "__main__":
Expand Down
Loading