Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to enable/disable act ckpt and seq parallelism in GPT #6327

Merged
merged 14 commits into from
Apr 13, 2023
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,24 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)

# Save activations checkpointing and sequence parallelism parameters to be able to restore them later.
self.original_checkpointing_granularity = self.model.language_model.encoder.activations_checkpoint_granularity
self.original_checkpointing_num_layers = self.model.language_model.encoder.activations_checkpoint_num_layers
self.original_checkpointing_method = self.model.language_model.encoder.activations_checkpoint_method
self.original_activations_checkpoint_layers_per_pipeline = (
self.model.language_model.encoder.activations_checkpoint_layers_per_pipeline
)
self.original_sequence_parallel = self.model.language_model.encoder.sequence_parallel
ericharper marked this conversation as resolved.
Show resolved Hide resolved

@property
def model(self):
if isinstance(self._model, list):
return [model.module if isinstance(model, Float16Module) else model for model in self._model]
elif isinstance(self._model, Float16Module):
return self._model.module
else:
return self._model

def set_inference_config(self, inference_config):
self._inference_config = inference_config

Expand Down Expand Up @@ -1071,3 +1089,45 @@ def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unus
# Reset the optimizer update skipped to `None` - this is to prevent scheduler no-ops during
# accumulated gradient updates.
grad_scaler.optimizer_update_skipped = None

def _reset_activation_checkpointing_args(self):
# Reset config values. Needed for calling generate.
self.cfg.activations_checkpoint_granularity = None
self.cfg.activations_checkpoint_method = None
self.cfg.activations_checkpoint_num_layers = None
self.cfg.activations_checkpoint_layers_per_pipeline = None

# Reset model parameters.
self.model.language_model.encoder.activations_checkpoint_granularity = None
self.model.language_model.encoder.activations_checkpoint_method = None
self.model.language_model.encoder.activations_checkpoint_num_layers = None
self.model.language_model.encoder.activations_checkpoint_layers_per_pipeline = None

def _restore_activation_checkpointing_args(self):
# Restore config values.
self.cfg.activations_checkpoint_granularity = self.original_checkpointing_granularity
self.cfg.activations_checkpoint_method = self.original_checkpointing_method
self.cfg.activations_checkpoint_num_layers = self.original_checkpointing_num_layers
self.cfg.activations_checkpoint_layers_per_pipeline = self.original_activations_checkpoint_layers_per_pipeline

# Restore model parameters.
self.model.language_model.encoder.activations_checkpoint_granularity = self.original_checkpointing_granularity
self.model.language_model.encoder.activations_checkpoint_method = self.original_checkpointing_method
self.model.language_model.encoder.activations_checkpoint_num_layers = self.original_checkpointing_num_layers
self.model.language_model.encoder.activations_checkpoint_layers_per_pipeline = (
self.original_activations_checkpoint_layers_per_pipeline
)

def _reset_sequence_parallelism_args(self):
# Reset config values. Needed for calling generate.
self.cfg.sequence_parallel = None

# Reset model parameters.
self.model.language_model.encoder.sequence_parallel = None

def _restore_sequence_parallelism_args(self):
# Restore config values.
self.cfg.sequence_parallel = self.original_sequence_parallel

# Restore model parameters.
self.model.language_model.encoder.sequence_parallel = self.original_sequence_parallel