diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index e7d2cfc5b3c01..caa735b33abeb 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -700,7 +700,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): new_attrs = {k: attrs[k] for k in attrs.keys()} - if F == tvm.relay.op: + if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] diff --git a/topi/python/topi/cuda/conv2d_winograd.py b/topi/python/topi/cuda/conv2d_winograd.py index 4020b0713acd7..3d46bcd81c705 100644 --- a/topi/python/topi/cuda/conv2d_winograd.py +++ b/topi/python/topi/cuda/conv2d_winograd.py @@ -371,7 +371,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): copy_inputs = [s for s in inputs] new_attrs = {k: attrs[k] for k in attrs.keys()} - if F == tvm.relay.op: + if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 3ef7799fc2123..bab49a7524aef 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -54,7 +54,6 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None @conv2d_alter_layout.register(["intel_graphics"]) def _alter_conv2d_layout(attrs, inputs, tinfos, F): - import nnvm.symbol as sym copy_inputs = [s for s in inputs] @@ -75,11 +74,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, F): new_attrs = {k: attrs[k] for k in attrs.keys()} new_attrs["kernel_layout"] = 'OIHW%do' % (oc_bn) - if F == tvm.relay.op: + if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] - if F == sym: + if F.__name__ == 'nnvm.symbol': out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) else: out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 8884a23d322df..02f78f8007f9d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -323,12 +323,11 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs): @conv2d_alter_layout.register("cpu") def _alter_conv2d_layout(attrs, inputs, tinfo, F): - import nnvm.symbol as sym copy_inputs = [s for s in inputs] new_attrs = {k : attrs[k] for k in attrs.keys()} - if F == tvm.relay.op: + if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] @@ -336,13 +335,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): batch_size, in_channel, height, width = get_const_tuple(data.shape) groups = attrs.get_int("groups") - out_channel = attrs.get_int("channels") if F == sym else new_attrs["channels"] + out_channel = attrs.get_int("channels") \ + if F.__name__ == 'nnvm.symbol' else new_attrs["channels"] padding = attrs.get_int_tuple("padding") strides = attrs.get_int_tuple("strides") dilation = attrs.get_int_tuple("dilation") out_dtype = attrs["out_dtype"] - layout_name = 'layout' if F == sym else 'data_layout' + layout_name = 'layout' if F.__name__ == 'nnvm.symbol' else 'data_layout' layout = attrs[layout_name] kh, kw = attrs.get_int_tuple("kernel_size") @@ -399,12 +399,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): dispatch_ctx.update(target, new_workload, cfg) if is_depthwise: - if F == sym: + if F.__name__ == 'nnvm.symbol': logging.warning("Use native layout for depthwise convolution on NNVM.") return None return F.nn.contrib_depthwise_conv2d_nchwc(*copy_inputs, **new_attrs) else: - if F == sym: + if F.__name__ == 'nnvm.symbol': return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)