Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

demo(nyz): slime volleyball league training #229

Merged
merged 7 commits into from
Mar 19, 2022
Merged
5 changes: 4 additions & 1 deletion ding/entry/application_entry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union, Optional, List, Any, Tuple
import pickle
import numpy as np
import torch
from functools import partial
import os
Expand Down Expand Up @@ -70,7 +71,9 @@ def eval(
evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode)

# Evaluate
_, eval_reward = evaluator.eval()
_, episode_info = evaluator.eval()
reward = [e['final_eval_reward'] for e in episode_info]
eval_reward = np.mean(reward)
print('Eval is over! The performance of your RL policy is {}'.format(eval_reward))
return eval_reward

Expand Down
19 changes: 18 additions & 1 deletion ding/league/base_league.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import Union, Dict
import uuid
import copy
import os
import os.path as osp
from abc import abstractmethod
from easydict import EasyDict
import os.path as osp
from tabulate import tabulate

from ding.league.player import ActivePlayer, HistoricalPlayer, create_player
from ding.league.shared_payoff import create_payoff
Expand Down Expand Up @@ -268,6 +270,21 @@ def save_checkpoint(src_checkpoint, dst_checkpoint) -> None:
checkpoint = read_file(src_checkpoint)
save_file(dst_checkpoint, checkpoint)

def player_rank(self, string: bool = False) -> Union[str, Dict[str, float]]:
rank = {}
for p in self.active_players + self.historical_players:
name = p.player_id
rank[name] = p.rating.exposure
if string:
headers = ["Player ID", "Rank (TrueSkill)"]
data = []
for k, v in rank.items():
data.append([k, "{:.2f}".format(v)])
s = "\n" + tabulate(data, headers=headers, tablefmt='pipe')
return s
else:
return rank


