Skip to content

Commit

Permalink
Fix the auto-tuner by registering the correct schedules
Browse files Browse the repository at this point in the history
Change-Id: Id9273688b2620e1ea849ab01b4c46af8fbf37fd0
  • Loading branch information
Giuseppe Rossini committed Jun 11, 2020
1 parent 7195ba2 commit b6dc7c5
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
name="compute_conv2d_NHWC_quantized.arm_cpu")
name="conv2d_NHWC_quantized.arm_cpu")

strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
new_attrs['out_layout'], out_dtype], topi_tmpl)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
if topi_tmpl == "compute_conv2d_NHWC_quantized.arm_cpu":
if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
data.dtype == 'uint8' and kernel.dtype == 'uint8')
CO, IC, KH, KW = get_const_tuple(kernel.shape)
Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def traverse(op):
return s


@autotvm.register_topi_compute("compute_conv2d_NHWC_quantized.arm_cpu")
@autotvm.register_topi_compute("conv2d_NHWC_quantized.arm_cpu")
def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation, out_dtype):
N, IH, IW, IC = get_const_tuple(data.shape)
KH, KW, _, OC = get_const_tuple(kernel.shape)
Expand All @@ -122,7 +122,7 @@ def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation,
dilation, out_dtype, (KH, KW), OC)


@autotvm.register_topi_compute("compute_conv2d_NHWC_quantized_without_transform.arm_cpu")
@autotvm.register_topi_compute("conv2d_NHWC_quantized_without_transform.arm_cpu")
def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, padding,
dilation, out_dtype, kernel_size=None,
output_channels=None):
Expand Down

0 comments on commit b6dc7c5

Please sign in to comment.