diff --git a/python/tvm/relay/frontend/__init__.py b/python/tvm/relay/frontend/__init__.py index aba9eea494be..7154f5a1ab6d 100644 --- a/python/tvm/relay/frontend/__init__.py +++ b/python/tvm/relay/frontend/__init__.py @@ -33,3 +33,4 @@ from .tensorflow import from_tensorflow from .darknet import from_darknet from .pytorch import from_pytorch +from .caffe import from_caffe diff --git a/python/tvm/relay/frontend/caffe.py b/python/tvm/relay/frontend/caffe.py new file mode 100644 index 000000000000..b7bcbde0de63 --- /dev/null +++ b/python/tvm/relay/frontend/caffe.py @@ -0,0 +1,848 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel +# pylint: disable=no-else-return, no-else-continue +"""Caffe frontend.""" +import numpy as np +import tvm +from tvm.ir import IRModule +from .. import analysis +from .. import expr as _expr +from .. import function as _function +from .. import op as _op +from ... import nd as _nd +from .common import ExprTable +from .common import infer_shape as _infer_shape + +__all__ = ['from_caffe'] + + +class OperatorConverter(object): + """ Operator Converted for converting Caffe ops to Relay ops """ + def __init__(self, init_layer_dict, predict_layer, exp_tab): + self.init_layer_dict = init_layer_dict + self.predict_layer = predict_layer + self.exp_tab = exp_tab + self.new_bn = {} + self.changed_layers = None + + self.convert_map = { + 'BatchNorm': self.convert_batch_norm, + 'Concat': self.convert_concat, + 'Convolution': self.convert_conv, + 'Crop': self.convert_crop, + 'Deconvolution': self.convert_deconv, + 'Dropout': self.convert_dropout, + 'Eltwise': self.convert_eltwise, + 'Flatten': self.convert_flatten, + 'InnerProduct': self.convert_innerproduct, + 'Input': None, + 'LRN': self.convert_lrn, + 'Pooling': self.convert_pooling, + 'PReLU': self.convert_prelu, + 'ReLU': self.convert_relu, + 'Reshape': self.convert_reshape, + 'Scale': self.convert_scale, + 'Sigmoid': self.convert_sigmoid, + 'Slice': self.convert_slice, + 'Softmax': self.convert_softmax, + 'TanH': self.convert_tanh, + } + + def convert_flatten(self, op): + """ Convert Flatten layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + + flatten_params = op.flatten_param.axis + assert flatten_params == 1, "flatten axis should be 1" + out = _op.nn.batch_flatten(in_expr) + + return out + + def convert_eltwise(self, op): + """ Convert Eltwise layer """ + inputs = op.bottom + assert len(inputs) == 2, "input tensors length should be 2" + + lhs_expr = self.exp_tab.get_expr(inputs[0]) + rhs_expr = self.exp_tab.get_expr(inputs[1]) + + lhs_shape = _infer_shape(lhs_expr) + rhs_shape = _infer_shape(rhs_expr) + + assert lhs_shape == rhs_shape, "input tensors shape should be equal" + + eltwise_params = op.eltwise_param + eltwise_type_dict = ['PROD', 'SUM', 'MAX'] + eltwise_type = eltwise_params.operation + coeff = list(eltwise_params.coeff) + + if eltwise_type_dict[eltwise_type] == 'PROD': + out = _op.multiply(lhs_expr, rhs_expr) + elif eltwise_type_dict[eltwise_type] == 'SUM': + if coeff: + left_coeff_expr = self.exp_tab.new_const( + np.asarray(coeff[0], np.float32)) + right_coeff_expr = self.exp_tab.new_const( + np.asarray(coeff[1], np.float32)) + lhs_expr_scale = _op.multiply(lhs_expr, left_coeff_expr) + rhs_expr_scale = _op.multiply(rhs_expr, right_coeff_expr) + out = _op.add(lhs_expr_scale, rhs_expr_scale) + else: + out = _op.add(lhs_expr, rhs_expr) + elif eltwise_type_dict[eltwise_type] == 'MAX': + out = _op.maximum(lhs_expr, rhs_expr) + else: + raise tvm.error.OpNotImplemented( + "eltwise_type {} is not supported for frontend Caffe.".format( + eltwise_type)) + + return out + + def _parse_conv_params(self, op): + """ Parse the parameters of Convolution and Deconvolution layer """ + nonzone = lambda val, pos, dflt: val[pos] if pos < len(val) else dflt + + conv_params = op.convolution_param + + params = dict() + # parse kernel size + if conv_params.kernel_h > 0 or conv_params.kernel_w > 0: + params['kernel_size'] = (conv_params.kernel_h, + conv_params.kernel_w) + else: + ksize_h = nonzone(conv_params.kernel_size, 0, 1) + ksize_w = nonzone(conv_params.kernel_size, 1, ksize_h) + params['kernel_size'] = (ksize_h, ksize_w) + + # parse padding size + if conv_params.pad_h > 0 or conv_params.pad_w > 0: + params['padding'] = (conv_params.pad_h, conv_params.pad_w) + else: + pad_h = nonzone(conv_params.pad, 0, 0) + pad_w = nonzone(conv_params.pad, 1, pad_h) + params['padding'] = (pad_h, pad_w) + + # parse stride size + if conv_params.stride_h > 0 or conv_params.stride_w > 0: + params['strides'] = (conv_params.stride_h, conv_params.stride_w) + else: + stride_h = nonzone(conv_params.stride, 0, 1) + stride_w = nonzone(conv_params.stride, 1, stride_h) + params['strides'] = (stride_h, stride_w) + + # parse dilation size + if hasattr(conv_params, 'dilation') and len(conv_params.dilation) > 0: + dilation = ' '.join(str(d) for d in conv_params.dilation) + dilation = tuple(map(int, dilation.split(' '))) + params['dilation'] = dilation + if len(dilation) == 1: + params['dilation'] = (dilation[0], dilation[0]) + + params['kernel_layout'] = 'OIHW' + params['data_layout'] = 'NCHW' + params['groups'] = conv_params.group + params['channels'] = conv_params.num_output + return params + + def convert_batch_norm(self, op): + """ Convert BatchNorm layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + n, c, h, w = _infer_shape(in_expr) + + if op.name in self.new_bn: + mean, var, eps, gamma, beta = self.new_bn[op.name] + mean_expr = self.exp_tab.new_const(mean, dtype='float32') + var_expr = self.exp_tab.new_const(var, dtype='float32') + gamma_expr = self.exp_tab.new_const(gamma, dtype='float32') + beta_expr = self.exp_tab.new_const(beta, dtype='float32') + out = _op.nn.batch_norm(in_expr, + gamma_expr, + beta_expr, + mean_expr, + var_expr, + epsilon=eps, + scale=True) + + else: + weight_bias_blobs = self.init_layer_dict[op.name].blobs + mean = np.asarray(weight_bias_blobs[0].data, np.float32) + var = np.asarray(weight_bias_blobs[1].data, np.float32) + if len(weight_bias_blobs) == 2: + mean = np.repeat(mean, h * w).reshape((c, h, w)) + mean = np.expand_dims(mean, 0).repeat(n, axis=0) + mean_expr = self.exp_tab.new_const(mean, dtype='float32') + + var = np.repeat(var, h * w).reshape((c, h, w)) + var = np.expand_dims(var, 0).repeat(n, axis=0) + var_expr = self.exp_tab.new_const(var, dtype='float32') + + tmp_out = _op.multiply(in_expr, mean_expr) + out = _op.add(tmp_out, var_expr) + + return out + else: + scale = np.asarray(weight_bias_blobs[2].data, np.float32) + if scale: + scale = 1 / scale + mean_expr = self.exp_tab.new_const(mean * scale, dtype='float32') + var_expr = self.exp_tab.new_const(var * scale, dtype='float32') + + #caffe bn layer not support scale + gamma_expr = self.exp_tab.new_const(np.ones(mean.shape, + dtype=np.float32), + dtype='float32') + beta_expr = self.exp_tab.new_const(np.zeros(mean.shape, + dtype=np.float32), + dtype='float32') + + bn_params = op.batch_norm_param.eps + out = _op.nn.batch_norm(in_expr, + gamma_expr, + beta_expr, + mean_expr, + var_expr, + epsilon=bn_params, + scale=False) + + return out[0] + + def convert_scale(self, op): + """ Convert Scale layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + weight_bias_blobs = self.init_layer_dict[op.name].blobs + + params = dict() + params['bias'] = op.scale_param.bias_term + params['axis'] = op.scale_param.axis + + gamma = np.asarray(weight_bias_blobs[0].data, np.float32) + gamma_expr = self.exp_tab.new_const(gamma, dtype='float32') + if params['bias']: + beta = np.asarray(weight_bias_blobs[1].data, np.float32) + beta_expr = self.exp_tab.new_const(beta, dtype='float32') + else: + beta_expr = self.exp_tab.new_const(np.zeros(gamma.shape, + dtype=np.float32), + dtype='float32') + + _, c, _, _ = _infer_shape(in_expr) + gamma_expr = _op.reshape(gamma_expr, newshape=(1, c, 1, 1)) + beta_expr = _op.reshape(beta_expr, newshape=(1, c, 1, 1)) + out = _op.multiply(in_expr, gamma_expr) + out = _op.add(out, beta_expr) + + return out + + def convert_concat(self, op): + """ Convert Concat layer """ + inputs = op.bottom + in_expr = (self.exp_tab.get_expr(inputs[i]) + for i in range(len(inputs))) + + c_params = dict() + c_params['axis'] = op.concat_param.axis + out = _op.concatenate(in_expr, axis=c_params['axis']) + + return out + + def convert_reshape(self, op): + """ Convert Reshape layer """ + inputs = op.bottom + input_name = inputs[0] + + reshape_param = op.reshape_param + dims = list(reshape_param.shape.dim) + + in_expr = self.exp_tab.get_expr(input_name) + input_shape = list(_infer_shape(in_expr)) + + start_axis = int(reshape_param.axis) + if start_axis < 0: + start_axis = len(input_shape) + start_axis + 1 + num_axes = int(reshape_param.num_axes) + end_axis = len(input_shape) + if num_axes != -1: + end_axis = start_axis + num_axes + + left_shape = input_shape[:start_axis] + if end_axis == len(input_shape): + center_shape = input_shape[start_axis:] + right_shape = [] + else: + center_shape = input_shape[start_axis:end_axis] + right_shape = input_shape[end_axis:] + + for idx, dim in enumerate(dims): + if dim == 0: + dims[idx] = center_shape[idx] + + tmp = np.random.rand(*center_shape) + tmp = np.reshape(tmp, dims) + center_shape = list(tmp.shape) + + newshape = left_shape + center_shape + right_shape + + out = _op.reshape(in_expr, newshape=newshape) + return out + + def convert_softmax(self, op): + """ Convert Softmax layer """ + inputs = op.bottom + assert len(inputs) == 1, "input tensors length should be 1" + + input_name = inputs[0] + in_expr = self.exp_tab.get_expr(input_name) + + softmax_param = op.softmax_param + parmas = {'axis': softmax_param.axis} + + out = _op.nn.softmax(in_expr, **parmas) + + return out + + def convert_conv(self, op): + """ Convert Convolution layer """ + params = self._parse_conv_params(op) + weight_bias_blobs = self.init_layer_dict[op.name].blobs + conv_params = op.convolution_param + inputs = op.bottom + # process weight and bias blobs + weight, bias = None, None + if len(weight_bias_blobs) > 1: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + else: + weight = weight_bias_blobs[0] + if weight: + kh, kw = params['kernel_size'] + weight_shape = [conv_params.num_output, -1, kh, kw] + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, weight_shape) + else: + raise Exception('No weight value of layer {} in caffemodel'.format( + op.name)) + + weight_expr = self.exp_tab.new_const(weight_value, dtype='float32') + in_expr = self.exp_tab.get_expr(inputs[0]) + out = _op.nn.conv2d(data=in_expr, weight=weight_expr, **params) + if bias: + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype='float32') + out = _op.nn.bias_add(out, bias_expr) + return out + + def convert_pooling(self, op): + """ Convert Pooling layer """ + inputs = op.bottom + input_name = inputs[0] + + pool_params = op.pooling_param + pool_type_dict = ['MAX', 'AVE', 'STOCHASTIC'] + + params = dict() + # parse pool type: 0: MAX, 1: AVE, 2: STOCHASTIC + pool_type = pool_params.pool + # parse kernel size + if pool_params.kernel_h > 0 or pool_params.kernel_w > 0: + params['pool_size'] = (pool_params.kernel_h, pool_params.kernel_w) + else: + params['pool_size'] = (pool_params.kernel_size, + pool_params.kernel_size) + + # parse padding size + if pool_params.pad_h > 0 or pool_params.pad_w > 0: + params['padding'] = (pool_params.pad_h, pool_params.pad_w) + else: + params['padding'] = (pool_params.pad, pool_params.pad) + + # parse stride size + if pool_params.stride_h > 0 or pool_params.stride_w > 0: + params['strides'] = (pool_params.stride_h, pool_params.stride_w) + else: + params['strides'] = (pool_params.stride, pool_params.stride) + + params['ceil_mode'] = True + if hasattr(pool_params, 'ceil_mode'): + params['ceil_mode'] = pool_params.ceil_mode + + in_expr = self.exp_tab.get_expr(input_name) + + if pool_type_dict[pool_type] == 'MAX': + if pool_params.global_pooling: + out = _op.nn.global_max_pool2d(in_expr) + else: + if len(op.top) == 1: + out = _op.nn.max_pool2d(in_expr, **params) + elif len(op.top) == 2: + out1 = _op.nn.max_pool2d_with_argmax(in_expr, **params) + out2 = _op.vision.max_pool2d_location(in_expr, **params) + return _expr.Tuple((out1, out2)) + + elif pool_type_dict[pool_type] == 'AVE': # AVE + if pool_params.global_pooling: + out = _op.nn.global_avg_pool2d(in_expr) + else: + params['count_include_pad'] = True + out = _op.nn.avg_pool2d(in_expr, **params) + else: + raise tvm.error.OpNotImplemented( + "Operator {} is not supported for frontend Caffe.".format( + pool_type_dict[pool_type] + ' pool')) + + return out + + def convert_lrn(self, op): + """ Convert LRN layer """ + inputs = op.bottom + input_name = inputs[0] + + params = dict() + lrn_params = op.lrn_param + params['size'] = lrn_params.local_size + params['bias'] = lrn_params.k + params['alpha'] = lrn_params.alpha + params['beta'] = lrn_params.beta + + in_expr = self.exp_tab.get_expr(input_name) + out = _op.nn.lrn(in_expr, **params) + return out + + def convert_innerproduct(self, op): + """ Convert InnerProduct layer """ + inputs = op.bottom + weight_bias_blobs = self.init_layer_dict[op.name].blobs + dense_params = op.inner_product_param + + params = dict() + params["num_output"] = dense_params.num_output + params["bias"] = dense_params.bias_term + params["axis"] = dense_params.axis + if params["axis"] != 1: + raise Exception("Only support 2D InnerProduct") + + # process weight and bias blobs + weight, bias = None, None + if params["bias"]: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + else: + weight = weight_bias_blobs[0] + + if weight: + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, (params["num_output"], -1)) + weight_shape = weight_value.shape + else: + raise Exception('No weight value of layer {} in caffemodel'.format( + op.name)) + + weight_expr = self.exp_tab.new_const(weight_value, dtype='float32') + + in_expr = self.exp_tab.get_expr(inputs[0]) + in_reshape = _op.reshape(data=in_expr, newshape=(-1, weight_shape[-1])) + + out = _op.nn.dense(data=in_reshape, weight=weight_expr) + + if bias: + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype='float32') + out = _op.nn.bias_add(out, bias_expr, axis=params["axis"]) + return out + + def convert_dropout(self, op): + """ Convert Dropout layer """ + inputs = op.bottom + input_name = inputs[0] + + params = dict() + dropout_params = op.dropout_param + + params['rate'] = dropout_params.dropout_ratio + + in_expr = self.exp_tab.get_expr(input_name) + out = _op.nn.dropout(in_expr, **params) + return out + + def convert_relu(self, op): + """ Convert ReLU layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + negative_slope = op.relu_param.negative_slope + if negative_slope: + out = _op.nn.leaky_relu(in_expr, negative_slope) + return out + + out = _op.nn.relu(in_expr) + return out + + def convert_prelu(self, op): + """ Convert PReLU layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + + alpha = self.init_layer_dict[op.name].blobs[0].data + alpha = np.asarray(alpha, np.float32) + alpha = self.exp_tab.new_const(alpha, dtype='float32') + axis = 1 + out = _op.nn.prelu(in_expr, alpha, axis=axis) + return out + + def convert_deconv(self, op): + """ Convert Deconvolution layer """ + params = self._parse_conv_params(op) + weight_bias_blobs = self.init_layer_dict[op.name].blobs + conv_params = op.convolution_param + inputs = op.bottom + + # process weight and bias blobs + weight, bias = None, None + if len(weight_bias_blobs) > 1: + weight = weight_bias_blobs[0] + bias = weight_bias_blobs[1] + else: + weight = weight_bias_blobs[0] + if weight: + kh, kw = params['kernel_size'] + weight_shape = [-1, conv_params.num_output, kh, kw] + weight_value = np.asarray(weight.data, np.float32) + weight_value = np.reshape(weight_value, weight_shape) + else: + raise Exception('No weight value of layer {} in caffemodel'.format( + op.name)) + + weight_expr = self.exp_tab.new_const(weight_value, dtype='float32') + in_expr = self.exp_tab.get_expr(inputs[0]) + out = _op.nn.conv2d_transpose(data=in_expr, + weight=weight_expr, + **params) + if bias: + + bias_value = np.asarray(bias.data, np.float32) + bias_expr = self.exp_tab.new_const(bias_value, dtype='float32') + out = _op.nn.bias_add(out, bias_expr) + return out + + def convert_slice(self, op): + """ Convert Slice layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + + output_num = len(op.top) + + slice_params = op.slice_param + axis = int(slice_params.axis) + indices_or_sections = list([int(s) for s in slice_params.slice_point]) + if len(indices_or_sections) == 0: + indices_or_sections = output_num + else: + indices_or_sections = sorted(indices_or_sections) + + out = _op.split(in_expr, + indices_or_sections=indices_or_sections, + axis=axis) + return out + + def convert_sigmoid(self, op): + """ Convert Sigmoid layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + out = _op.sigmoid(in_expr) + return out + + def convert_tanh(self, op): + """ Convert TanH layer """ + inputs = op.bottom + in_expr = self.exp_tab.get_expr(inputs[0]) + out = _op.tanh(in_expr) + return out + + def convert_crop(self, op): + """ Convert Crop layer """ + inputs = op.bottom + assert len(inputs) == 2, "Need two inputs of Crop layer" + in_expr_a = self.exp_tab.get_expr(inputs[0]) + in_expr_b = self.exp_tab.get_expr(inputs[1]) + + # parse crop params + crop_params = op.crop_param + axis = int(getattr(crop_params, 'axis', 2)) + offset = list(getattr(crop_params, 'offset', 0)) + + # expand offset to (offset1, offset2, ...) + in_a_shape = _infer_shape(in_expr_a) + num_to_crop = len(in_a_shape) - axis + if not offset: + offset = [0] * num_to_crop + if len(offset) == 1: + offset = offset * num_to_crop + elif len(offset) != num_to_crop: + raise Exception("No matching the number between axis and offset!") + + slice_end = in_a_shape + slice_start = [0] * len(in_a_shape) + for i in range(num_to_crop): + slice_start[i + axis] = offset[i] + + to_crop_axis = list(range(len(in_a_shape))) + to_crop_axis = to_crop_axis[axis:] + + # secondly, crop in_expr_a by in_expr_b + in_expr_a_stride = _op.strided_slice(in_expr_a, slice_start, slice_end) + out = _op.slice_like(in_expr_a_stride, in_expr_b, axes=to_crop_axis) + return out + + + def check_unsupported_ops(self): + """Check unsupported Caffe ops in our converter.""" + unsupported_ops_set = set() + + include_layer = dict() + for pl in self.predict_layer: + if pl.type not in include_layer: + include_layer[pl.type] = 1 + else: + include_layer[pl.type] = include_layer[pl.type] + 1 + + for pl in self.predict_layer: + op_name = pl.type + if op_name not in self.convert_map: + unsupported_ops_set.add(op_name) + + if unsupported_ops_set: + msg = 'The following operators are not supported in frontend ' \ + 'Caffe: {}' + ops = str(list(unsupported_ops_set)).strip('[,]') + raise tvm.error.OpNotImplemented(msg.format(ops)) + + def fuse_op(self, layers): + """ Fusing the BatchNorm and Scale layer """ + bn, scale = layers["bn"], layers["scale"] + + # bn params + bn_weight_bias_blobs = self.init_layer_dict[bn.name].blobs + bn_scale = np.asarray(bn_weight_bias_blobs[2].data, np.float32) + if bn_scale: + bn_scale = 1 / bn_scale + bn_mean = np.asarray(bn_weight_bias_blobs[0].data, + np.float32) * bn_scale + bn_var = np.asarray(bn_weight_bias_blobs[1].data, + np.float32) * bn_scale + bn_eps = bn.batch_norm_param.eps + + # scale params + scale_weight_bias_blobs = self.init_layer_dict[scale.name].blobs + scale_gamma = np.asarray(scale_weight_bias_blobs[0].data, np.float32) + scale_bias = scale.scale_param.bias_term + if scale_bias: + scale_beta = np.asarray(scale_weight_bias_blobs[1].data, + np.float32) + else: + scale_beta = np.zeros(scale_gamma.shape, dtype=np.float32) + + # new params + self.new_bn[bn.name] = [ + bn_mean, bn_var, bn_eps, scale_gamma, scale_beta + ] + return bn + + def op_fuse(self): + """fuse bn and scale """ + new_layers = [] + temp_layers = {} + changed_layers = {} + + for index, pl in enumerate(self.predict_layer): + op_type = pl.type + if op_type == "Input": + new_layers.append(pl) + continue + elif op_type == "BatchNorm": + if (index != len(self.predict_layer) - 1) and ( + self.predict_layer[index + 1].type == "Scale"): + temp_layers["bn"] = pl + continue + else: + new_layers.append(pl) + temp_layers.clear() + elif op_type == "Scale": + if self.predict_layer[index - 1].type == "BatchNorm": + temp_layers["scale"] = pl + else: + new_layers.append(pl) + temp_layers.clear() + else: + temp_layers.clear() + + if len(temp_layers) == 2: + layer = self.fuse_op(temp_layers) + new_layers.append(layer) + changed_layers[ + temp_layers["scale"].name] = temp_layers['bn'].name + + for idx, plt in enumerate(pl.bottom): + if plt in changed_layers: + pl.bottom[idx] = changed_layers[plt] + + if op_type not in ['BatchNorm', 'Scale']: + new_layers.append(pl) + + self.predict_layer = new_layers + self.changed_layers = changed_layers + + def convert_op_to_relay(self): + """Convert Caffe ops to relay ops""" + for pl in self.predict_layer: + op_type = pl.type + if op_type == "Input": + continue + output_tensors = pl.top + + ret = self.convert_map[op_type](pl) + + if len(output_tensors) == 1: + self.exp_tab.set_expr(output_tensors[0], ret) + else: + for idx, output_tensor in enumerate(output_tensors): + self.exp_tab.set_expr(output_tensor, ret[idx]) + + +def _rebuild_layers(predict_layer): + """Rebuild caffe layer. If the the caffe net include in-place layers, repalce its top + with its name and update the bottom of other layer that is related to it. + """ + # dict of input name that will be changed to new name + changed_top_dict = dict() + + for pl in predict_layer: + if pl.type == "Input": + continue + # if current layer has single input and output and input equals to output + # it means that the layer does "in-place" + if (len(pl.top) == 1 and len(pl.bottom) == 1): + if pl.top[0] == pl.bottom[0]: + # change current layer's input firstly + if pl.bottom[0] in changed_top_dict: + pl.bottom[0] = changed_top_dict[pl.bottom[0]] + # update "change" dict + changed_top_dict[pl.top[0]] = pl.name + # change current layer's output to its name + pl.top[0] = pl.name + else: + if pl.bottom[0] in changed_top_dict: + pl.bottom[0] = changed_top_dict[pl.bottom[0]] + # if the layer does not + else: + for index, plt in enumerate(pl.bottom): + if plt in changed_top_dict: + pl.bottom[index] = changed_top_dict[plt] + + +def _get_inputs_outputs(predict_layer): + """Obtain Caffe model's inputs and outpus""" + # model inputs / outputs + model_inputs = list() + model_outputs = list() + + # The bottoms of every layer can not be as outputs + not_outputs = set() + for pl in predict_layer: + if pl.type == "Input": + assert len( + pl.top + ) == 1, "The number of Input layer's output is more than 1." + model_inputs.append(pl.top[0]) + for i in pl.bottom: + not_outputs.add(i) + + for pl in predict_layer: + if len(pl.bottom) > 0: + for t in pl.top: + if t not in not_outputs: + model_outputs.append(t) + return model_inputs, model_outputs + + +def from_caffe(init_net, predict_net, shape_dict, dtype_dict): + """Convert from caffe model into compatible relay Function. + + Parameters + ---------- + init_net : caffe_pb2.NetParameter + caffemodel + predict_net : caffe_pb2.NetParameter + caffe prototxt + shape_dict : dict of str to int list/tuple + Input shapes of the model. + dtype_dict : dict of str to str + Input types of the model. + + Returns + ------- + mod : tvm.relay.Module + The relay module for compilation. + + params : dict of str to tvm.NDArray + The parameter dict to be used by relay + """ + old_caffe = False + if len(predict_net.input) != 0: # old caffe version + old_caffe = True + model_inputs = list(predict_net.input) + + predict_layer = predict_net.layer + + # replace layer's top with its name and update other layers'bottoms + _rebuild_layers(predict_layer) + # obtain inputs and outputs of Net + if old_caffe: + _, model_outputs = _get_inputs_outputs(predict_layer) + else: + model_inputs, model_outputs = _get_inputs_outputs(predict_layer) + + exp_tab = ExprTable() + for in_name in model_inputs: + shape = shape_dict[in_name] if in_name in shape_dict else None + dtype = dtype_dict[in_name] if in_name in dtype_dict else "float32" + exp_tab.set_expr(in_name, _expr.var(in_name, shape=shape, dtype=dtype)) + if list(init_net.layer): + init_layer = init_net.layer + else: + init_layer = init_net.layers + init_layer_dict = {il.name: il for il in init_layer} + # op code in model + op_converter = OperatorConverter(init_layer_dict, predict_layer, exp_tab) + op_converter.check_unsupported_ops() + op_converter.op_fuse() + op_converter.convert_op_to_relay() + + # params and outputs + params = {k: _nd.array(np.array(v)) for k, v in exp_tab.params.items()} + outputs = list() + for n in model_outputs: + if n in op_converter.changed_layers: + n = op_converter.changed_layers[n] + outputs.append(exp_tab.get_expr(n)) + outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) + func = _function.Function(analysis.free_vars(outputs), outputs) + mod = IRModule.from_expr(func) + + return mod, params diff --git a/tests/python/frontend/caffe/test_forward.py b/tests/python/frontend/caffe/test_forward.py new file mode 100644 index 000000000000..8567e4b4f565 --- /dev/null +++ b/tests/python/frontend/caffe/test_forward.py @@ -0,0 +1,968 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name, unused-argument +""" +Caffe testcases +==================== +This article is a test script to test Caffe operator with Relay. +""" +import os +os.environ['GLOG_minloglevel'] = '2' +import sys +import logging +logging.basicConfig(level=logging.ERROR) + +import numpy as np +from google.protobuf import text_format +import caffe +from caffe import layers as L, params as P +from caffe.proto import caffe_pb2 as pb + +import tvm +from tvm import relay +from tvm.contrib import util, graph_runtime +from tvm.contrib.download import download_testdata + +CURRENT_DIR = os.path.join(os.path.expanduser('~'), '.tvm_test_data', 'caffe_test') + +####################################################################### +# Generic functions for TVM & Caffe +# ------------------------------------------ + + +def _create_dir(d_path): + """ If the directory is not existed, create it""" + if not (os.path.exists(d_path) and os.path.isdir(d_path)): + os.makedirs(d_path) + + +def _list_to_str(ll): + """ Convert list or tuple to str, separated by underline. """ + if isinstance(ll, (tuple, list)): + tmp = [str(i) for i in ll] + return '_'.join(tmp) + + +def _gen_filename_str(op_name, data_shape, *args, **kwargs): + """ Combining the filename according to the op_name, shape and other args. """ + file_dir = os.path.join(CURRENT_DIR, op_name) + _create_dir(file_dir) + res = op_name + "_" + shape_str = _list_to_str(list(data_shape)) + res += shape_str + for arg in args: + if isinstance(arg, (tuple, list)): + res += ("_" + _list_to_str(arg)) + elif isinstance(arg, (int, float, str)): + res += ("_" + str(arg)) + for _, v in kwargs.items(): + if isinstance(v, (tuple, list)): + res += ("_" + _list_to_str(v)) + elif isinstance(v, (int, float, str)): + res += ("_" + str(v)) + res = res.replace(".", "_") + res = res.replace("-", "_") + proto_file = os.path.join(file_dir, res + ".prototxt") + blob_file = os.path.join(file_dir, res + ".caffemodel") + solver_file = os.path.join(file_dir, res + "_solver.prototxt") + + return (proto_file, blob_file, solver_file) + + +def _save_prototxt(n_netspec, f_path): + """ Generate .prototxt file according to caffe.NetSpec""" + s = n_netspec.to_proto() + with open(f_path, 'w') as f: + f.write(str(s)) + + +def _save_solver(solver_file, proto_file, blob_file): + """ Define a solver proto, you can change the configs.""" + blob_file_prefix = blob_file.split(".caffemodel")[0] + s = pb.SolverParameter() + s.train_net = proto_file + s.base_lr = 0.01 + s.momentum = 0.9 + s.weight_decay = 0.0005 + s.lr_policy = "inv" + s.gamma = 0.0001 + s.power = 0.75 + s.display = 1 + s.max_iter = 100000 + s.snapshot = 100000 + s.snapshot_prefix = blob_file_prefix + + with open(solver_file, 'w') as f: + f.write(str(s)) + + +def _save_caffemodel(solver_file, blob_file): + """ Generate .caffemodel file.""" + solver = caffe.SGDSolver(solver_file) + solver.net.save(blob_file) + + +def _gen_model_files(n_netspec, proto_file, blob_file, solver_file): + _save_prototxt(n_netspec, proto_file) + _save_solver(solver_file, proto_file, blob_file) + _save_caffemodel(solver_file, blob_file) + + +def _siso_op(data, func, *args, **kwargs): + """ Create single input and single output Caffe op """ + n = caffe.NetSpec() + n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}}) + n.output = func(n.data, *args, **kwargs) + return n + + +def _miso_op(data_list, func, *args, **kwargs): + """ Create multi input and single output Caffe op """ + n = caffe.NetSpec() + if not isinstance(data_list, (tuple, list)): + raise TypeError("Need tuple or list but get {}".format( + type(data_list))) + input_list = list() + for idx, data in enumerate(data_list): + n['data' + + str(idx)] = L.Input(input_param={'shape': { + 'dim': list(data.shape) + }}) + input_list.append(n['data' + str(idx)]) + n.output = func(*input_list, *args, **kwargs) + return n + + +def _simo_op(data, func, *args, **kwargs): + """ Create single input and multi output Caffe op """ + n = caffe.NetSpec() + n.data = L.Input(input_param={'shape': {'dim': list(data.shape)}}) + output_list = func(n.data, *args, **kwargs) + for idx, out in enumerate(output_list): + n['output' + str(idx)] = out + return n + + +def _run_caffe(data, proto_file, blob_file): + """ Run caffe model by Caffe according to .caffemodel and .prototxt""" + net = caffe.Net(proto_file, blob_file, caffe.TEST) + if isinstance(data, (list, tuple)): + for idx, d in enumerate(data): + net.blobs['data' + str(idx)].data[...] = d + else: + net.blobs['data'].data[...] = data + out = net.forward() + + caffe_output = list() + for i in range(len(out.keys())): + if 'output'+str(i) not in out.keys(): + caffe_output.clear() + return list(out.values()) + caffe_output.append(out['output'+str(i)]) + return caffe_output + + +def _run_tvm(data, proto_file, blob_file): + """ Run caffe model by TVM according to .caffemodel and .prototxt""" + init_net = pb.NetParameter() + predict_net = pb.NetParameter() + + # load model + with open(proto_file, 'r') as f: + text_format.Merge(f.read(), predict_net) + # load blob + with open(blob_file, 'rb') as f: + init_net.ParseFromString(f.read()) + + shape_dict = dict() + dtype_dict = dict() + if isinstance(data, (tuple, list)): + for idx, d in enumerate(data): + shape_dict['data' + str(idx)] = d.shape + dtype_dict['data' + str(idx)] = 'float32' + else: + shape_dict = {'data': data.shape} + dtype_dict = {'data': 'float32'} + + mod, params = relay.frontend.from_caffe( + init_net, predict_net, shape_dict, dtype_dict) + + target = 'llvm' + target_host = 'llvm' + + ctx = tvm.cpu(0) + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, + target=target, + target_host=target_host, + params=params) + dtype = 'float32' + m = graph_runtime.GraphModule(lib['default'](ctx)) + if isinstance(data, (tuple, list)): + for idx, d in enumerate(data): + m.set_input('data' + str(idx), tvm.nd.array(d.astype(dtype))) + else: + m.set_input('data', tvm.nd.array(data.astype(dtype))) + # execute + m.run() + tvm_output = list() + # get outputs + for i in range(m.get_num_outputs()): + tvm_output.append(m.get_output(i).asnumpy()) + return tvm_output + + +def _compare_caffe_tvm(caffe_out, tvm_out, is_network=False): + for i in range(len(caffe_out)): + if is_network: + caffe_out[i] = caffe_out[i][:1] + tvm.testing.assert_allclose(caffe_out[i], + tvm_out[i], + rtol=1e-5, + atol=1e-5) + + +def _test_op(data, func_op, op_name, **kwargs): + """ Single op testing pipline. """ + shape_list = list() + if isinstance(data, (list, tuple)): + n = _miso_op(data, func_op, **kwargs) + for d in data: + shape_list.extend(list(d.shape)) + else: + output_num = 1 + if 'ntop' in kwargs.keys(): + output_num = kwargs['ntop'] + if output_num == 1: + n = _siso_op(data, func_op, **kwargs) + else: + n = _simo_op(data, func_op, **kwargs) + shape_list = list(data.shape) + + # obtain the .caffemodel file and .prototxt file + (proto_file, blob_file, + solver_file) = _gen_filename_str(op_name, shape_list, **kwargs) + _gen_model_files(n, proto_file, blob_file, solver_file) + # run model in Caffe + caffe_out = _run_caffe(data, proto_file, blob_file) + # run model in TVM + tvm_out = _run_tvm(data, proto_file, blob_file) + _compare_caffe_tvm(caffe_out, tvm_out) + + +def _test_network(data, proto_file, blob_file): + # run model in Caffe + caffe_out = _run_caffe(data, proto_file, blob_file) + # run model in TVM + tvm_out = _run_tvm(data, proto_file, blob_file) + _compare_caffe_tvm(caffe_out, tvm_out, is_network=True) + + +####################################################################### +# BatchNorm +# ----------- + + +def _test_batchnorm(data, moving_average_fraction=0.999, eps=1e-5): + """ One iteration of BatchNorm """ + _test_op(data, + L.BatchNorm, + "BatchNorm", + moving_average_fraction=moving_average_fraction, + eps=eps) + + +def test_forward_BatchNorm(): + """ BatchNorm """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_batchnorm(data) + _test_batchnorm(data, moving_average_fraction=0.88, eps=1e-4) + + +####################################################################### +# Concat +# ----------- + + +def _test_concat(data_list, axis=1): + """ One iteration of Concat """ + _test_op(data_list, L.Concat, "Concat", axis=axis) + + +def test_forward_Concat(): + """ Concat """ + _test_concat([np.random.rand(1, 3, 10, 10), + np.random.rand(1, 2, 10, 10)], + axis=1) + _test_concat([np.random.rand(3, 10, 10), + np.random.rand(2, 10, 10)], + axis=0) + _test_concat([np.random.rand(3, 10), np.random.rand(2, 10)], axis=0) + + +####################################################################### +# Convolution +# ----------- + + +def _test_convolution(data, **kwargs): + """ One iteration of Convolution """ + _test_op(data, L.Convolution, "Convolution", **kwargs) + + +def test_forward_Convolution(): + """ Convolution """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_convolution(data, + num_output=20, + bias_term=True, + pad=0, + kernel_size=3, + stride=2, + dilation=1, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier")) + _test_convolution(data, + num_output=20, + bias_term=False, + pad=[1, 2], + kernel_size=3, + stride=2, + dilation=1, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier")) + _test_convolution(data, + num_output=20, + bias_term=True, + pad=[1, 2], + kernel_size=[3, 5], + stride=[2, 1], + dilation=[1, 2], + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier")) + _test_convolution(np.random.rand(1, 2, 10, 10).astype(np.float32), + num_output=20, + bias_term=True, + pad=[1, 2], + kernel_size=[3, 5], + stride=[2, 1], + dilation=[1, 2], + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"), + group=2) + _test_convolution(data, + num_output=20, + bias_term=True, + pad_h=1, + pad_w=2, + kernel_h=3, + kernel_w=5, + stride_h=2, + stride_w=1, + dilation=[1, 2], + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier")) + + +####################################################################### +# Crop +# ----------- + + +def _test_crop(data, **kwargs): + """ One iteration of Crop """ + _test_op(data, L.Crop, "Crop", **kwargs) + + +def test_forward_Crop(): + """ Crop """ + _test_crop( + [np.random.rand(10, 10, 120, 120), + np.random.rand(10, 5, 50, 60)]) + _test_crop( + [np.random.rand(10, 10, 120, 120), + np.random.rand(10, 5, 50, 60)], + axis=1) + _test_crop( + [np.random.rand(10, 10, 120, 120), + np.random.rand(10, 5, 50, 60)], + axis=1, + offset=2) + _test_crop( + [np.random.rand(10, 10, 120, 120), + np.random.rand(10, 5, 50, 60)], + axis=1, + offset=[1, 2, 4]) + _test_crop( + [np.random.rand(10, 10, 120, 120), + np.random.rand(10, 5, 50, 60)], + axis=2, + offset=[2, 4]) + _test_crop([np.random.rand(10, 120, 120), + np.random.rand(5, 50, 60)], + axis=1, + offset=[2, 4]) + _test_crop([np.random.rand(120, 120), + np.random.rand(50, 60)], + axis=0, + offset=[2, 4]) + + +####################################################################### +# Deconvolution +# ----------- + + +def _test_deconvolution(data, **kwargs): + """ One iteration of Deconvolution """ + _test_op(data, L.Deconvolution, "Deconvolution", **kwargs) + + +def test_forward_Deconvolution(): + """ Deconvolution """ + data = np.random.rand(1, 16, 32, 32).astype(np.float32) + _test_deconvolution(data, + convolution_param=dict( + num_output=20, + bias_term=True, + pad=0, + kernel_size=3, + stride=2, + dilation=1, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"))) + _test_deconvolution(data, + convolution_param=dict( + num_output=20, + bias_term=False, + pad=[1, 2], + kernel_size=3, + stride=2, + dilation=1, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"))) + _test_deconvolution(data, + convolution_param=dict( + num_output=20, + bias_term=True, + pad_h=1, + pad_w=2, + kernel_h=3, + kernel_w=5, + stride_h=2, + stride_w=1, + dilation=1, + weight_filler=dict(type="xavier"), + bias_filler=dict(type="xavier"))) + + +####################################################################### +# Dropout +# ----------- + + +def _test_dropout(data, **kwargs): + """ One iteration of Dropout """ + _test_op(data, L.Dropout, "Dropout", **kwargs) + + +def test_forward_Dropout(): + """ Dropout """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_dropout(data) + _test_dropout(data, dropout_ratio=0.7) + + +####################################################################### +# Eltwise +# ----------- + + +def _test_eltwise(data_list, **kwargs): + """ One iteration of Eltwise """ + _test_op(data_list, L.Eltwise, "Eltwise", **kwargs) + + +def test_forward_Eltwise(): + """ Eltwise """ + _test_eltwise([ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32) + ], + operation=0) + _test_eltwise([ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32) + ], + operation=1) + _test_eltwise([ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32) + ], + operation=2) + _test_eltwise([ + np.random.rand(1, 3, 10, 11).astype(np.float32), + np.random.rand(1, 3, 10, 11).astype(np.float32) + ], + operation=1, + coeff=[0.5, 1]) + + +####################################################################### +# Flatten +# ----------- + + +def _test_flatten(data, axis=1): + """ One iteration of Flatten """ + _test_op(data, L.Flatten, 'Flatten', axis=axis) + + +def test_forward_Flatten(): + """ Flatten """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_flatten(data) + _test_flatten(data, axis=1) + + +####################################################################### +# Flatten +# ----------- + + +def _test_inner_product(data, **kwargs): + """ One iteration of InnerProduct""" + _test_op(data, L.InnerProduct, "InnerProduct", **kwargs) + + +def test_forward_InnerProduct(): + """ InnerProduct """ + data = np.random.rand(1, 3, 10, 10) + _test_inner_product(data, + num_output=20, + bias_term=False, + weight_filler=dict(type='xavier')) + _test_inner_product(data, + num_output=20, + bias_term=True, + weight_filler=dict(type='xavier'), + bias_filler=dict(type='xavier')) + _test_inner_product(np.random.rand(20, 10).astype(np.float32), + num_output=30, + bias_term=True, + weight_filler=dict(type='xavier'), + bias_filler=dict(type='xavier')) + + +####################################################################### +# LRN +# ----------- + + +def _test_lrn(data, local_size=5, alpha=1., beta=0.75, k=1.): + """ One iteration of LRN """ + _test_op(data, + L.LRN, + 'LRN', + local_size=local_size, + alpha=alpha, + beta=beta, + k=k) + + +def test_forward_LRN(): + """ LRN """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_lrn(data) + _test_lrn(data, local_size=3) + _test_lrn(data, local_size=3, alpha=2.) + _test_lrn( + data, + local_size=3, + alpha=2., + beta=0.5, + ) + _test_lrn(data, local_size=3, alpha=2., beta=0.5, k=2.) + + +####################################################################### +# Pooling +# ----------- + + +def _test_pooling(data, **kwargs): + """ One iteration of Pooling. """ + _test_op(data, L.Pooling, "Pooling", **kwargs) + + +def test_forward_Pooling(): + """ Pooing """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + # MAX Pooling + _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.MAX) + _test_pooling(data, + kernel_h=2, + kernel_w=3, + stride_h=2, + stride_w=1, + pad_h=1, + pad_w=2, + pool=P.Pooling.MAX) + _test_pooling(data, pool=P.Pooling.MAX, global_pooling=True) + + # AVE Pooing + _test_pooling(data, kernel_size=2, stride=2, pad=0, pool=P.Pooling.AVE) + _test_pooling(data, + kernel_h=2, + kernel_w=3, + stride_h=2, + stride_w=1, + pad_h=1, + pad_w=2, + pool=P.Pooling.AVE) + _test_pooling(data, pool=P.Pooling.AVE, global_pooling=True) + + +####################################################################### +# PReLU +# ----------- + + +def _test_prelu(data, **kwargs): + """ One iteration of PReLU. """ + _test_op(data, L.PReLU, "PReLU", **kwargs) + + +def test_forward_PReLU(): + """ PReLU """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_prelu(data, filler=dict(type='constant', value=0.5)) + _test_prelu(data) + _test_prelu(np.random.rand(10, 20).astype(np.float32)) + + +####################################################################### +# ReLU +# ----------- + + +def _test_relu(data, **kwargs): + """ One iteration of ReLU. """ + _test_op(data, L.ReLU, "ReLU", **kwargs) + + +def test_forward_ReLU(): + """ ReLU """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_relu(data) + _test_relu(np.random.rand(10, 20).astype(np.float32)) + + +####################################################################### +# Reshape +# ----------- + + +def _test_reshape(data, **kwargs): + """ One iteration of Reshape. """ + _test_op(data, L.Reshape, "Reshape", **kwargs) + + +def test_forward_Reshape(): + """ Reshape """ + data = np.random.rand(1, 8, 6).astype(np.float32) + _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}}) + _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, 3]}}) + _test_reshape(data, reshape_param={'shape': {'dim': [2, 0, -1]}}) + _test_reshape(data, reshape_param={'shape': {'dim': [0, -1]}}) + + _test_reshape(data, reshape_param={'shape': {'dim': [2, 3]}, 'axis': 2}) + _test_reshape(data, reshape_param={'shape': {'dim': [4, 3, 4]}, 'axis': 1}) + _test_reshape(data, + reshape_param={ + 'shape': { + 'dim': [4, 3, 4] + }, + 'axis': -3 + }) + + _test_reshape(data, + reshape_param={ + 'shape': { + 'dim': [2, 4] + }, + 'axis': 1, + 'num_axes': 1 + }) + _test_reshape(data, + reshape_param={ + 'shape': { + 'dim': [3, 16] + }, + 'axis': 1, + 'num_axes': 2 + }) + + +####################################################################### +# Scale +# ----------- + + +def _test_scale(data, **kwargs): + """ One iteration of Scale. """ + _test_op(data, L.Scale, "Scale", **kwargs) + + +def test_forward_Scale(): + """ Scale """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_scale(data, filler=dict(type="xavier")) + _test_scale(data, + filler=dict(type="xavier"), + bias_term=True, + bias_filler=dict(type="xavier")) + + +####################################################################### +# Sigmoid +# ----------- + + +def _test_sigmoid(data, **kwargs): + """ One iteration of Sigmoid. """ + _test_op(data, L.Sigmoid, "Sigmoid", **kwargs) + + +def test_forward_Sigmoid(): + """ Sigmoid """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_sigmoid(data) + + +####################################################################### +# Slice +# ----------- + + +def _test_slice(data, **kwargs): + """ One iteration of Slice """ + _test_op(data, L.Slice, "Slice", **kwargs) + + +def test_forward_Slice(): + """ Slice """ + data = np.random.rand(1, 3, 10, 10).astype(np.float32) + _test_slice(data, ntop=2, slice_param=dict(axis=1, slice_point=[1])) + _test_slice(data, ntop=2, slice_param=dict(axis=-1, slice_point=[1])) + _test_slice(data, ntop=3, slice_param=dict(axis=2, slice_point=[1, 6])) + _test_slice(data, ntop=3) + + +####################################################################### +# Softmax +# ----------- + + +def _test_softmax(data, **kwargs): + """ One iteration of Softmax """ + _test_op(data, L.Softmax, "Softmax", **kwargs) + + +def test_forward_Softmax(): + """ Softmax""" + _test_softmax(np.random.rand(1, 3, 10, 10).astype(np.float32)) + _test_softmax(np.random.rand(1, 3, 10, 10).astype(np.float32), axis=2) + _test_softmax(np.random.rand(10, 10).astype(np.float32), axis=0) + _test_softmax(np.random.rand(2, 10, 10).astype(np.float32), axis=1) + + +####################################################################### +# TanH +# ----------- + + +def _test_tanh(data, **kwargs): + """ One iteration of TanH """ + _test_op(data, L.TanH, "TanH", **kwargs) + + +def test_forward_TanH(): + """ TanH """ + _test_tanh(np.random.rand(1, 3, 10, 10).astype(np.float32)) + _test_tanh(np.random.rand(3, 10, 10).astype(np.float32)) + _test_tanh(np.random.rand(10, 10).astype(np.float32)) + _test_tanh(np.random.rand(10).astype(np.float32)) + + +####################################################################### +# Mobilenetv2 +# ----------- + + +def _test_mobilenetv2(data): + """ One iteration of Mobilenetv2 """ + mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32) + mean_val = np.reshape(mean_val, (1, 3, 1, 1)) + mean_val = np.tile(mean_val, (1, 1, 224, 224)) + data_process = data - mean_val + data_process = data_process / 58.8 + data_process = data_process.astype(np.float32) + + proto_file_url = ("https://github.com/shicai/MobileNet-Caffe/raw/" + "master/mobilenet_v2_deploy.prototxt") + blob_file_url = ("https://github.com/shicai/MobileNet-Caffe/blob/" + "master/mobilenet_v2.caffemodel?raw=true") + proto_file = download_testdata(proto_file_url, 'mobilenetv2.prototxt', + module='model') + blob_file = download_testdata(blob_file_url, 'mobilenetv2.caffemodel', + module='model') + _test_network(data_process, proto_file, blob_file) + + +def test_forward_Mobilenetv2(): + """ Mobilenetv2 """ + data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32) + _test_mobilenetv2(data) + + +####################################################################### +# Alexnet +# ----------- + + +def _test_alexnet(data): + """ One iteration of Alexnet """ + mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32) + mean_val = np.reshape(mean_val, (1, 3, 1, 1)) + mean_val = np.tile(mean_val, (1, 1, 227, 227)) + data_process = data - mean_val + data_process = data_process.astype(np.float32) + + proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models/" + "bvlc_alexnet/deploy.prototxt") + blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_alexnet.caffemodel' + proto_file = download_testdata(proto_file_url, 'alexnet.prototxt', + module="model") + blob_file = download_testdata(blob_file_url, 'alexnet.caffemodel', + module='model') + _test_network(data_process, proto_file, blob_file) + + +def test_forward_Alexnet(): + """ Alexnet """ + data = np.random.randint(0, 256, size=(1, 3, 227, 227)).astype(np.float32) + _test_alexnet(data) + + +####################################################################### +# Resnet50 +# ----------- + + +def _test_resnet50(data): + """ One iteration of Resnet50 """ + mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32) + mean_val = np.reshape(mean_val, (1, 3, 1, 1)) + mean_val = np.tile(mean_val, (1, 1, 224, 224)) + data_process = data - mean_val + data_process = data_process.astype(np.float32) + + proto_file_url = ("https://github.com/fernchen/CaffeModels/raw/" + "master/resnet/ResNet-50-deploy.prototxt") + blob_file_url = ("https://github.com/fernchen/CaffeModels/raw/" + "master/resnet/ResNet-50-model.caffemodel") + + proto_file = download_testdata(proto_file_url, 'resnet50.prototxt', + module="model") + blob_file = download_testdata(blob_file_url, 'resnet50.caffemodel', + module='model') + + _test_network(data_process, proto_file, blob_file) + + +def test_forward_Resnet50(): + """ Resnet50 """ + data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32) + _test_resnet50(data) + + +####################################################################### +# Inceptionv4 +# ----------- + + +def _test_inceptionv1(data): + """ One iteration of Inceptionv4 """ + mean_val = np.array([103.939, 116.779, 123.68], dtype=np.float32) + mean_val = np.reshape(mean_val, (1, 3, 1, 1)) + mean_val = np.tile(mean_val, (1, 1, 224, 224)) + data_process = data - mean_val + data_process = data_process / 58.8 + data_process = data_process.astype(np.float32) + + proto_file_url = ("https://github.com/BVLC/caffe/raw/master/models" + "/bvlc_googlenet/deploy.prototxt") + blob_file_url = 'http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel' + proto_file = download_testdata(proto_file_url, 'inceptionv1.prototxt', + module="model") + blob_file = download_testdata(blob_file_url, 'inceptionv1.caffemodel', + module='model') + _test_network(data_process, proto_file, blob_file) + + +def test_forward_Inceptionv1(): + """ Inceptionv4 """ + data = np.random.randint(0, 256, size=(1, 3, 224, 224)).astype(np.float32) + _test_inceptionv1(data) + + +if __name__ == "__main__": + # NN + test_forward_Convolution() + test_forward_Deconvolution() + test_forward_Dropout() + test_forward_LRN() + test_forward_Pooling() + test_forward_Scale() + test_forward_InnerProduct() + test_forward_BatchNorm() + + # Elemwise + test_forward_Eltwise() + + # Activation + test_forward_PReLU() + test_forward_ReLU() + test_forward_Sigmoid() + test_forward_Softmax() + test_forward_TanH() + + # Reshape + test_forward_Reshape() + test_forward_Flatten() + + # Math + test_forward_Concat() + test_forward_Crop() + test_forward_Slice() + + # End to End + test_forward_Mobilenetv2() + test_forward_Alexnet() + test_forward_Resnet50() + test_forward_Inceptionv1() diff --git a/tests/scripts/task_python_frontend_cpu.sh b/tests/scripts/task_python_frontend_cpu.sh index 96c5ce631a17..10354e588720 100755 --- a/tests/scripts/task_python_frontend_cpu.sh +++ b/tests/scripts/task_python_frontend_cpu.sh @@ -35,3 +35,6 @@ python3 -m pytest tests/python/frontend/tflite echo "Running relay Keras frontend test..." python3 -m pytest tests/python/frontend/keras + +echo "Running relay Caffe frontend test..." +python3 -m pytest tests/python/frontend/caffe