Skip to content

Commit

Permalink
Interface to provide custom userbuffer communicator settings by yaml …
Browse files Browse the repository at this point in the history
…file
  • Loading branch information
erhoo82 committed Apr 6, 2023
1 parent c77d97c commit 0ea2930
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
4 changes: 4 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ model:
ub_tp_comm_overlap: False
# Use userbuffer backend to overlap tensor-parallel communications with computes.
# This feature is only available with Transformer Engine and squence parallelism enabled and, currently, supports only GPT models.
ub_tp_comm_overlap_cfg: null
# A yaml file with userbuffer communicator configurations. This file should provide `method`, `dtype`, `num_sm`, `num_splits`,
# `cga_size`, `num_splits`, `set_sm_margin`, and `aggregate` for the communicators to use custom settings.
# If the configuration file is not provided a default setting is used for all communicators.

data:
# Path to data must be specified by the user.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,22 @@ def training_step(self, dataloader_iter, batch_idx):
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
te_module.pre_init_ub(shape=input_shape, is_fp8=self.cfg.get('fp8'))
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfg_file_name is not None:
try:
import yaml
with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
print("Fail to read ub_tp_comm_overlap config file.")
else:
ub_cfgs = None
te_module.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
use_fp8=self.cfg.get('fp8'),
ub_cfgs=ub_cfgs,
)
self.initialize_ub = False

# we zero grads here because we also call backward in the apex fwd/bwd functions
Expand Down

0 comments on commit 0ea2930

Please sign in to comment.