diff --git a/trinity/algorithm/advantage_fn/opmd_advantage.py b/trinity/algorithm/advantage_fn/opmd_advantage.py index 2707e11e0b..82bea6c90d 100644 --- a/trinity/algorithm/advantage_fn/opmd_advantage.py +++ b/trinity/algorithm/advantage_fn/opmd_advantage.py @@ -119,12 +119,13 @@ def calculate_group_advantage( self, group_id: str, exps: List[Experience] ) -> Tuple[List[Experience], Dict]: with torch.no_grad(): + group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) + reward_mean = torch.mean(group_rewards) if len(exps) == 1: group_baseline = torch.tensor(0.0) else: - group_rewards = torch.tensor([exp.reward for exp in exps], dtype=torch.float32) if self.opmd_baseline == "mean": - group_baseline = torch.mean(group_rewards) + group_baseline = reward_mean else: group_baseline = self.tau * ( torch.logsumexp(group_rewards / self.tau, dim=-1) @@ -136,7 +137,7 @@ def calculate_group_advantage( exp.returns = exp.advantages.clone() metrics = { "group_baseline": group_baseline.item(), - "reward_mean": torch.mean(group_rewards).item(), + "reward_mean": reward_mean.item(), } return exps, metrics