Skip to content

Commit

Permalink
add vgg
Browse files Browse the repository at this point in the history
  • Loading branch information
changqian committed Nov 1, 2019
1 parent 44460cb commit 5c7ea82
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
37 changes: 37 additions & 0 deletions vgg/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch import nn
import torchvision
import os
import struct
from torchsummary import summary

def main():
print('cuda device count: ', torch.cuda.device_count())
net = torch.load('vgg.pth')
net = net.to('cuda:0')
net = net.eval()
print('model: ', net)
#print('state dict: ', net.state_dict().keys())
tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
print('input: ', tmp)
out = net(tmp)

print('output:', out)

summary(net, (3, 224, 224))
#return
f = open("vgg.wts", 'w')
f.write("{}\n".format(len(net.state_dict().keys())))
for k,v in net.state_dict().items():
print('key: ', k)
print('value: ', v.shape)
vr = v.reshape(-1).cpu().numpy()
f.write("{} {}".format(k, len(vr)))
for vv in vr:
f.write(" ")
f.write(struct.pack(">f", float(vv)).hex())
f.write("\n")

if __name__ == '__main__':
main()

20 changes: 20 additions & 0 deletions vgg/vgg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import torch
from torch import nn
from torch.nn import functional as F
import torchvision

def main():
print('cuda device count: ', torch.cuda.device_count())
net = torchvision.models.vgg11(pretrained=True)
#net.fc = nn.Linear(512, 2)
net = net.eval()
net = net.to('cuda:1')
print(net)
tmp = torch.ones(2, 3, 224, 224).to('cuda:1')
out = net(tmp)
print('vgg out:', out.shape)
torch.save(net, "vgg.pth")

if __name__ == '__main__':
main()

0 comments on commit 5c7ea82

Please sign in to comment.