diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 3eaddf0ad1..12744c6370 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -611,7 +611,9 @@ def load_from_json(cls, accelerator: Accelerator, json_path: str): @torch.no_grad() -def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]: +def get_global_statistics( + accelerator, xs: torch.Tensor, mask=None, device="cpu" +) -> Tuple[torch.Tensor, torch.Tensor, int]: """ Computes element-wise mean and variance of the tensor across processes. Reference: https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75 @@ -626,7 +628,7 @@ def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu" sum_var = accelerator.reduce(sum_var) global_var = sum_var / count - return global_mean.to(device), global_var.to(device), count.to(device) + return global_mean.to(device), global_var.to(device), count.item() def compute_accuracy(eval_pred) -> Dict[str, float]: