You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found an error when run the example pytorch-lightning/examples/fabric/reinforcement_learning on M2 Mac (device type=mps)
Reproduce Error
reinforcement_learning git:(master) ✗ fabric run train_fabric.py
W0617 12:53:22.541000 8107367488 torch/distributed/elastic/multiprocessing/redirects.py:27] NOTE: Redirects are currently not supported in Windows or MacOs.
[rank: 0] Seed set to 42
Missing logger folder: logs/fabric_logs/2024-06-17_12-53-24/CartPole-v1_default_42_1718596404
set default torch dtype as torch.float32
Traceback (most recent call last):
File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/train_fabric.py", line 215, in <module>
main(args)
File "/Users/user/workspace/pytorch-lightning/examples/fabric/reinforcement_learning/train_fabric.py", line 154, in main
rewards[step] = torch.tensor(reward, device=device).view(-1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
This bug is fixed by checking device.type and type casting to torch.float32 reward.
Bug description
I found an error when run the example
pytorch-lightning/examples/fabric/reinforcement_learning
on M2 Mac (device type=mps)Reproduce Error
This bug is fixed by checking device.type and type casting to torch.float32
reward
.What version are you seeing the problem on?
master
Environment
Current environment
The text was updated successfully, but these errors were encountered: