How to load MushroomRL agent trained on a GPU? #15
-
Hello there! I used MushroomRL to train an agent on a GPU. But now I can't load it to evaluate it on a CPU because PyTorch tries to load it again on the GPU. This is the error I'm getting:
As PyTorch suggests, I need to specify that I want to map the internal modules back to the CPU when loading them. I looked through the MushroomRL docs, and I couldn't find a way to specify this. I also looked through the code, and the library is just loading from PyTorch without a way to override the mapping, as seen here. Do you have any suggestions on how to get around this problem? Any help will be greatly appreciated! Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Unfortunately, you cannot do anything with the saved model. You need to load in a machine using a GPU and then move back the tensors to the CPU before saving the agent again. You can access the torch network with the "network" attribute of the TorchApproximator and move the module on the CPU using the standard torch code. You also need to set use_cuda to false. If you are using a mushroom Regressor (the interface that wraps the approximation class, wich is used by any rl algorithm in mushroom) you can access your networks using Otherwise, you can modify the code of mushroom-rl at this line: and force the loading on the CPU. You may want also to make sure that all the use_cuda variables are set properly. In future releases of mushroom-rl, we might consider including the safe loading of agents from GPU, unfortunately, right now it's not possible. |
Beta Was this translation helpful? Give feedback.
Unfortunately, you cannot do anything with the saved model. You need to load in a machine using a GPU and then move back the tensors to the CPU before saving the agent again.
You can access the torch network with the "network" attribute of the TorchApproximator and move the module on the CPU using the standard torch code. You also need to set use_cuda to false.
If you are using a mushroom Regressor (the interface that wraps the approximation class, wich is used by any rl algorithm in mushroom) you can access your networks using
regressor.model.network
Otherwise, you can modify the code of mushroom-rl at this line:
https://github.com/MushroomRL/mushroom-rl/blob/b2a715221cf000417f2ad69fa6a…