Skip to content

Commit

Permalink
Merge pull request #27 from Chen001117/dev
Browse files Browse the repository at this point in the history
fix "rewards not found" bug
  • Loading branch information
WentseChen authored May 4, 2023
2 parents fa73dfa + 0f612cb commit e64becf
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 21 deletions.
2 changes: 1 addition & 1 deletion openrl/envs/vec_env/vec_info/simple_vec_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, parallel_env_num: int, agent_num: int):

def statistics(self, buffer: Any) -> Dict[str, Any]:
# this function should be called each episode
rewards = buffer.data.rewardsc.copy()
rewards = buffer.data.rewards.copy()
self.total_step += np.prod(rewards.shape[:2])
rewards = rewards.transpose(2, 1, 0, 3)
info_dict = {}
Expand Down
25 changes: 5 additions & 20 deletions openrl/rewards/base_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,14 @@ def __init__(self):
def step_reward(
self, data: Dict[str, Any]
) -> Union[np.ndarray, List[Dict[str, Any]]]:
rewards = data["reward"].copy()
infos = []

for rew_func in self.step_rew_funcs.values():
new_rew, new_info = rew_func(data)
if len(infos) == 0:
infos = new_info
else:
for i in range(len(infos)):
infos[i].update(new_info[i])
rewards += new_rew

rewards = data["rewards"].copy()
infos = [dict() for _ in range(rewards.shape[0])]

return rewards, infos

def batch_rewards(self, buffer: Any) -> Dict[str, Any]:

infos = dict()

for rew_func in self.batch_rew_funcs.values():
new_rew, new_info = rew_func()
if len(infos) == 0:
infos = new_info
else:
infos.update(new_info)
# update rewards, and infos here

return dict()
return infos

0 comments on commit e64becf

Please sign in to comment.