1
1
import os
2
2
import sys
3
- import torch
4
3
import time
5
4
6
5
import habana_frameworks .torch .core as htcore
7
-
8
- from torch .utils .data import DataLoader
9
- from torchvision import transforms , datasets
6
+ import torch
10
7
import torch .nn as nn
11
8
import torch .nn .functional as F
9
+ from torch .utils .data import DataLoader
10
+ from torchvision import datasets , transforms
11
+
12
12
13
13
class Net (nn .Module ):
14
14
def __init__ (self ):
15
15
super (Net , self ).__init__ ()
16
- self .fc1 = nn .Linear (784 , 256 )
17
- self .fc2 = nn .Linear (256 , 64 )
18
- self .fc3 = nn .Linear (64 , 10 )
16
+ self .fc1 = nn .Linear (784 , 256 )
17
+ self .fc2 = nn .Linear (256 , 64 )
18
+ self .fc3 = nn .Linear (64 , 10 )
19
+
19
20
def forward (self , x ):
20
- out = x .view (- 1 ,28 * 28 )
21
+ out = x .view (- 1 , 28 * 28 )
21
22
out = F .relu (self .fc1 (out ))
22
23
out = F .relu (self .fc2 (out ))
23
24
out = self .fc3 (out )
24
25
out = F .log_softmax (out , dim = 1 )
25
26
return out
26
27
28
+
27
29
model = Net ()
28
30
model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth"
29
31
model_path = "/tmp/.neural_compressor/mnist-epoch_20.pth"
@@ -36,14 +38,12 @@ def forward(self, x):
36
38
model = model .to ("hpu" )
37
39
38
40
39
- transform = transforms .Compose ([
40
- transforms .ToTensor (),
41
- transforms .Normalize ((0.1307 ,), (0.3081 ,))])
41
+ transform = transforms .Compose ([transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))])
42
42
43
- data_path = ' ./data'
44
- test_kwargs = {' batch_size' : 32 }
43
+ data_path = " ./data"
44
+ test_kwargs = {" batch_size" : 32 }
45
45
dataset1 = datasets .MNIST (data_path , train = False , download = True , transform = transform )
46
- test_loader = torch .utils .data .DataLoader (dataset1 ,** test_kwargs )
46
+ test_loader = torch .utils .data .DataLoader (dataset1 , ** test_kwargs )
47
47
48
48
correct = 0
49
49
for batch_idx , (data , label ) in enumerate (test_loader ):
@@ -56,4 +56,4 @@ def forward(self, x):
56
56
57
57
correct += output .max (1 )[1 ].eq (label ).sum ()
58
58
59
- print (' Accuracy: {:.2f}%' .format (100. * correct / (len (test_loader ) * 32 )))
59
+ print (" Accuracy: {:.2f}%" .format (100.0 * correct / (len (test_loader ) * 32 )))
0 commit comments