Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] max_batches support to training log and tqdm progress bar. #1554

Merged
merged 3 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions src/super_gradients/training/sg_trainer/sg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,12 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
# SET THE MODEL IN training STATE
self.net.train()

expected_iterations = len(self.train_loader) if self.max_train_batches is None else self.max_train_batches

# THE DISABLE FLAG CONTROLS WHETHER THE PROGRESS BAR IS SILENT OR PRINTS THE LOGS
with tqdm(self.train_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_train_loader:
with tqdm(
self.train_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_train_loader:
progress_bar_train_loader.set_description(f"Train epoch {context.epoch}")

# RESET/INIT THE METRIC LOGGERS
Expand All @@ -471,6 +475,9 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
context.update_context(loss_avg_meter=loss_avg_meter, metrics_compute_fn=self.train_metrics)

for batch_idx, batch_items in enumerate(progress_bar_train_loader):
if self.max_train_batches is not None and self.max_train_batches <= batch_idx:
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved
break

batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)

Expand Down Expand Up @@ -510,9 +517,6 @@ def _train_epoch(self, context: PhaseContext, silent_mode: bool = False) -> tupl
progress_bar_train_loader.set_postfix(**pbar_message_dict)
self.phase_callback_handler.on_train_batch_end(context)

if self.max_train_batches is not None and self.max_train_batches - 1 <= batch_idx:
break

self.train_monitored_values = sg_trainer_utils.update_monitored_values_dict(
monitored_values_dict=self.train_monitored_values, new_values_dict=pbar_message_dict
)
Expand Down Expand Up @@ -1331,21 +1335,23 @@ def forward(self, inputs, targets):

self.ckpt_best_name = self.training_params.ckpt_best_name

self.max_train_batches = self.training_params.max_train_batches
self.max_valid_batches = self.training_params.max_valid_batches

if self.training_params.max_train_batches is not None:
if self.training_params.max_train_batches > len(self.train_loader):
logger.warning("max_train_batches is greater than len(self.train_loader) and will have no effect.")
self.max_train_batches = len(self.train_loader)
elif self.training_params.max_train_batches <= 0:
raise ValueError("max_train_batches must be positive.")

if self.training_params.max_valid_batches is not None:
if self.training_params.max_valid_batches > len(self.valid_loader):
logger.warning("max_valid_batches is greater than len(self.valid_loader) and will have no effect.")
self.max_valid_batches = len(self.valid_loader)
elif self.training_params.max_valid_batches <= 0:
raise ValueError("max_valid_batches must be positive.")

self.max_train_batches = self.training_params.max_train_batches
self.max_valid_batches = self.training_params.max_valid_batches

# STATE ATTRIBUTE SET HERE FOR SUBSEQUENT TRAIN() CALLS
self._first_backward = True

Expand Down Expand Up @@ -1394,6 +1400,7 @@ def forward(self, inputs, targets):
batch_accumulate=self.batch_accumulate,
train_dataset_length=len(self.train_loader.dataset),
train_dataloader_len=len(self.train_loader),
max_train_batches=self.max_train_batches,
)

processing_params = self._get_preprocessing_from_valid_loader()
Expand Down Expand Up @@ -2014,7 +2021,12 @@ def _validate_epoch(self, context: PhaseContext, silent_mode: bool = False) -> D
self._reset_metrics()
self.valid_metrics.to(device_config.device)
return self.evaluate(
data_loader=self.valid_loader, metrics=self.valid_metrics, evaluation_type=EvaluationType.VALIDATION, epoch=context.epoch, silent_mode=silent_mode
data_loader=self.valid_loader,
metrics=self.valid_metrics,
evaluation_type=EvaluationType.VALIDATION,
epoch=context.epoch,
silent_mode=silent_mode,
max_batches=self.max_valid_batches,
)

