Skip to content

Commit

Permalink
Review comments fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Jul 18, 2018
1 parent c8c37f1 commit 2aa9e26
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 40 deletions.
13 changes: 5 additions & 8 deletions nnvm/tests/python/frontend/darknet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@
"""
import os
import requests
import sys
import urllib
import numpy as np
import tvm
from tvm.contrib import graph_runtime
from nnvm import frontend
from nnvm.testing.darknet import __darknetffi__
import nnvm.compiler
import tvm
import sys
import urllib
if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2


def _download(url, path, overwrite=False, sizecompare=False):
''' Download from internet'''
if os.path.isfile(path) and not overwrite:
Expand Down Expand Up @@ -55,13 +57,8 @@ def _get_tvm_output(net, data):

target = 'llvm'
shape_dict = {'data': data.shape}
#with nnvm.compiler.build_config(opt_level=2):
graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params)
######################################################################
# Execute on TVM
# ---------------
# The process is no different from other examples.
from tvm.contrib import graph_runtime
ctx = tvm.cpu(0)
m = graph_runtime.create(graph, library, ctx)
# set inputs
Expand Down
44 changes: 12 additions & 32 deletions tutorials/nnvm/nlp/from_darknet_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,34 +25,23 @@
import urllib
import requests
import numpy as np
import urllib.request as urllib2
import tvm
from tvm.contrib import graph_runtime
#from tvm.contrib.debugger import debug_runtime as graph_runtime
from nnvm.testing.darknet import __darknetffi__
import nnvm
import nnvm.frontend.darknet

if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2

######################################################################
# Prepare cfg and weights file
# Pretrained model available
# --------------------------------------------------------------------
# Download cfg and weights file first time.
MODEL_NAME = 'rnn'

MODEL_NAME = 'rnn' #Model name
seed = 'Thus' #Seed value
num = 1000 #Number of characters to predict

# Download cfg and weights file if first time.
CFG_NAME = MODEL_NAME + '.cfg'
WEIGHTS_NAME = MODEL_NAME + '.weights'
CFG_URL = 'https://github.com/siju-samuel/darknet/blob/master/cfg/' + \
CFG_NAME + '?raw=true'
WEIGHTS_URL = 'https://github.com/siju-samuel/darknet/blob/master/weights/' + \
WEIGHTS_NAME + '?raw=true'
REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/'
CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true'
WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true'

def _dl_progress(count, block_size, total_size):
"""Show the download progress."""
Expand Down Expand Up @@ -95,13 +84,9 @@ def _download(url, path, overwrite=False, sizecompare=False):
_download(CFG_URL, CFG_NAME)
_download(WEIGHTS_URL, WEIGHTS_NAME)

######################################################################
# Download and Load darknet library
# ---------------------------------

DARKNET_LIB = 'libdarknet.so'
DARKNET_URL = 'https://github.com/siju-samuel/darknet/blob/master/lib/' + \
DARKNET_LIB + '?raw=true'
DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true'
_download(DARKNET_URL, DARKNET_LIB)
DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB)
cfg = "./" + str(CFG_NAME)
Expand All @@ -112,11 +97,7 @@ def _download(url, path, overwrite=False, sizecompare=False):
print("Converting darknet rnn model to nnvm symbols...")
sym, params = nnvm.frontend.darknet.from_darknet(net, dtype)

######################################################################
# Compile the model on NNVM
# --------------------------------------------------------------------
# compile the model

data = np.empty([1, net.inputs], dtype)#net.inputs

target = 'llvm'
Expand All @@ -129,9 +110,7 @@ def _download(url, path, overwrite=False, sizecompare=False):
with nnvm.compiler.build_config(opt_level=2):
graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params)

#####################################################################
# Save the json
# --------------------------------------------------------------------
def _save_lib():
'''Save the graph, params and .so to the current directory'''
print("Saving the compiled output...")
Expand All @@ -144,19 +123,17 @@ def _save_lib():
fo.write(nnvm.compiler.save_param_dict(params))
#_save_lib()

######################################################################
# Execute on TVM
# --------------------------------------------------------------------
# The process is no different from other examples.

ctx = tvm.cpu(0)

# Create graph runtime
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)

print("RNN generaring text...")

def _proc_rnn_output(out_data):
'''Generate the characters from the output array'''
sum_array = 0
n = out_data.size
r = random.uniform(0, 1)
Expand All @@ -173,6 +150,7 @@ def _proc_rnn_output(out_data):
return n-1

def _init_state_memory(rnn_cells_count, dtype):
'''Initialize memory for states'''
states = {}
state_shape = (1024,)
for i in range(rnn_cells_count):
Expand All @@ -181,10 +159,12 @@ def _init_state_memory(rnn_cells_count, dtype):
return states

def _set_state_input(runtime, states):
'''Set the state inputs'''
for state in states:
runtime.set_input(state, states[state])

def _get_state_output(runtime, states):
'''Get the state outputs and save'''
i = 1
for state in states:
data = states[state]
Expand Down

0 comments on commit 2aa9e26

Please sign in to comment.