Skip to content

Commit 4b94a98

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2697120 commit 4b94a98

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
import os
22
import sys
3-
import torch
43
import time
54

65
import habana_frameworks.torch.core as htcore
7-
8-
from torch.utils.data import DataLoader
9-
from torchvision import transforms, datasets
6+
import torch
107
import torch.nn as nn
118
import torch.nn.functional as F
9+
from torch.utils.data import DataLoader
10+
from torchvision import datasets, transforms
11+
1212

1313
class Net(nn.Module):
1414
def __init__(self):
1515
super(Net, self).__init__()
16-
self.fc1 = nn.Linear(784, 256)
17-
self.fc2 = nn.Linear(256, 64)
18-
self.fc3 = nn.Linear(64, 10)
16+
self.fc1 = nn.Linear(784, 256)
17+
self.fc2 = nn.Linear(256, 64)
18+
self.fc3 = nn.Linear(64, 10)
19+
1920
def forward(self, x):
20-
out = x.view(-1,28*28)
21+
out = x.view(-1, 28 * 28)
2122
out = F.relu(self.fc1(out))
2223
out = F.relu(self.fc2(out))
2324
out = self.fc3(out)
2425
out = F.log_softmax(out, dim=1)
2526
return out
2627

28+
2729
model = Net()
2830
model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth"
2931
model_path = "/tmp/.neural_compressor/mnist-epoch_20.pth"
@@ -36,14 +38,12 @@ def forward(self, x):
3638
model = model.to("hpu")
3739

3840

39-
transform=transforms.Compose([
40-
transforms.ToTensor(),
41-
transforms.Normalize((0.1307,), (0.3081,))])
41+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
4242

43-
data_path = './data'
44-
test_kwargs = {'batch_size': 32}
43+
data_path = "./data"
44+
test_kwargs = {"batch_size": 32}
4545
dataset1 = datasets.MNIST(data_path, train=False, download=True, transform=transform)
46-
test_loader = torch.utils.data.DataLoader(dataset1,**test_kwargs)
46+
test_loader = torch.utils.data.DataLoader(dataset1, **test_kwargs)
4747

4848
correct = 0
4949
for batch_idx, (data, label) in enumerate(test_loader):
@@ -56,4 +56,4 @@ def forward(self, x):
5656

5757
correct += output.max(1)[1].eq(label).sum()
5858

59-
print('Accuracy: {:.2f}%'.format(100. * correct / (len(test_loader) * 32)))
59+
print("Accuracy: {:.2f}%".format(100.0 * correct / (len(test_loader) * 32)))

0 commit comments

Comments
 (0)