Skip to content

Commit

Permalink
fix serialization of RunningMoments on multiple GPUs (#1892)
Browse files Browse the repository at this point in the history
Co-authored-by: Clara Luise Pohland <clara-luise.pohland@telekom.de>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
  • Loading branch information
3 people authored Aug 4, 2024
1 parent ac7c8b1 commit 2004d62
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,9 @@ def load_from_json(cls, accelerator: Accelerator, json_path: str):


@torch.no_grad()
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
def get_global_statistics(
accelerator, xs: torch.Tensor, mask=None, device="cpu"
) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Computes element-wise mean and variance of the tensor across processes. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
Expand All @@ -626,7 +628,7 @@ def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu"
sum_var = accelerator.reduce(sum_var)
global_var = sum_var / count

return global_mean.to(device), global_var.to(device), count.to(device)
return global_mean.to(device), global_var.to(device), count.item()


def compute_accuracy(eval_pred) -> Dict[str, float]:
Expand Down

0 comments on commit 2004d62

Please sign in to comment.