Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

low GPU usage while training with PPO #483

Closed
1 of 5 tasks
quark2019 opened this issue Nov 21, 2021 · 9 comments
Closed
1 of 5 tasks

low GPU usage while training with PPO #483

quark2019 opened this issue Nov 21, 2021 · 9 comments
Labels
question Further information is requested

Comments

@quark2019
Copy link

quark2019 commented Nov 21, 2021

  • I have marked all applicable categories:
    • exception-raising bug
    • RL algorithm bug
    • documentation request (i.e. "X is missing from the documentation.")
    • new feature request

Question:
When runing PPO example, whatever the value of args.training_num ( 16, 64, 128, 256, ...) is used,
the GPU memory is only about 2Gb used (most GPU memory is free.),
and GPU usage is only about 1% ~ 2% (shown by nvidia-smi, Volatile GPU-Util = 2%)

How to fix the issue? Thanks!


hardware infos:
GPU: Nvidia A100 (with dirver 470, CUDA 11.4)

software version infos:
tianshou: 0.4.4
torch: 1.9.1+cu11
numpy: 1.20.3
sys: ubuntu20.04


PPO args (most other args value are the same with test_PPO.py in examples):
env.max_step: 50000
buffer_size: 4096 * 16
hidden_size: [128, 128]
step_per_epoch: 50000
batch_size: 2048

@Trinkle23897
Copy link
Collaborator

import argparse
from typing import Any, Dict, Optional, Tuple, Type

import envpool
import numpy as np
import torch
import torch.nn.functional as F
import tqdm
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from gae import compute_gae


class CnnActorCritic(nn.Module):
  def __init__(self, action_size: int):
    super().__init__()
    layers = [
      nn.Conv2d(4, 32, kernel_size=8, stride=4),
      nn.ReLU(inplace=True),
      nn.Conv2d(32, 64, kernel_size=4, stride=2),
      nn.ReLU(inplace=True),
      nn.Conv2d(64, 64, kernel_size=3, stride=1),
      nn.ReLU(inplace=True),
      nn.Flatten(),
      nn.Linear(3136, 512),
      nn.ReLU(inplace=True),
    ]
    self.net = nn.Sequential(*layers)
    self.actor = nn.Linear(512, action_size)
    self.critic = nn.Linear(512, 1)
    # orthogonal initialization
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight)
        nn.init.zeros_(m.bias)

  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    feature = self.net(x / 255.0)
    return F.softmax(self.actor(feature), dim=-1), self.critic(feature)


class MlpActorCritic(nn.Module):
  def __init__(self, state_size: int, action_size: int):
    super().__init__()
    layers = [
      nn.Linear(state_size, 64),
      nn.ReLU(inplace=True),
      nn.Linear(64, 64),
      nn.ReLU(inplace=True),
    ]
    self.net = nn.Sequential(*layers)
    self.actor = nn.Linear(64, action_size)
    self.critic = nn.Linear(64, 1)
    # orthogonal initialization
    for m in self.modules():
      if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight)
        nn.init.zeros_(m.bias)

  def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    feature = self.net(x)
    return F.softmax(self.actor(feature), dim=-1), self.critic(feature)


