From 62d8636e57e35afea0cefdd6b307dd6082d66a03 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Sun, 15 May 2022 03:50:25 +0800 Subject: [PATCH] Oneflow fronted support more model and fix bug (#11321) * add relay.f.frontend.fm_oneflow support cnns * support cuda * fix mobilenetv2 and reviews * fix: model without meta info * support eager and yolo, add test * fix: license * add: tutorials * fix: support new graph * fix some comments * refine * fix concat op convert bug * refine * refine * change cuda to cpu * fix bug * fix ci error in tvm * fix pylint check * delete useless file * add skimage package in docker * fix ci error * fix bug * add oneflow fronted test in ci * merge conflict * fix tutorial * try to find error in ci * revert * merge conflict * black oneflow * Delete from_oneflow.py * restruct oneflow fronted * support vision-transformer * black format * update black version and reformat * fix ci error * fix doc error * fix gpu fronted test failed Co-authored-by: hhhfccz --- python/tvm/relay/frontend/oneflow.py | 418 ++++++++++++------ tests/python/frontend/oneflow/test_forward.py | 199 +++++++++ .../frontend/oneflow/test_vision_models.py | 150 +++++++ 3 files changed, 630 insertions(+), 137 deletions(-) create mode 100644 tests/python/frontend/oneflow/test_vision_models.py diff --git a/python/tvm/relay/frontend/oneflow.py b/python/tvm/relay/frontend/oneflow.py index a1a7d513f8d0..ff4b5a5bcc42 100644 --- a/python/tvm/relay/frontend/oneflow.py +++ b/python/tvm/relay/frontend/oneflow.py @@ -21,7 +21,7 @@ import os import re import copy -import warnings +from collections import OrderedDict import numpy as np import tvm @@ -38,7 +38,6 @@ Renamer, fold_constant, get_relay_op, - infer_channels, infer_shape, infer_type, new_var, @@ -97,7 +96,6 @@ def _dtype_shape_promotion(inputs): """Promote data type and shape for list of tensors.""" dtype_order = ["bool", "int8", "int16", "int32", "int64", "float32", "float64"] - ranks = [len(infer_shape(x)) for x in inputs] if set(ranks) == set([1, 0]): for i, r in enumerate(ranks): @@ -497,19 +495,26 @@ class Flatten(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - axis = attrs.get("axis", 1) - ishape = _op.shape_of(inputs[0]) - ndim = infer_shape(ishape)[0] - if axis < 0: - axis = axis + ndim - - if axis == 1: - out = _op.nn.batch_flatten(inputs[0]) - else: - pre_shape = _op.prod(_op.strided_slice(ishape, [0], [axis], [1]), keepdims=True) - post_shape = _op.prod(_op.strided_slice(ishape, [axis], [ndim], [1]), keepdims=True) - newshape = _op.concatenate([pre_shape, post_shape], axis=0) - out = _op.reshape(inputs[0], newshape) + x = inputs[0] + input_shape = list(infer_shape(x)) + + start = attrs["start_dim"] + end = attrs["end_dim"] + ndim = len(input_shape) + if end < 0: + end += ndim + new_shape = [0] * start + + new_shape.append(-1) + squeeze_axes = [] + for i in range(start + 1, end + 1): + new_shape.append(1) + squeeze_axes.append(i) + for _ in range(end + 1, ndim): + new_shape.append(0) + out = _op.reshape(x, new_shape) + if squeeze_axes: + out = _op.squeeze(out, axis=squeeze_axes) return out @@ -518,36 +523,119 @@ class MatMul(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - assert len(inputs) == 2, "Gemm op take 2 inputs, {} given".format(len(inputs)) - # Similar to 'class Conv' - true_names = ["weight"] - false_names = ["_input."] - for i in inputs: - T_NAMES = any(x in str(i) for x in true_names) - F_NAMES = any(x in str(i) for x in false_names) - if T_NAMES and not F_NAMES: - matmul_b = i - else: - matmul_a = i - - dtype = infer_type(matmul_a).checked_type.dtype + assert len(inputs) == 2, "MatMul op take 2 inputs, {} given".format(len(inputs)) + dtype = infer_type(inputs[0]).checked_type.dtype # Y = alpha * A * B alpha = float(attrs.get("alpha", 1.0)) transA = bool(attrs.get("transpose_a", False)) transB = bool(attrs.get("transpose_b", False)) - # get number of channels - channels = infer_channels(matmul_b, not transB) - if transA: - matmul_a = _op.transpose(matmul_a, axes=(1, 0)) - if not transB: - matmul_b = _op.transpose(matmul_b, axes=(1, 0)) - matmul_a = _op.nn.batch_flatten(matmul_a) - if alpha != 1.0: - matmul_a *= _expr.const(alpha, dtype=dtype) + a_shape = infer_shape(inputs[0]) + b_shape = infer_shape(inputs[1]) + if ( + (transA and transB and a_shape[-2] != b_shape[-1]) + or (transA and not transB and a_shape[-2] != b_shape[-2]) + or (transB and not transA and a_shape[-1] != b_shape[-1]) + or (not transB and not transA and a_shape[-1] != b_shape[-2]) + ): + matmul_a = inputs[1] + matmul_b = inputs[0] + else: + matmul_a = inputs[0] + matmul_b = inputs[1] - return _op.nn.dense(matmul_a, matmul_b, units=channels) + if transA: + perm = list(range(len(a_shape))) + perm[-2] = len(a_shape) - 1 + perm[-1] = len(a_shape) - 2 + matmul_a = _op.transpose(matmul_a, axes=perm) + if transB: + perm = list(range(len(b_shape))) + perm[-2] = len(b_shape) - 1 + perm[-1] = len(b_shape) - 2 + matmul_b = _op.transpose(matmul_b, axes=perm) + + # This implemention almost keeps same with ONNX + # Need to check input shape as batch matmul must be supported. + a_shape = shape_of(matmul_a, dtype="int32") + a_rank = infer_shape(a_shape)[0] + b_shape = shape_of(matmul_b, dtype="int32") + b_rank = infer_shape(b_shape)[0] + # When performing a batch matmul, we need to properly handle N-dim shapes. + if a_rank > 2 or b_rank > 2: + + def flatten_to_nd(x, x_shape, nd=3): + ndims = infer_shape(x_shape)[0] + if ndims == nd: + return x + newshape = _op.concatenate( + [ + _expr.const([-1], dtype=infer_type(x_shape).checked_type.dtype), + _op.strided_slice(x_shape, [ndims - nd + 1], [ndims]), + ], + 0, + ) + out = _op.reshape(x, fold_constant(newshape)) + return out + + b_type = infer_type(matmul_b) + # Convert to dense if the second matrix is 2d and non-dynamic + if b_rank == 2 and not _ty.is_dynamic(b_type.checked_type): + a = flatten_to_nd(matmul_a, a_shape, 2) + b = _op.transpose(matmul_b) + output = _op.nn.dense(a, b) + else: + # Convert a and b into 3 dimensional tensors. + a = flatten_to_nd(matmul_a, a_shape, 3) + b = flatten_to_nd(matmul_b, b_shape, 3) + # Transpose matrix dimensions of b. + b = _op.transpose(b, [0, 2, 1]) + # Perform a batch matmul. + output = _op.nn.batch_matmul(a, b) + # Determine the output batch dimension. + if a_rank > b_rank: + out_batch = _op.strided_slice(a_shape, [0], [a_rank - 2]) + elif a_rank < b_rank: + out_batch = _op.strided_slice(b_shape, [0], [b_rank - 2]) + # If its unclear how broadcasting should be applied, the output + # shape is determined by choosing the maximum value from each input. + else: + out_batch = _op.concatenate( + [ + _op.maximum( + _op.strided_slice(a_shape, [i], [i + 1]), + _op.strided_slice(b_shape, [i], [i + 1]), + ) + for i in range(a_rank - 2) + ], + 0, + ) + # Reshape output to original dimensions. + final_shape = _op.concatenate( + [ + out_batch, + _op.strided_slice( + a_shape, [infer_shape(a_shape)[0] - 2], [infer_shape(a_shape)[0] - 1] + ), + _op.strided_slice( + b_shape, [infer_shape(b_shape)[0] - 1], [infer_shape(b_shape)[0]] + ), + ], + 0, + ) + out = _op.reshape(output, fold_constant(final_shape)) + else: + if b_rank == 1: + matmul_b = _op.expand_dims(matmul_b, 1, 1) + # Otherwise a simple dense op will get the job done. + input_1_t = _op.transpose(matmul_b, axes=(1, 0)) + out = _op.nn.dense(matmul_a, input_1_t) + if b_rank == 1: + out = _op.squeeze(out, axis=[-1]) + if not np.isclose(alpha, 1.0): + out = out * _expr.const(alpha, dtype=dtype) + return out class Reduce(OneFlowOpConverter): @@ -635,15 +723,34 @@ class Expand(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - input_shape = infer_shape(inputs[0]) - assert input_shape == attrs["in_shape"], "shape wrong" - - new_shape = attrs["out_shape"] - out = _op.broadcast_to(inputs[0], shape=new_shape) + data_in = inputs[0] + shape = list(infer_shape(data_in)) + + ndims = len(shape) + sizes = attrs["logical_expand_shape"] + out = data_in + out_dims = len(sizes) + if ndims < out_dims: + num_newaxis = out_dims - ndims + out = _op.expand_dims(out, axis=0, num_newaxis=num_newaxis) + shape = [1] * num_newaxis + shape + + for i in range(out_dims): + if sizes[i] != -1 and shape[i] == 1: + out = _op.repeat(out, sizes[i], axis=i) return out +class Transpose(OneFlowOpConverter): + """Operator converter for transpose.""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + perm = attrs["perm"] + return _op.transpose(inputs[0], axes=perm) + + class ExpandDim(OneFlowOpConverter): """Operator converter for ExpandDim""" @@ -718,12 +825,25 @@ class BroadcastDiv(BroadcastMath): name = "divide" -class Greater(OneFlowOpConverter): +class LogicalGreater(OneFlowOpConverter): """Operator converter for greater""" @classmethod def _impl_v1(cls, inputs, attrs, params): - return _op.greater(inputs[0], inputs[1]) + res = None + if attrs.get("has_int_operand", True): + value = attrs.get("int_operand", 0.0) + res = _op.greater(inputs[0], _op.full_like(inputs[0], fill_value=_expr.const(value))) + elif attrs.get("has_float_operand", True): + value = float(attrs.get("float_operand", 0.0)) + res = _op.greater( + inputs[0], _op.full_like(inputs[0], fill_value=_expr.const(value)).astype("float32") + ) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + return res class Log1p(OneFlowOpConverter): @@ -734,6 +854,15 @@ def _impl_v1(cls, inputs, attrs, params): return _op.log(inputs[0] + _expr.const(1.0)) +class Pow(OneFlowOpConverter): + """Operator converter for Power""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + inputs = _dtype_shape_promotion(inputs) + return get_relay_op(cls.name)(inputs[0], inputs[1]) + + class Expm1(OneFlowOpConverter): """Operator converter for Expm1""" @@ -812,14 +941,35 @@ def _impl_v1(cls, inputs, attrs, params): return res +class ScalarDiv(OneFlowOpConverter): + """Operator convert for Div_scalar""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + assert len(inputs) == 1, "div_scalar take == 1 inputs, but {} given.".format(len(inputs)) + + if attrs.get("has_int_operand", True): + res = inputs[0] / _expr.const(attrs["int_operand"], dtype="float32") + elif attrs.get("has_float_operand", True): + res = inputs[0] / _expr.const(attrs["float_operand"]) + else: + raise AttributeError( + "please check if has_int_operand or has_float_operand in your attrs" + ) + + return res + + class ScalarPow(OneFlowOpConverter): """Operator convert for Pow_scalar""" @classmethod def _impl_v1(cls, inputs, attrs, params): - exponent = attrs.get("exponent", 1.0) - exponent = _expr.const(exponent, dtype="float32") - return _op.power(inputs[0], exponent) + if attrs.get("has_int_operand", True): + coeff = _expr.const(attrs["int_operand"]) + elif attrs.get("has_float_operand", True): + coeff = _expr.const(attrs["float_operand"]) + return _op.power(inputs[0], coeff) class MaxPool2d(Pool): @@ -857,15 +1007,12 @@ class Softmax(OneFlowOpConverter): @classmethod def _impl_v1(cls, inputs, attrs, params): - axis = attrs.get("axis", 1) - ndim = len(infer_shape(inputs[0])) - if axis < 0: - axis += ndim - axes = list(range(axis, ndim)) - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - return e / _op.sum(e, axes, keepdims=True) + axis = attrs.get("axis", -1) + data = inputs[0] + if isinstance(axis, str): + axis = int(axis) + + return _op.nn.softmax(data, axis=axis) class LogSoftmax(OneFlowOpConverter): @@ -1000,6 +1147,17 @@ def _impl_v1(cls, inputs, attrs, params): return inputs[0] / (_expr.const(1.0) + Absolute.get_converter()(inputs, attrs, params)) +class Variance(OneFlowOpConverter): + """Operator converter for Variance""" + + @classmethod + def _impl_v1(cls, inputs, attrs, params): + axis = attrs["dim"] + keepdims = attrs["keepdim"] + unbiased = bool(attrs["unbiased"]) + return _op.reduce.variance(inputs[0], axis=axis, keepdims=keepdims, unbiased=unbiased) + + class Concat(OneFlowOpConverter): """Operator converter for Concat""" @@ -1234,6 +1392,7 @@ def get_convert_map(): "bias_add": Add.get_converter(), "scalar_add": ScalarAdd.get_converter(), "scalar_mul": ScalarMul.get_converter(), + "scalar_div": ScalarDiv.get_converter(), "scalar_pow": ScalarPow.get_converter(), "reduce_sum": ReduceSum.get_converter(), "reduce_max": ReduceMax.get_converter(), @@ -1243,7 +1402,7 @@ def get_convert_map(): "broadcast_mul": BroadcastMul.get_converter(), "broadcast_sub": BroadcastSub.get_converter(), "broadcast_div": BroadcastDiv.get_converter(), - "broadcast_greater": Greater.get_converter(), + "scalar_logical_greater": LogicalGreater.get_converter(), "log": Renamer("log"), "log1p": Log1p.get_converter(), "acos": Renamer("acos"), @@ -1258,7 +1417,7 @@ def get_convert_map(): "sinh": Renamer("sinh"), "tan": Renamer("tan"), "tanh": Renamer("tanh"), - "pow": Renamer("power"), + "pow": Pow.get_converter(), "exp": Renamer("exp"), "expm1": Expm1.get_converter(), "floor": Renamer("floor"), @@ -1271,7 +1430,7 @@ def get_convert_map(): "sign": Sign.get_converter(), "erf": Erf.get_converter(), "erfc": Erfc.get_converter(), - "reciprocal_no_nan": Reciprocal.get_converter(), + "reciprocal": Reciprocal.get_converter(), # defs/activation "softmax": Softmax.get_converter(), "softsign": Softsign.get_converter(), @@ -1295,24 +1454,29 @@ def get_convert_map(): "upsample_bilinear_2d": UpsampleBiLinear.get_converter(), # defs/tensor "matmul": MatMul.get_converter(), + "batch_matmul": MatMul.get_converter(), + "broadcast_matmul": MatMul.get_converter(), "concat": Concat.get_converter(), "clip_by_scalar": Clip.get_converter(), "slice": Slice.get_converter(), "expand": Expand.get_converter(), - "transpose": AttrCvt("transpose", {"perm": "axes"}), + "transpose": Transpose.get_converter(), "expand_dims": ExpandDim.get_converter(), "range": Range.get_converter(), "cast": Cast.get_converter(), # defs/others "reshape": Reshape.get_converter(), "constant": Constant.get_converter(), - # "where": Where.get_converter(), + "where": Where.get_converter(), "flatten": Flatten.get_converter(), "sigmoid": Renamer("sigmoid"), "sigmoid_v2": Renamer("sigmoid"), "hardsigmoid": HardSigmoid.get_converter(), + "softplus": Softplus.get_converter(), "squeeze": AttrCvt("squeeze", {"axes": "axis"}), "unsqueeze": Unsqueeze.get_converter(), + "identity": Renamer("copy"), + "var": Variance.get_converter(), } @@ -1402,7 +1566,7 @@ def deal_parameter_convert( ): """deal with parameter(weight) convert in oneflow.""" for node_input_path in node_input_paths: - node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "")) + node_path = os.path.join(model_dir_path, node_input_path.replace("m.", "", 1)) node_input_name = node_input_path.split("/")[0] _input_path_2_name[node_path] = node_input_name for param_name in _model_array: @@ -1503,7 +1667,11 @@ def __init__(self, shape, dtype, nodes, model_dir_path): print("{} should be defined by user".format(self._init_variable_node)) def _parse_input(self, node, model_dir_path): + input_user_conf_list = [] for input_name in node.user_conf.input: + input_user_conf_list.append(input_name) + input_user_conf_list.sort() + for input_name in input_user_conf_list: node_input_paths = getattr(node.user_conf.input[input_name], "s") for i in node_input_paths: node_input = i.split("/")[0] @@ -1548,58 +1716,11 @@ def _parse_output(self, op_name, outputs, cnt_init=0): return outputs - def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=None): + def from_oneflow(self, nodes, model_dir_path): """ - Parameters - ---------- - nodes : dict, keys: node.name, value: node - contain the graph - model_dir_path: str - The path of parameter - freeze_params: bool - If freeze_params is True, - the computational graph input is the input of the first layer of the network, - which cannot be specified by the user, e.g. - Default input is: %v_ResNetGraph_0_input.0: Tensor[(1, 3, 224, 224), float32] - User-defined input is: %_0_input.0: Tensor[(1, 3, 640, 480), float32] - If freeze_params is on, then conv1-in will be the graph input, not Input_0 - user_input: dict - User-defined input information for the graph - { - node1_name: - { - 'name': node1_name, # str, like "%v_ResNetGraph_0_input.0" - 'shape': node1_shape, # tuple - 'dtype': node1_dtype # str, like "float32" - } - ... - } - We recommend that users specify the input by specifying the job function, - rather than by this function - - Returns - ------- - mod : tvm.IRModule - The returned relay module - params : dict - A dict of name: tvm.nd.array pairs, used as pretrained weights + Implementation of convert the OneFlow model into an equivalent Relay Function. """ - # step 1: get the graph input - if not freeze_params: - for node_init_name in user_input: - if "_input." not in node_init_name: - raise KeyError( - "user_input['name'] should contain '_input.' " - + "to let program know that this is input node" - ) - self._nodes[node_init_name] = new_var( - node_init_name, - shape=user_input[node_init_name]["shape"], - dtype=user_input[node_init_name]["dtype"], - ) - self._inputs[node_init_name] = self._nodes[node_init_name] - - # step 2: find out if unsupported ops are used + # step 1: find out if unsupported ops are used convert_map = get_convert_map() unsupported_ops = set() for node_name in nodes: @@ -1619,7 +1740,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non msg += ", ".join(unsupported_ops) raise tvm.error.OpNotImplemented(msg) - # step 3: convert op + # step 2: convert op for node_name in nodes: node = nodes[node_name] if is_user_op(node): @@ -1633,7 +1754,11 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non self._parse_input(node, model_dir_path=model_dir_path) node_inputs = oneflow_input() + input_user_conf_list = [] for input_name in node.user_conf.input: + input_user_conf_list.append(input_name) + input_user_conf_list.sort() + for input_name in input_user_conf_list: node_input_paths = getattr(node.user_conf.input[input_name], "s") for i in node_input_paths: node_input = i.split("/")[0] @@ -1663,7 +1788,6 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non ), "Number of output mismatch {} vs {} in {}.".format( len(node_outputs), outputs_num, op_name ) - if outputs_num == 1: op = fold_constant(op) else: @@ -1678,10 +1802,9 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non else: self._nodes[node_outputs[i]] = op_temp[i] - # step 4: get the outputs + # step 3: get the outputs outputs = [] - for node_name in nodes: - node = nodes[node_name] + for node_name, node in nodes.items(): if is_output_op(node): node_name_v2 = getattr(node.output_conf, "in").split("/")[0] if node_name in self._nodes: @@ -1690,13 +1813,21 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non outputs.append(self._nodes[node_name_v2]) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) - # step 5: get the relay IR + # step 4: get the relay IR free_vars = analysis.free_vars(outputs) nodes = {v: k for k, v in self._nodes.items()} free_vars = [nodes[var] for var in free_vars] + free_vars_inputs = [] + free_vars_parameters = [] + for x in free_vars: + if "_input.0" in x: + free_vars_inputs.append(x) + else: + free_vars_parameters.append(x) + free_vars = free_vars_inputs + free_vars_parameters - # step 6: make sure the '_input.0' is the first in self._inputs + # step 5: make sure the '_input.0' is the first in self._inputs for free_var in free_vars: if free_var not in self._inputs: self._inputs[free_var] = self._nodes[free_var] @@ -1708,7 +1839,7 @@ def from_oneflow(self, nodes, model_dir_path, freeze_params=True, user_input=Non else: raise IndexError("{} is not in self._inputs".format(input_name)) - # step 7: create a function from our output expression and all input variables. + # step 6: create a function from our output expression and all input variables. func = _function.Function([v for _, v in self._sort_inputs.items()], outputs) return IRModule.from_expr(func), self._params @@ -1740,20 +1871,38 @@ def _convert_operator(self, op_name, node_inputs, op_attr): return sym -def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): - """ - see OneflowGraph.from_oneflow +def from_oneflow(graph, model_dir_path): + """Convert a OneFlow model into an equivalent Relay Function. + + At present, there are two ways to run models in deep learning framework + Dynamic Graph and Static Graph, which are also called Eager Mode and Graph + Mode in OneFlow. + + In general, dynamic graphs are easier to use and static graphs have better performance. + OneFlow offers nn.Graph, so that users can use the eager-like programming style to build + static graphs and train the models. + + We utilize the intermediate representation of nn.Graph to convert the OneFlow model to Reley. + + Parameters + ---------- + nodes : dict, keys: node.name, value: node + contain the graph + model_dir_path: str + The path of weight + + Returns + ------- + mod : tvm.IRModule + The returned relay module + params : dict + A dict of name: tvm.nd.array pairs, used as pretrained weights """ try: import oneflow as flow except ImportError: raise ImportError("please check that OneFlow is installed") - if not freeze_params and user_input is None: - raise ValueError("if you want to specify graph input, please give the 'user_input'") - if freeze_params and user_input is not None: - warnings.warn("'user_input' will not work, please check the 'freeze_params'") - # get info of nodes shape = {} dtype = {} @@ -1800,18 +1949,13 @@ def from_oneflow(graph, model_dir_path, freeze_params=True, user_input=None): graph_proto = graph._graph_proto # get all nodes - nodes = {} + nodes = OrderedDict() for op in graph_proto.net.op: nodes[op.name] = op g = OneflowGraph(shape, dtype, nodes, model_dir_path) # Use the graph proto as a scope so that ops can access other nodes if needed. - mod, params = g.from_oneflow( - nodes=nodes, - model_dir_path=model_dir_path, - freeze_params=freeze_params, - user_input=user_input, - ) + mod, params = g.from_oneflow(nodes=nodes, model_dir_path=model_dir_path) return mod, params diff --git a/tests/python/frontend/oneflow/test_forward.py b/tests/python/frontend/oneflow/test_forward.py index d144cdad2bc5..0d18a2fb5c21 100644 --- a/tests/python/frontend/oneflow/test_forward.py +++ b/tests/python/frontend/oneflow/test_forward.py @@ -79,6 +79,16 @@ def build(self, x1, x2, x3): return out +class OneFlowGraph_v3(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x1, x2): + out = self.m(x1, x2) + return out + + def get_oneflow_output(model, inputs): flow_output = model(inputs) return flow_output.numpy() @@ -89,6 +99,10 @@ def get_oneflow_concat_output(model, input1, input2, input3): return flow_output +def get_oneflow_elementwise_output(model, input1, input2): + return model(input1, input2).numpy() + + def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"): inputs_numpy = inputs.numpy() if target == "llvm": @@ -132,6 +146,32 @@ def get_tvm_concat_output( return tvm_output +def get_tvm_elementwise_output( + graph, + model_path, + input1: flow.tensor, + input2: flow.tensor, + target="llvm", + dtype="float32", +): + input1_numpy = input1.numpy() + input2_numpy = input2.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()( + tvm.nd.array(input1_numpy.astype(dtype)), + tvm.nd.array(input2_numpy.astype(dtype)), + **params, + ).numpy() + return tvm_output + + def verify_conv( model, name="", @@ -336,6 +376,33 @@ def verify_math( tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) +def verify_matmul( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs1=flow.tensor(np.random.randn(2, 5), dtype=flow.float32), + inputs2=flow.tensor(np.random.randn(5, 2), dtype=flow.float32), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs1 = inputs1.to(device) + inputs2 = inputs2.to(device) + + graph = OneFlowGraph_v3(model) + graph._compile(inputs1, inputs2) + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_elementwise_output(graph, inputs1, inputs2) + out_tvm = get_tvm_elementwise_output(graph, MODEL_HOME, inputs1, inputs2, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + def verify_concat( model, name="", @@ -602,6 +669,23 @@ def forward(self, x): x = self.active(x) return x + class HardTanh(flow.nn.Module): + def __init__(self): + super().__init__() + self.active = flow.nn.Hardtanh() + + def forward(self, x): + x = self.active(x) + return x + + class TensorSoftmax(flow.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + x = x.softmax(dim=-1) + return x + if os.path.exists(MODEL_HOME): rmdir(MODEL_HOME) @@ -616,6 +700,8 @@ def forward(self, x): model9 = SiLU().eval() model10 = LeakyReLU().eval() model11 = GELU().eval() + model12 = HardTanh().eval() + model13 = TensorSoftmax().eval() for device in ["llvm"]: verify_activation(model1, device=device) @@ -629,6 +715,12 @@ def forward(self, x): verify_activation(model9, device=device) verify_activation(model10, device=device) verify_activation(model11, device=device) + verify_activation(model12, device=device) + verify_activation( + model13, + device=device, + inputs=flow.tensor(np.random.rand(1, 12, 197, 197).astype(np.float32)), + ) @tvm.testing.uses_gpu @@ -665,12 +757,19 @@ class Exp2(flow.nn.Module): def forward(self, x): return flow.expm1(x) + class Variance(flow.nn.Module): + def forward(self, x): + return flow.var(x, 1, unbiased=False, keepdim=True) + model1 = Sigmoid().eval() model2 = Sign().eval() model3 = Log().eval() model4 = Log2().eval() model5 = Exp().eval() model6 = Exp2().eval() + model7 = Reciprocal().eval() + model8 = Pow().eval() + model9 = Variance().eval() for device in ["llvm"]: verify_math(model1, device=device) @@ -679,6 +778,9 @@ def forward(self, x): verify_math(model4, device=device) verify_math(model5, device=device) verify_math(model6, device=device) + verify_math(model7, device=device) + verify_math(model8, device=device) + verify_math(model9, device=device) @tvm.testing.uses_gpu @@ -710,6 +812,99 @@ def forward(self, x1, x2, x3): verify_concat(model, device=device) +@tvm.testing.uses_gpu +def test_add_constant(): + class ConstantAdd(flow.nn.Module): + def forward(self, x): + out = flow.add(1.0, x) + return out + + model = ConstantAdd().eval() + + for device in ["llvm"]: + verify_math( + model, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_logical(): + class LogicalGreater(flow.nn.Module): + def forward(self, x): + return x > 1.0 + + model1 = LogicalGreater().eval() + + for device in ["llvm"]: + verify_math( + model1, device=device, inputs=flow.tensor(np.random.randn(3, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_expand(): + class Expand(flow.nn.Module): + def forward(self, x): + return x.expand(2, -1, -1) + + model1 = Expand().eval() + + for device in ["llvm"]: + verify_math( + model1, device=device, inputs=flow.tensor(np.random.randn(1, 6, 9).astype(np.float32)) + ) + + +@tvm.testing.uses_gpu +def test_matmul(): + class MatMul(flow.nn.Module): + def forward(self, x, y): + return flow._C.matmul(x, y) + + class MatMulTranspose(flow.nn.Module): + def forward(self, x, y): + return flow._C.matmul(x, y, transpose_b=True) + + class BatchMatMul(flow.nn.Module): + def forward(self, x, y): + return flow._C.batch_matmul(x, y) + + class BroadCastMatMul(flow.nn.Module): + def forward(self, x, y): + return flow._C.matmul(x, y) + + model1 = MatMul().eval() + model2 = MatMulTranspose().eval() + model3 = BatchMatMul().eval() + model4 = BroadCastMatMul().eval() + + for device in ["llvm"]: + verify_matmul( + model1, + device=device, + inputs1=flow.tensor(np.random.randn(2, 3).astype(np.float32)), + inputs2=flow.tensor(np.random.randn(3, 3).astype(np.float32)), + ) + verify_matmul( + model2, + device=device, + inputs1=flow.tensor(np.random.randn(1, 2).astype(np.float32)), + inputs2=flow.tensor(np.random.randn(3, 2).astype(np.float32)), + ) + verify_matmul( + model3, + device=device, + inputs1=flow.tensor(np.random.randn(2, 1, 2).astype(np.float32)), + inputs2=flow.tensor(np.random.randn(2, 2, 3).astype(np.float32)), + ) + verify_matmul( + model4, + device=device, + inputs1=flow.tensor(np.random.randn(3, 8, 8, 16).astype(np.float32)), + inputs2=flow.tensor(np.random.randn(16, 8).astype(np.float32)), + ) + + if __name__ == "__main__": test_conv2d() test_pool2d() @@ -720,4 +915,8 @@ def forward(self, x1, x2, x3): test_math() test_slice() test_concat() + test_add_constant() + test_logical() + test_expand() + test_matmul() rmdir("log") diff --git a/tests/python/frontend/oneflow/test_vision_models.py b/tests/python/frontend/oneflow/test_vision_models.py new file mode 100644 index 000000000000..e8d0627001ca --- /dev/null +++ b/tests/python/frontend/oneflow/test_vision_models.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=import-self, invalid-name +# pylint: disable=arguments-differ, unused-argument, unused-import +"""Unit tests for various models and operators""" +import os +import sys + +import numpy as np +import pytest +import tvm +import tvm.testing +import tvm.topi.testing +from tvm import relay +from tvm.contrib import graph_executor + +import oneflow as flow +from flowvision.models.alexnet import alexnet +from flowvision.models.squeezenet import squeezenet1_0 +from flowvision.models.shufflenet_v2 import shufflenet_v2_x0_5 +from flowvision.models.mobilenet import mobilenet_v2 +from flowvision.models.ghostnet import ghostnet +from flowvision.models.vision_transformer import vit_base_patch16_224 + +MODEL_HOME = "test_model" + + +def mkdir(path): + # init + path = path.strip() + path = path.rstrip("\\") + + if not os.path.exists(path): + os.makedirs(path) + else: + print("{} is already here".format(path)) + + +def rmdir(path): + for root, dirs, files in os.walk(path, topdown=False): + for name in files: + os.remove(os.path.join(root, name)) + for name in dirs: + os.rmdir(os.path.join(root, name)) + os.removedirs(path) + + +def assert_shape(out1, out2): + if out1.shape != out2.shape: + msg = "Output shapes {} and {} don't match" + raise AssertionError(msg.format(out1.shape, out2.shape)) + + +class OneFlowGraph(flow.nn.Graph): + def __init__(self, module): + super().__init__() + self.m = module + + def build(self, x): + out = self.m(x) + return out + + +def get_oneflow_output(model, inputs): + flow_output = model(inputs) + return flow_output.numpy() + + +def get_tvm_output(graph, model_path, inputs: flow.tensor, target="llvm", dtype="float32"): + inputs_numpy = inputs.numpy() + if target == "llvm": + device = tvm.cpu(0) + elif target == "cuda": + device = tvm.cuda(0) + + mod, params = relay.frontend.from_oneflow(graph, model_path) + with tvm.transform.PassContext(opt_level=10): + intrp = relay.build_module.create_executor("graph", mod, device, target) + tvm_output = intrp.evaluate()(tvm.nd.array(inputs_numpy.astype(dtype)), **params).numpy() + return tvm_output + + +def verify_model( + model, + name="", + rtol=1e-5, + atol=1e-5, + inputs=flow.tensor( + np.random.rand(1, 3, 224, 224), + dtype=flow.float32, + ), + device="llvm", +): + if device == "cuda": + model.to(device) + inputs = inputs.to(device) + + graph = OneFlowGraph(model) + graph._compile(inputs) + + mkdir(MODEL_HOME) + flow.save(model.state_dict(), MODEL_HOME) + + out_flow = get_oneflow_output(graph, inputs) + out_tvm = get_tvm_output(graph, MODEL_HOME, inputs, target=device) + rmdir(MODEL_HOME) + + assert_shape(out_flow, out_tvm) + tvm.testing.assert_allclose(out_flow, out_tvm, rtol=rtol, atol=atol) + + +@tvm.testing.uses_gpu +def test_vision_models(): + + if os.path.exists(MODEL_HOME): + rmdir(MODEL_HOME) + + vision_alexnet = alexnet().eval() + vision_squeezenet = squeezenet1_0().eval() + vision_shufflenet = shufflenet_v2_x0_5().eval() + vision_mobilenetv2 = mobilenet_v2().eval() + vision_ghostnet = ghostnet().eval() + vision_vit = vit_base_patch16_224().eval() + + for device in ["llvm"]: + verify_model(vision_alexnet, device=device) + verify_model(vision_squeezenet, device=device) + verify_model(vision_shufflenet, device=device) + verify_model(vision_mobilenetv2, device=device) + verify_model(vision_ghostnet, device=device) + verify_model(vision_vit, device=device) + + +if __name__ == "__main__": + test_vision_models() + rmdir("log")