We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0acb81f commit 9447852Copy full SHA for 9447852
beginner_source/quickstart/buildmodel_tutorial.py
@@ -80,7 +80,7 @@ def forward(self, x):
80
# Calling the model on the input returns a 10-dimensional tensor with raw predicted values for each class.
81
# We get the prediction probabilities by passing it through an instance of the ``nn.Softmax`` module.
82
83
-X = torch.rand(1, 28, 28)
+X = torch.rand(1, 28, 28, device=device)
84
logits = model(X)
85
pred_probab = nn.Softmax(dim=1)(logits)
86
y_pred = pred_probab.argmax(1)
0 commit comments