class DiscretePPO:
  def __init__(
    self,
    actor_critic: nn.Module,
    optim: torch.optim.Optimizer,
    dist_fn: Type[torch.distributions.Distribution],
    lr_scheduler: torch.optim.lr_scheduler.LambdaLR,
    config: argparse.Namespace,
  ):
    self.actor_critic = actor_critic
    self.optim = optim
    self.dist_fn = dist_fn
    self.config = config
    self.training = True
    self.numenv = config.numenv
    self.lr_scheduler = lr_scheduler

  def predictor(self, obs: torch.Tensor) -> torch.Tensor:
    logits, value = self.actor_critic(obs)
    if not self.training:
      action = logits.argmax(-1)
      dist = None
      log_prob = None
    else:
      dist = self.dist_fn(logits)
      action = dist.sample()
      log_prob = dist.log_prob(action)
    return action, log_prob, value, dist

  def learner(
    self,
    obs: torch.Tensor,
    act: torch.Tensor,
    rew: np.ndarray,
    done: np.ndarray,
    env_id: np.ndarray,
    log_prob: torch.Tensor,
    value: torch.Tensor,
  ) -> Dict[str, float]:
    # compute GAE
    T, B = rew.shape
    N = T * B
    returns, advantage, mask = compute_gae(
      self.config.gamma,
      self.config.gae_lambda,
      value.cpu().numpy().reshape(T, B),
      rew,
      done,
      env_id,
      self.numenv,
    )
    index = np.arange(N)[mask.reshape(N) > 0]
    returns = torch.from_numpy(returns.reshape(N)).to(value.device)
    advantage = torch.from_numpy(advantage.reshape(N)).to(value.device)
    losses, clip_losses, vf_losses, ent_losses = [], [], [], []
    # split traj
    for _ in range(self.config.repeat_per_collect):
      np.random.shuffle(index)
      for start_index in range(0, len(index), self.config.batch_size):
        i = index[start_index:start_index + self.config.batch_size]
        b_adv = advantage[i]
        if self.config.norm_adv:
          mean, std = b_adv.mean(), b_adv.std()
          b_adv = (b_adv - mean) / std
        _, b_log_prob, b_value, b_dist = self.predictor(obs[i])
        ratio = (b_dist.log_prob(act[i]) - log_prob[i]).exp().float()
        ratio = ratio.reshape(ratio.shape[0], -1).transpose(0, 1)
        surr1 = ratio * b_adv
        surr2 = ratio.clamp(1.0 - self.config.eps_clip, 1.0 + self.config.eps_clip) * b_adv
        clip_loss = -torch.min(surr1, surr2).mean()
        vf_loss = (returns[i] - b_value.flatten()).pow(2).mean()
        ent_loss = b_dist.entropy().mean()
        loss = clip_loss + self.config.vf_coef * vf_loss - self.config.ent_coef * ent_loss
        # update param
        self.optim.zero_grad()
        loss.backward()
        if self.config.max_grad_norm:
          nn.utils.clip_grad_norm_(
            self.actor_critic.parameters(), max_norm=self.config.max_grad_norm
          )
        self.optim.step()
        clip_losses.append(clip_loss.item())
        vf_losses.append(vf_loss.item())
        ent_losses.append(ent_loss.item())
        losses.append(loss.item())
    self.lr_scheduler.step()
    # return loss
    return {
      "loss": np.mean(losses),
      "loss/clip": np.mean(clip_losses),
      "loss/vf": np.mean(vf_losses),
      "loss/ent": np.mean(ent_losses),
    }


class MovAvg:
  def __init__(self, size: int = 100):
    self.size = size
    self.cache = []

  def add_bulk(self, x: np.ndarray) -> float:
    self.cache += x.tolist()
    if len(self.cache) > self.size:
      self.cache = self.cache[-self.size:]
    return np.mean(self.cache)


