Skip to content

Commit

Permalink
Fix batch size reconf for T5 FT for multi-validation (#6582)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <aklife97@gmail.com>
  • Loading branch information
aklife97 authored May 7, 2023
1 parent 22f1f2f commit 680cdac
Showing 1 changed file with 27 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
try:
from apex.transformer.pipeline_parallel.utils import (
_reconfigure_microbatch_calculator,
get_current_global_batch_size,
get_micro_batch_size,
get_num_microbatches,
)
Expand Down Expand Up @@ -260,16 +261,33 @@ def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_ar
def _reconfigure_and_process_inference_batch(self, batch, ds_config):
global_batch_size_per_gpu = batch['text_enc'].size(0)
# This should happen only on the last batch of the dataset.
if global_batch_size_per_gpu != ds_config.global_batch_size // parallel_state.get_data_parallel_world_size():
if (
global_batch_size_per_gpu
!= get_current_global_batch_size() // parallel_state.get_data_parallel_world_size()
):
# NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches.
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(),
micro_batch_size=global_batch_size_per_gpu,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
if (
global_batch_size_per_gpu
!= ds_config.global_batch_size // parallel_state.get_data_parallel_world_size()
):
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(),
micro_batch_size=global_batch_size_per_gpu,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
# NOTE: need to explicitly handle resetting for multi-validation
else:
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
rampup_batch_size=None,
global_batch_size=ds_config.global_batch_size,
micro_batch_size=ds_config.micro_batch_size,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)

def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
"""
Expand Down

0 comments on commit 680cdac

Please sign in to comment.