def _test_epoch(self, data_loader: DataLoader, context: PhaseContext, silent_mode: bool = False, dataset_name: str = "") -> Dict[str, float]:
Expand Down Expand Up @@ -2047,6 +2059,7 @@ def evaluate(
silent_mode: bool = False,
metrics_progress_verbose: bool = False,
dataset_name: str = "",
max_batches: Optional[int] = None,
) -> Dict[str, float]:
"""
Evaluates the model on given dataloader and metrics.
Expand Down Expand Up @@ -2081,7 +2094,11 @@ def evaluate(
loss_logging_items_names=self.loss_logging_items_names,
)

with tqdm(data_loader, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode) as progress_bar_data_loader:
expected_iterations = len(data_loader) if max_batches is None else max_batches

with tqdm(
data_loader, total=expected_iterations, bar_format="{l_bar}{bar:10}{r_bar}", dynamic_ncols=True, disable=silent_mode
) as progress_bar_data_loader:

if not silent_mode:
# PRINT TITLES
Expand All @@ -2091,9 +2108,11 @@ def evaluate(
if epoch:
pbar_start_msg += f" epoch {epoch}"
progress_bar_data_loader.set_description(pbar_start_msg)

with torch.no_grad():
for batch_idx, batch_items in enumerate(progress_bar_data_loader):
if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches <= batch_idx:
hakuryuu96 marked this conversation as resolved.
Show resolved Hide resolved
break

batch_items = core_utils.tensor_container_to_device(batch_items, device_config.device, non_blocking=True)
inputs, targets, additional_batch_items = sg_trainer_utils.unpack_batch_items(batch_items)

Expand Down Expand Up @@ -2128,9 +2147,6 @@ def evaluate(

progress_bar_data_loader.set_postfix(**pbar_message_dict)

if evaluation_type == EvaluationType.VALIDATION and self.max_valid_batches is not None and self.max_valid_batches - 1 <= batch_idx:
break

logging_values = get_logging_values(loss_avg_meter, metrics, self.criterion)
# NEED TO COMPUTE METRICS FOR THE FIRST TIME IF PROGRESS VERBOSITY IS NOT SET
if not metrics_progress_verbose:
Expand Down
23 changes: 19 additions & 4 deletions src/super_gradients/training/utils/sg_trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,19 +447,34 @@ def get_callable_param_names(obj: callable) -> Tuple[str]:


def log_main_training_params(
multi_gpu: MultiGPUMode, num_gpus: int, batch_size: int, batch_accumulate: int, train_dataset_length: int, train_dataloader_len: int
multi_gpu: MultiGPUMode,
num_gpus: int,
batch_size: int,
batch_accumulate: int,
train_dataset_length: int,
train_dataloader_len: int,
max_train_batches: Optional[int] = None,
):
"""Log training parameters"""

iterations_per_epoch = int(train_dataloader_len) if max_train_batches is None else max_train_batches
gradients_updates_per_epoch = int(iterations_per_epoch / batch_accumulate)
what_used_str = "len(train_loader)" if max_train_batches is None else "max_train_batches"

msg = (
"TRAINING PARAMETERS:\n"
f" - Mode: {multi_gpu.name if multi_gpu else 'Single GPU'}\n"
f" - Number of GPUs: {num_gpus if 'cuda' in device_config.device else 0:<10} ({torch.cuda.device_count()} available on the machine)\n"
f" - Dataset size: {train_dataset_length:<10} (len(train_set))\n"
f" - Full dataset size: {train_dataset_length:<10} (len(train_set))\n"
f" - Batch size per GPU: {batch_size:<10} (batch_size)\n"
f" - Batch Accumulate: {batch_accumulate:<10} (batch_accumulate)\n"
f" - Total batch size: {num_gpus * batch_size:<10} (num_gpus * batch_size)\n"
f" - Effective Batch size: {num_gpus * batch_size * batch_accumulate:<10} (num_gpus * batch_size * batch_accumulate)\n"
f" - Iterations per epoch: {int(train_dataloader_len):<10} (len(train_loader))\n"
f" - Gradient updates per epoch: {int(train_dataloader_len / batch_accumulate):<10} (len(train_loader) / batch_accumulate)\n"
f" - Iterations per epoch: {iterations_per_epoch:<10} ({what_used_str})\n"
f" - Gradient updates per epoch: {gradients_updates_per_epoch:<10} ({what_used_str} / batch_accumulate)\n"
)

logger.info(msg)

if max_train_batches:
logger.warning(f"max_train_batch is set to {max_train_batches}. This limits the number of iterations per epoch and gradient updates per epoch.")