class Actor:
  def __init__(
    self,
    policy: DiscretePPO,
    train_envs: Any,
    test_envs: Any,
    writer: SummaryWriter,
    config: argparse.Namespace,
  ):
    self.policy = policy
    self.train_envs = train_envs
    self.test_envs = test_envs
    self.writer = writer
    self.config = config
    self.obs_batch = []
    self.act_batch = []
    self.rew_batch = []
    self.done_batch = []
    self.envid_batch = []
    self.value_batch = []
    self.logprob_batch = []
    self.reward_stat = np.zeros(len(train_envs))
    train_envs.async_reset()
    test_envs.async_reset()

  def run(self) -> None:
    env_step = 0
    stat = MovAvg()
    episodic_reward = 0
    for epoch in range(1, 1 + self.config.epoch):
      with tqdm.trange(self.config.step_per_epoch, desc=f'Epoch #{epoch}') as t:
        while t.n < self.config.step_per_epoch:
          # collect
          for i in range(self.config.step_per_collect // self.config.waitnum):
            obs, rew, done, info = self.train_envs.recv()
            env_id = info["env_id"]
            obs = torch.tensor(obs, device="cuda")
            self.obs_batch.append(obs)
            with torch.no_grad():
              act, log_prob, value, _ = self.policy.predictor(obs)
            self.act_batch.append(act)
            self.logprob_batch.append(log_prob)
            self.value_batch.append(value)
            self.train_envs.send(act.cpu().numpy(), env_id)
            self.rew_batch.append(rew)
            self.done_batch.append(done)
            self.envid_batch.append(env_id)
            t.update(self.config.waitnum)
            env_step += self.config.waitnum
            self.reward_stat[env_id] += rew
            if np.any(done):
              done_id = env_id[done]
              episodic_reward = self.reward_stat[done_id]
              self.reward_stat[done_id] = 0
              self.writer.add_scalar(
                "train/reward",
                stat.add_bulk(episodic_reward),
                global_step=env_step,
              )
          # learn
          result = self.policy.learner(
            torch.cat(self.obs_batch),
            torch.cat(self.act_batch),
            np.stack(self.rew_batch),
            np.stack(self.done_batch),
            np.stack(self.envid_batch),
            torch.cat(self.logprob_batch),
            torch.cat(self.value_batch),
          )
          result["reward"] = np.mean(episodic_reward)
          self.obs_batch = []
          self.act_batch = []
          self.rew_batch = []
          self.done_batch = []
          self.envid_batch = []
          self.value_batch = []
          self.logprob_batch = []
          t.set_postfix(**result)
          for k, v in result.items():
            self.writer.add_scalar(f"train/{k}", v, global_step=env_step)


if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument('--task', type=str, default='Pong-v5')
  parser.add_argument('--seed', type=int, default=0)
  parser.add_argument('--lr', type=float, default=2.5e-4)
  parser.add_argument('--gamma', type=float, default=0.99)
  parser.add_argument('--epoch', type=int, default=100)
  parser.add_argument('--step-per-epoch', type=int, default=100000)
  parser.add_argument('--step-per-collect', type=int, default=1024)
  parser.add_argument('--repeat-per-collect', type=int, default=4)
  parser.add_argument('--batch-size', type=int, default=256)
  parser.add_argument('--numenv', type=int, default=16)
  parser.add_argument('--waitnum', type=int, default=8)
  parser.add_argument('--test-num', type=int, default=10)
  parser.add_argument('--logdir', type=str, default='log')
  parser.add_argument(
    '--watch',
    default=False,
    action='store_true',
    help='watch the play of pre-trained policy only'
  )
  # ppo special
  parser.add_argument('--vf-coef', type=float, default=0.25)
  parser.add_argument('--ent-coef', type=float, default=0.01)
  parser.add_argument('--gae-lambda', type=float, default=0.95)
  parser.add_argument('--eps-clip', type=float, default=0.2)
  parser.add_argument('--max-grad-norm', type=float, default=0.5)
  parser.add_argument('--rew-norm', type=int, default=0)
  parser.add_argument('--norm-adv', type=int, default=1)
  parser.add_argument('--recompute-adv', type=int, default=0)
  parser.add_argument('--dual-clip', type=float, default=None)
  parser.add_argument('--value-clip', type=int, default=0)
  parser.add_argument('--lr-decay', type=int, default=True)
  args = parser.parse_args()

  train_envs = envpool.make(
    args.task,
    env_type="gym",
    num_envs=args.numenv,
    batch_size=args.waitnum,
    episodic_life=True,
    reward_clip=True,
    # thread_affinity=False,
  )
  test_envs = envpool.make(
    args.task,
    env_type="gym",
    num_envs=args.test_num,
    episodic_life=False,
    reward_clip=False,
    # thread_affinity=False,
  )
  state_n = np.prod(train_envs.observation_space.shape)
  action_n = train_envs.action_space.n
  actor_critic = nn.DataParallel(CnnActorCritic(action_n).cuda())
  # actor_critic = nn.DataParallel(MlpActorCritic(state_n, action_n).cuda())
  optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
  # decay learning rate to 0 linearly
  max_update_num = np.ceil(
    args.step_per_epoch / args.step_per_collect
  ) * args.epoch

  lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
    optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
  )
  dist = torch.distributions.Categorical
  policy = DiscretePPO(
    actor_critic=actor_critic,
    optim=optim,
    dist_fn=dist,
    lr_scheduler=lr_scheduler,
    config=args,
  )
  writer = SummaryWriter(args.logdir)
  writer.add_text("args", str(args))
  Actor(policy, train_envs, test_envs, writer, args).run()

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Nov 21, 2021

Above is what I use to achieve 200k FPS in DGX-A100 by setting numenv and waitnum to a large number.

