Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix wandb logger on resume #766

Merged
merged 19 commits into from
Apr 18, 2023
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
67c1e72
fix
Louis-Dupont Mar 8, 2023
49e4ba2
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Mar 8, 2023
5dfc06f
deepcopy
Louis-Dupont Mar 8, 2023
05d11ae
fix according to comments
Louis-Dupont Mar 12, 2023
b175cb5
make private
Louis-Dupont Mar 12, 2023
592b3ea
remove deepcopy
Louis-Dupont Mar 12, 2023
e285580
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Mar 12, 2023
8d39743
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Mar 18, 2023
a4fbbae
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Mar 20, 2023
805b19b
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Mar 20, 2023
20b4fd6
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 16, 2023
092e500
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 16, 2023
c82775c
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 17, 2023
405ecbb
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
BloodAxe Apr 17, 2023
c4ac2ee
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
shaydeci Apr 17, 2023
5ed3909
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
shaydeci Apr 18, 2023
3b5bacb
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 18, 2023
ee3ff33
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 18, 2023
7197da3
Merge branch 'master' into feature/SG-670-fix_wandb_logger_bug_on_resume
Louis-Dupont Apr 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def __init__(self, experiment_name: str, device: str = None, multi_gpu: Union[Mu
self.max_train_batches = None
self.max_valid_batches = None

self._epoch_start_logging_values = {}

@property
def device(self) -> str:
return device_config.device
Expand Down Expand Up @@ -443,9 +445,8 @@ def _train_epoch(self, epoch: int, silent_mode: bool = False) -> tuple:
context.update_context(preds=outputs, loss_log_items=loss_log_items)
self.phase_callback_handler.on_train_batch_loss_end(context)

# LOG LR THAT WILL BE USED IN CURRENT EPOCH AND AFTER FIRST WARMUP/LR_SCHEDULER UPDATE BEFORE WEIGHT UPDATE
if not self.ddp_silent_mode and batch_idx == 0:
self._write_lrs(epoch)
self._epoch_start_logging_values = self._get_epoch_start_logging_values()

self._backward_step(loss, epoch, batch_idx, context)

Expand Down Expand Up @@ -1294,7 +1295,14 @@ def forward(self, inputs, targets):

if not self.ddp_silent_mode:
# SAVING AND LOGGING OCCURS ONLY IN THE MAIN PROCESS (IN CASES THERE ARE SEVERAL PROCESSES - DDP)
self._write_to_disk_operations(train_metrics_tuple, validation_results_tuple, inf_time, epoch, context)
self._write_to_disk_operations(
train_metrics=train_metrics_tuple,
validation_results=validation_results_tuple,
lr_dict=self._epoch_start_logging_values,
inf_time=inf_time,
epoch=epoch,
context=context,
)
self.sg_logger.upload()

# Evaluating the average model and removing snapshot averaging file if training is completed
Expand Down Expand Up @@ -1649,24 +1657,27 @@ def _get_hyper_param_config(self):
}
return hyper_param_config

def _write_to_disk_operations(self, train_metrics: tuple, validation_results: tuple, inf_time: float, epoch: int, context: PhaseContext):
def _write_to_disk_operations(self, train_metrics: tuple, validation_results: tuple, lr_dict: dict, inf_time: float, epoch: int, context: PhaseContext):
"""Run the various logging operations, e.g.: log file, Tensorboard, save checkpoint etc."""
# STORE VALUES IN A TENSORBOARD FILE
train_results = list(train_metrics) + list(validation_results) + [inf_time]
all_titles = self.results_titles + ["Inference Time"]

result_dict = {all_titles[i]: train_results[i] for i in range(len(train_results))}
self.sg_logger.add_scalars(tag_scalar_dict=result_dict, global_step=epoch)
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)

# SAVE THE CHECKPOINT
if self.training_params.save_model:
self._save_checkpoint(self.optimizer, epoch + 1, validation_results, context)

def _write_lrs(self, epoch):
def _get_epoch_start_logging_values(self) -> dict:
"""Get all the values that should be logged at the start of each epoch.
This is useful for values like Learning Rate that can change over an epoch."""
lrs = [self.optimizer.param_groups[i]["lr"] for i in range(len(self.optimizer.param_groups))]
lr_titles = ["LR/Param_group_" + str(i) for i in range(len(self.optimizer.param_groups))] if len(self.optimizer.param_groups) > 1 else ["LR"]
lr_dict = {lr_titles[i]: lrs[i] for i in range(len(lrs))}
self.sg_logger.add_scalars(tag_scalar_dict=lr_dict, global_step=epoch)
return lr_dict

def test(
self,
Expand Down