diff --git a/nnvm/python/nnvm/frontend/darknet.py b/nnvm/python/nnvm/frontend/darknet.py index 878b31846406..49d608137212 100644 --- a/nnvm/python/nnvm/frontend/darknet.py +++ b/nnvm/python/nnvm/frontend/darknet.py @@ -227,8 +227,7 @@ def _darknet_dense(inputs, attrs): op_name, new_attrs = 'dense', {} new_attrs['units'] = _darknet_required_attr(attrs, 'num_hidden') out_name = {} - if attrs.get('use_bias', False) is True: - new_attrs['use_bias'] = True + new_attrs['use_bias'] = attrs.get('use_bias', False) if attrs.get('use_flatten', False) is True: inputs[0] = _sym.flatten(inputs[0]) sym = _darknet_get_nnvm_op(op_name)(*inputs, **new_attrs) @@ -397,235 +396,301 @@ def _as_list(arr): return arr return [arr] -def _read_memory_buffer(shape, data, dtype): - length = 1 - for x in shape: - length *= x - data_np = np.zeros(length, dtype=dtype) - for i in range(length): - data_np[i] = data[i] - return data_np.reshape(shape) - -def _get_convolution_weights(layer, opname, params, dtype): - """Get the convolution layer weights and biases.""" - if layer.nweights == 0: - return - if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: - raise RuntimeError("layer weights size not matching with n c h w") +class GraphProto(object): + """A helper class for handling nnvm graph copying from darknet model. + """ - weights = _read_memory_buffer((layer.n, layer.c, layer.size, layer.size), layer.weights, dtype) + def __init__(self, net, dtype='float32'): + self.net = net + self.dtype = dtype + self._sym_array = {} + self._tvmparams = {} + self._outs = [] + self._rnn_state_ctr = 0 - biases = _read_memory_buffer((layer.n, ), layer.biases, dtype) + def _read_memory_buffer(self, shape, data): + length = 1 + for x in shape: + length *= x + data_np = np.zeros(length, dtype=self.dtype) + for i in range(length): + data_np[i] = data[i] + return data_np.reshape(shape) - k = _get_tvm_params_name(opname[0], 'weight') - params[k] = tvm.nd.array(weights) + def _get_convolution_weights(self, layer, opname): + """Get the convolution layer weights and biases.""" + if layer.nweights == 0: + return - if layer.batch_normalize == 1 and layer.dontloadscales != 1: - _get_batchnorm_weights(layer, opname[1], params, layer.n, dtype) - k = _get_tvm_params_name(opname[1], 'beta') - params[k] = tvm.nd.array(biases) - else: - k = _get_tvm_params_name(opname[0], 'bias') - params[k] = tvm.nd.array(biases) + if (layer.n * layer.c * layer.size * layer.size) != layer.nweights: + raise RuntimeError("layer weights size not matching with n c h w") -def _get_connected_weights(layer, opname, params, dtype): - """Parse the weights and biases for fully connected or dense layer.""" - size = layer.outputs * layer.inputs - if size == 0: - return + shape = (layer.n, layer.c, layer.size, layer.size) + weights = self._read_memory_buffer(shape, layer.weights) - weights = _read_memory_buffer((layer.outputs, layer.inputs), layer.weights, dtype) - biases = _read_memory_buffer((layer.outputs, ), layer.biases, dtype) + biases = self._read_memory_buffer((layer.n, ), layer.biases) - k = _get_tvm_params_name(opname[0], 'weight') - params[k] = tvm.nd.array(weights) + k = self._get_tvm_params_name(opname[0], 'weight') + self._tvmparams[k] = tvm.nd.array(weights) - if layer.batch_normalize == 1 and layer.dontloadscales != 1: - _get_batchnorm_weights(layer, opname[1], params, layer.outputs, dtype) - k = _get_tvm_params_name(opname[1], 'beta') - params[k] = tvm.nd.array(biases) - else: - k = _get_tvm_params_name(opname[0], 'bias') - params[k] = tvm.nd.array(biases) - -def _get_batchnorm_weights(layer, opname, params, size, dtype): - """Parse the weights for batchnorm, which includes, scales, moving mean - and moving variances.""" - scales = _read_memory_buffer((size, ), layer.scales, dtype) - rolling_mean = _read_memory_buffer((size, ), layer.rolling_mean, dtype) - rolling_variance = _read_memory_buffer((size, ), layer.rolling_variance, dtype) - - k = _get_tvm_params_name(opname, 'moving_mean') - params[k] = tvm.nd.array(rolling_mean) - k = _get_tvm_params_name(opname, 'moving_var') - params[k] = tvm.nd.array(rolling_variance) - k = _get_tvm_params_name(opname, 'gamma') - params[k] = tvm.nd.array(scales) - -def _get_darknet_attrs(net, layer_num): - """Parse attributes of each layer and return.""" - attr = {} - use_flatten = True - layer = net.layers[layer_num] - if LAYERTYPE.CONVOLUTIONAL == layer.type: - attr.update({'layout' : 'NCHW'}) - attr.update({'pad' : str(layer.pad)}) - attr.update({'num_group' : str(layer.groups)}) - attr.update({'num_filter' : str(layer.n)}) - attr.update({'stride' : str(layer.stride)}) - attr.update({'kernel' : str(layer.size)}) - attr.update({'activation' : (layer.activation)}) - - if layer.nbiases == 0: - attr.update({'use_bias' : False}) + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + self._get_batchnorm_weights(layer, opname[1], layer.n) + k = self._get_tvm_params_name(opname[1], 'beta') + self._tvmparams[k] = tvm.nd.array(biases) else: - attr.update({'use_bias' : True}) + k = self._get_tvm_params_name(opname[0], 'bias') + self._tvmparams[k] = tvm.nd.array(biases) + + def _get_connected_weights(self, layer, opname): + """Parse the weights and biases for fully connected or dense layer.""" + size = layer.outputs * layer.inputs + if size == 0: + return + + weights = self._read_memory_buffer((layer.outputs, layer.inputs), layer.weights) + biases = self._read_memory_buffer((layer.outputs, ), layer.biases) + + k = self._get_tvm_params_name(opname[0], 'weight') + self._tvmparams[k] = tvm.nd.array(weights) if layer.batch_normalize == 1 and layer.dontloadscales != 1: - attr.update({'use_batchNorm' : True}) - attr.update({'use_scales' : True}) - - #elif LAYERTYPE.BATCHNORM == layer.type: - # attr.update({'flatten' : str('True')}) - - elif LAYERTYPE.CONNECTED == layer.type: - attr.update({'num_hidden' : str(layer.outputs)}) - attr.update({'activation' : (layer.activation)}) - if layer_num != 0: - layer_prev = net.layers[layer_num - 1] - if (layer_prev.out_h == layer.h and - layer_prev.out_w == layer.w and - layer_prev.out_c == layer.c): - use_flatten = False - attr.update({'use_flatten' : use_flatten}) - if layer.nbiases == 0: - attr.update({'use_bias' : False}) + self._get_batchnorm_weights(layer, opname[1], layer.outputs) + k = self._get_tvm_params_name(opname[1], 'beta') + self._tvmparams[k] = tvm.nd.array(biases) else: + k = self._get_tvm_params_name(opname[0], 'bias') + self._tvmparams[k] = tvm.nd.array(biases) + + def _get_batchnorm_weights(self, layer, opname, size): + """Parse the weights for batchnorm, which includes, scales, moving mean + and moving variances.""" + scales = self._read_memory_buffer((size, ), layer.scales) + rolling_mean = self._read_memory_buffer((size, ), layer.rolling_mean) + rolling_variance = self._read_memory_buffer((size, ), layer.rolling_variance) + + k = self._get_tvm_params_name(opname, 'moving_mean') + self._tvmparams[k] = tvm.nd.array(rolling_mean) + k = self._get_tvm_params_name(opname, 'moving_var') + self._tvmparams[k] = tvm.nd.array(rolling_variance) + k = self._get_tvm_params_name(opname, 'gamma') + self._tvmparams[k] = tvm.nd.array(scales) + + def _get_darknet_attrs(self, layer, layer_num): + """Parse attributes of each layer and return.""" + attr = {} + use_flatten = True + if LAYERTYPE.CONVOLUTIONAL == layer.type: + attr.update({'layout' : 'NCHW'}) + attr.update({'pad' : str(layer.pad)}) + attr.update({'num_group' : str(layer.groups)}) + attr.update({'num_filter' : str(layer.n)}) + attr.update({'stride' : str(layer.stride)}) + attr.update({'kernel' : str(layer.size)}) + attr.update({'activation' : (layer.activation)}) + + if layer.nbiases == 0: + attr.update({'use_bias' : False}) + else: + attr.update({'use_bias' : True}) + + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + + elif LAYERTYPE.CONNECTED == layer.type: + attr.update({'num_hidden' : str(layer.outputs)}) + attr.update({'activation' : (layer.activation)}) + if layer_num != 0: + layer_prev = self.net.layers[layer_num - 1] + if (layer_prev.out_h == layer.h and + layer_prev.out_w == layer.w and + layer_prev.out_c == layer.c): + use_flatten = False + attr.update({'use_flatten' : use_flatten}) attr.update({'use_bias' : True}) - if layer.batch_normalize == 1 and layer.dontloadscales != 1: - attr.update({'use_batchNorm' : True}) - attr.update({'use_scales' : True}) - - elif LAYERTYPE.MAXPOOL == layer.type: - attr.update({'pad' : str(layer.pad)}) - attr.update({'stride' : str(layer.stride)}) - attr.update({'kernel' : str(layer.size)}) - max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1 - if max_output < layer.out_w: - extra_pad = (layer.out_w - max_output)*layer.stride - attr.update({'extra_pad_size' : int(extra_pad)}) - elif LAYERTYPE.AVGPOOL == layer.type: - attr.update({'pad' : str(layer.pad)}) - if layer.stride == 0: - attr.update({'stride' : str(1)}) - else: + if layer.batch_normalize == 1 and layer.dontloadscales != 1: + attr.update({'use_batchNorm' : True}) + attr.update({'use_scales' : True}) + attr.update({'use_bias' : False}) + + elif LAYERTYPE.MAXPOOL == layer.type: + attr.update({'pad' : str(layer.pad)}) attr.update({'stride' : str(layer.stride)}) - if layer.size == 0 and layer.h == layer.w: - attr.update({'kernel' : str(layer.h)}) - else: attr.update({'kernel' : str(layer.size)}) + max_output = (layer.w - layer.size + 2 * layer.pad)/float(layer.stride) + 1 + if max_output < layer.out_w: + extra_pad = (layer.out_w - max_output)*layer.stride + attr.update({'extra_pad_size' : int(extra_pad)}) + elif LAYERTYPE.AVGPOOL == layer.type: + attr.update({'pad' : str(layer.pad)}) + if layer.stride == 0: + attr.update({'stride' : str(1)}) + else: + attr.update({'stride' : str(layer.stride)}) + if layer.size == 0 and layer.h == layer.w: + attr.update({'kernel' : str(layer.h)}) + else: + attr.update({'kernel' : str(layer.size)}) + + elif LAYERTYPE.DROPOUT == layer.type: + attr.update({'p' : str(layer.probability)}) + + elif LAYERTYPE.SOFTMAX == layer.type: + attr.update({'axis' : 1}) + attr.update({'use_flatten' : True}) + if layer.temperature: + attr.update({'temperature' : str(layer.temperature)}) + + elif LAYERTYPE.SHORTCUT == layer.type: + add_layer = self.net.layers[layer.index] + attr.update({'activation' : (layer.activation)}) + attr.update({'out_channel' : (layer.out_c)}) + attr.update({'out_size' : (layer.out_h)}) + attr.update({'add_out_channel' : (add_layer.out_c)}) + attr.update({'add_out_size' : (add_layer.out_h)}) + + elif LAYERTYPE.ROUTE == layer.type: + pass + + elif LAYERTYPE.COST == layer.type: + pass + + elif LAYERTYPE.REORG == layer.type: + attr.update({'stride' : layer.stride}) + + elif LAYERTYPE.REGION == layer.type: + attr.update({'n' : layer.n}) + attr.update({'classes' : layer.classes}) + attr.update({'coords' : layer.coords}) + attr.update({'background' : layer.background}) + attr.update({'softmax' : layer.softmax}) + else: + err = "Darknet layer type {} is not supported in nnvm.".format(layer.type) + raise NotImplementedError(err) - elif LAYERTYPE.DROPOUT == layer.type: - attr.update({'p' : str(layer.probability)}) - - elif LAYERTYPE.SOFTMAX == layer.type: - attr.update({'axis' : 1}) - attr.update({'use_flatten' : True}) - if layer.temperature: - attr.update({'temperature' : str(layer.temperature)}) - - elif LAYERTYPE.SHORTCUT == layer.type: - add_layer = net.layers[layer.index] - attr.update({'activation' : (layer.activation)}) - attr.update({'out_channel' : (layer.out_c)}) - attr.update({'out_size' : (layer.out_h)}) - attr.update({'add_out_channel' : (add_layer.out_c)}) - attr.update({'add_out_size' : (add_layer.out_h)}) - - elif LAYERTYPE.ROUTE == layer.type: - pass - - elif LAYERTYPE.COST == layer.type: - pass - - elif LAYERTYPE.REORG == layer.type: - attr.update({'stride' : layer.stride}) - - elif LAYERTYPE.REGION == layer.type: - attr.update({'n' : layer.n}) - attr.update({'classes' : layer.classes}) - attr.update({'coords' : layer.coords}) - attr.update({'background' : layer.background}) - attr.update({'softmax' : layer.softmax}) - else: - err = "Darknet layer type {} is not supported in nnvm.".format(layer.type) - raise NotImplementedError(err) - - return layer.type, attr - -def _get_tvm_params_name(opname, arg_name): - """Makes the params name for the k,v pair.""" - return opname + '_'+ arg_name - -def _get_darknet_params(layer, opname, tvmparams, dtype='float32'): - """To parse and get the darknet params.""" - if LAYERTYPE.CONVOLUTIONAL == layer.type: - _get_convolution_weights(layer, opname, tvmparams, dtype) - - #elif LAYERTYPE.BATCHNORM == layer.type: - # size = layer.outputs - # _get_batchnorm_weights(layer, opname, tvmparams, size, dtype) - - elif LAYERTYPE.CONNECTED == layer.type: - _get_connected_weights(layer, opname, tvmparams, dtype) - -def _preproc_layer(net, i, sym_array): - """To preprocess each darknet layer, some layer doesnt need processing.""" - layer = net.layers[i] - if i == 0: - name = 'data' - attribute = {} - sym = [_sym.Variable(name, **attribute)] - else: - sym = sym_array[i - 1] - skip_layer = False - - if LAYERTYPE.ROUTE == layer.type: - sym = [] - for j in range(layer.n): - sym.append(sym_array[layer.input_layers[j]]) - if layer.n == 1: + return attr + + def _get_tvm_params_name(self, opname, arg_name): + """Makes the params name for the k,v pair.""" + return opname + '_'+ arg_name + + def _get_darknet_params(self, layer, opname): + """To parse and get the darknet params.""" + if LAYERTYPE.CONVOLUTIONAL == layer.type: + self._get_convolution_weights(layer, opname) + + elif LAYERTYPE.CONNECTED == layer.type: + self._get_connected_weights(layer, opname) + + def _preproc_layer(self, layer, layer_num): + """To preprocess each darknet layer, some layer doesnt need processing.""" + if layer_num == 0: + name = 'data' + attribute = {} + sym = [_sym.Variable(name, **attribute)] + else: + sym = self._sym_array[layer_num - 1] + skip_layer = False + + if LAYERTYPE.ROUTE == layer.type: + sym = [] + for j in range(layer.n): + sym.append(self._sym_array[layer.input_layers[j]]) + if layer.n == 1: + skip_layer = True + + elif LAYERTYPE.COST == layer.type: skip_layer = True - elif LAYERTYPE.COST == layer.type: - skip_layer = True + elif LAYERTYPE.SHORTCUT == layer.type: + sym = [sym, self._sym_array[layer.index]] - elif LAYERTYPE.SHORTCUT == layer.type: - sym = [sym, sym_array[layer.index]] + elif LAYERTYPE.BLANK == layer.type: + skip_layer = True - elif LAYERTYPE.BLANK == layer.type: - skip_layer = True + if skip_layer is True: + self._sym_array[layer_num] = sym - if skip_layer is True: - sym_array[i] = sym + return skip_layer, sym - return skip_layer, sym + def _get_opname(self, layer): + """Returs the layer name.""" + return layer.type -def _from_darknet(net, dtype='float32'): - """To convert the darknet symbol to nnvm symbols.""" - sym_array = {} - tvmparams = {} - for i in range(net.n): - need_skip, sym = _preproc_layer(net, i, sym_array) - if need_skip is True: - continue - op_name, attr = _get_darknet_attrs(net, i) - layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr) - _get_darknet_params(net.layers[i], layer_name, tvmparams, dtype) - sym_array[i] = sym + def _new_rnn_state_sym(self, state=None): + """Returs a symbol for state""" + name = "rnn%d_state" % (self._rnn_state_ctr) + self._rnn_state_ctr += 1 + return _sym.Variable(name=name, init=state) + + def _get_rnn_state_buffer(self, layer): + """Get the state buffer for rnn.""" + buffer = np.zeros((1, layer.outputs), self.dtype) + return self._new_rnn_state_sym(buffer) - return sym, tvmparams + def _get_darknet_rnn_attrs(self, layer, sym): + """Get the rnn converted symbol from attributes.""" + attr = self._get_darknet_attrs(layer, 0) + op_name = self._get_opname(layer) + layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr) + self._get_darknet_params(layer, layer_name) + return sym + + def _handle_darknet_rnn_layers(self, layer_num, sym): + """Parse attributes and handle the rnn layers.""" + attr = {} + layer = self.net.layers[layer_num] + processed = False + + if LAYERTYPE.RNN == layer.type: + attr.update({'n' : layer.n}) + attr.update({'batch' : layer.batch}) + attr.update({'num_hidden' : str(layer.outputs)}) + + state = self._get_rnn_state_buffer(layer) + + for _ in range(layer.steps): + input_layer = layer.input_layer + sym = self._get_darknet_rnn_attrs(input_layer, sym) + + self_layer = layer.self_layer + state = self._get_darknet_rnn_attrs(self_layer, state) + + op_name, new_attrs = 'elemwise_add', {} + new_inputs = _as_list([sym, state]) + state = _darknet_get_nnvm_op(op_name)(*new_inputs, **new_attrs) + self._outs.append(state) + + output_layer = layer.output_layer + sym = self._get_darknet_rnn_attrs(output_layer, state) + + self._sym_array[layer_num] = sym + processed = True + + return processed, sym + + def from_darknet(self): + """To convert the darknet symbol to nnvm symbols.""" + for i in range(self.net.n): + layer = self.net.layers[i] + need_skip, sym = self._preproc_layer(layer, i) + if need_skip is True: + continue + + processed, sym = self._handle_darknet_rnn_layers(i, sym) + if processed is True: + continue + + attr = self._get_darknet_attrs(layer, i) + op_name = self._get_opname(layer) + layer_name, sym = _darknet_convert_symbol(op_name, _as_list(sym), attr) + self._get_darknet_params(self.net.layers[i], layer_name) + self._sym_array[i] = sym + self._outs = _as_list(sym) + self._outs + if isinstance(self._outs, list): + sym = _sym.Group(self._outs) + return sym, self._tvmparams def from_darknet(net, dtype='float32'): """Convert from darknet's model into compatible NNVM format. @@ -648,4 +713,4 @@ def from_darknet(net, dtype='float32'): The parameter dict to be used by nnvm """ - return _from_darknet(net, dtype) + return GraphProto(net, dtype).from_darknet() diff --git a/nnvm/python/nnvm/testing/darknet.py b/nnvm/python/nnvm/testing/darknet.py index f6bbf8f7d951..362fd3058954 100644 --- a/nnvm/python/nnvm/testing/darknet.py +++ b/nnvm/python/nnvm/testing/darknet.py @@ -479,6 +479,7 @@ class ACTIVATION(object): void free_image(image m); image load_image_color(char *filename, int w, int h); float *network_predict_image(network *net, image im); +float *network_predict(network *net, float *input); network *make_network(int n); layer make_convolutional_layer(int batch, int h, int w, int c, int n, int groups, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam); layer make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation, int batch_normalize, int adam); @@ -488,6 +489,8 @@ class ACTIVATION(object): layer make_batchnorm_layer(int batch, int w, int h, int c); layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse, int flatten, int extra); layer make_region_layer(int batch, int w, int h, int n, int classes, int coords); +layer make_softmax_layer(int batch, int inputs, int groups); +layer make_rnn_layer(int batch, int inputs, int outputs, int steps, ACTIVATION activation, int batch_normalize, int adam); void free_network(network *net); """ ) diff --git a/nnvm/tests/python/frontend/darknet/test_forward.py b/nnvm/tests/python/frontend/darknet/test_forward.py index 0abc595c426f..a8f7ccdeffd7 100644 --- a/nnvm/tests/python/frontend/darknet/test_forward.py +++ b/nnvm/tests/python/frontend/darknet/test_forward.py @@ -7,18 +7,20 @@ """ import os import requests +import sys +import urllib import numpy as np +import tvm +from tvm.contrib import graph_runtime from nnvm import frontend from nnvm.testing.darknet import __darknetffi__ import nnvm.compiler -import tvm -import sys -import urllib if sys.version_info >= (3,): import urllib.request as urllib2 else: import urllib2 + def _download(url, path, overwrite=False, sizecompare=False): ''' Download from internet''' if os.path.isfile(path) and not overwrite: @@ -48,43 +50,31 @@ def _download(url, path, overwrite=False, sizecompare=False): _download(DARKNETLIB_URL, DARKNET_LIB) LIB = __darknetffi__.dlopen('./' + DARKNET_LIB) +def _get_tvm_output(net, data): + '''Compute TVM output''' + dtype = 'float32' + sym, params = frontend.darknet.from_darknet(net, dtype) + + target = 'llvm' + shape_dict = {'data': data.shape} + graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params) + # Execute on TVM + ctx = tvm.cpu(0) + m = graph_runtime.create(graph, library, ctx) + # set inputs + m.set_input('data', tvm.nd.array(data.astype(dtype))) + m.set_input(**params) + m.run() + # get outputs + out_shape = (net.outputs,) + tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + return tvm_out + def test_forward(net): '''Test network with given input image on both darknet and tvm''' def get_darknet_output(net, img): return LIB.network_predict_image(net, img) - - def get_tvm_output(net, img): - '''Compute TVM output''' - dtype = 'float32' - batch_size = 1 - sym, params = frontend.darknet.from_darknet(net, dtype) - data = np.empty([batch_size, img.c, img.h, img.w], dtype) - i = 0 - for c in range(img.c): - for h in range(img.h): - for k in range(img.w): - data[0][c][h][k] = img.data[i] - i = i + 1 - - target = 'llvm' - shape_dict = {'data': data.shape} - #with nnvm.compiler.build_config(opt_level=2): - graph, library, params = nnvm.compiler.build(sym, target, shape_dict, dtype, params=params) - ###################################################################### - # Execute on TVM - # --------------- - # The process is no different from other examples. - from tvm.contrib import graph_runtime - ctx = tvm.cpu(0) - m = graph_runtime.create(graph, library, ctx) - # set inputs - m.set_input('data', tvm.nd.array(data.astype(dtype))) - m.set_input(**params) - m.run() - # get outputs - out_shape = (net.outputs,) - tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() - return tvm_out + dtype = 'float32' test_image = 'dog.jpg' img_url = 'https://github.com/siju-samuel/darknet/blob/master/data/' + test_image +'?raw=true' @@ -94,9 +84,35 @@ def get_tvm_output(net, img): darknet_out = np.zeros(net.outputs, dtype='float32') for i in range(net.outputs): darknet_out[i] = darknet_output[i] - tvm_out = get_tvm_output(net, img) + batch_size = 1 + + data = np.empty([batch_size, img.c, img.h, img.w], dtype) + i = 0 + for c in range(img.c): + for h in range(img.h): + for k in range(img.w): + data[0][c][h][k] = img.data[i] + i = i + 1 + + tvm_out = _get_tvm_output(net, data) np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-3, atol=1e-3) +def test_rnn_forward(net): + '''Test network with given input data on both darknet and tvm''' + def get_darknet_network_predict(net, data): + return LIB.network_predict(net, data) + from cffi import FFI + ffi = FFI() + np_arr = np.zeros([1, net.inputs], dtype='float32') + np_arr[0, 84] = 1 + cffi_arr = ffi.cast('float*', np_arr.ctypes.data) + tvm_out = _get_tvm_output(net, np_arr) + darknet_output = get_darknet_network_predict(net, cffi_arr) + darknet_out = np.zeros(net.outputs, dtype='float32') + for i in range(net.outputs): + darknet_out[i] = darknet_output[i] + np.testing.assert_allclose(darknet_out, tvm_out, rtol=1e-4, atol=1e-4) + def test_forward_extraction(): '''test extraction model''' model_name = 'extraction' @@ -289,6 +305,25 @@ def test_forward_softmax_temperature(): test_forward(net) LIB.free_network(net) +def test_forward_rnn(): + '''test softmax layer''' + net = LIB.make_network(1) + batch = 1 + inputs = 256 + outputs = 256 + steps = 1 + activation = 1 + batch_normalize = 0 + adam = 0 + layer_1 = LIB.make_rnn_layer(batch, inputs, outputs, steps, activation, batch_normalize, adam) + net.layers[0] = layer_1 + net.inputs = inputs + net.outputs = outputs + net.w = net.h = 0 + LIB.resize_network(net, net.w, net.h) + test_rnn_forward(net) + LIB.free_network(net) + if __name__ == '__main__': test_forward_resnet50() test_forward_alexnet() @@ -303,6 +338,7 @@ def test_forward_softmax_temperature(): test_forward_dense_batchnorm() test_forward_softmax() test_forward_softmax_temperature() + test_forward_rnn() test_forward_reorg() test_forward_region() test_forward_elu() diff --git a/tutorials/nnvm/nlp/from_darknet_rnn.py b/tutorials/nnvm/nlp/from_darknet_rnn.py new file mode 100644 index 000000000000..54013f04fca6 --- /dev/null +++ b/tutorials/nnvm/nlp/from_darknet_rnn.py @@ -0,0 +1,184 @@ +""" +Compile Darknet Models for RNN +============================== +**Author**: `Siju Samuel `_ + +This article is an introductory tutorial to deploy darknet rnn models with NNVM. + +This script will run a character prediction model +Each module consists of 3 fully-connected layers. The input layer propagates information from the +input to the current state. The recurrent layer propagates information through time from the +previous state to the current one. + +The input to the network is a 1-hot encoding of ASCII characters. We train the network to predict +the next character in a stream of characters. The output is constrained to be a probability +distribution using a softmax layer. + +Since each recurrent layer contains information about the current character and the past +characters, it can use this context to predict the future characters in a word or phrase. + +All the required models and libraries will be downloaded from the internet +by the script. +""" +import random +import numpy as np +from mxnet.gluon.utils import download +import tvm +from tvm.contrib import graph_runtime +from nnvm.testing.darknet import __darknetffi__ +import nnvm +import nnvm.frontend.darknet + +# Set the parameters +# ----------------------- +# Set the seed value and the number of characters to predict + +#Model name +MODEL_NAME = 'rnn' +#Seed value +seed = 'Thus' +#Number of characters to predict +num = 1000 + +# Download required files +# ----------------------- +# Download cfg and weights file if first time. +CFG_NAME = MODEL_NAME + '.cfg' +WEIGHTS_NAME = MODEL_NAME + '.weights' +REPO_URL = 'https://github.com/dmlc/web-data/blob/master/darknet/' +CFG_URL = REPO_URL + 'cfg/' + CFG_NAME + '?raw=true' +WEIGHTS_URL = REPO_URL + 'weights/' + WEIGHTS_NAME + '?raw=true' + +download(CFG_URL, CFG_NAME) +download(WEIGHTS_URL, WEIGHTS_NAME) + +# Download and Load darknet library +DARKNET_LIB = 'libdarknet.so' +DARKNET_URL = REPO_URL + 'lib/' + DARKNET_LIB + '?raw=true' +download(DARKNET_URL, DARKNET_LIB) +DARKNET_LIB = __darknetffi__.dlopen('./' + DARKNET_LIB) +cfg = "./" + str(CFG_NAME) +weights = "./" + str(WEIGHTS_NAME) +net = DARKNET_LIB.load_network(cfg.encode('utf-8'), weights.encode('utf-8'), 0) +dtype = 'float32' +batch_size = 1 + +# Import the graph to NNVM +# ------------------------ +# Import darknet graph definition to nnvm. +# +# Results: +# sym: nnvm graph for rnn model +# params: params converted from darknet weights +print("Converting darknet rnn model to nnvm symbols...") +sym, params = nnvm.frontend.darknet.from_darknet(net, dtype) + +# Compile the model on NNVM +data = np.empty([1, net.inputs], dtype)#net.inputs + +target = 'llvm' +shape = {'data': data.shape} +print("Compiling the model...") + +shape_dict = {'data': data.shape} +dtype_dict = {'data': data.dtype} + +with nnvm.compiler.build_config(opt_level=2): + graph, lib, params = nnvm.compiler.build(sym, target, shape_dict, dtype_dict, params) + +# Execute the portable graph on TVM +# --------------------------------- +# Now we can try deploying the NNVM compiled model on cpu target. + +# Set the cpu context +ctx = tvm.cpu(0) +# Create graph runtime +m = graph_runtime.create(graph, lib, ctx) +# Set the params to runtime +m.set_input(**params) + +def _init_state_memory(rnn_cells_count, dtype): + '''Initialize memory for states''' + states = {} + state_shape = (1024,) + for i in range(rnn_cells_count): + k = 'rnn' + str(i) + '_state' + states[k] = tvm.nd.array(np.zeros(state_shape, dtype).astype(dtype)) + return states + +def _set_state_input(runtime, states): + '''Set the state inputs''' + for state in states: + runtime.set_input(state, states[state]) + +def _get_state_output(runtime, states): + '''Get the state outputs and save''' + i = 1 + for state in states: + data = states[state] + states[state] = runtime.get_output((i), tvm.nd.empty(data.shape, data.dtype)) + i += 1 + +def _proc_rnn_output(out_data): + '''Generate the characters from the output array''' + sum_array = 0 + n = out_data.size + r = random.uniform(0, 1) + for j in range(n): + if out_data[j] < 0.0001: + out_data[j] = 0 + sum_array += out_data[j] + + for j in range(n): + out_data[j] *= float(1.0) / sum_array + r = r - out_data[j] + if r <= 0: + return j + return n-1 + +print("RNN generaring text...") + +out_shape = (net.outputs,) +rnn_cells_count = 3 + +# Initialize state memory +# ----------------------- +states = _init_state_memory(rnn_cells_count, dtype) + +len_seed = len(seed) +count = len_seed + num +out_txt = "" + +#Initialize random seed +random.seed(0) +c = ord(seed[0]) +inp_data = np.zeros([net.inputs], dtype) + +# Run the model +# ------------- + +# Predict character by character till `num` +for i in range(count): + inp_data[c] = 1 + + # Set the input data + m.set_input('data', tvm.nd.array(inp_data.astype(dtype))) + inp_data[c] = 0 + + # Set the state inputs + _set_state_input(m, states) + + # Run the model + m.run() + + # Get the output + tvm_out = m.get_output(0, tvm.nd.empty(out_shape, dtype)).asnumpy() + + # Get the state outputs + _get_state_output(m, states) + + # Get the predicted character and keep buffering it + c = ord(seed[i]) if i < len_seed else _proc_rnn_output(tvm_out) + out_txt += chr(c) + +print("Predicted Text =", out_txt)