So as for the question, there are many reasons:

  1. if you set numenv to a large number but still using python vector env, the env throughput maybe still the same. Try to replace vector env with envpool, or using async settings (numenv > waitnum in tianshou's vector env, or num_envs > batch_size in envpool);
  2. tianshou's ppo is not the most efficient one due to its API design. When in learn(), it will pass two times of shared CNN instead of one pass to generate both action and value;
  3. data movement is carefully designed in the above script, there's no overhead when organizing trajectories instead of using a common buffer layer;
  4. which environment do you use?

@Trinkle23897 Trinkle23897 added the question Further information is requested label Nov 21, 2021
@Trinkle23897 Trinkle23897 changed the title low GPU usage while training low GPU usage while training with PPO Nov 21, 2021
@quark2019

This comment has been minimized.

@quark2019
Copy link
Author

quark2019 commented Nov 22, 2021

Above is what I use to achieve 200k FPS in DGX-A100 by setting numenv and waitnum to a large number.

So as for the question, there are many reasons:

1. if you set numenv to a large number but still using python vector env, the env throughput maybe still the same. Try to replace vector env with envpool, or using async settings (numenv > waitnum in tianshou's vector env, or num_envs > batch_size in envpool);

2. tianshou's ppo is not the most efficient one due to its API design. When in `learn()`, it will pass two times of shared CNN instead of one pass to generate both action and value;

3. data movement is carefully designed in the above script, there's no overhead when organizing trajectories instead of using a common `buffer` layer;

4. which environment do you use?

I use a self-defined Env (Continuous status/action env) on my side, I'll follow your suggestion and try more.
(It seems that there's not a Continuous-status/action-Env in Envpool's support-env-list?)

When trying you code, 'from gae import compute_gae' will raise an import error, not a module named 'gae', which module to install will fix this error?

Much thanks.

@Trinkle23897
Copy link
Collaborator

It seems that there's not a Continuous-status/action-Env in Envpool's support-env-list?

No, EnvPool supports both discrete and continuous action space, and this is a good point, I'll add it to documentation.

'from gae import compute_gae' will raise an import error, not a module named 'gae', which module to install will fix this error?

from timeit import timeit
from numba import njit
import numpy as np
from typing import Tuple


@njit
def compute_gae(
  gamma: float,
  gae_lambda: float,
  value: np.ndarray,
  reward: np.ndarray,
  done: np.ndarray,
  env_id: np.ndarray,
  numenv: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
  # shape of array: [T, B]
  # return returns, advantange, mask
  T, B = value.shape
  mask = (1.0 - done) * (gamma * gae_lambda)
  index_tp1 = np.zeros(numenv, np.int32) - 1
  value_tp1 = np.zeros(numenv)
  gae_tp1 = np.zeros(numenv)
  delta = reward - value
  adv = np.zeros((T, B))
  for t in range(T - 1, -1, -1):
    eid = env_id[t]
    adv[t] = delta[t] + gamma * value_tp1[eid] * (1 - done[t]) + mask[t] * gae_tp1[eid]
    mask[t] = (done[t] | (index_tp1[eid] != -1))
    gae_tp1[eid] = adv[t]
    value_tp1[eid] = value[t]
    index_tp1[eid] = t
  return adv + value, adv, mask


def test_episodic_returns():
  # basic test for 1d array
  value = np.zeros([8, 1])
  done = np.array([1, 0, 0, 1, 0, 1, 0, 1.]).reshape(8, 1).astype(bool)
  rew = np.array([0, 1, 2, 3, 4, 5, 6, 7.]).reshape(8, 1)
  env_id = np.zeros([8, 1], int)
  returns, adv, mask = compute_gae(
    gamma=0.1, gae_lambda=1, value=value, reward=rew, done=done, env_id=env_id, numenv=1
  )
  ans = np.array([0, 1.23, 2.3, 3, 4.5, 5, 6.7, 7]).reshape([8, 1])
  assert np.allclose(returns, ans) and np.allclose(adv, ans)
  ref_mask = np.array([1, 1, 1, 1, 1, 1, 1, 1]).reshape(8, 1)
  assert np.allclose(ref_mask, mask)

  # same as above, only shuffle index
  env_id = np.array([[1, 2, 0, 1], [3, 3, 1, 2]]).transpose()
  value = np.zeros([4, 2])
  done = np.array([[0, 0, 1, 1], [0, 1, 0, 1]], bool).transpose().astype(bool)
  rew = np.array([[1, 4, 0, 3], [6, 7, 2, 5]]).transpose()
  returns, adv, mask = compute_gae(
    gamma=0.1, gae_lambda=1, value=value, reward=rew, done=done, env_id=env_id, numenv=4
  )
  ans = np.array([[1.23, 4.5, 0, 3], [6.7, 7, 2.3, 5]]).transpose()
  assert np.allclose(returns, ans) and np.allclose(adv, ans), returns
  ref_mask = np.ones([4, 2])
  assert np.allclose(ref_mask, mask)

  # check if mask correct in done=False at the end of trajectory
  env_id = np.zeros([7, 1])
  done = np.array([0, 1, 0, 1, 0, 1, 0]).reshape(7, 1).astype(bool)
  rew = np.array([7, 6, 1, 2, 3, 4, 5.]).reshape(7, 1)
  env_id = np.zeros([7, 1], int)
  value = np.zeros([7, 1])
  returns, adv, mask = compute_gae(
    gamma=0.1, gae_lambda=1, value=value, reward=rew, done=done, env_id=env_id, numenv=1
  )
  ans = np.array([7.6, 6, 1.2, 2, 3.4, 4, 5]).reshape(7, 1)
  assert np.allclose(returns, ans) and np.allclose(adv, ans)
  ref_mask = np.ones([7, 1])
  ref_mask[-1] = 0
  assert np.allclose(ref_mask, mask), mask

  done = np.array([0, 1, 0, 1, 0, 0, 1], bool).reshape(7, 1).astype(bool)
  rew = np.array([7, 6, 1, 2, 3, 4, 5.]).reshape(7, 1)
  returns, adv, mask = compute_gae(
    gamma=0.1, gae_lambda=1, value=value, reward=rew, done=done, env_id=env_id, numenv=1
  )
  ans = np.array([7.6, 6, 1.2, 2, 3.45, 4.5, 5]).reshape(7, 1)
  assert np.allclose(returns, ans) and np.allclose(adv, ans)
  ref_mask = np.ones([7, 1])
  assert np.allclose(ref_mask, mask)

  done = np.array([0, 0, 0, 1., 0, 0, 0, 1, 0, 0, 0, 1]).reshape([12, 1]).astype(bool)
  rew = np.array([101, 102, 103., 200, 104, 105, 106, 201, 107, 108, 109, 202])
  rew = rew.reshape([12, 1])
  value = np.array([1000, 2., 3., 4, -1, 5., 6., 7, -2, 8., 9., 10]).reshape([12, 1])
  env_id = np.zeros([12, 1], int)
  returns, adv, mask = compute_gae(
    gamma=0.99, gae_lambda=0.95, value=value, reward=rew, done=done,
    env_id=env_id, numenv=1,
  )
  ans = np.array([
    454.8344, 376.1143, 291.298, 200.,
    464.5610, 383.1085, 295.387, 201.,
    474.2876, 390.1027, 299.476, 202.,
  ]).reshape([12, 1])
  assert np.allclose(returns, ans), (returns, adv)
  ref_mask = np.ones([12, 1])
  assert np.allclose(ref_mask, mask)

  done = np.zeros([4, 3], bool)
  done[-1] = 1
  env_id = np.array([[0, 1, 2, 1], [1, 0, 1, 2], [2, 2, 0, 0]]).transpose()
  value = np.array([[-1000, 5, 9, 7], [-1000, 2, 6, 10], [-1000, 8., 3, 4]]).transpose()
  rew = np.array([
    [101, 105, 109, 201], [104, 102, 106, 202], [107, 108, 103, 200.]]).transpose()
  returns, adv, mask = compute_gae(
    gamma=0.99, gae_lambda=0.95, value=value, reward=rew, done=done,
    env_id=env_id, numenv=3,
  )
  ans = np.array([
    [454.8344, 383.1085, 299.476, 201.],
    [464.5610, 376.1143, 295.387, 202.],
    [474.2876, 390.1027, 291.298, 200.],
  ]).transpose()
  assert np.allclose(returns, ans), returns
  assert np.allclose(mask, 1)


def test_time():
  T, B, N = 128, 8, 8 * 4
  cnt = 10000
  value = np.random.rand(T, B)
  rew = np.random.rand(T, B)
  done = np.random.randint(2, size=[T, B]).astype(bool)
  env_id = np.random.randint(N, size=[T, B])

  def wrapper():
    return compute_gae(
      gamma=0.99, gae_lambda=0.95, value=value, reward=rew, done=done,
      env_id=env_id, numenv=N,
    )

  wrapper()  # for compile

  print(timeit(wrapper, setup=wrapper, number=cnt) / cnt)


if __name__ == "__main__":
  # tests are from tianshou unit test
  # test_episodic_returns()
  test_time()

@quark2019
Copy link
Author

quark2019 commented Nov 23, 2021

Above is what I use to achieve 200k FPS in DGX-A100 by setting numenv and waitnum to a large number.

So as for the question, there are many reasons:

1. if you set numenv to a large number but still using python vector env, the env throughput maybe still the same. Try to replace vector env with envpool, or using async settings (numenv > waitnum in tianshou's vector env, or num_envs > batch_size in envpool);

2. tianshou's ppo is not the most efficient one due to its API design. When in `learn()`, it will pass two times of shared CNN instead of one pass to generate both action and value;

3. data movement is carefully designed in the above script, there's no overhead when organizing trajectories instead of using a common `buffer` layer;

4. which environment do you use?

Hi Trinkle,
I tried your code above with different tasks and args:
CPU usage groups up with envnum increased from 64 to 256,
GPU usage groups up with envnum increased from 64 to 256, about 30% GPU usage groups up.

There are 2 problems:
1. If envnum > 256, with the increase of envnum, the probability of throwing an exception increases? (Please see below.)
2. Only about 1.8Gb GPU-memory was used, not changed with different envnum. Does it mean only one Torch process running in GPU? Nvidia A100 has big GPU-memory, how to impove GPU performance with more processes run in GPU parallelly?


Another issue about envpool:
It seems that with the increase of envnum, the probability of throwing an exception increases?
Not only one task/env, as I can see, many tasks has the same issue.
Each env re-implemented in envpool, maybe need more guard / boundary condition protection?

--- If numenv >= 512, there will be an error:
$ python ppo.ts.py --task='DoubleDunk-v5' --epoch=7 --numenv=521 --waitnum=128
Epoch #1: 3%| | 3200/100000 [00:16<04:22, 368.61it/s, loss=-.0331, loss/clip=-.00431, loss/ent=2.88, loss/vf=1Invalid Player A Action: Invalid Player A Action: 1852795251
543911541
Invalid Player A Action: 32737
Invalid Player A Action: 49
[1] 39901 segmentation fault (core dumped) python ppo.ts.py --task='DoubleDunk-v5' --epoch=7 --numenv=521 --waitnum=128

--- Some tasks raise the same error ( error is more pro raised if envnum > 256?):
$ python ppo.ts.py --task='Pong-v5' --epoch=10 --numenv=512 --waitnum=128

Epoch #1: 100224it [01:08, 2129.19it/s, loss=0.000819, loss/clip=-.000251, loss/ent=1.58, loss/vf=0.0676, reward
Epoch #1: 100352it [01:08, 2129.19it/s, loss=-.00233, loss/clip=0.000189, loss/ent=1.61, loss/vf=0.0542, reward=Epoch #1: 100352it [01:08, 1466.75it/s, loss=-.00233, loss/clip=0.000189, loss/ent=1.61, loss/vf=0.0542, reward=
0]
Epoch #2: 100224it [00:51, 2133.27it/s, loss=-.00683, loss/clip=3.87e-6, loss/ent=1.79, loss/vf=0.0443, reward=0
Epoch #2: 100352it [00:51, 2133.27it/s, loss=-.00852, loss/clip=-3.17e-5, loss/ent=1.79, loss/vf=0.0377, reward=
Epoch #2: 100352it [00:51, 1934.26it/s, loss=-.00852, loss/clip=-3.17e-5, loss/ent=1.79, loss/vf=0.0377, reward=
0]
Epoch #3: 100096it [00:52, 2149.14it/s, loss=-.00716, loss/clip=-6.12e-6, loss/ent=1.79, loss/vf=0.043, reward=0
Epoch #3: 100352it [00:53, 2149.14it/s, loss=-.00736, loss/clip=-1.86e-5, loss/ent=1.79, loss/vf=0.0423, reward=
Epoch #3: 100352it [00:53, 1886.16it/s, loss=-.00736, loss/clip=-1.86e-5, loss/ent=1.79, loss/vf=0.0423, reward=
0]
Epoch #4: 87%|▊| 86784/100000 [00:44<00:06, 2124.50it/s, loss=-.00603, loss/clip=-6.75e-6, loss/ent=1.79, loss/
ppo.ts.py:144: FutureWarning: Non-finite norm encountered in torch.nn.utils.clip_grad_norm_; continuing anyway.
Note that the default behavior will change in a future release to error out if a non-finite total norm is encoun
tered. At that point, setting error_if_nonfinite=false will be required to retain the old behavior.
nn.utils.clip_grad_norm_(
Epoch #4: 87%|▊| 87040/100000 [00:45<00:06, 1928.39it/s, loss=-.00603, loss/clip=-6.75e-6, loss/ent=1.79, loss/
Traceback (most recent call last):
File "ppo.ts.py", line 331, in
Actor(policy, train_envs, test_envs, writer, args).run()
File "ppo.ts.py", line 234, in run
result = self.policy.learner(
File "ppo.ts.py", line 131, in learner
_, b_log_prob, b_value, b_dist = self.predictor(obs[i])
File "ppo.ts.py", line 91, in predictor
dist = self.dist_fn(logits)
File "/home/test/work/usr/pyenv/versions/3.8.12/lib/python3.8/site-packages/torch/distributions/categorical.py"
, line 64, in init
super(Categorical, self).init(batch_shape, validate_args=validate_args)
File "/home/test/work/usr/pyenv/versions/3.8.12/lib/python3.8/site-packages/torch/distributions/distribution.py
", line 53, in init
raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter probs has invalid values

Thanks!

@Trinkle23897
Copy link
Collaborator

  1. how many CPU cores?
  2. The parameter probs has invalid values means nan occurs in dist_fn
  3. 1.8G is normal because we use a simple CNN network instead of complex model like ResNet or BERT.
  4. Are you directly plug envpool async mode into tianshou? I think sync mode would work well but async mode is a little bit different than previous setting (no explicit reset in envpool), I'm planning a big refactor with tianshou.

@quark2019
Copy link
Author

quark2019 commented Nov 24, 2021

  1. how many CPU cores?

256 cores

  1. The parameter probs has invalid values means nan occurs in dist_fn

Where does 'nan' values come from, has matter with Envs implemented in envpool or not?

  1. Are you directly plug envpool async mode into tianshou? I think sync mode would work well but async mode is a little bit different than previous setting (no explicit reset in envpool),

I didn't adjust anything in tianshou or the code you show me above, I assume that the code from you would work fine ...

I'm planning a big refactor with tianshou.

It is worth looking forward to!

Thanks.

@Trinkle23897
Copy link
Collaborator

Trinkle23897 commented Nov 24, 2021

I use the following config:

  parser = argparse.ArgumentParser()
  parser.add_argument('--task', type=str, default='Pong-v5')
  parser.add_argument('--seed', type=int, default=0)
  parser.add_argument('--lr', type=float, default=2.5e-4)
  parser.add_argument('--gamma', type=float, default=0.99)
  parser.add_argument('--epoch', type=int, default=100)
  parser.add_argument('--step-per-epoch', type=int, default=1000000)
  parser.add_argument('--step-per-collect', type=int, default=30720)
  parser.add_argument('--repeat-per-collect', type=int, default=4)
  parser.add_argument('--batch-size', type=int, default=2560)
  parser.add_argument('--numenv', type=int, default=624)
  parser.add_argument('--waitnum', type=int, default=240)
  parser.add_argument('--test-num', type=int, default=10)
  parser.add_argument('--logdir', type=str, default='log')
  parser.add_argument(
    '--watch',
    default=False,
    action='store_true',
    help='watch the play of pre-trained policy only'
  )
  # ppo special
  parser.add_argument('--vf-coef', type=float, default=0.25)
  parser.add_argument('--ent-coef', type=float, default=0.01)
  parser.add_argument('--gae-lambda', type=float, default=0.95)
  parser.add_argument('--eps-clip', type=float, default=0.2)
  parser.add_argument('--max-grad-norm', type=float, default=0.5)
  parser.add_argument('--rew-norm', type=int, default=0)
  parser.add_argument('--norm-adv', type=int, default=1)
  parser.add_argument('--recompute-adv', type=int, default=0)
  parser.add_argument('--dual-clip', type=float, default=None)
  parser.add_argument('--value-clip', type=int, default=0)
  parser.add_argument('--lr-decay', type=int, default=True)
  args = parser.parse_args()

and you'll see the training FPS * 4 will be >= 100k~200k

Trinkle23897 added a commit to Trinkle23897/envpool that referenced this issue Jan 12, 2022
Trinkle23897 added a commit to sail-sg/envpool that referenced this issue Jan 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants