Skip to content

Commit

Permalink
Add conv2d transpose (#9)
Browse files Browse the repository at this point in the history
* pass mobilenet

* rename

* failed to tensorize

* can pass one

* pass stride = 1 for conv2d_transpose

* add conv2d transpose

* add end2end support for gan

* fix
  • Loading branch information
merrymercy authored and tmoreau89 committed Dec 1, 2018
1 parent 85e5cec commit daaf172
Show file tree
Hide file tree
Showing 13 changed files with 826 additions and 10 deletions.
1 change: 1 addition & 0 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def lower(sch,
stmt = ir_pass.LowerStorageAccessInfo(stmt)
stmt = ir_pass.RemoveNoOp(stmt)
stmt = ir_pass.RewriteUnsafeSelect(stmt)
stmt = ir_pass.CanonicalSimplify(stmt)
for f in lower_phase3:
stmt = f(stmt)
# Instrument BoundCheckers
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ def find_all(op):
for out in outputs:
find_all(out)

return lower(s, inputs, simple_mode=True)
return lower(s, inputs + [x.output(0) for x in outputs], simple_mode=True)
3 changes: 2 additions & 1 deletion vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def add_debug(stmt):
debug_flag)

return tvm.make.stmt_seq(debug, stmt)
pass_list = [(1, ptr_alias.lower_ptr_alias),
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ptr_alias.lower_ptr_alias),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
Expand Down
129 changes: 129 additions & 0 deletions vta/python/vta/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,135 @@ def _do_fold(stmt):
stmt_in, _do_fold, None, ["AttrStmt"])


def show_dir(x):
print(type(x), x)
for key in dir(x):
print(key, getattr(x, key))


def _get_gemm_intrin_buffer():
env = get_env()
wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
assert inp_lanes == env.BATCH * env.BLOCK_IN
inp_shape = (env.BATCH, env.BLOCK_IN)
assert inp_shape[0] * inp_shape[1] == inp_lanes
out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
assert out_lanes == env.BATCH * env.BLOCK_OUT
out_shape = (env.BATCH, env.BLOCK_OUT)
assert out_shape[0] * out_shape[1] == out_lanes
wgt = tvm.placeholder((wgt_shape[0], wgt_shape[1]),
dtype="int%d" % env.WGT_WIDTH,
name=env.wgt_scope)
inp = tvm.placeholder((inp_shape[0], inp_shape[1]),
dtype="int%d" % env.INP_WIDTH,
name=env.inp_scope)
k = tvm.reduce_axis((0, wgt_shape[1]), name="k")
out_dtype = "int%d" % env.ACC_WIDTH
out = tvm.compute((out_shape[0], out_shape[1]),
lambda i, j: tvm.sum(inp[i, k].astype(out_dtype) *
wgt[j, k].astype(out_dtype),
axis=[k]),
name="out")
wgt_layout = tvm.decl_buffer(
wgt.shape, wgt.dtype, env.wgt_scope,
scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
inp_layout = tvm.decl_buffer(
inp.shape, inp.dtype, env.inp_scope,
scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
out_layout = tvm.decl_buffer(
out.shape, out.dtype, env.acc_scope,
scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)

return wgt_layout, inp_layout, out_layout


def inject_conv2d_transpose_skip(stmt_in):
env = get_env()
dwgt, dinp, dout = _get_gemm_intrin_buffer()

calls = []
selects = []

def _find_basics(op):
if isinstance(op, tvm.expr.Call):
calls.append(op)
elif isinstance(op, tvm.expr.Select):
selects.append(op)

def _do_fold(op):
if _match_pragma(op, "conv2d_transpose_gemm"):
is_init = ".init" in str(op)
tvm.ir_pass.PostOrderVisit(op, _find_basics)

if is_init:
# create inner most block
irb = tvm.ir_builder.create()
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.call_extern("int32", "VTAUopPush",
0, 1,
dout.access_ptr("rw", "int32"),
0, 0,
0, 0, 0))
inner = irb.get()
args = op.body.body.args
res_tensor = op.body.body.func.output(0)
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16)
inner = tvm.make.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
else:
conv_call, data_call, kernel_call = calls[-3:]
pad_data_tensor, kernel_tensor, res_tensor = (data_call.func.output(0),
kernel_call.func.output(0), conv_call.func.output(0))

