-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor-parallel communication overlap with userbuffer backend #6362
Changes from 5 commits
9c10e08
c77d97c
0ea2930
1f1d7e3
a68f120
67de1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -108,6 +108,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): | |
micro_batch_size=cfg.get('micro_batch_size'), | ||
global_batch_size=cfg.get('global_batch_size'), | ||
use_fp8=cfg.get('fp8', False), | ||
init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False), | ||
seed=self.cfg.get('seed', 1234), | ||
apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), | ||
) | ||
|
@@ -512,6 +513,14 @@ def _validate_and_override_config(self): | |
'Make sure the number of model chunks is the same across all pipeline stages.' | ||
) | ||
|
||
if self.cfg.get('ub_tp_comm_overlap', False): | ||
if not self.cfg.get('transformer_engine', False) or not self.cfg.get('sequence_parallel', False): | ||
logging.info( | ||
"Userbuffer tensor-parallel communication overlap is available with both Transformer Engine and sequence-parallelism." | ||
) | ||
Comment on lines
+518
to
+520
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be |
||
with open_dict(self.cfg): | ||
self.cfg.ub_tp_comm_overlap = False | ||
|
||
def is_data_parallel_rank_zero(self): | ||
if is_global_rank_zero(): | ||
return True | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,6 +70,7 @@ | |
|
||
try: | ||
import transformer_engine | ||
from transformer_engine.pytorch import module as te_module | ||
|
||
HAVE_TE = True | ||
|
||
|
@@ -158,6 +159,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): | |
self._nsys_profile_end_step *= grad_accum_steps | ||
|
||
self.get_attention_mask_from_fusion = self.cfg.get('get_attention_mask_from_fusion', True) | ||
self.initialize_ub = self.cfg.get('ub_tp_comm_overlap', False) | ||
|
||
def set_inference_config(self, inference_config): | ||
self._inference_config = inference_config | ||
|
@@ -224,6 +226,7 @@ def model_provider_func(self, pre_process, post_process): | |
fp8_amax_compute_algo=self.cfg.get('fp8_amax_compute_algo', 'most_recent'), | ||
reduce_amax=self.cfg.get('reduce_amax', True), | ||
use_emha=self.cfg.get('use_emha', False), | ||
ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False), | ||
) | ||
|
||
return model | ||
|
@@ -327,6 +330,31 @@ def training_step(self, dataloader_iter, batch_idx): | |
The input batch to each micro-batch is fetched using the dataloader function | ||
in the micro-batch fwd function. | ||
""" | ||
# Initialize userbuffer communicators. Initialization is done only once at the | ||
# beginning of the first training step. | ||
if self.initialize_ub: | ||
input_shape = [ | ||
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'), | ||
self.cfg.get('hidden_size'), | ||
] | ||
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None) | ||
if ub_cfg_file_name is not None: | ||
try: | ||
import yaml | ||
|
||
with open(ub_cfg_file_name, 'r') as ub_cfg_file: | ||
ub_cfgs = yaml.safe_load(ub_cfg_file) | ||
except (ImportError, TypeError): | ||
print("Fail to read ub_tp_comm_overlap config file.") | ||
else: | ||
ub_cfgs = None | ||
te_module.initialize_ub( | ||
shape=input_shape, | ||
tp_size=self.cfg.get('tensor_model_parallel_size'), | ||
use_fp8=self.cfg.get('fp8'), | ||
ub_cfgs=ub_cfgs, | ||
) | ||
self.initialize_ub = False | ||
Comment on lines
+333
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this go in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can we make it a private method and then call the it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To add; I'm not a fan of how we're importing the whole file from TE (which isn't a part of the API) |
||
|
||
# we zero grads here because we also call backward in the apex fwd/bwd functions | ||
self._optimizer.zero_grad() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a yaml file? Should it just be
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two choices were either a config object or config file since there are two many UB related args