Skip to content

Commit

Permalink
Merge pull request #21 from baai-open-internal/fix_multi_gpu_training
Browse files Browse the repository at this point in the history
fix bug multi_gpu_training
  • Loading branch information
Anhforth authored Jul 8, 2022
2 parents efc1310 + 9b81869 commit 7ad38a0
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
2 changes: 1 addition & 1 deletion flagai/model/glm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def forward(self,
else:

loss = F.cross_entropy(
logits_parallel.contiguous().float(), labels.long())
logits_parallel.reshape(-1, logits_parallel.shape[-1]).contiguous().float(), labels.reshape(-1).long())

if self.parallel_output: # Put in different GPUs
return {
Expand Down
22 changes: 13 additions & 9 deletions flagai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,12 +309,17 @@ def get_dataloader(self, dataset, collate_fn, shuffle=False):
shuffle=shuffle)
else:
if self.env_type == 'deepspeed+mpu':
num_replicas = self.world_size // mpu.get_model_parallel_world_size(
)
rank = self.rank // mpu.get_model_parallel_world_size()
# num_replicas = self.world_size // mpu.get_model_parallel_world_size(
# )
# rank = self.rank // mpu.get_model_parallel_world_size()
# rank = mpu.get_model_parallel_rank()
rank = mpu.get_model_parallel_src_rank()
print("*"*80)
print("local rank",self.rank, "model rank", rank)
print("*"*80)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=num_replicas,
# num_replicas=num_replicas,
rank=rank,
shuffle=shuffle)
else:
Expand Down Expand Up @@ -474,13 +479,12 @@ def train(self,
for epoch in range(self.epochs):
# log_dist('working on epoch {} ...'.format(epoch), [0])
# Set the data loader epoch to shuffle the index iterator.
if self.env_type == 'deepspeed+mpu':
if mpu.get_model_parallel_rank() == 0:
train_dataloader.sampler.set_epoch(epoch + self.world_size)
elif self.env_type != 'pytorch':
# if self.env_type == 'deepspeed+mpu':
# if mpu.get_model_parallel_rank() == 0:
# train_dataloader.sampler.set_epoch(epoch + self.world_size)
if self.env_type != 'pytorch':
train_dataloader.sampler.set_epoch(epoch + self.world_size)


# For all the batches in the dataset.
for iteration_, batch in enumerate(train_dataloader):
# Train for one step.
Expand Down

0 comments on commit 7ad38a0

Please sign in to comment.