diff --git a/training/benchmarks/bert/pytorch/dataloaders/dataset.py b/training/benchmarks/bert/pytorch/dataloaders/dataset.py index f142700e8..2115de493 100644 --- a/training/benchmarks/bert/pytorch/dataloaders/dataset.py +++ b/training/benchmarks/bert/pytorch/dataloaders/dataset.py @@ -1,14 +1,17 @@ -import random - import h5py import numpy as np import os +import sys import torch import torch.distributed as dist import torch.nn.functional as F from torch.utils.data import Dataset +CURR_PATH = os.path.abspath(os.path.dirname(__file__)) +sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../../"))) +from driver import dist_pytorch + class PretrainingDataset(Dataset): @@ -60,10 +63,14 @@ def exchange_padding_fast(device, max_batch_size, input_ids, segment_ids, input_mask = F.pad(input_mask, (0, 0, 0, pad_size)) masked_lm_labels = F.pad(masked_lm_labels, (0, 0, 0, pad_size)) next_sentence_labels = F.pad(next_sentence_labels, (0, pad_size)) - ngpus = torch.distributed.get_world_size() + + ngpus = 1 + igpu = 0 + if dist_pytorch.is_dist_avail_and_initialized(): + ngpus = dist_pytorch.get_world_size() + igpu = dist_pytorch.get_rank() nseqs = input_mask.shape[0] ntokensperseq = input_mask.shape[1] - igpu = torch.distributed.get_rank() flattened_length_seq = nseqs * ntokensperseq flattened_length_nsp = nseqs @@ -141,8 +148,11 @@ def decode_packet(flat_packet): dtype=torch.float16) tensors_ = list(torch.split(tensors_, 1)) # Address valueError: ProcessGroupGloo::allgather: invalid tensor size at index 0 (expected (2049), got (1, 2049)) - torch.distributed.all_gather(tensors_, - tensors.view(torch.float16).unsqueeze(0)) + + if dist_pytorch.is_dist_avail_and_initialized(): + dist.all_gather(tensors_, tensors.view(torch.float16).unsqueeze(0)) + else: + tensors_ = tuple(tensors.view(torch.float16).unsqueeze(0)) tensors_ = torch.stack(tensors_).view(torch.int16).long() input_ids_, segment_ids_, input_mask_, masked_lm_labels_, next_sentence_labels_ = decode_packet( diff --git a/training/benchmarks/bert/pytorch/run_pretraining.py b/training/benchmarks/bert/pytorch/run_pretraining.py index ad346b36f..708c4eaac 100644 --- a/training/benchmarks/bert/pytorch/run_pretraining.py +++ b/training/benchmarks/bert/pytorch/run_pretraining.py @@ -42,6 +42,7 @@ def main(): bert_driver = Driver(config, config.mutable_params) bert_driver.setup_config(argparse.ArgumentParser("Bert")) bert_driver.setup_modules(driver, globals(), locals()) + config.distributed = dist_pytorch.get_world_size() > 1 logger = bert_driver.logger