From ad0ac5f94e0f24845ad4190cd10d1dde164d5b77 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 22 Oct 2020 18:19:03 +0000 Subject: [PATCH 1/9] HWNC layout conversion support and better tensorcore strategy checking. --- python/tvm/relay/op/nn/_nn.py | 14 +++++ python/tvm/relay/op/strategy/cuda.py | 9 ++- python/tvm/relay/transform/transform.py | 26 ++++++++ src/runtime/graph/graph_runtime.cc | 2 +- .../relay/test_pass_convert_op_layout.py | 61 +++++++++++++++++++ 5 files changed, 108 insertions(+), 4 deletions(-) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index e1aabe1e15b5..a47e0114a5bc 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -28,6 +28,7 @@ from ..op import OpPattern from .._tensor import elemwise_shape_func from ..strategy.generic import is_depthwise_conv2d +from ...transform import LayoutConfig # relu reg.register_broadcast_schedule("nn.relu") @@ -164,6 +165,16 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): from tvm import relay data, weight = inputs + + # First check if there is a LayoutConfig scope, and if so, whether + # it indicates we should ignore this layer or not. + layout_config = LayoutConfig.current + if layout_config is not None: + skip_layer = layout_config.check_skip() + if skip_layer: + return relay.nn.conv2d(data, weight, **attrs) + + # Prepare new layout. new_attrs = dict(attrs) assert len(desired_layouts) == 2, "A desired layout is expected for both of nn.conv2d's inputs" desired_data_layout, desired_kernel_layout = map(str, desired_layouts) @@ -192,6 +203,9 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): else: new_attrs["kernel_layout"] = "HWIO" return relay.nn.conv2d(data, weight, **new_attrs) + elif desired_data_layout == 'HWNC': + new_attrs['kernel_layout'] = 'HWOI' + return relay.nn.conv2d(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported." % desired_data_layout) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 7031365251aa..ccd895029bd3 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -219,9 +219,12 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): out_channels = oc_chunk * oc_block_factor else: _, _, out_channels, _ = get_const_tuple(kernel.shape) - if topi.cuda.is_shape_tensorcore_direct_qualified( - batch=N, in_channels=in_channels, num_filter=out_channels, in_dtype=data.dtype - ): + + tensorcore_dtypes = ['int4', 'uint4', 'int8', 'uint8'] + if (N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0) or \ + (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0) or \ + (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0) and \ + (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes): strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e155f83a7c5d..cbe9df7aa689 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -386,6 +386,32 @@ def AlterOpLayout(): return _ffi_api.AlterOpLayout() +class LayoutConfig(object): + """A structure for customizing the ConvertLayout pass.""" + current = None + + def __init__(self, skip_layers = []): + self.skip_counter = 0 + self.skip_layers = skip_layers + + def check_skip(self): + skip = self.skip_counter in self.skip_layers + self.skip_counter += 1 + return skip + + def reset(self): + self.skip_counter = 0 + self.skip_layers = [] + + def __enter__(self): + self._old_manager = LayoutConfig.current + LayoutConfig.current = self + return self + + def __exit__(self, ptype, value, trace): + LayoutConfig.current = self._old_manager + + def ConvertLayout(desired_layouts): """Given a dest layout, this pass transforms the expr such that most of the ops input data layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 18245badcc74..c59b10103391 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -268,7 +268,7 @@ void GraphRuntime::SetupStorage() { CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; - CHECK(bits % 8U == 0U || bits == 1U); + CHECK(bits % 8U == 0U || bits == 1U || bits == 4U); size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 40aef264a335..6c3ef00bd739 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1162,6 +1162,66 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_convert_with_config(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + y2 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y2 = relay.nn.relu(y2) + + out = relay.Function([x, weight, weight2], y2) + return out + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + weight2 = relay.layout_transform(weight2, 'HWIO', 'HWOI') + + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y = relay.layout_transform(y, 'NHWC', 'HWNC') + + y2 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='HWNC', + kernel_layout='HWOI') + y2 = relay.nn.relu(y2) + + y2 = relay.layout_transform(y2, 'HWNC', 'NHWC') + output = relay.Function(relay.analysis.free_vars(y2), y2) + return output + + a = before() + layout_config = relay.transform.LayoutConfig(skip_layers=[0]) + with layout_config: + a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['HWNC', 'default']})) + b = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1185,3 +1245,4 @@ def expected(): test_default_keyword() test_different_ops_convert_layout() test_no_desired_layout() + test_convert_with_config() From 7e71d4b0a556677840bc938be6151cd7efa65726 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 22 Oct 2020 23:31:08 +0000 Subject: [PATCH 2/9] Add first draft at recast pass. --- python/tvm/relay/depth_count.py | 121 +++++++++++++++++++++++++++ python/tvm/relay/transform/recast.py | 88 +++++++++++++++++++ 2 files changed, 209 insertions(+) create mode 100644 python/tvm/relay/depth_count.py create mode 100644 python/tvm/relay/transform/recast.py diff --git a/python/tvm/relay/depth_count.py b/python/tvm/relay/depth_count.py new file mode 100644 index 000000000000..696aaeb3a404 --- /dev/null +++ b/python/tvm/relay/depth_count.py @@ -0,0 +1,121 @@ +from tvm.ir import Op +from tvm import relay + +from tvm.relay import ExprVisitor +from tvm.relay.function import Function +from tvm.relay.expr import Call, Let, Var, GlobalVar +from tvm.relay.expr import If, Tuple, TupleGetItem, Constant +from tvm.relay.expr import RefCreate, RefRead, RefWrite +from tvm.relay.adt import Constructor, Match, Clause + +# TODO Make local_count some more generic name and separate this into a base class that +# allows stuff to be passed around and a specific implementation for counting depth. +# Also make one for exprmutator. Good to have both. +# Add to relay.whatever + +class DepthCounter(ExprVisitor): + """Determine how many operations of the specified type are in the graph.""" + def __init__(self, valid_ops): + self.depth_count = 0 + self.valid_ops = [relay.op.get(op) for op in valid_ops] + super().__init__() + + # pylint: disable=no-else-return + def visit(self, expr, local_count=0): + """Apply the visitor to an expression.""" + if expr in self.memo_map: + return self.memo_map[expr] + + if isinstance(expr, Function): + res = self.visit_function(expr, local_count) + elif isinstance(expr, Call): + res = self.visit_call(expr, local_count) + elif isinstance(expr, Let): + res = self.visit_let(expr, local_count) + elif isinstance(expr, Var): + res = self.visit_var(expr, local_count) + elif isinstance(expr, GlobalVar): + res = self.visit_global_var(expr, local_count) + elif isinstance(expr, If): + res = self.visit_if(expr, local_count) + elif isinstance(expr, Tuple): + res = self.visit_tuple(expr, local_count) + elif isinstance(expr, TupleGetItem): + res = self.visit_tuple_getitem(expr, local_count) + elif isinstance(expr, Constant): + res = self.visit_constant(expr, local_count) + elif isinstance(expr, Op): + res = self.visit_op(expr, local_count) + elif isinstance(expr, RefCreate): + res = self.visit_ref_create(expr, local_count) + elif isinstance(expr, RefRead): + res = self.visit_ref_read(expr, local_count) + elif isinstance(expr, RefWrite): + res = self.visit_ref_write(expr, local_count) + elif isinstance(expr, Constructor): + res = self.visit_constructor(expr, local_count) + elif isinstance(expr, Match): + res = self.visit_match(expr, local_count) + else: + raise Exception("warning unhandled case: {0}".format(type(expr))) + + self.memo_map[expr] = res + + return res + + def visit_call(self, call, local_count): + if call.op in self.valid_ops: + local_count = local_count + 1 + self.depth_count = max(self.depth_count, local_count) + for arg in call.args: + self.visit(arg, local_count) + + def visit_tuple(self, tup, local_count): + for x in tup.fields: + self.visit(x, local_count) + + def visit_var(self, var, local_count): + pass + + def visit_let(self, let, local_count): + self.visit(let.var, local_count) + self.visit(let.value, local_count) + self.visit(let.body, local_count) + + def visit_function(self, f, local_count): + self.visit(f.body, local_count) + + def visit_if(self, i, local_count): + self.visit(i.cond, local_count) + self.visit(i.true_branch, local_count) + self.visit(i.false_branch, local_count) + + def visit_global_var(self, gv, local_count): + pass + + def visit_constructor(self, c, local_count): + pass + + def visit_op(self, op, local_count): + pass + + def visit_constant(self, const, local_count): + pass + + def visit_ref_create(self, r, local_count): + self.visit(r.value, local_count) + + def visit_ref_read(self, r, local_count): + self.visit(r.ref, local_count) + + def visit_ref_write(self, r, local_count): + self.visit(r.ref, local_count) + self.visit(r.value, local_count) + + def visit_tuple_getitem(self, t, local_count): + self.visit(t.tuple_value, local_count) + + def visit_match(self, m, local_count): + self.visit(m.data, local_count) + for c in m.clauses: + self.visit(c.rhs, local_count) \ No newline at end of file diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py new file mode 100644 index 000000000000..23f3d612953c --- /dev/null +++ b/python/tvm/relay/transform/recast.py @@ -0,0 +1,88 @@ +"""Relay Downcast from Full-precision to Half-precision floating-point Pass""" +import tvm +from tvm import relay +from tvm.relay import ExprVisitor, ExprMutator, Call, Var, Constant, TupleGetItem, Function +import tvm.relay.transform as _transform +from tvm.relay.frontend.common import infer_type +from depth_count import DepthCounter + + +def recast(func, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): + # pylint: disable=line-too-long + """Downcast mutator + Parameters + --------- + function: Function + The original function that will have its type changed. + dtype: str + The target type to cast to. + out_dtype: str + The output type to cast to. + ops: List[str] + A list of operations that should have their type changed, + others will be left as is. + skip_layers: List[int] + A list of integers indicating operations that should + not have their type changed, counted starting with the + first valid operation encountered. Negative indices are + allowed and indicate starting at the last layer. + Returns + ------- + The graph after downcasting to the specified datatype. + """ + # Collect valid relay ops that should be recast. + valid_ops = [relay.op.get(op) for op in ops] + + class RecastMutator(ExprMutator): + """Cast operations to the target type.""" + def __init__(self, valid_op_count): + self.layer_count = 0 + self.valid_op_count = valid_op_count + self.skip_layers = skip_layers + # Convert negative indices to positive ones. + for i, layer in enumerate(skip_layers): + if layer < 0: + skip_layers[i] = self.valid_op_count + layer + super().__init__() + + def set_attr_dtype(self, attrs): + new_attr_dict= {} + for attr in attrs.keys(): + attr_value = attrs[attr] + if isinstance(attr_value, tvm.ir.container.Array): + attr_value = tuple(attr_value) + new_attr_dict[str(attr)] = attr_value + new_attr_dict['out_dtype'] = out_dtype + attr_type = str(attrs).split('(')[0] + return tvm.ir.make_node(attr_type, **new_attr_dict) + + def visit_call(self, call): + if call.op in valid_ops: + layer_count = self.valid_op_count - self.layer_count - 1 + self.layer_count += 1 + print(layer_count) + print(call) + print("\n\n\n") + if layer_count in skip_layers: + return super().visit_call(call) + + # Otherwise recast its inputs. + new_fn = self.visit(call.op) + args = [self.visit(arg) for arg in call.args] + self.layer_count = 0 + new_args = list() + for arg in args: + new_args.append(relay.cast(arg, dtype=dtype)) + new_attrs = self.set_attr_dtype(call.attrs) + # Recast the output for compatibility with other graph operations. + return relay.cast(Call(new_fn, new_args, new_attrs), infer_type(args[0]).checked_type.dtype) + + return super().visit_call(call) + + count_pass = DepthCounter(ops) + count_pass.visit(func) + print(count_pass.depth_count) + exit() + recast_pass = RecastMutator(count_pass.valid_op_count) + func = recast_pass.visit(func) + return tvm.IRModule.from_expr(func) \ No newline at end of file From b68f5154f550589676fb22f2a19d029e6cd4c174 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 23 Oct 2020 20:54:13 +0000 Subject: [PATCH 3/9] Layer count pass now working and tested. --- python/tvm/relay/analysis/__init__.py | 3 + python/tvm/relay/analysis/count_layers.py | 66 ++++++++++++ python/tvm/relay/depth_count.py | 121 ---------------------- python/tvm/relay/transform/recast.py | 7 +- tests/python/relay/test_layer_count.py | 32 ++++++ 5 files changed, 104 insertions(+), 125 deletions(-) create mode 100644 python/tvm/relay/analysis/count_layers.py delete mode 100644 python/tvm/relay/depth_count.py create mode 100644 tests/python/relay/test_layer_count.py diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index e5b21cb107f5..b062c6f11f5c 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -29,3 +29,6 @@ # Feature from . import feature from . import sparse_dense + +# Utilities +from . import count_layers diff --git a/python/tvm/relay/analysis/count_layers.py b/python/tvm/relay/analysis/count_layers.py new file mode 100644 index 000000000000..ce7df797f8f8 --- /dev/null +++ b/python/tvm/relay/analysis/count_layers.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from ..expr_functor import ExprVisitor + + +class LayerCounter(ExprVisitor): + """A visitor pass that computes the deepest chain of specified ops in graph.""" + def __init__(self, valid_ops): + self.depth_count = 0 + self.deepest_count = 0 + self.valid_ops = [relay.op.get(op) for op in valid_ops] + super().__init__() + + def visit_call(self, call): + if call.op in self.valid_ops: + self.depth_count = self.depth_count + 1 + current_count = self.depth_count + self.deepest_count = max(self.deepest_count, current_count) + for arg in call.args: + self.visit(arg) + self.depth_count = current_count + + def count(self): + return self.deepest_count + + +def count_layers(expr, valid_ops): + """Determine the number of layers of specified ops in a graph. + This pass computes only the deepest chain of ops rather than the + total number of ops in a graph. Thus, if there are two parallel + convolutions (for example), they would be considered a single layer. + + Parameters + ---------- + expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule. + The input expression. + + valid_ops: List[str] + A list of the operations that should be included in the count. + + Returns + ------- + layer_count : int + The number of layers of the specified operations found in the graph. + """ + if isinstance(expr, tvm.ir.IRModule): + expr = expr['main'] + count_pass = LayerCounter(valid_ops) + count_pass.visit(expr) + return count_pass.count() \ No newline at end of file diff --git a/python/tvm/relay/depth_count.py b/python/tvm/relay/depth_count.py deleted file mode 100644 index 696aaeb3a404..000000000000 --- a/python/tvm/relay/depth_count.py +++ /dev/null @@ -1,121 +0,0 @@ -from tvm.ir import Op -from tvm import relay - -from tvm.relay import ExprVisitor -from tvm.relay.function import Function -from tvm.relay.expr import Call, Let, Var, GlobalVar -from tvm.relay.expr import If, Tuple, TupleGetItem, Constant -from tvm.relay.expr import RefCreate, RefRead, RefWrite -from tvm.relay.adt import Constructor, Match, Clause - -# TODO Make local_count some more generic name and separate this into a base class that -# allows stuff to be passed around and a specific implementation for counting depth. -# Also make one for exprmutator. Good to have both. -# Add to relay.whatever - -class DepthCounter(ExprVisitor): - """Determine how many operations of the specified type are in the graph.""" - def __init__(self, valid_ops): - self.depth_count = 0 - self.valid_ops = [relay.op.get(op) for op in valid_ops] - super().__init__() - - # pylint: disable=no-else-return - def visit(self, expr, local_count=0): - """Apply the visitor to an expression.""" - if expr in self.memo_map: - return self.memo_map[expr] - - if isinstance(expr, Function): - res = self.visit_function(expr, local_count) - elif isinstance(expr, Call): - res = self.visit_call(expr, local_count) - elif isinstance(expr, Let): - res = self.visit_let(expr, local_count) - elif isinstance(expr, Var): - res = self.visit_var(expr, local_count) - elif isinstance(expr, GlobalVar): - res = self.visit_global_var(expr, local_count) - elif isinstance(expr, If): - res = self.visit_if(expr, local_count) - elif isinstance(expr, Tuple): - res = self.visit_tuple(expr, local_count) - elif isinstance(expr, TupleGetItem): - res = self.visit_tuple_getitem(expr, local_count) - elif isinstance(expr, Constant): - res = self.visit_constant(expr, local_count) - elif isinstance(expr, Op): - res = self.visit_op(expr, local_count) - elif isinstance(expr, RefCreate): - res = self.visit_ref_create(expr, local_count) - elif isinstance(expr, RefRead): - res = self.visit_ref_read(expr, local_count) - elif isinstance(expr, RefWrite): - res = self.visit_ref_write(expr, local_count) - elif isinstance(expr, Constructor): - res = self.visit_constructor(expr, local_count) - elif isinstance(expr, Match): - res = self.visit_match(expr, local_count) - else: - raise Exception("warning unhandled case: {0}".format(type(expr))) - - self.memo_map[expr] = res - - return res - - def visit_call(self, call, local_count): - if call.op in self.valid_ops: - local_count = local_count + 1 - self.depth_count = max(self.depth_count, local_count) - for arg in call.args: - self.visit(arg, local_count) - - def visit_tuple(self, tup, local_count): - for x in tup.fields: - self.visit(x, local_count) - - def visit_var(self, var, local_count): - pass - - def visit_let(self, let, local_count): - self.visit(let.var, local_count) - self.visit(let.value, local_count) - self.visit(let.body, local_count) - - def visit_function(self, f, local_count): - self.visit(f.body, local_count) - - def visit_if(self, i, local_count): - self.visit(i.cond, local_count) - self.visit(i.true_branch, local_count) - self.visit(i.false_branch, local_count) - - def visit_global_var(self, gv, local_count): - pass - - def visit_constructor(self, c, local_count): - pass - - def visit_op(self, op, local_count): - pass - - def visit_constant(self, const, local_count): - pass - - def visit_ref_create(self, r, local_count): - self.visit(r.value, local_count) - - def visit_ref_read(self, r, local_count): - self.visit(r.ref, local_count) - - def visit_ref_write(self, r, local_count): - self.visit(r.ref, local_count) - self.visit(r.value, local_count) - - def visit_tuple_getitem(self, t, local_count): - self.visit(t.tuple_value, local_count) - - def visit_match(self, m, local_count): - self.visit(m.data, local_count) - for c in m.clauses: - self.visit(c.rhs, local_count) \ No newline at end of file diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 23f3d612953c..58ec501a8578 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -4,7 +4,7 @@ from tvm.relay import ExprVisitor, ExprMutator, Call, Var, Constant, TupleGetItem, Function import tvm.relay.transform as _transform from tvm.relay.frontend.common import infer_type -from depth_count import DepthCounter +from tvm.relay.analysis import count_layers def recast(func, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): @@ -79,9 +79,8 @@ def visit_call(self, call): return super().visit_call(call) - count_pass = DepthCounter(ops) - count_pass.visit(func) - print(count_pass.depth_count) + layer_depth = count_layers.count_layers(func, ['nn.conv2d', 'nn.dense']) + print(layer_depth) exit() recast_pass = RecastMutator(count_pass.valid_op_count) func = recast_pass.visit(func) diff --git a/tests/python/relay/test_layer_count.py b/tests/python/relay/test_layer_count.py new file mode 100644 index 000000000000..c143b855d515 --- /dev/null +++ b/tests/python/relay/test_layer_count.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from tvm.relay.testing import resnet +from tvm.relay.analysis.count_layers import count_layers + +def test_layer_count(): + def verify(num_layers): + # Load a resnet with a known number of layers. + mod, _ = resnet.get_workload(num_layers=num_layers) + # Count the number of conv and dense layers. + count = count_layers(mod, valid_ops=['nn.conv2d', 'nn.dense']) + assert count == num_layers + + verify(18) + verify(50) + +if __name__ == "__main__": + test_layer_count() \ No newline at end of file From 92da6259841e2878126712d72a47d7a7e03d5ce8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 23 Oct 2020 21:34:34 +0000 Subject: [PATCH 4/9] Recast pass now working as expected. --- python/tvm/relay/analysis/__init__.py | 2 +- python/tvm/relay/analysis/count_layers.py | 2 +- python/tvm/relay/transform/__init__.py | 1 + python/tvm/relay/transform/recast.py | 161 +++++++++++++--------- tests/python/relay/test_layer_count.py | 2 +- 5 files changed, 102 insertions(+), 66 deletions(-) diff --git a/python/tvm/relay/analysis/__init__.py b/python/tvm/relay/analysis/__init__.py index b062c6f11f5c..b4ea7f3cff62 100644 --- a/python/tvm/relay/analysis/__init__.py +++ b/python/tvm/relay/analysis/__init__.py @@ -31,4 +31,4 @@ from . import sparse_dense # Utilities -from . import count_layers +from .count_layers import count_layers diff --git a/python/tvm/relay/analysis/count_layers.py b/python/tvm/relay/analysis/count_layers.py index ce7df797f8f8..b3f04a4f557b 100644 --- a/python/tvm/relay/analysis/count_layers.py +++ b/python/tvm/relay/analysis/count_layers.py @@ -29,7 +29,7 @@ def __init__(self, valid_ops): def visit_call(self, call): if call.op in self.valid_ops: - self.depth_count = self.depth_count + 1 + self.depth_count += 1 current_count = self.depth_count self.deepest_count = max(self.deepest_count, current_count) for arg in call.args: diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 138a36611c6f..1d0ea176b16f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -18,4 +18,5 @@ """The Relay IR namespace containing transformations.""" # transformation passes from .transform import * +from .recast import recast from . import memory_alloc diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 58ec501a8578..f50583d81dc7 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -1,18 +1,97 @@ -"""Relay Downcast from Full-precision to Half-precision floating-point Pass""" +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Relay type recasting pass""" import tvm from tvm import relay -from tvm.relay import ExprVisitor, ExprMutator, Call, Var, Constant, TupleGetItem, Function -import tvm.relay.transform as _transform +from tvm.relay import ExprMutator, Call from tvm.relay.frontend.common import infer_type from tvm.relay.analysis import count_layers +class RecastMutator(ExprMutator): + """Cast operations to the target type.""" + def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers=[]): + self.dtype = dtype + self.out_dtype = out_dtype + self.depth_count = 0 + self.valid_ops = [relay.op.get(op) for op in valid_ops] + self.valid_op_count = valid_op_count + self.skip_layers = skip_layers + # Convert negative indices to positive ones. + for i, layer in enumerate(skip_layers): + if layer < 0: + skip_layers[i] = self.valid_op_count + layer + super().__init__() + + def visit_call(self, call): + # Keep track of our current depth and layer count + # so we can know whether to skip this layer or not. + current_depth = self.depth_count + current_layer = self.valid_op_count - current_depth - 1 + if call.op in self.valid_ops: + self.depth_count += 1 + # Visit current call operation + new_fn = self.visit(call.op) + # Visit current arguments + args = [] + for arg in call.args: + args.append(self.visit(arg)) + self.depth_count = current_depth + + # Downcast this op if its the correct type and not skipped. + if call.op in self.valid_ops and current_layer not in self.skip_layers: + # Recast inputs to specified type. + args = [self.visit(arg) for arg in call.args] + new_args = list() + for arg in args: + new_args.append(relay.cast(arg, dtype=self.dtype)) + + # If out_dtype is in the attributes, we need to update it. + if 'out_dtype' in call.attrs.keys(): + new_attr_dict= {} + for attr in call.attrs.keys(): + attr_value = call.attrs[attr] + if isinstance(attr_value, tvm.ir.container.Array): + attr_value = tuple(attr_value) + new_attr_dict[str(attr)] = attr_value + new_attr_dict['out_dtype'] = self.out_dtype + attr_type = str(call.attrs).split('(')[0] + new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) + else: + new_attrs = call.attrs + # Recast the output for compatibility with other graph operations. + return relay.cast(Call(new_fn, new_args, new_attrs), infer_type(args[0]).checked_type.dtype) + + # Otherwise return the unchanged call. + return Call(new_fn, args, call.attrs) + + +def recast(expr, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): + """Convert the types of operations in a graph to a new value. + Note that this is primarily useful for testing performance of individual + operations at the new datatype. In a real setting, this pass will + almost certainly do a poor job converting from one datatype to another + as it just applies hard casting. For example, when recasting from float + to integer, many small values will simply be set to 0. Although this will + allow autotuning and benchmarking to produce proper timings at the new + data type, the output of the model will of course be heavily impacted. -def recast(func, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): - # pylint: disable=line-too-long - """Downcast mutator Parameters --------- - function: Function + expr: tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule The original function that will have its type changed. dtype: str The target type to cast to. @@ -28,60 +107,16 @@ def recast(func, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): allowed and indicate starting at the last layer. Returns ------- - The graph after downcasting to the specified datatype. + output_expr : tvm.relay.Expr, tvm.relay.Function, or tvm.ir.IRModule + The graph after recasting to the specified datatype. """ - # Collect valid relay ops that should be recast. - valid_ops = [relay.op.get(op) for op in ops] - - class RecastMutator(ExprMutator): - """Cast operations to the target type.""" - def __init__(self, valid_op_count): - self.layer_count = 0 - self.valid_op_count = valid_op_count - self.skip_layers = skip_layers - # Convert negative indices to positive ones. - for i, layer in enumerate(skip_layers): - if layer < 0: - skip_layers[i] = self.valid_op_count + layer - super().__init__() - - def set_attr_dtype(self, attrs): - new_attr_dict= {} - for attr in attrs.keys(): - attr_value = attrs[attr] - if isinstance(attr_value, tvm.ir.container.Array): - attr_value = tuple(attr_value) - new_attr_dict[str(attr)] = attr_value - new_attr_dict['out_dtype'] = out_dtype - attr_type = str(attrs).split('(')[0] - return tvm.ir.make_node(attr_type, **new_attr_dict) - - def visit_call(self, call): - if call.op in valid_ops: - layer_count = self.valid_op_count - self.layer_count - 1 - self.layer_count += 1 - print(layer_count) - print(call) - print("\n\n\n") - if layer_count in skip_layers: - return super().visit_call(call) - - # Otherwise recast its inputs. - new_fn = self.visit(call.op) - args = [self.visit(arg) for arg in call.args] - self.layer_count = 0 - new_args = list() - for arg in args: - new_args.append(relay.cast(arg, dtype=dtype)) - new_attrs = self.set_attr_dtype(call.attrs) - # Recast the output for compatibility with other graph operations. - return relay.cast(Call(new_fn, new_args, new_attrs), infer_type(args[0]).checked_type.dtype) - - return super().visit_call(call) - - layer_depth = count_layers.count_layers(func, ['nn.conv2d', 'nn.dense']) - print(layer_depth) - exit() - recast_pass = RecastMutator(count_pass.valid_op_count) - func = recast_pass.visit(func) - return tvm.IRModule.from_expr(func) \ No newline at end of file + return_mod = False + if isinstance(expr, tvm.ir.IRModule): + expr = expr['main'] + return_mod = True + layer_depth = count_layers.count_layers(expr, ops) + recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) + expr = recast_pass.visit(expr) + if return_mod: + return tvm.IRModule.from_expr(expr) + return expr \ No newline at end of file diff --git a/tests/python/relay/test_layer_count.py b/tests/python/relay/test_layer_count.py index c143b855d515..8d603e5a30ca 100644 --- a/tests/python/relay/test_layer_count.py +++ b/tests/python/relay/test_layer_count.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from tvm.relay.testing import resnet -from tvm.relay.analysis.count_layers import count_layers +from tvm.relay.analysis import count_layers def test_layer_count(): def verify(num_layers): From f8d24938a5233ae159c89fda92860395f33a3b24 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 23 Oct 2020 22:36:44 +0000 Subject: [PATCH 5/9] Recast tests added. --- python/tvm/relay/transform/recast.py | 22 ++++-- tests/python/relay/test_recast.py | 104 +++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 5 deletions(-) create mode 100644 tests/python/relay/test_recast.py diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index f50583d81dc7..0156cc91e60a 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -17,9 +17,11 @@ """Relay type recasting pass""" import tvm from tvm import relay -from tvm.relay import ExprMutator, Call -from tvm.relay.frontend.common import infer_type -from tvm.relay.analysis import count_layers +from tvm.ir import IRModule +from .transform import InferType +from ..function import Function +from ..analysis import count_layers +from ..expr_functor import ExprMutator, Call class RecastMutator(ExprMutator): """Cast operations to the target type.""" @@ -60,6 +62,7 @@ def visit_call(self, call): new_args.append(relay.cast(arg, dtype=self.dtype)) # If out_dtype is in the attributes, we need to update it. + orig_dtype = None if 'out_dtype' in call.attrs.keys(): new_attr_dict= {} for attr in call.attrs.keys(): @@ -70,10 +73,19 @@ def visit_call(self, call): new_attr_dict['out_dtype'] = self.out_dtype attr_type = str(call.attrs).split('(')[0] new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) + if call.attrs['out_dtype'] != "": + orig_dtype = call.attrs['out_dtype'] else: new_attrs = call.attrs + + if orig_dtype is None: + # Perform type inference to determine the original type. + new_mod = IRModule.from_expr(args[0]) + entry = new_mod['main'] + checked_arg = entry if isinstance(args[0], Function) else entry.body + orig_dtype = checked_arg.checked_type.dtype # Recast the output for compatibility with other graph operations. - return relay.cast(Call(new_fn, new_args, new_attrs), infer_type(args[0]).checked_type.dtype) + return relay.cast(Call(new_fn, new_args, new_attrs), orig_dtype) # Otherwise return the unchanged call. return Call(new_fn, args, call.attrs) @@ -114,7 +126,7 @@ def recast(expr, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): if isinstance(expr, tvm.ir.IRModule): expr = expr['main'] return_mod = True - layer_depth = count_layers.count_layers(expr, ops) + layer_depth = count_layers(expr, ops) recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) expr = recast_pass.visit(expr) if return_mod: diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py new file mode 100644 index 000000000000..1c6c2aecc53d --- /dev/null +++ b/tests/python/relay/test_recast.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay.transform import recast + +def test_recast_simple(): + """Recast a single convolution operator.""" + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + return relay.Function([x, w], c) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + x_int = relay.cast(x, 'int8') + w_int = relay.cast(w, 'int8') + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype='int32') + c_float = relay.cast(c, 'float32') + return relay.Function([x, w], c_float) + + pre = before() + post = recast(pre, 'int8', 'int32') + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +def test_recast_medium(): + """Recast a slightly larger graph.""" + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype='float32') + return relay.Function([x, w, w2], c2) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + x_int = relay.cast(x, 'int8') + w_int = relay.cast(w, 'int8') + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype='int32') + c_float = relay.cast(c, 'float32') + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + w2_int = relay.cast(w2, 'int8') + c_float_int = relay.cast(c_float, 'int8') + c2 = relay.nn.conv2d(c_float_int, w2_int, padding=(1, 1), out_dtype='int32') + c2_float = relay.cast(c2, 'float32') + return relay.Function([x, w, w2], c2_float) + + pre = before() + post = recast(pre, 'int8', 'int32') + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +def test_recast_skip(): + """Recast a graph using skip layers.""" + def before(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype='float32') + return relay.Function([x, w, w2], c2) + + def expected(): + x = relay.var("x", shape=[8, 8, 8, 8]) + w = relay.var("w", shape=[8, 8, 3, 3]) + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + w2 = relay.var("w2", shape=[8, 8, 3, 3]) + w2_int = relay.cast(w2, 'int8') + c_int = relay.cast(c, 'int8') + c2 = relay.nn.conv2d(c_int, w2_int, padding=(1, 1), out_dtype='int32') + c2_float = relay.cast(c2, 'float32') + return relay.Function([x, w, w2], c2_float) + + pre = before() + post = recast(pre, 'int8', 'int32', skip_layers=[0]) + expected = expected() + assert tvm.ir.structural_equal(expected, post) + + +if __name__ == "__main__": + test_recast_simple() + test_recast_medium() + test_recast_skip() From 079d8f49d8c462cb5a511bb2a160cb59eeef5bfa Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 23 Oct 2020 22:55:39 +0000 Subject: [PATCH 6/9] Formatting applied. --- python/tvm/relay/analysis/count_layers.py | 9 ++- python/tvm/relay/op/nn/_nn.py | 6 +- python/tvm/relay/op/strategy/cuda.py | 12 +-- python/tvm/relay/transform/recast.py | 28 +++---- python/tvm/relay/transform/transform.py | 7 +- tests/python/relay/test_layer_count.py | 8 +- .../relay/test_pass_convert_op_layout.py | 78 +++++++++++-------- tests/python/relay/test_recast.py | 54 +++++++------ 8 files changed, 113 insertions(+), 89 deletions(-) diff --git a/python/tvm/relay/analysis/count_layers.py b/python/tvm/relay/analysis/count_layers.py index b3f04a4f557b..9362c250dcc6 100644 --- a/python/tvm/relay/analysis/count_layers.py +++ b/python/tvm/relay/analysis/count_layers.py @@ -21,15 +21,16 @@ class LayerCounter(ExprVisitor): """A visitor pass that computes the deepest chain of specified ops in graph.""" + def __init__(self, valid_ops): self.depth_count = 0 self.deepest_count = 0 self.valid_ops = [relay.op.get(op) for op in valid_ops] super().__init__() - + def visit_call(self, call): if call.op in self.valid_ops: - self.depth_count += 1 + self.depth_count += 1 current_count = self.depth_count self.deepest_count = max(self.deepest_count, current_count) for arg in call.args: @@ -60,7 +61,7 @@ def count_layers(expr, valid_ops): The number of layers of the specified operations found in the graph. """ if isinstance(expr, tvm.ir.IRModule): - expr = expr['main'] + expr = expr["main"] count_pass = LayerCounter(valid_ops) count_pass.visit(expr) - return count_pass.count() \ No newline at end of file + return count_pass.count() diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index a47e0114a5bc..c9926647989e 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -172,7 +172,7 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): if layout_config is not None: skip_layer = layout_config.check_skip() if skip_layer: - return relay.nn.conv2d(data, weight, **attrs) + return relay.nn.conv2d(data, weight, **attrs) # Prepare new layout. new_attrs = dict(attrs) @@ -203,8 +203,8 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts): else: new_attrs["kernel_layout"] = "HWIO" return relay.nn.conv2d(data, weight, **new_attrs) - elif desired_data_layout == 'HWNC': - new_attrs['kernel_layout'] = 'HWOI' + elif desired_data_layout == "HWNC": + new_attrs["kernel_layout"] = "HWOI" return relay.nn.conv2d(data, weight, **new_attrs) raise ValueError("Layout %s is not yet supported." % desired_data_layout) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index ccd895029bd3..ca44e49ce1dd 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -220,11 +220,13 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): else: _, _, out_channels, _ = get_const_tuple(kernel.shape) - tensorcore_dtypes = ['int4', 'uint4', 'int8', 'uint8'] - if (N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0) or \ - (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0) or \ - (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0) and \ - (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes): + tensorcore_dtypes = ["int4", "uint4", "int8", "uint8"] + if ( + (N % 16 == 0 and in_channels % 16 == 0 and out_channels % 16 == 0) + or (N % 8 == 0 and in_channels % 16 == 0 and out_channels % 32 == 0) + or (N % 32 == 0 and in_channels % 16 == 0 and out_channels % 8 == 0) + and (data.dtype in tensorcore_dtypes and kernel.dtype in tensorcore_dtypes) + ): strategy.add_implementation( wrap_compute_conv2d(topi.cuda.conv2d_hwnc_tensorcore), wrap_topi_schedule(topi.cuda.schedule_conv2d_hwnc_tensorcore), diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 0156cc91e60a..9d72380a759f 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -23,8 +23,10 @@ from ..analysis import count_layers from ..expr_functor import ExprMutator, Call + class RecastMutator(ExprMutator): """Cast operations to the target type.""" + def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers=[]): self.dtype = dtype self.out_dtype = out_dtype @@ -63,35 +65,35 @@ def visit_call(self, call): # If out_dtype is in the attributes, we need to update it. orig_dtype = None - if 'out_dtype' in call.attrs.keys(): - new_attr_dict= {} + if "out_dtype" in call.attrs.keys(): + new_attr_dict = {} for attr in call.attrs.keys(): attr_value = call.attrs[attr] if isinstance(attr_value, tvm.ir.container.Array): attr_value = tuple(attr_value) new_attr_dict[str(attr)] = attr_value - new_attr_dict['out_dtype'] = self.out_dtype - attr_type = str(call.attrs).split('(')[0] + new_attr_dict["out_dtype"] = self.out_dtype + attr_type = str(call.attrs).split("(")[0] new_attrs = tvm.ir.make_node(attr_type, **new_attr_dict) - if call.attrs['out_dtype'] != "": - orig_dtype = call.attrs['out_dtype'] + if call.attrs["out_dtype"] != "": + orig_dtype = call.attrs["out_dtype"] else: new_attrs = call.attrs if orig_dtype is None: # Perform type inference to determine the original type. - new_mod = IRModule.from_expr(args[0]) - entry = new_mod['main'] - checked_arg = entry if isinstance(args[0], Function) else entry.body + new_mod = IRModule.from_expr(call) + new_mod = InferType()(new_mod) + checked_arg = new_mod["main"].body orig_dtype = checked_arg.checked_type.dtype # Recast the output for compatibility with other graph operations. return relay.cast(Call(new_fn, new_args, new_attrs), orig_dtype) - + # Otherwise return the unchanged call. return Call(new_fn, args, call.attrs) -def recast(expr, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): +def recast(expr, dtype, out_dtype, ops=["nn.conv2d"], skip_layers=[]): """Convert the types of operations in a graph to a new value. Note that this is primarily useful for testing performance of individual operations at the new datatype. In a real setting, this pass will @@ -124,11 +126,11 @@ def recast(expr, dtype, out_dtype, ops=['nn.conv2d'], skip_layers=[]): """ return_mod = False if isinstance(expr, tvm.ir.IRModule): - expr = expr['main'] + expr = expr["main"] return_mod = True layer_depth = count_layers(expr, ops) recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) expr = recast_pass.visit(expr) if return_mod: return tvm.IRModule.from_expr(expr) - return expr \ No newline at end of file + return expr diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index cbe9df7aa689..b1561daa5ac4 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -388,9 +388,10 @@ def AlterOpLayout(): class LayoutConfig(object): """A structure for customizing the ConvertLayout pass.""" + current = None - def __init__(self, skip_layers = []): + def __init__(self, skip_layers=[]): self.skip_counter = 0 self.skip_layers = skip_layers @@ -398,7 +399,7 @@ def check_skip(self): skip = self.skip_counter in self.skip_layers self.skip_counter += 1 return skip - + def reset(self): self.skip_counter = 0 self.skip_layers = [] @@ -407,7 +408,7 @@ def __enter__(self): self._old_manager = LayoutConfig.current LayoutConfig.current = self return self - + def __exit__(self, ptype, value, trace): LayoutConfig.current = self._old_manager diff --git a/tests/python/relay/test_layer_count.py b/tests/python/relay/test_layer_count.py index 8d603e5a30ca..f680bb2725f2 100644 --- a/tests/python/relay/test_layer_count.py +++ b/tests/python/relay/test_layer_count.py @@ -17,16 +17,18 @@ from tvm.relay.testing import resnet from tvm.relay.analysis import count_layers + def test_layer_count(): def verify(num_layers): # Load a resnet with a known number of layers. mod, _ = resnet.get_workload(num_layers=num_layers) # Count the number of conv and dense layers. - count = count_layers(mod, valid_ops=['nn.conv2d', 'nn.dense']) + count = count_layers(mod, valid_ops=["nn.conv2d", "nn.dense"]) assert count == num_layers - + verify(18) verify(50) + if __name__ == "__main__": - test_layer_count() \ No newline at end of file + test_layer_count() diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 6c3ef00bd739..1fc5d39b9486 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1165,22 +1165,28 @@ def expected(): def test_convert_with_config(): def before(): x = relay.var("x", shape=(1, 56, 56, 64)) - weight = relay.var('weight', shape=(3, 3, 64, 64)) - y = relay.nn.conv2d(x, weight, - channels=64, - kernel_size=(3, 3), - padding=(1, 1), - data_layout='NHWC', - kernel_layout='HWIO') + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) y = relay.nn.relu(y) - weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) - y2 = relay.nn.conv2d(y, weight2, - channels=64, - kernel_size=(3, 3), - padding=(1, 1), - data_layout='NHWC', - kernel_layout='HWIO') + weight2 = relay.var("weight2", shape=(3, 3, 64, 64)) + y2 = relay.nn.conv2d( + y, + weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) y2 = relay.nn.relu(y2) out = relay.Function([x, weight, weight2], y2) @@ -1188,36 +1194,42 @@ def before(): def expected(): x = relay.var("x", shape=(1, 56, 56, 64)) - weight = relay.var('weight', shape=(3, 3, 64, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) - weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) - weight2 = relay.layout_transform(weight2, 'HWIO', 'HWOI') + weight2 = relay.var("weight2", shape=(3, 3, 64, 64)) + weight2 = relay.layout_transform(weight2, "HWIO", "HWOI") - y = relay.nn.conv2d(x, weight, - channels=64, - kernel_size=(3, 3), - padding=(1, 1), - data_layout='NHWC', - kernel_layout='HWIO') + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) y = relay.nn.relu(y) - y = relay.layout_transform(y, 'NHWC', 'HWNC') - - y2 = relay.nn.conv2d(y, weight2, - channels=64, - kernel_size=(3, 3), - padding=(1, 1), - data_layout='HWNC', - kernel_layout='HWOI') + y = relay.layout_transform(y, "NHWC", "HWNC") + + y2 = relay.nn.conv2d( + y, + weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="HWNC", + kernel_layout="HWOI", + ) y2 = relay.nn.relu(y2) - y2 = relay.layout_transform(y2, 'HWNC', 'NHWC') + y2 = relay.layout_transform(y2, "HWNC", "NHWC") output = relay.Function(relay.analysis.free_vars(y2), y2) return output a = before() layout_config = relay.transform.LayoutConfig(skip_layers=[0]) with layout_config: - a = run_opt_pass(a, transform.ConvertLayout({'nn.conv2d': ['HWNC', 'default']})) + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["HWNC", "default"]})) b = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) diff --git a/tests/python/relay/test_recast.py b/tests/python/relay/test_recast.py index 1c6c2aecc53d..8c5a562ddbba 100644 --- a/tests/python/relay/test_recast.py +++ b/tests/python/relay/test_recast.py @@ -18,82 +18,86 @@ from tvm import relay from tvm.relay.transform import recast + def test_recast_simple(): """Recast a single convolution operator.""" + def before(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") return relay.Function([x, w], c) def expected(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - x_int = relay.cast(x, 'int8') - w_int = relay.cast(w, 'int8') - c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype='int32') - c_float = relay.cast(c, 'float32') + x_int = relay.cast(x, "int8") + w_int = relay.cast(w, "int8") + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype="int32") + c_float = relay.cast(c, "float32") return relay.Function([x, w], c_float) pre = before() - post = recast(pre, 'int8', 'int32') + post = recast(pre, "int8", "int32") expected = expected() assert tvm.ir.structural_equal(expected, post) def test_recast_medium(): """Recast a slightly larger graph.""" + def before(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") w2 = relay.var("w2", shape=[8, 8, 3, 3]) - c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype='float32') + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype="float32") return relay.Function([x, w, w2], c2) def expected(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - x_int = relay.cast(x, 'int8') - w_int = relay.cast(w, 'int8') - c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype='int32') - c_float = relay.cast(c, 'float32') + x_int = relay.cast(x, "int8") + w_int = relay.cast(w, "int8") + c = relay.nn.conv2d(x_int, w_int, padding=(1, 1), out_dtype="int32") + c_float = relay.cast(c, "float32") w2 = relay.var("w2", shape=[8, 8, 3, 3]) - w2_int = relay.cast(w2, 'int8') - c_float_int = relay.cast(c_float, 'int8') - c2 = relay.nn.conv2d(c_float_int, w2_int, padding=(1, 1), out_dtype='int32') - c2_float = relay.cast(c2, 'float32') + w2_int = relay.cast(w2, "int8") + c_float_int = relay.cast(c_float, "int8") + c2 = relay.nn.conv2d(c_float_int, w2_int, padding=(1, 1), out_dtype="int32") + c2_float = relay.cast(c2, "float32") return relay.Function([x, w, w2], c2_float) pre = before() - post = recast(pre, 'int8', 'int32') + post = recast(pre, "int8", "int32") expected = expected() assert tvm.ir.structural_equal(expected, post) def test_recast_skip(): """Recast a graph using skip layers.""" + def before(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") w2 = relay.var("w2", shape=[8, 8, 3, 3]) - c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype='float32') + c2 = relay.nn.conv2d(c, w2, padding=(1, 1), out_dtype="float32") return relay.Function([x, w, w2], c2) def expected(): x = relay.var("x", shape=[8, 8, 8, 8]) w = relay.var("w", shape=[8, 8, 3, 3]) - c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype='float32') + c = relay.nn.conv2d(x, w, padding=(1, 1), out_dtype="float32") w2 = relay.var("w2", shape=[8, 8, 3, 3]) - w2_int = relay.cast(w2, 'int8') - c_int = relay.cast(c, 'int8') - c2 = relay.nn.conv2d(c_int, w2_int, padding=(1, 1), out_dtype='int32') - c2_float = relay.cast(c2, 'float32') + w2_int = relay.cast(w2, "int8") + c_int = relay.cast(c, "int8") + c2 = relay.nn.conv2d(c_int, w2_int, padding=(1, 1), out_dtype="int32") + c2_float = relay.cast(c2, "float32") return relay.Function([x, w, w2], c2_float) pre = before() - post = recast(pre, 'int8', 'int32', skip_layers=[0]) + post = recast(pre, "int8", "int32", skip_layers=[0]) expected = expected() assert tvm.ir.structural_equal(expected, post) From b8710b70f520ce70a0d4f9364a2f9f493fcfc664 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 23 Oct 2020 23:55:26 +0000 Subject: [PATCH 7/9] Style fixes. --- python/tvm/relay/analysis/count_layers.py | 1 + python/tvm/relay/transform/recast.py | 7 +++---- python/tvm/relay/transform/transform.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/relay/analysis/count_layers.py b/python/tvm/relay/analysis/count_layers.py index 9362c250dcc6..93d4f2766284 100644 --- a/python/tvm/relay/analysis/count_layers.py +++ b/python/tvm/relay/analysis/count_layers.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""Utilities that enable counting the number of layers in a graph.""" import tvm from tvm import relay from ..expr_functor import ExprVisitor diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 9d72380a759f..11150f15e795 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -19,7 +19,6 @@ from tvm import relay from tvm.ir import IRModule from .transform import InferType -from ..function import Function from ..analysis import count_layers from ..expr_functor import ExprMutator, Call @@ -27,13 +26,13 @@ class RecastMutator(ExprMutator): """Cast operations to the target type.""" - def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers=[]): + def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers=None): self.dtype = dtype self.out_dtype = out_dtype self.depth_count = 0 self.valid_ops = [relay.op.get(op) for op in valid_ops] self.valid_op_count = valid_op_count - self.skip_layers = skip_layers + self.skip_layers = skip_layers if skip_layers is not None else [] # Convert negative indices to positive ones. for i, layer in enumerate(skip_layers): if layer < 0: @@ -93,7 +92,7 @@ def visit_call(self, call): return Call(new_fn, args, call.attrs) -def recast(expr, dtype, out_dtype, ops=["nn.conv2d"], skip_layers=[]): +def recast(expr, dtype, out_dtype, ops=["nn.conv2d"], skip_layers=None): """Convert the types of operations in a graph to a new value. Note that this is primarily useful for testing performance of individual operations at the new datatype. In a real setting, this pass will diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index b1561daa5ac4..060547e4c4d7 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -391,9 +391,9 @@ class LayoutConfig(object): current = None - def __init__(self, skip_layers=[]): + def __init__(self, skip_layers=None): self.skip_counter = 0 - self.skip_layers = skip_layers + self.skip_layers = skip_layers if skip_layers is not None else [] def check_skip(self): skip = self.skip_counter in self.skip_layers From 4f8e47744197e17e89c62446bc271c2dd740e787 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sat, 24 Oct 2020 00:10:50 +0000 Subject: [PATCH 8/9] Another style fix. --- python/tvm/relay/transform/recast.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/transform/recast.py b/python/tvm/relay/transform/recast.py index 11150f15e795..05a72676a907 100644 --- a/python/tvm/relay/transform/recast.py +++ b/python/tvm/relay/transform/recast.py @@ -26,13 +26,13 @@ class RecastMutator(ExprMutator): """Cast operations to the target type.""" - def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers=None): + def __init__(self, dtype, out_dtype, valid_ops, valid_op_count, skip_layers): self.dtype = dtype self.out_dtype = out_dtype self.depth_count = 0 self.valid_ops = [relay.op.get(op) for op in valid_ops] self.valid_op_count = valid_op_count - self.skip_layers = skip_layers if skip_layers is not None else [] + self.skip_layers = skip_layers # Convert negative indices to positive ones. for i, layer in enumerate(skip_layers): if layer < 0: @@ -92,7 +92,7 @@ def visit_call(self, call): return Call(new_fn, args, call.attrs) -def recast(expr, dtype, out_dtype, ops=["nn.conv2d"], skip_layers=None): +def recast(expr, dtype, out_dtype, ops=None, skip_layers=None): """Convert the types of operations in a graph to a new value. Note that this is primarily useful for testing performance of individual operations at the new datatype. In a real setting, this pass will @@ -127,6 +127,10 @@ def recast(expr, dtype, out_dtype, ops=["nn.conv2d"], skip_layers=None): if isinstance(expr, tvm.ir.IRModule): expr = expr["main"] return_mod = True + if ops is None: + ops = ["nn.conv2d"] + if skip_layers is None: + skip_layers = [] layer_depth = count_layers(expr, ops) recast_pass = RecastMutator(dtype, out_dtype, ops, layer_depth, skip_layers) expr = recast_pass.visit(expr) From 0055631aa86cfedde5a34710b9e2c17ef718bdf8 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Sun, 25 Oct 2020 11:35:06 -0700 Subject: [PATCH 9/9] Remove extra newline. --- src/runtime/graph/graph_runtime.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 968f7fe270d7..45a36900b586 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -269,7 +269,6 @@ void GraphRuntime::SetupStorage() { DLDataType t = vtype[i]; size_t bits = t.bits * t.lanes; ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U); - size_t bytes = ((bits + 7U) / 8U) * size; uint32_t sid = static_cast(storage_id);