Skip to content

Commit

Permalink
[AlterOpLayout][x86] NHWC to NCHWc conv support.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Oct 8, 2019
1 parent 76c2392 commit 4f15316
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions topi/python/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return None
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)

if data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)

if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
return None
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)

if F.__name__ == 'nnvm.symbol':
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)


@conv2d_legalize.register("cpu")
Expand Down

0 comments on commit 4f15316

Please sign in to comment.