Skip to content

Commit

Permalink
bert: bugfix for 1x1 training (#160)
Browse files Browse the repository at this point in the history
Co-authored-by: zhouyu <zhouyu@baai.ac.cn>
  • Loading branch information
yuzhou03 and zhouyu authored Jul 21, 2023
1 parent f992673 commit 6be8329
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
22 changes: 16 additions & 6 deletions training/benchmarks/bert/pytorch/dataloaders/dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import random

import h5py
import numpy as np
import os
import sys

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.utils.data import Dataset

CURR_PATH = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.abspath(os.path.join(CURR_PATH, "../../../")))
from driver import dist_pytorch


class PretrainingDataset(Dataset):

Expand Down Expand Up @@ -60,10 +63,14 @@ def exchange_padding_fast(device, max_batch_size, input_ids, segment_ids,
input_mask = F.pad(input_mask, (0, 0, 0, pad_size))
masked_lm_labels = F.pad(masked_lm_labels, (0, 0, 0, pad_size))
next_sentence_labels = F.pad(next_sentence_labels, (0, pad_size))
ngpus = torch.distributed.get_world_size()

ngpus = 1
igpu = 0
if dist_pytorch.is_dist_avail_and_initialized():
ngpus = dist_pytorch.get_world_size()
igpu = dist_pytorch.get_rank()
nseqs = input_mask.shape[0]
ntokensperseq = input_mask.shape[1]
igpu = torch.distributed.get_rank()

flattened_length_seq = nseqs * ntokensperseq
flattened_length_nsp = nseqs
Expand Down Expand Up @@ -141,8 +148,11 @@ def decode_packet(flat_packet):
dtype=torch.float16)
tensors_ = list(torch.split(tensors_, 1))
# Address valueError: ProcessGroupGloo::allgather: invalid tensor size at index 0 (expected (2049), got (1, 2049))
torch.distributed.all_gather(tensors_,
tensors.view(torch.float16).unsqueeze(0))

if dist_pytorch.is_dist_avail_and_initialized():
dist.all_gather(tensors_, tensors.view(torch.float16).unsqueeze(0))
else:
tensors_ = tuple(tensors.view(torch.float16).unsqueeze(0))

tensors_ = torch.stack(tensors_).view(torch.int16).long()
input_ids_, segment_ids_, input_mask_, masked_lm_labels_, next_sentence_labels_ = decode_packet(
Expand Down
1 change: 1 addition & 0 deletions training/benchmarks/bert/pytorch/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def main():
bert_driver = Driver(config, config.mutable_params)
bert_driver.setup_config(argparse.ArgumentParser("Bert"))
bert_driver.setup_modules(driver, globals(), locals())
config.distributed = dist_pytorch.get_world_size() > 1

logger = bert_driver.logger

Expand Down

0 comments on commit 6be8329

Please sign in to comment.