Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in train.py #36

Open
zqs010908 opened this issue Jul 13, 2024 · 5 comments
Open

Error in train.py #36

zqs010908 opened this issue Jul 13, 2024 · 5 comments
Assignees

Comments

@zqs010908
Copy link

Thank you for providing the code.
I am trying to use train.by to train my model,but I encountered the following issue while using train.py

Traceback (most recent call last):
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 402, in
main(config)
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 326, in main
train_eval_loop_nomad(
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/vint_train/training/train_eval_loop.py", line 196, in train_eval_loop_nomad
ema_model = EMAModel(model=model,power=0.75)
TypeError: init() missing 1 required positional argument: 'parameters'

Traceback (most recent call last):
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 402, in
main(config)
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 326, in main
train_eval_loop_nomad(
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/vint_train/training/train_eval_loop.py", line 203, in train_eval_loop_nomad
train_nomad(
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/vint_train/training/train_utils.py", line 661, in train_nomad
loss.backward()
File "/home/iiau-vln/miniconda3/envs/nomad/lib/python3.8/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/iiau-vln/miniconda3/envs/nomad/lib/python3.8/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@zolmaeng
Copy link

Try using EMAModel(model.parameters(), power=0.75) instead of EMAModel(model, power=0.75).

@robodhruv
Copy link
Owner

robodhruv commented Jul 14, 2024

Thanks @zolmaeng! Does that fix the problem, @zqs010908 ?
We also welcome PRs with reproducible bug reports :)

This seems like an issue for other users too, tracking it at #30

@robodhruv robodhruv self-assigned this Jul 14, 2024
@zqs010908
Copy link
Author

I have already used EMAModel(model.parameters(), power=0.75),but a new problem has arisen

/home/iiau-vln/miniconda3/envs/nomad/lib/python3.8/site-packages/diffusers/training_utils.py:361: FutureWarning: Passing a torch.nn.Module to ExponentialMovingAverage.step is deprecated. Please pass the parameters of the module instead.
deprecate(
Traceback (most recent call last):
File "/home/iiau-vln/ws_zqs/nomad/ori_nomad/visualnav-transformer/train/train.py", line 402, in
main(config)
File "/home/iiau-vln/ws_zqs/nomad/ori_nomad/visualnav-transformer/train/train.py", line 326, in main
train_eval_loop_nomad(
File "/home/iiau-vln/ws_zqs/nomad/ori_nomad/visualnav-transformer/train/vint_train/training/train_eval_loop.py", line 203, in train_eval_loop_nomad
train_nomad(
File "/home/iiau-vln/ws_zqs/nomad/ori_nomad/visualnav-transformer/train/vint_train/training/train_utils.py", line 676, in train_nomad
ema_model.averaged_model,
AttributeError: 'EMAModel' object has no attribute 'averaged_model'

I found that the EMAModel class in diffusers.training_utils indeed does not have the averaged_model object. I see that the EMAModel in diffusion_policy has the averaged_model object, as referenced in line 31 of ema_model.py. I'm not sure if this method is correct.

@ajaysridhar0 ajaysridhar0 reopened this Jul 14, 2024
@zqs010908
Copy link
Author

And after I used the EMAModel in diffusion_policy, I found that it solved the previous problem, but now a new issue has arisen. I am using the config from nomad and training with the SACSoN/HuRoN dataset. Have you encountered this issue before?
Traceback (most recent call last):
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 402, in
main(config)
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/train.py", line 326, in main
train_eval_loop_nomad(
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/vint_train/training/train_eval_loop.py", line 203, in train_eval_loop_nomad
train_nomad(
File "/home/iiau-vln/ws_zqs/nomad/visualnav-transformer/train/vint_train/training/train_utils.py", line 860, in train_nomad
loss.backward()
File "/home/iiau-vln/miniconda3/envs/nomad/lib/python3.8/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/iiau-vln/miniconda3/envs/nomad/lib/python3.8/site-packages/torch/autograd/init.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

@keshav0306
Copy link

Change line 31 of the file ema_model.py in diffusion_policy/diffusion_policy/model/diffusion/ from
self.averaged_model = model to self.averaged_model = copy.deepcopy(model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants