Skip to content

Commit

Permalink
SFT patch: (1) enable sequence parallelism and (2) enable profile (#7963
Browse files Browse the repository at this point in the history
)

* SFT profile start and end step fix

Signed-off-by: Sangkug Lym <slym@nvidia.com>

* Removed sequence parallelism assertion check

Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

---------

Signed-off-by: Sangkug Lym <slym@nvidia.com>
Signed-off-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Selvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: Cheng-Ping Hsieh <37269846+hsiehjackson@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 18, 2023
1 parent 6b40e62 commit e482965
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
else:
base_module = self.model

# Set the profile start and end steps in the unit of global batach
if hasattr(self, '_nsys_profile_enabled'):
self._nsys_profile_start_step = self.cfg.nsys_profile.get('start_step', 0)
self._nsys_profile_end_step = self.cfg.nsys_profile.get('end_step', 0)

self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
self.virtual_tokens = 0

def setup_metric(self, data_cfg):
Expand Down Expand Up @@ -593,7 +597,6 @@ def inference_epoch_end(self, outputs, mode, data_cfg):
# Merge the functionality of previous on_inference_epoch_end() within inference_epoch_end() func here
app_state = AppState()
self._restore_activation_checkpointing_args()
self._restore_sequence_parallelism_args()
if hasattr(self, "_train_ds"):
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand Down Expand Up @@ -816,7 +819,6 @@ def setup_eval_dataloader(self, datasets, data_cfg):

def on_validation_epoch_start(self):
self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand All @@ -829,7 +831,6 @@ def on_validation_epoch_start(self):

def on_test_epoch_start(self):
self._reset_activation_checkpointing_args()
self._reset_sequence_parallelism_args()
app_state = AppState()
_reconfigure_microbatch_calculator(
rank=app_state.global_rank,
Expand Down
3 changes: 0 additions & 3 deletions nemo/collections/nlp/modules/common/text_generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,6 @@ def sample_sequence_batch(
micro_batch_size=micro_batch_size,
data_parallel_size=1,
)
assert (
model.cfg.get('sequence_parallel', False) == False
), 'sequence_parallel should be False during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint'
assert (
model.cfg.get('activations_checkpoint_granularity', None) is None
), 'activations_checkpoint_granularity should be None during inference. Disable it in the model config if restoring from nemo or in hparams.yaml if restoring from PTL checkpoint'
Expand Down
8 changes: 6 additions & 2 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Setup nsys profiling if it has been enabled in the model config
self._setup_nsys_profiling()

# A flag for the profile generation
self._profile_complete = False

def __init_subclass__(cls) -> None:
cls._save_restore_connector = SaveRestoreConnector()

Expand Down Expand Up @@ -1720,7 +1723,7 @@ def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> O
# nsys profiling
if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled:
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== Start nsys profiling ======")
torch.cuda.cudart().cudaProfilerStart()
Expand Down Expand Up @@ -1757,10 +1760,11 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =

if self.device.type == 'cuda':
if hasattr(self, '_nsys_profile_enabled'):
if self._nsys_profile_enabled:
if self._nsys_profile_enabled and not self._profile_complete:
if batch_idx == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks:
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()
self._profile_complete = True

def _cleanup_on_execution_end(self):
"""
Expand Down

0 comments on commit e482965

Please sign in to comment.