Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DuelingDQN.ipynb中可能存在的两个BUG~ #140

Open
libermeng opened this issue Aug 3, 2023 · 2 comments
Open

DuelingDQN.ipynb中可能存在的两个BUG~ #140

libermeng opened this issue Aug 3, 2023 · 2 comments

Comments

@libermeng
Copy link

  1. 定义模型部分forward函数中return value + advantage - advantage.mean()可能有误,应该改为return value + advantage - advantage.mean(dim=1, keepdim=True)
    因为按照定义,优势网络输出的值要满足的条件应该是保持在动作维度上的和为0,那么减去的均值应该只是动作维度的均值,而不是总体的均值。
  2. 定义算法部分初始化函数中self.policy_net = model.to(self.device)self.target_net = model.to(self.device)有误,应该改成 self.policy_net = DuelingNet(cfg.n_states, cfg.n_actions, hidden_dim=cfg.hidden_dim).to(self.device)self.target_net = DuelingNet(cfg.n_states, cfg.n_actions, hidden_dim=cfg.hidden_dim).to(self.device)
    因为原初始化方式是初始化了两个相同内存地址的policy_net和target_net对象,修改后的初始化方式才是初始化两个不同内存地址的对象。
@severus98
Copy link

severus98 commented Aug 28, 2024

附议,DDPG的初始化也存在这个问题:

self.device = torch.device(cfg['device'])
self.critic = models['critic'].to(self.device)
self.target_critic = models['critic'].to(self.device)
self.actor = models['actor'].to(self.device)
self.target_actor = models['actor'].to(self.device)

for target_param, param in zip(self.target_critic.parameters(), self.critic.parameters()):
target_param.data.copy_(param.data)
for target_param, param in zip(self.target_actor.parameters(), self.actor.parameters()):
target_param.data.copy_(param.data)

这里的actor和target_actor,critic和target_critic实际上是对同一个网络模型的引用,并非独立的网络。

有两种改正方式:

  1. 一种如贴主所言,在初始化阶段传入类名Actor和Critic而非直接传入实例,然后分别赋参数,各实例化为独立的两个网络模型实例,然后再执行param.data.copy_参数拷贝;

  2. 另一种,可以如https://github.com/sfujim/TD3/blob/master/DDPG.py,在实例化一个网络之后,采用copy.deepcopy进行深拷贝到对应的target网络,从而创建两个独立且初始参数相同的网络模型

@severus98
Copy link

severus98 commented Aug 29, 2024

补充一下我尝试DDPG的实验结果:
分别使用:
1)蘑菇书配套原版代码
2)蘑菇书代码基础上,使用上述的copy.deepcopy复制target网络
3) https://github.com/sfujim/TD3/blob/master/DDPG.py
训练得到网络。 然后使用同一环境种子,测试对比100回合:
1)
load_WRONG_Network_test_seed_plus100

2)
load_deepcopy_test_seed_plus100

3)
load_github_test_seed_plus100

2)和 3) 大多数时候reward在-200 以上,性能接近,而 1)大多数时候在-200以下,性能差距较大;3)比2)略好,可能与训练方式、超参数有关。

考虑到方法3)是TD3算法作者实现的源码,有一定参考性和准确性,可以说明2)这种改正网络copy的方式是有效的。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants