Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 757 resume for spots #870

Merged
merged 18 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions documentation/source/Checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,41 @@ For this reason, SG also offers a safer option for resuming interrupted training
Note that resuming training this way requires the interrupted training to be launched with configuration files (i.e., `Trainer.train_from_config`), which outputs the Hydra final config to the `.hydra` directory inside the checkpoints directory.
See usage in our [resume_experiment_example](https://github.com/Deci-AI/super-gradients/blob/master/src/super_gradients/examples/resume_experiment_example/resume_experiment.py).

## Resuming Training from SG Logger's Remote Storage (WandB only)

SG supports saving checkpoints throughout the training process in the remote storage defined by `SG Logger` (more info about this object and it's role during training in SG at [Third-party experiment monitoring](experiment_monitoring.md).)
Suppose we run an experiment with a `WandB` SG logger, then our `training_hyperparams` should hold:
```yaml
sg_logger: wandb_sg_logger, # Weights&Biases Logger, see class super_gradients.common.sg_loggers.wandb_sg_logger.WandBSGLogger for details
sg_logger_params: # Params that will be passes to __init__ of the logger super_gradients.common.sg_loggers.wandb_sg_logger.WandBSGLogger
project_name: project_name, # W&B project name
save_checkpoints_remote: True,
save_tensorboard_remote: True,
save_logs_remote: True,
entity: <YOUR-ENTITY-NAME>, # username or team name where you're sending runs
api_server: <OPTIONAL-WANDB-URL> # Optional: In case your experiment tracking is not hosted at wandb servers
```

The `save_checkpoints_remote` flag is set which will result in saving checkpoints in WandB throughout training.
Now, in case the training was interrupted, we can resume it from the checkpoint located in the WandB run storage by setting 2 training hyperparameters:
1. Set `resume_from_remote_sg_logger`:
```yaml
resume_from_remote_sg_logger: True
```
2. Pass `run_id` through `wandb_id` to `sg_logger_params`:
```yaml
sg_logger: wandb_sg_logger, # Weights&Biases Logger, see class super_gradients.common.sg_loggers.wandb_sg_logger.WandBSGLogger for details
sg_logger_params: # Params that will be passes to __init__ of the logger super_gradients.common.sg_loggers.wandb_sg_logger.WandBSGLogger
wandb_id: <YOUR_RUN_ID>
project_name: project_name, # W&B project name
save_checkpoints_remote: True,
save_tensorboard_remote: True,
save_logs_remote: True,
entity: <YOUR-ENTITY-NAME>, # username or team name where you're sending runs
api_server: <OPTIONAL-WANDB-URL> # Optional: In case your experiment tracking is not hosted at wandb servers
```

And that's it! Once you re-launch your training, `ckpt_latest.pth` (by default) will be downloaded to the checkpoints directory, and the training will resume from it just as if it was locally stored.

## Evaluating Checkpoints

Expand Down
4 changes: 4 additions & 0 deletions src/super_gradients/common/sg_loggers/abstract_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,7 @@ def local_dir(self) -> str:
:return:
"""
raise NotImplementedError

def download_remote_ckpt(self, ckpt_name: str, *args, **kwargs):

raise NotImplementedError
3 changes: 2 additions & 1 deletion src/super_gradients/common/sg_loggers/base_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from super_gradients.common.environment.ddp_utils import multi_process_safe
from super_gradients.common.sg_loggers.abstract_sg_logger import AbstractSGLogger
from super_gradients.training.params import TrainingParams
from super_gradients.training.utils import sg_trainer_utils
from super_gradients.training.utils import sg_trainer_utils, get_param
from super_gradients.common.environment.monitoring import SystemMonitor
from super_gradients.common.auto_logging.auto_logger import AutoLoggerConfig
from super_gradients.common.auto_logging.console_logging import ConsoleSink
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
self._init_system_monitor(monitor_system)

self._save_code()
self._resume_from_remote_sg_logger = get_param(training_params, "resume_from_remote_sg_logger")

@multi_process_safe
def _launch_tensorboard(self, port):
Expand Down
14 changes: 11 additions & 3 deletions src/super_gradients/common/sg_loggers/wandb_sg_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,16 @@ def __init__(

# allow passing an arbitrary pre-defined wandb_id
wandb_id = kwargs.pop("wandb_id", None)

self.resumed = resumed
if self.resumed:
if wandb_id is not None:
logger.warning("Resuming the run with a previous WandB ID instead of the one from logger params")
wandb_id = self._get_wandb_id()
if wandb_id is None:
if self._resume_from_remote_sg_logger:
raise RuntimeError(
"For WandB loggers, when training_params.resume_from_remote_sg_logger=True "
"pass the run id through the wandb_id arg in sg_logger_params"
)
wandb_id = self._get_wandb_id()

run = wandb.init(project=project_name, name=experiment_name, entity=entity, resume=resumed, id=wandb_id, **kwargs)
if save_code:
Expand Down Expand Up @@ -316,3 +321,6 @@ def _search_upwards_for_file(file_name: str):
return None

return None

def download_remote_ckpt(self, *args, **kwargs):
wandb.restore("ckpt_latest.pth", replace=True, root=self.local_dir())
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
resume: False # whether to continue training from ckpt with the same experiment name.
resume_path: # Explicit checkpoint path (.pth file) to use to resume training.

resume_from_remote_sg_logger: False # bool (default=False), When true, ckpt_name (checkpoint filename
# to resume i.e ckpt_latest.pth bydefault) will be downloaded into the experiment checkpoints directory
# prior to loading weights, then training is resumed from that checkpoint. The source is unique to
# every logger, and currently supported for WandB loggers only.
#
# IMPORTANT: Only works for experiments that were ran with sg_logger_params.save_checkpoints_remote=True.
# IMPORTANT: For WandB loggers, one must also pass the run id through the wandb_id arg in sg_logger_params.

ckpt_name: ckpt_latest.pth # The checkpoint (.pth file) filename in CKPT_ROOT_DIR/EXPERIMENT_NAME/ to use when resume=True and resume_path=None
lr_mode: # Learning rate scheduling policy, one of ['step','poly','cosine','function']
lr_schedule_function: # Learning rate scheduling function to be used when `lr_mode` is 'function'.
Expand Down
5 changes: 5 additions & 0 deletions src/super_gradients/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@
# (i.e iterating over train_loader) when reaching this number of batches.
"max_valid_batches": None, # For debug- when not None- will break out of inner valid loop
# (i.e iterating over valid_loader) when reaching this number of batches.
"resume_from_remote_sg_logger": False # When true, ckpt_name (checkpoint filename to resume, ckpt_latest.pth by
# default) will be downloaded into the experiment checkpoints directory prior to loading weights, then resumed
# from that checkpoint. The source is unique to every logger, and currently supported for WandB loggers only.
# Note that for this to work, the experiment must be ran with sg_logger_params.save_checkpoints_remote=True. For
# WandB loggers, one must also pass the run id through the wandb_id arg in sg_logger_params.
}

DEFAULT_OPTIMIZER_PARAMS_SGD = {"weight_decay": 1e-4, "momentum": 0.9}
Expand Down
20 changes: 16 additions & 4 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,9 +641,9 @@ def _prep_net_for_train(self) -> None:
self.load_ema_as_net = False
self.load_checkpoint = core_utils.get_param(self.training_params, "resume", False)
self.external_checkpoint_path = core_utils.get_param(self.training_params, "resume_path")
self.load_checkpoint = self.load_checkpoint or self.external_checkpoint_path is not None
self.ckpt_name = core_utils.get_param(self.training_params, "ckpt_name", "ckpt_latest.pth")
self._load_checkpoint_to_model()
self.resume_from_remote_sg_logger = core_utils.get_param(self.training_params, "resume_from_remote_sg_logger", False)
self.load_checkpoint = self.load_checkpoint or self.external_checkpoint_path is not None or self.resume_from_remote_sg_logger

def _init_arch_params(self) -> None:
default_arch_params = HpmStruct()
Expand Down Expand Up @@ -990,6 +990,14 @@ def forward(self, inputs, targets):
- `max_valid_batches`: int, for debug- when not None- will break out of inner valid loop (i.e iterating over
valid_loader) when reaching this number of batches. Usefull for debugging (default=None).

- `resume_from_remote_sg_logger`: bool (default=False), bool (default=False), When true, ckpt_name (checkpoint filename
to resume i.e ckpt_latest.pth bydefault) will be downloaded into the experiment checkpoints directory
prior to loading weights, then training is resumed from that checkpoint. The source is unique to
every logger, and currently supported for WandB loggers only.

IMPORTANT: Only works for experiments that were ran with sg_logger_params.save_checkpoints_remote=True.
IMPORTANT: For WandB loggers, one must also pass the run id through the wandb_id arg in sg_logger_params.



:return:
Expand Down Expand Up @@ -1034,6 +1042,9 @@ def forward(self, inputs, targets):

self.net = model
self._prep_net_for_train()
if not self.ddp_silent_mode:
self._initialize_sg_logger_objects(additional_configs_to_log)
self._load_checkpoint_to_model()

# SET RANDOM SEED
random_seed(is_ddp=device_config.multi_gpu == MultiGPUMode.DISTRIBUTED_DATA_PARALLEL, device=device_config.device, seed=self.training_params.seed)
Expand Down Expand Up @@ -1131,8 +1142,6 @@ def forward(self, inputs, targets):
self.phase_callback_handler = CallbackHandler(callbacks=self.phase_callbacks)

if not self.ddp_silent_mode:
self._initialize_sg_logger_objects(additional_configs_to_log)

if self.training_params.dataset_statistics:
dataset_statistics_logger = DatasetStatisticsTensorboardLogger(self.sg_logger)
dataset_statistics_logger.analyze(self.train_loader, all_classes=self.classes, title="Train-set", anchors=self.net.module.arch_params.anchors)
Expand Down Expand Up @@ -1491,6 +1500,9 @@ def _load_checkpoint_to_model(self): # noqa: C901 - too complex
NOTE: 'acc', 'epoch', 'optimizer_state_dict' and the logs are NOT loaded if self.zeroize_prev_train_params
is True
"""
with wait_for_the_master(get_local_rank()):
if self.resume_from_remote_sg_logger and not self.ddp_silent_mode:
self.sg_logger.download_remote_ckpt(ckpt_name=self.ckpt_name)

if self.load_checkpoint or self.external_checkpoint_path:
# GET LOCAL PATH TO THE CHECKPOINT FILE FIRST
Expand Down