Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor-parallel communication overlap with userbuffer backend #6362

Merged
merged 6 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ model:
fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
reduce_amax: True # Perform reduction to sync amax tensors across GPUs after every iteration
use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.
ub_tp_comm_overlap: False
# Use userbuffer backend to overlap tensor-parallel communications with computes.
# This feature is only available with Transformer Engine and squence parallelism enabled and, currently, supports only GPT models.
ub_tp_comm_overlap_cfg: null
# A yaml file with userbuffer communicator configurations. This file should provide `method`, `dtype`, `num_sm`, `num_splits`,
# `cga_size`, `num_splits`, `set_sm_margin`, and `aggregate` for the communicators to use custom settings.
# If the configuration file is not provided a default setting is used for all communicators.
Comment on lines +171 to +174
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a yaml file? Should it just be

ub_tp_comm_overlap_cfg:
  method: blah
  dtype: blah
  ...

?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two choices were either a config object or config file since there are two many UB related args


data:
# Path to data must be specified by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -228,6 +229,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True):
micro_batch_size=cfg.get('micro_batch_size'),
global_batch_size=cfg.get('global_batch_size'),
use_fp8=cfg.get('fp8', False),
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False),
seed=self.cfg.get('seed', 1234),
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
)
Expand Down Expand Up @@ -512,6 +513,14 @@ def _validate_and_override_config(self):
'Make sure the number of model chunks is the same across all pipeline stages.'
)

if self.cfg.get('ub_tp_comm_overlap', False):
if not self.cfg.get('transformer_engine', False) or not self.cfg.get('sequence_parallel', False):
logging.info(
"Userbuffer tensor-parallel communication overlap is available with both Transformer Engine and sequence-parallelism."
)
Comment on lines +518 to +520
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be .warning and maybe add ".. only available with ... Setting ub_tp_comm_overlap to True"

with open_dict(self.cfg):
self.cfg.ub_tp_comm_overlap = False

def is_data_parallel_rank_zero(self):
if is_global_rank_zero():
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@

try:
import transformer_engine
from transformer_engine.pytorch import module as te_module

HAVE_TE = True

Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self._nsys_profile_end_step *= grad_accum_steps

self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True)
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False)

def set_inference_config(self, inference_config):
self._inference_config = inference_config
Expand Down Expand Up @@ -224,6 +226,7 @@ def model_provider_func(self, pre_process, post_process):
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'),
reduce_amax=self.cfg.get('reduce_amax', True),
use_emha=self.cfg.get('use_emha', False),
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False),
)

return model
Expand Down Expand Up @@ -327,6 +330,31 @@ def training_step(self, dataloader_iter, batch_idx):
The input batch to each micro-batch is fetched using the dataloader function
in the micro-batch fwd function.
"""
# Initialize userbuffer communicators. Initialization is done only once at the
# beginning of the first training step.
if self.initialize_ub:
input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
print("Fail to read ub_tp_comm_overlap config file.")
else:
ub_cfgs = None
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False
Comment on lines +333 to +357
Copy link
Collaborator

@ericharper ericharper Apr 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this go in the .setup method then? (since it is only called once at beginning of training)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we make it a private method and then call the it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To add; I'm not a fan of how we're importing the whole file from TE (which isn't a part of the API)


# we zero grads here because we also call backward in the apex fwd/bwd functions
self._optimizer.zero_grad()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_language_model(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -173,6 +174,7 @@ def get_language_model(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -472,6 +474,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
):
super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -573,6 +576,7 @@ def __init__(
fp8_amax_compute_algo=fp8_amax_compute_algo,
reduce_amax=reduce_amax,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
self._encoder_key = 'encoder'

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=None,
global_batch_size=None,
use_fp8=False,
init_mpi_proc_group=False,
seed=1234,
apex_transformer_log_level=30,
):
Expand All @@ -76,6 +77,7 @@ def initialize_model_parallel_for_nemo(
app_state.pipeline_model_parallel_size = pipeline_model_parallel_size
app_state.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size
app_state.use_fp8 = use_fp8
app_state.init_mpi_proc_group = init_mpi_proc_group
(
app_state.tensor_model_parallel_rank,
app_state.pipeline_model_parallel_rank,
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ def __init__(
layer_type: str = "encoder",
drop_path_rate: float = 0,
use_emha: bool = False,
ub_tp_comm_overlap: bool = False,
autocast_dtype: Any = 16,
zero_centered_gamma: bool = False,
) -> None:
Expand Down Expand Up @@ -811,6 +812,7 @@ def __init__(
set_parallel_mode=tp_size > 1,
fuse_qkv_params=True,
zero_centered_gamma=zero_centered_gamma,
ub_tp_comm_overlap=ub_tp_comm_overlap,
)
# use_emha=use_emha,

Expand Down Expand Up @@ -911,6 +913,7 @@ def __init__(
fp8_amax_compute_algo='most_recent',
reduce_amax=True,
use_emha=False,
ub_tp_comm_overlap=False,
normalize_attention_scores=True,
multi_query_attention=False,
num_moe_experts=1,
Expand Down Expand Up @@ -1050,6 +1053,7 @@ def build_layer(layer_number):
apply_residual_connection_post_layernorm=False,
autocast_dtype=precision,
use_emha=use_emha,
ub_tp_comm_overlap=ub_tp_comm_overlap,
zero_centered_gamma=normalization == 'layernorm1p',
)
else:
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def init_model_parallel(self, global_rank: int, world_size: int) -> None:
pipeline_model_parallel_split_rank_=app_state.pipeline_model_parallel_split_rank,
virtual_pipeline_model_parallel_size_=app_state.virtual_pipeline_model_parallel_size,
use_fp8_=app_state.use_fp8,
init_mpi_proc_group=app_state.init_mpi_proc_group,
)

# assert that fake tp and pp rank match after model parallel init
Expand Down
17 changes: 17 additions & 0 deletions nemo/utils/app_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(self):
self._data_parallel_group = None
self._megatron_checkpoint_version = None
self._use_fp8 = False
self._init_mpi_proc_gruop = False

self._random_seed = None

Expand Down Expand Up @@ -363,6 +364,22 @@ def use_fp8(self, use_fp8):
"""
self._use_fp8 = use_fp8

@property
def init_mpi_proc_group(self):
""" Property sets the initialization of mpi process group.
Returns:
Initialize mpi process group.
"""
return self._init_mpi_proc_group

@init_mpi_proc_group.setter
def init_mpi_proc_group(self, init_mpi_proc_group):
""" Property sets the initialization of mpi process group.
Args:
init_mpi_proc_group: Initialize mpi process group.
"""
self._init_mpi_proc_group = init_mpi_proc_group

@property
def random_seed(self):
""" Property returns the random seed.
Expand Down