Skip to content

Commit

Permalink
Formalize variable names (#509)
Browse files Browse the repository at this point in the history
Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
  • Loading branch information
ChenDRAG and Trinkle23897 authored Jan 29, 2022
1 parent bc53ead commit c25926d
Show file tree
Hide file tree
Showing 42 changed files with 607 additions and 581 deletions.
10 changes: 5 additions & 5 deletions docs/tutorials/cheatsheet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ But the state stored in the buffer may be a shallow-copy. To make sure each of y

def reset():
return copy.deepcopy(self.graph)
def step(a):
def step(action):
...
return copy.deepcopy(self.graph), reward, done, {}

Expand Down Expand Up @@ -391,13 +391,13 @@ In addition, legal actions in multi-agent RL often vary with timestep (just like
The above description gives rise to the following formulation of multi-agent RL:
::

action = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(action)
act = policy(state, agent_id, mask)
(next_state, next_agent_id, next_mask), reward = env.step(act)

By constructing a new state ``state_ = (state, agent_id, mask)``, essentially we can return to the typical formulation of RL:
::

action = policy(state_)
next_state_, reward = env.step(action)
act = policy(state_)
next_state_, reward = env.step(act)

Following this idea, we write a tiny example of playing `Tic Tac Toe <https://en.wikipedia.org/wiki/Tic-tac-toe>`_ against a random player by using a Q-learning algorithm. The tutorial is at :doc:`/tutorials/tictactoe`.
20 changes: 10 additions & 10 deletions docs/tutorials/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ where :math:`\gamma` is the discount factor, :math:`\gamma \in [0, 1]`. Here is
::

# pseudocode, cannot work
s = env.reset()
obs = env.reset()
buffer = Buffer(size=10000)
agent = DQN()
for i in range(int(1e6)):
a = agent.compute_action(s)
s_, r, d, _ = env.step(a)
buffer.store(s, a, s_, r, d)
s = s_
act = agent.compute_action(obs)
obs_next, rew, done, _ = env.step(act)
buffer.store(obs, act, obs_next, rew, done)
obs = obs_next
if i % 1000 == 0:
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64)
# compute 2-step returns. How?
Expand Down Expand Up @@ -390,14 +390,14 @@ We give a high-level explanation through the pseudocode used in section :ref:`pr
::

# pseudocode, cannot work # methods in tianshou
s = env.reset()
obs = env.reset()
buffer = Buffer(size=10000) # buffer = tianshou.data.ReplayBuffer(size=10000)
agent = DQN() # policy.__init__(...)
for i in range(int(1e6)): # done in trainer
a = agent.compute_action(s) # act = policy(batch, ...).act
s_, r, d, _ = env.step(a) # collector.collect(...)
buffer.store(s, a, s_, r, d) # collector.collect(...)
s = s_ # collector.collect(...)
act = agent.compute_action(obs) # act = policy(batch, ...).act
obs_next, rew, done, _ = env.step(act) # collector.collect(...)
buffer.store(obs, act, obs_next, rew, done) # collector.collect(...)
obs = obs_next # collector.collect(...)
if i % 1000 == 0: # done in trainer
# the following is done in policy.update(batch_size, buffer)
b_s, b_a, b_s_, b_r, b_d = buffer.get(size=64) # batch, indices = buffer.sample(batch_size)
Expand Down
38 changes: 19 additions & 19 deletions examples/atari/atari_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ def __init__(

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Q(x, \*)."""
x = torch.as_tensor(x, device=self.device, dtype=torch.float32)
return self.net(x), state
r"""Mapping: s -> Q(s, \*)."""
obs = torch.as_tensor(obs, device=self.device, dtype=torch.float32)
return self.net(obs), state


class C51(DQN):
Expand All @@ -73,15 +73,15 @@ def __init__(

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.num_atoms).softmax(dim=-1)
x = x.view(-1, self.action_num, self.num_atoms)
return x, state
obs, state = super().forward(obs)
obs = obs.view(-1, self.num_atoms).softmax(dim=-1)
obs = obs.view(-1, self.action_num, self.num_atoms)
return obs, state


class Rainbow(DQN):
Expand Down Expand Up @@ -127,22 +127,22 @@ def linear(x, y):

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
q = self.Q(x)
obs, state = super().forward(obs)
q = self.Q(obs)
q = q.view(-1, self.action_num, self.num_atoms)
if self._is_dueling:
v = self.V(x)
v = self.V(obs)
v = v.view(-1, 1, self.num_atoms)
logits = q - q.mean(dim=1, keepdim=True) + v
else:
logits = q
y = logits.softmax(dim=2)
return y, state
probs = logits.softmax(dim=2)
return probs, state


class QRDQN(DQN):
Expand All @@ -168,11 +168,11 @@ def __init__(

def forward(
self,
x: Union[np.ndarray, torch.Tensor],
obs: Union[np.ndarray, torch.Tensor],
state: Optional[Any] = None,
info: Dict[str, Any] = {},
) -> Tuple[torch.Tensor, Any]:
r"""Mapping: x -> Z(x, \*)."""
x, state = super().forward(x)
x = x.view(-1, self.action_num, self.num_quantiles)
return x, state
obs, state = super().forward(obs)
obs = obs.view(-1, self.action_num, self.num_quantiles)
return obs, state
8 changes: 4 additions & 4 deletions examples/box2d/bipedal_hardcore_sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,16 @@ def __init__(self, env, action_repeat=3, reward_scale=5, rm_done=True):
self.rm_done = rm_done

def step(self, action):
r = 0.0
rew_sum = 0.0
for _ in range(self.action_repeat):
obs, reward, done, info = self.env.step(action)
obs, rew, done, info = self.env.step(action)
# remove done reward penalty
if not done or not self.rm_done:
r = r + reward
rew_sum = rew_sum + rew
if done:
break
# scale reward
return obs, self.reward_scale * r, done, info
return obs, self.reward_scale * rew_sum, done, info


def test_sac_bipedal(args=get_args()):
Expand Down
6 changes: 3 additions & 3 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,10 @@ def reset(self, state=0):

def _get_reward(self):
"""Generate a non-scalar reward if ma_rew is True."""
x = int(self.done)
end_flag = int(self.done)
if self.ma_rew > 0:
return [x] * self.ma_rew
return x
return [end_flag] * self.ma_rew
return end_flag

def _get_state(self):
"""Generate state(observation) of MyTestEnv"""
Expand Down
14 changes: 8 additions & 6 deletions test/base/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ def test_replaybuffer(size=10, bufsize=20):
assert str(buf) == buf.__class__.__name__ + '()'
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(act)
buf.add(
Batch(obs=obs, act=[a], rew=rew, done=done, obs_next=obs_next, info=info)
Batch(
obs=obs, act=[act], rew=rew, done=done, obs_next=obs_next, info=info
)
)
obs = obs_next
assert len(buf) == min(bufsize, i + 1)
Expand Down Expand Up @@ -220,11 +222,11 @@ def test_priortized_replaybuffer(size=32, bufsize=15):
buf2 = PrioritizedVectorReplayBuffer(bufsize, buffer_num=3, alpha=0.5, beta=0.5)
obs = env.reset()
action_list = [1] * 5 + [0] * 10 + [1] * 10
for i, a in enumerate(action_list):
obs_next, rew, done, info = env.step(a)
for i, act in enumerate(action_list):
obs_next, rew, done, info = env.step(act)
batch = Batch(
obs=obs,
act=a,
act=act,
rew=rew,
done=done,
obs_next=obs_next,
Expand Down
20 changes: 10 additions & 10 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,20 @@ def test_collector_with_ma():
policy = MyPolicy()
c0 = Collector(policy, env, ReplayBuffer(size=100), Logger.single_preprocess_fn)
# n_step=3 will collect a full episode
r = c0.collect(n_step=3)['rews']
assert len(r) == 0
r = c0.collect(n_episode=2)['rews']
assert r.shape == (2, 4) and np.all(r == 1)
rew = c0.collect(n_step=3)['rews']
assert len(rew) == 0
rew = c0.collect(n_episode=2)['rews']
assert rew.shape == (2, 4) and np.all(rew == 1)
env_fns = [lambda x=i: MyTestEnv(size=x, sleep=0, ma_rew=4) for i in [2, 3, 4, 5]]
envs = DummyVectorEnv(env_fns)
c1 = Collector(
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4),
Logger.single_preprocess_fn
)
r = c1.collect(n_step=12)['rews']
assert r.shape == (2, 4) and np.all(r == 1), r
r = c1.collect(n_episode=8)['rews']
assert r.shape == (8, 4) and np.all(r == 1)
rew = c1.collect(n_step=12)['rews']
assert rew.shape == (2, 4) and np.all(rew == 1), rew
rew = c1.collect(n_episode=8)['rews']
assert rew.shape == (8, 4) and np.all(rew == 1)
batch, _ = c1.buffer.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
Expand Down Expand Up @@ -446,8 +446,8 @@ def test_collector_with_ma():
policy, envs, VectorReplayBuffer(total_size=100, buffer_num=4, stack_num=4),
Logger.single_preprocess_fn
)
r = c2.collect(n_episode=10)['rews']
assert r.shape == (10, 4) and np.all(r == 1)
rew = c2.collect(n_episode=10)['rews']
assert rew.shape == (10, 4) and np.all(rew == 1)
batch, _ = c2.buffer.sample(10)


Expand Down
10 changes: 5 additions & 5 deletions test/base/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,22 @@ def test_async_env(size=10000, num=8, sleep=0.1):
# should be smaller
action_list = [1] * num + [0] * (num * 2) + [1] * (num * 4)
current_idx_start = 0
action = action_list[:num]
act = action_list[:num]
env_ids = list(range(num))
o = []
spent_time = time.time()
while current_idx_start < len(action_list):
A, B, C, D = v.step(action=action, id=env_ids)
A, B, C, D = v.step(action=act, id=env_ids)
b = Batch({'obs': A, 'rew': B, 'done': C, 'info': D})
env_ids = b.info.env_id
o.append(b)
current_idx_start += len(action)
current_idx_start += len(act)
# len of action may be smaller than len(A) in the end
action = action_list[current_idx_start:current_idx_start + len(A)]
act = action_list[current_idx_start:current_idx_start + len(A)]
# truncate env_ids with the first terms
# typically len(env_ids) == len(A) == len(action), except for the
# last batch when actions are not enough
env_ids = env_ids[:len(action)]
env_ids = env_ids[:len(act)]
spent_time = time.time() - spent_time
Batch.cat(o)
v.close()
Expand Down
8 changes: 4 additions & 4 deletions test/base/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
returns = np.zeros_like(indices, dtype=float)
buf_len = len(buffer)
for i in range(len(indices)):
flag, r = False, 0.
flag, rew = False, 0.
real_step_n = nstep
for n in range(nstep):
idx = (indices[i] + n) % buf_len
r += buffer.rew[idx] * gamma**n
rew += buffer.rew[idx] * gamma**n
if buffer.done[idx]:
if not (
hasattr(buffer, 'info') and buffer.info['TimeLimit.truncated'][idx]
Expand All @@ -156,8 +156,8 @@ def compute_nstep_return_base(nstep, gamma, buffer, indices):
break
if not flag:
idx = (indices[i] + real_step_n - 1) % buf_len
r += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = r
rew += to_numpy(target_q_fn(buffer, idx)) * gamma**real_step_n
returns[i] = rew
return returns


Expand Down
6 changes: 3 additions & 3 deletions test/multiagent/Gomoku.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def env_func():
return TicTacToeEnv(args.board_size, args.win_size)

test_envs = DummyVectorEnv([env_func for _ in range(args.test_num)])
for r in range(args.self_play_round):
for round in range(args.self_play_round):
rews = []
agent_learn.set_eps(0.0)
# compute the reward over previous learner
Expand All @@ -66,12 +66,12 @@ def env_func():
# previous learner can only be used for forward
agent.forward = opponent.forward
args.model_save_path = os.path.join(
args.logdir, 'Gomoku', 'dqn', f'policy_round_{r}_epoch_{epoch}.pth'
args.logdir, 'Gomoku', 'dqn', f'policy_round_{round}_epoch_{epoch}.pth'
)
result, agent_learn = train_agent(
args, agent_learn=agent_learn, agent_opponent=agent, optim=optim
)
print(f'round_{r}_epoch_{epoch}')
print(f'round_{round}_epoch_{epoch}')
pprint.pprint(result)
learnt_agent = deepcopy(agent_learn)
learnt_agent.set_eps(0.0)
Expand Down
Loading

0 comments on commit c25926d

Please sign in to comment.