Skip to content

Commit

Permalink
Bring back some of local-rank -1
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed May 23, 2023
1 parent d497f8d commit 22b4e24
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2336,7 +2336,7 @@ def _load_rng_state(self, checkpoint):
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
if torch.cuda.is_available():
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
if self.args.local_rank != -1:
torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
else:
try:
Expand Down Expand Up @@ -2931,7 +2931,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):

def store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
if self.args.local_rank != -1:
self.state.total_flos += (
distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
)
Expand Down Expand Up @@ -3347,7 +3347,7 @@ def _nested_gather(self, tensors, name=None):
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return tensors

Expand Down Expand Up @@ -3873,7 +3873,7 @@ def _gather_and_numpify(self, tensors, name):
tensors = nested_xla_mesh_reduce(tensors, name)
elif is_sagemaker_mp_enabled():
tensors = smp_gather(tensors)
elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)

return nested_numpify(tensors)
Expand Down

0 comments on commit 22b4e24

Please sign in to comment.