if selects:
condition = selects[0].condition
else:
condition = tvm.const(1, 'int')

# create inner most block
irb = tvm.ir_builder.create()
with irb.if_scope(condition):
dev = env.dev
irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
irb.emit(tvm.call_extern("int32", "VTAUopPush",
0, 0,
dout.access_ptr("rw", "int32"),
dinp.access_ptr("r", "int32"),
dwgt.access_ptr("r", "int32"),
0, 0, 0))
inner = irb.get()

args = conv_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16)
inner = tvm.make.AttrStmt(
[dout, res_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = kernel_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 16, 0, 16)
inner = tvm.make.AttrStmt(
[dwgt, kernel_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
args = data_call.args
tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, 16)
inner = tvm.make.AttrStmt(
[dinp, pad_data_tensor], 'buffer_bind_scope',
tvm.call_intrin('handle', 'tvm_tuple', *tpl), inner)
return inner
return None
ret = tvm.ir_pass.IRTransform(
stmt_in, _do_fold, None, ["AttrStmt"])
return ret


def inject_coproc_sync(stmt_in):
"""Pass to inject skip copy stmt, used in debug.
Expand Down
2 changes: 2 additions & 0 deletions vta/python/vta/top/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
from . import arm_conv2d

from .bitpack import bitpack
from .vta_dense import packed_dense, schedule_packed_dense
from .vta_conv2d import packed_conv2d, schedule_packed_conv2d
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d
from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose
1 change: 0 additions & 1 deletion vta/python/vta/top/arm_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic


@conv2d.register(["vtacpu", "vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
Expand Down
39 changes: 38 additions & 1 deletion vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..environment import get_env
from ..ptr_alias import reinterpret
from .vta_group_conv2d import packed_group_conv2d, schedule_packed_group_conv2d

from .vta_conv2d_transpose import packed_conv2d_transpose, schedule_packed_conv2d_transpose

Workload = namedtuple("Conv2DWorkload",
['batch', 'height', 'width', 'in_filter', 'out_filter',
Expand Down Expand Up @@ -156,6 +156,7 @@ def _get_data_movement_byte(schedule, layer):
return [fil_sched[xfer_size.index(min(xfer_size))]]
return fil_sched


def packed_conv2d(data,
kernel,
padding,
Expand Down Expand Up @@ -309,6 +310,42 @@ def schedule_conv2d(attrs, outs, target):
return _nn.schedule_conv2d(attrs, outs, target)


@reg.register_compute("conv2d_transpose", level=15)
def compute_conv2d_transpose(attrs, inputs, out):
""" 2D convolution algorithm.
"""
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
layout = attrs["layout"]
out_dtype = attrs['out_dtype']

print(inputs)

assert dilation == (1, 1), "not support dilate now"
if is_packed_layout(layout):
return packed_conv2d_transpose(inputs[0], inputs[1],
padding, strides,
out_dtype=out_dtype)
return _nn.compute_conv2d_transpose(attrs, inputs, out)


@reg.register_schedule("conv2d_transpose", level=15)
def schedule_conv2d_transpose(attrs, outs, target):
""" 2D convolution schedule.
"""
layout = attrs["layout"]

if is_packed_layout(layout):
target = tvm.target.create(target)
if target.device_name == "vta":
return schedule_packed_conv2d_transpose(outs)
elif str(target).startswith("llvm"):
return tvm.create_schedule([x.op for x in outs])
else:
raise RuntimeError("not support target %s" % target)
return _nn.schedule_conv2d_transpose(attrs, outs, target)

def _get_workload(data, pad_data, kernel, output):
""" Get the workload structure.
"""
Expand Down
Loading

0 comments on commit daaf172

Please sign in to comment.