def create_league(cfg: EasyDict, *args) -> BaseLeague:
"""
Expand Down
27 changes: 17 additions & 10 deletions ding/league/shared_payoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __repr__(self) -> str:
naive_win_rate = (v['wins'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8)
data.append([k1[0], k1[1], v['wins'], v['draws'], v['losses'], naive_win_rate])
data = sorted(data, key=lambda x: x[0])
s = tabulate(data, headers=headers, tablefmt='grid')
s = tabulate(data, headers=headers, tablefmt='pipe')
return s

def __getitem__(self, players: tuple) -> np.ndarray:
Expand Down Expand Up @@ -202,15 +202,22 @@ def _win_loss_reverse(result_: str, reverse_: bool) -> str:
except Exception as e:
print("[ERROR] invalid job_info: {}\n\tError reason is: {}".format(job_info, e))
return False
key, reverse = self.get_key(home_id, away_id)
# Update with decay
for j in job_info_result:
for i in j:
# All categories should decay
self._data[key] *= self._decay
self._data[key]['games'] += 1
result = _win_loss_reverse(i, reverse)
self._data[key][result] += 1
if home_id == away_id: # self-play
key, reverse = self.get_key(home_id, away_id)
self._data[key]['draws'] += 1 # self-play defaults to draws
self._data[key]['games'] += 1
else:
key, reverse = self.get_key(home_id, away_id)
# Update with decay
# job_info_result is a two-layer list, including total NxM episodes of M envs,
# the first(outer) layer is episode dimension and the second(inner) layer is env dimension.
for one_episode_result in job_info_result:
for one_episode_result_per_env in one_episode_result:
# All categories should decay
self._data[key] *= self._decay
self._data[key]['games'] += 1
result = _win_loss_reverse(one_episode_result_per_env, reverse)
self._data[key][result] += 1
return True

def get_key(self, home: str, away: str) -> Tuple[str, bool]:
Expand Down
53 changes: 45 additions & 8 deletions ding/league/tests/test_one_vs_one_league.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random

import pytest
import copy
from easydict import EasyDict
import torch

Expand Down Expand Up @@ -52,6 +53,16 @@
one_vs_one_league_default_config = EasyDict(one_vs_one_league_default_config)


def get_random_result():
ran = random.random()
if ran < 1. / 3:
return "wins"
elif ran < 1. / 2:
return "losses"
else:
return "draws"


@pytest.mark.unittest
class TestOneVsOneLeague:

Expand Down Expand Up @@ -109,14 +120,6 @@ def test_naive(self):
assert eval_job_info['eval_opponent'] in league.active_players[0]._eval_opponent_difficulty

# finish_job
def get_random_result():
ran = random.random()
if ran < 1. / 3:
return "wins"
elif ran < 2. / 3:
return "losses"
else:
return "draws"

episode_num = 5
env_num = 8
Expand All @@ -141,6 +144,40 @@ def get_random_result():
os.popen("rm -rf {}".format(path_policy))
print("Finish!")

def test_league_info(self):
cfg = copy.deepcopy(one_vs_one_league_default_config.league)
cfg.path_policy = 'test_league_info'
league = create_league(cfg)
active_player_id = [p.player_id for p in league.active_players][0]
active_player_ckpt = [p.checkpoint_path for p in league.active_players][0]
tmp = torch.tensor([1, 2, 3])
torch.save(tmp, active_player_ckpt)
assert (len(league.active_players) == 1)
assert (len(league.historical_players) == 0)
print('\n')
print(repr(league.payoff))
print(league.player_rank(string=True))
league.judge_snapshot(active_player_id, force=True)
for i in range(10):
job = league.get_job_info(active_player_id, eval_flag=False)
payoff_update_info = {
'launch_player': active_player_id,
'player_id': job['player_id'],
'episode_num': 2,
'env_num': 4,
'result': [[get_random_result() for __ in range(4)] for _ in range(2)]
}
league.finish_job(payoff_update_info)
# if not self-play
if job['player_id'][0] != job['player_id'][1]:
win_loss_result = sum(payoff_update_info['result'], [])
home = league.get_player_by_id(job['player_id'][0])
away = league.get_player_by_id(job['player_id'][1])
home.rating, away.rating = league.metric_env.rate_1vs1(home.rating, away.rating, win_loss_result)
print(repr(league.payoff))
print(league.player_rank(string=True))
os.popen("rm -rf {}".format(cfg.path_policy))


if __name__ == '__main__':
pytest.main(["-sv", os.path.basename(__file__)])
2 changes: 2 additions & 0 deletions ding/league/tests/test_payoff.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def test_update(self, setup_battle_shared_payoff, random_job_result, get_job_res

for home in player_list:
for away in player_list:
if home == away:
continue # ignore self-play case
for i in range(games_per_player):
episode_num = 2
env_num = 4
Expand Down
3 changes: 2 additions & 1 deletion ding/policy/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
value *= self._running_mean_std.std
next_value *= self._running_mean_std.std

compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag'])
traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag)
data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)

unnormalized_returns = value + data['adv']
Expand Down
15 changes: 0 additions & 15 deletions ding/worker/collector/battle_episode_serial_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,6 @@ def collect(self,
for output in policy_output:
actions[env_id].append(output[env_id]['action'])
actions = to_ndarray(actions)
# temporally for viz
probs0 = torch.softmax(torch.stack([o['logit'] for o in policy_output[0].values()], 0), 1).mean(0)
probs1 = torch.softmax(torch.stack([o['logit'] for o in policy_output[1].values()], 0), 1).mean(0)
timesteps = self._env.step(actions)

# TODO(nyz) this duration may be inaccurate in async env
Expand Down Expand Up @@ -281,8 +278,6 @@ def collect(self,
'reward1': timestep.info[1]['final_eval_reward'],
'time': self._env_info[env_id]['time'],
'step': self._env_info[env_id]['step'],
'probs0': probs0,
'probs1': probs1,
}
collected_episode += 1
self._episode_info.append(info)
Expand Down Expand Up @@ -313,8 +308,6 @@ def _output_log(self, train_iter: int) -> None:
duration = sum([d['time'] for d in self._episode_info])
episode_reward0 = [d['reward0'] for d in self._episode_info]
episode_reward1 = [d['reward1'] for d in self._episode_info]
probs0 = [d['probs0'] for d in self._episode_info]
probs1 = [d['probs1'] for d in self._episode_info]
self._total_duration += duration
info = {
'episode_count': episode_count,
Expand All @@ -335,14 +328,6 @@ def _output_log(self, train_iter: int) -> None:
'total_episode_count': self._total_episode_count,
'total_duration': self._total_duration,
}
info.update(
{
'probs0_select_action0': sum([p[0] for p in probs0]) / len(probs0),
'probs0_select_action1': sum([p[1] for p in probs0]) / len(probs0),
'probs1_select_action0': sum([p[0] for p in probs1]) / len(probs1),
'probs1_select_action1': sum([p[1] for p in probs1]) / len(probs1),
}
)
self._episode_info.clear()
self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
for k, v in info.items():
Expand Down
7 changes: 3 additions & 4 deletions ding/worker/collector/battle_interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def eval(
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
) -> Tuple[bool, float, list]:
) -> Tuple[bool, List[dict]]:
'''
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Expand All @@ -185,8 +185,7 @@ def eval(
- n_episode (:obj:`int`): Number of evaluation episodes.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_reward (:obj:`float`): Current eval_reward.
- return_info (:obj:`list`): Environment information of each finished episode
- return_info (:obj:`list`): Environment information of each finished episode.
'''
if n_episode is None:
n_episode = self._default_n_episode
Expand Down Expand Up @@ -273,7 +272,7 @@ def eval(
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
)
return stop_flag, eval_reward, return_info
return stop_flag, return_info


class VectorEvalMonitor(object):
Expand Down
8 changes: 5 additions & 3 deletions ding/worker/collector/interaction_serial_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def eval(
train_iter: int = -1,
envstep: int = -1,
n_episode: Optional[int] = None
) -> Tuple[bool, float]:
) -> Tuple[bool, dict]:
'''
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Expand All @@ -166,13 +166,14 @@ def eval(
- n_episode (:obj:`int`): Number of evaluation episodes.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- eval_reward (:obj:`float`): Current eval_reward.
- return_info (:obj:`dict`): Current evaluation return information.
'''
if n_episode is None:
n_episode = self._default_n_episode
assert n_episode is not None, "please indicate eval n_episode"
envstep_count = 0
info = {}
return_info = []
eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
self._env.reset()
self._policy.reset()
Expand All @@ -198,6 +199,7 @@ def eval(
if 'episode_info' in t.info:
eval_monitor.update_info(env_id, t.info['episode_info'])
eval_monitor.update_reward(env_id, reward)
return_info.append(t.info)
self._logger.info(
"[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
Expand Down Expand Up @@ -245,4 +247,4 @@ def eval(
"Current eval_reward: {} is greater than stop_value: {}".format(eval_reward, self._stop_value) +
", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
)
return stop_flag, eval_reward
return stop_flag, return_info
Loading