diff --git a/python/tvm/autotvm/tophub.py b/python/tvm/autotvm/tophub.py index 92d1be2c8f20e..aa49beaecf53a 100644 --- a/python/tvm/autotvm/tophub.py +++ b/python/tvm/autotvm/tophub.py @@ -36,15 +36,16 @@ # the version of each package PACKAGE_VERSION = { - 'arm_cpu': "v0.04", - 'llvm': "v0.03", + 'arm_cpu': "v0.04", + 'llvm': "v0.03", - 'cuda': "v0.05", - 'rocm': "v0.03", - 'opencl': "v0.03", - 'mali': "v0.05", + 'cuda': "v0.05", + 'rocm': "v0.03", + 'opencl': "v0.03", + 'mali': "v0.05", + 'intel_graphics': "v0.01", - 'vta': "v0.06", + 'vta': "v0.06", } logger = logging.getLogger('autotvm') diff --git a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py index b7d21912e44ea..ed8b9cd9ed977 100644 --- a/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py +++ b/tests/python/frontend/nnvm_to_relay/test_alter_conv2d.py @@ -69,6 +69,7 @@ def convnet(): targets=['cuda', 'opencl -device=mali', 'opencl -device=intel_graphics', + 'llvm -device=arm_cpu', 'llvm -device=core-avx-ii'] @@ -83,5 +84,6 @@ def convnet(): assert not relay.analysis.alpha_equal(N, O) if __name__ == "__main__": + import numpy as np np.random.seed(42) test_alter_layout_conv2d() diff --git a/topi/python/topi/intel_graphics/conv2d.py b/topi/python/topi/intel_graphics/conv2d.py index 8fcde2513705e..63cae4c7da8db 100644 --- a/topi/python/topi/intel_graphics/conv2d.py +++ b/topi/python/topi/intel_graphics/conv2d.py @@ -21,15 +21,116 @@ import tvm +from tvm import autotvm +from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity +from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, conv2d_infer_layout +from ..nn.util import get_pad_tuple +from ..nn.depthwise_conv2d import depthwise_conv2d_nchw +from ..nn import pad +from .. import tag from .. import generic from .. import util -from .. import tag -from ..nn import pad -from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_alter_layout, _get_workload -from ..nn.util import get_pad_tuple -from ..util import simplify +from ..util import simplify, get_const_tuple +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): + if is_depthwise: + raise RuntimeError("Depthwise not supported for intel graphics.") + else: + batch_size, in_channel, height, width = get_const_tuple(data.shape) + out_channel, _, hkernel, _ = get_const_tuple(kernel.shape) + HSTR, _ = strides + + ic_bn = 1 + oc_bn, oc_bn_upper = 16, 16 + for i in range(oc_bn_upper, 0, -1): + if out_channel % i == 0: + oc_bn = i + break + + if HSTR == 2: + if out_channel + hkernel == 515: + block_oh = 4 + block_ow = 4 + else: + block_oh = 4 + block_ow = 5 + elif hkernel == 3: + if out_channel == 512: + block_oh = 2 + block_ow = 7 + else: + block_oh = 2 + block_ow = 14 + else: + block_oh = 1 + block_ow = 16 + cfg["tile_ic"] = SplitEntity([in_channel // ic_bn, ic_bn]) + cfg["tile_oc"] = SplitEntity([out_channel // oc_bn, oc_bn]) + cfg["block_oh"] = OtherOptionEntity(block_oh) + cfg["block_ow"] = OtherOptionEntity(block_ow) + + +def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, layout): + """Create schedule configuration from input arguments""" + dshape = get_const_tuple(data.shape) + kshape = get_const_tuple(kernel.shape) + if layout == 'NCHW': + n, ic, h, w = dshape + oc, _, kh, kw = kshape + else: + raise ValueError("Not support this layout {} with " + "schedule template.".format(layout)) + ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) + sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) + oh = (h - kh + 2 * ph) // sh + 1 + ow = (w - kw + 2 * pw) // sw + 1 + ic_bn_upper = 32 + oc_bn_upper = 64 + oc_bn_lower = min(oc, 8) + ic_bn_candidates, oc_bn_candidates = [], [] + for i in range(1, ic + 1): + if ic % i == 0 and i <= ic_bn_upper: + ic_bn_candidates.append(i) + if not ic_bn_candidates: + ic_bn_candidates.append(1) + ic_bn_candidates.append(ic) + + for i in range(1, oc + 1): + if oc % i == 0 and oc_bn_lower <= i <= oc_bn_upper: + oc_bn_candidates.append(i) + if not oc_bn_candidates: + oc_bn_candidates.append(1) + oc_bn_candidates.append(oc) + + blk_candidates_low_limits = 5 + blk_oh_list, blk_ow_list = [], [] + for i, j in zip(range(oh, 0, -1), range(ow, 0, -1)): + if i <= 16 and oh % i == 0: + blk_oh_list.append(i) + if j <= 16 and ow % j == 0: + blk_ow_list.append(j) + + if len(blk_oh_list) < blk_candidates_low_limits: + for i in range(2, oh): + if i not in blk_oh_list: + blk_oh_list.append(i) + if len(blk_oh_list) >= 5: + break + + if len(blk_ow_list) < blk_candidates_low_limits: + for i in range(min(ow - 1, 16), 1, -1): + if i not in blk_ow_list: + blk_ow_list.append(i) + if len(blk_ow_list) >= 5: + break + + # Create schedule config + cfg.define_knob("tile_ic", ic_bn_candidates) + cfg.define_knob("tile_oc", oc_bn_candidates) + cfg.define_knob("block_oh", blk_oh_list) + cfg.define_knob("block_ow", blk_ow_list) + ##### SCHEDULE UTILITIES ##### def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None): @@ -53,40 +154,82 @@ def tile_and_bind3d(s, tensor, z, y, x, z_factor=2, y_factor=None, x_factor=None return xi, thread_z, thread_y, thread_x @conv2d_alter_layout.register(["intel_graphics"]) -def _alter_conv2d_layout(attrs, inputs, tinfos, F): +def _alter_conv2d_layout(attrs, inputs, tinfo, F): + import nnvm.symbol as sym copy_inputs = [s for s in inputs] - - data = tinfos[0] - kernel = tinfos[1] - - import ast - padding = ast.literal_eval(str(attrs['padding'])) - stride = ast.literal_eval(str(attrs['strides'])) - - wkl = _get_workload(data, kernel, stride, padding, data.dtype) - oc_bn = 1 - kernel_shape = util.get_const_tuple(kernel.shape) - for oc_bn in range(16, 1, -1): - if kernel_shape[0] % oc_bn == 0: - break - - new_attrs = {k: attrs[k] for k in attrs.keys()} - new_attrs["kernel_layout"] = 'OIHW%do' % (oc_bn) + new_attrs = {k : attrs[k] for k in attrs.keys()} if F.__name__ == 'tvm.relay.op': # Derive channels for frontends (e.g ONNX) that miss "channel" field. new_attrs["channels"] = inputs[1].checked_type.shape[attrs['kernel_layout'].index('O')] - if F.__name__ == 'nnvm.symbol': - out = F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) - else: - out = F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) - - return out - -@conv2d_NCHWc.register(["intel_graphics"]) -def _decl_conv2d(data, kernel, stride, padding, dilation, layout, out_layout, out_dtype='float32'): + data, kernel = tinfo[0], tinfo[1] + batch_size, in_channel, height, width = get_const_tuple(data.shape) + + groups = attrs.get_int("groups") + out_channel = attrs.get_int("channels") + padding = attrs.get_int_tuple("padding") + strides = attrs.get_int_tuple("strides") + dilation = attrs.get_int_tuple("dilation") + out_dtype = attrs["out_dtype"] + + layout_name = 'layout' if F == sym else 'data_layout' + layout = attrs[layout_name] + kh, kw = attrs.get_int_tuple("kernel_size") + + dtype = data.dtype + out_dtype = dtype if out_dtype in ("same", "") else out_dtype + is_depthwise = groups == in_channel and groups == out_channel + + # only optimize for NCHW + if layout != 'NCHW': + return None + if groups != 1 and not is_depthwise: + return None + + dispatch_ctx = autotvm.task.DispatchContext.current + target = tvm.target.current_target() + + # query schedule and fallback if necessary + workload = autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \ + if is_depthwise else \ + autotvm.task.args_to_workload( + [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d) + if is_depthwise: + return None + cfg = dispatch_ctx.query(target, workload) + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise) + + ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + + new_attrs[layout_name] = 'NCHW%dc' % ic_bn + new_attrs['out_layout'] = 'NCHW%dc' % oc_bn + + new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn), + dtype=data.dtype) + + out_channel, _, kh, kw = get_const_tuple(kernel.shape) + # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) + new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) + + # Store altered operator's config + new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), + dtype=kernel.dtype) + new_workload = autotvm.task.args_to_workload( + [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], + new_attrs['out_layout'], out_dtype], conv2d_NCHWc) + + dispatch_ctx.update(target, new_workload, cfg) + if F == sym: + return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) + return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs) + +@autotvm.register_topi_compute(conv2d_NCHWc, 'intel_graphics', 'direct') +def _decl_conv2d(cfg, data, kernel, strides, padding, dilation, + layout, out_layout, out_dtype='float32'): """Conv2D operator for Intel Graphics backend. Parameters @@ -111,21 +254,39 @@ def _decl_conv2d(data, kernel, stride, padding, dilation, layout, out_layout, ou output : tvm.Tensor 4-D with shape [batch, out_channel, out_height, out_width] """ - assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu" - assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." - - out_dtype = data.dtype - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) - kernel_shape = util.get_const_tuple(kernel.shape) - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - return _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype) - -@generic.schedule_conv2d_NCHWc.register(["intel_graphics"]) -def schedule_conv2d_NCHWc(outs): + dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + assert (dh, dw) == (1, 1), "Does not support dilation" + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + in_channel = ic_chunk * ic_bn + num_filter = oc_chunk * oc_bn + if cfg.is_fallback: + _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), + tvm.placeholder((num_filter, in_channel, kernel_height, kernel_width), + dtype=kernel.dtype), + strides, padding, out_dtype) + + return _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype) + + +@conv2d_infer_layout.register("intel_graphics") +def _conv2d_infer_layout(workload, cfg): + _, data, kernel, strides, padding, dilation, layout, dtype = workload + batch_size, in_channel, in_height, in_width = data[:-1] + out_channel, _, k_height, k_width = kernel[:-1] + out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 + out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 + tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + in_shape = (batch_size, in_channel // tile_ic, in_height, in_width, tile_ic) + in_layout = "NCHW%dc" % tile_ic + out_shape = (batch_size, out_channel // tile_oc, out_height, out_width, tile_oc) + out_layout = "NCHW%dc" % tile_oc + return ((in_shape, in_layout),), ((out_shape, out_layout),) + + +@autotvm.register_topi_schedule(generic.schedule_conv2d_NCHWc, 'intel_graphics', ['direct']) +def _schedule_conv2d_NCHWc(cfg, outs): """Schedule for conv2d_nchw for Intel Graphics Parameters @@ -145,14 +306,14 @@ def schedule_conv2d_NCHWc(outs): def traverse(op): """inline all one-to-one-mapping operators except the last stage (output)""" - if tag.is_broadcast(op.tag): + if tag.is_injective(op.tag): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: - if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) - if 'conv2d' in op.tag: - _schedule_cl_spatialpack_NCHWc(s, op) + if "conv" in op.tag: + _schedule_cl_spatialpack_NCHWc(cfg, s, op) scheduled_ops.append(op) @@ -160,97 +321,101 @@ def traverse(op): return s -def _decl_cl_spatialpack_NCHWc(data, kernel, stride, padding, out_dtype='float16'): - batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape] - num_filter, channel, kernel_h, kernel_w, nv = [util.get_const_int(x) for x in kernel.shape] - num_filter = num_filter * nv +def _decl_cl_spatialpack_NCHWc(cfg, data, kernel, strides, padding, dilation, out_dtype='float16'): + batch, in_channel, in_height, in_width, vc = [util.get_const_int(x) for x in data.shape] + in_channel *= vc + num_filter, channel, kernel_h, kernel_w, ci, co = [util.get_const_int(x) for x in kernel.shape] + num_filter *= co pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) - if isinstance(stride, (tuple, list)): - stride_h, stride_w = stride + ic_bn = vc + assert vc == ci + + if isinstance(strides, (tuple, list)): + stride_h, stride_w = strides else: - stride_h, stride_w = stride, stride + stride_h, stride_w = strides, strides out_channel = num_filter out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1) out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1) - oshape = (batch, out_channel, out_height, out_width) + oshape = (batch, out_channel // co, out_height, out_width, co) rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') rx = tvm.reduce_axis((0, kernel_w), name='rx') - block_w = 1 - block_h = 1 - if stride_h == 2: - if num_filter + kernel_h == 515: - block_h = 4 - block_w = 4 - else: - block_h = 4 - block_w = 5 - elif kernel_h == 3: - if num_filter == 512: - block_h = 2 - block_w = 7 - else: - block_h = 2 - block_w = 14 - elif kernel_h == 7 and padding == 3 and stride == 1: - block_h = 3 - block_w = 4 - else: - block_h = 1 - block_w = 16 + block_h = cfg["block_oh"].val + block_w = cfg["block_ow"].val - attrs = {'block_h': block_h, 'block_w' : block_w} c_h = out_height c_w = out_width - if not out_height % block_h == 0: + if out_height % block_h != 0: c_h = (out_height // block_h + 1) * block_h - if not out_width % block_w == 0: + if out_width % block_w != 0: c_w = (out_width // block_w + 1) * block_w - pad_before = [0, 0, pad_top, pad_left] - pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w] - temp = pad(data, pad_before, pad_after, name="pad_temp") + cshape = (batch, out_channel // co, c_h, c_w, co) - cshape = (batch, out_channel // nv, c_h, c_w, nv) + pad_before = [0, 0, pad_top, pad_left, 0] + pad_after = [0, 0, pad_down + c_h - out_height, pad_right + \ + c_w - out_width, 0] + DOPAD = (pad_top != 0 or pad_left != 0 or pad_down + c_h - out_height != 0 \ + or pad_right + c_w - out_width != 0) + DOUNPACK = (c_h - out_height != 0 or c_w - out_width != 0) + if DOPAD: + temp = pad(data, pad_before, pad_after, name="pad_temp") + else: + temp = data conv = tvm.compute( cshape, - lambda nn, ff, yy, xx, vc:\ - tvm.sum( - temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * - kernel[ff, rc, ry, rx, vc].astype(out_dtype), - axis=[rc, ry, rx]), name='conv', attrs=attrs) + lambda nn, ff, yy, xx, ff_v: \ + tvm.sum( + temp[nn, rc//ic_bn, yy * stride_h + ry, xx * stride_w + rx, rc%ic_bn]. \ + astype(out_dtype) * + kernel[ff, rc//ic_bn, ry, rx, rc%ic_bn, ff_v].astype(out_dtype), + axis=[rc, ry, rx]), tag="conv", name='conv') + + if DOUNPACK: + output = tvm.compute( + oshape, + lambda nn, ff, yy, xx, ff_v: + conv[nn][ff][yy][xx][ff_v], + name='output_unpack', tag="conv_unpack") + else: + output = conv - output = tvm.compute( - oshape, - lambda nn, ff, yy, xx: - conv[nn][ff//nv][yy][xx][ff%nv], - name='output_unpack', tag='conv2d') return output -def _schedule_cl_spatialpack_NCHWc(s, op): - output = op.output(0) - _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape] +def _schedule_cl_spatialpack_NCHWc(cfg, s, op): + output = op.output(0) conv = op.input_tensors[0] - temp = s[conv].op.input_tensors[0] - kernel = s[conv].op.input_tensors[1] - temp_W = s.cache_read(temp, "warp", [conv]) - conv_L = s.cache_write(conv, "local") - + if conv.op.name == "conv": + temp = s[conv].op.input_tensors[0] + kernel = s[conv].op.input_tensors[1] + temp_W = s.cache_read(temp, "warp", [conv]) + conv_L = s.cache_write(conv, "local") + SCHEDULE_OUTPUT = True + else: + temp = op.input_tensors[0] + kernel = op.input_tensors[1] + temp_W = s.cache_read(temp, "warp", [output]) + conv_L = s.cache_write(output, "local") + if output.op in s.outputs: + conv = output + else: + s[output].compute_inline() + conv = s.outputs[0] + SCHEDULE_OUTPUT = False kernel_L = s.cache_read(kernel, "local", [conv_L]) - _, in_channel, temp_h, temp_w = [util.get_const_int(x) for x in temp.shape] - attrs = s[conv].op.attrs - OUTPUT_BLOCK_HEIGHT = attrs['block_h'] - OUTPUT_BLOCK_WIDTH = attrs['block_w'] + OUTPUT_BLOCK_HEIGHT = cfg["block_oh"].val + OUTPUT_BLOCK_WIDTH = cfg["block_ow"].val # schedule conv z_factor = 1 @@ -286,12 +451,13 @@ def _schedule_cl_spatialpack_NCHWc(s, op): s[conv_L].unroll(rx) # schedule temp - _, ci, h, w = s[temp].op.axis - tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) + if temp.op.name == "pad_temp": + _, ci, h, w, vci = s[temp].op.axis + tile_and_bind3d(s, temp, ci, h, w, 1, 16, 16) # schedule temp_W - _, ci, h, w = s[temp_W].op.axis - zo, zi = s[temp_W].split(ci, 1) + _, ci, h, w, vci = s[temp_W].op.axis + zo, zi = s[temp_W].split(vci, 1) yo, yi = s[temp_W].split(h, 1) xo, xi = s[temp_W].split(w, 16) s[temp_W].reorder(zo, yo, xo, zi, yi, xi) @@ -300,27 +466,37 @@ def _schedule_cl_spatialpack_NCHWc(s, op): s[temp_W].bind(xi, thread_x) s[temp_W].storage_align(s[temp_W].op.axis[2], 16, 0) - #schedule kernel - # schedule kernel_L - if "2_14" in s[conv].op.tag: + if OUTPUT_BLOCK_HEIGHT == 2 and OUTPUT_BLOCK_WIDTH == 14: s[kernel_L].compute_at(s[conv_L], ry) else: s[kernel_L].compute_at(s[conv_L], rx) # schedule output - if output.op in s.outputs: - out = output - else: - s[output].compute_inline() - out = s.outputs[0] - - _, co, h, w = s[out].op.axis - tile_and_bind3d(s, out, w, h, co, 4, 8, 8) - - -@conv2d.register(["intel_graphics"]) -def decl_conv2d(data, kernel, stride, padding, dilation, layout='NCHW', out_dtype='float32'): + if SCHEDULE_OUTPUT: + if output.op in s.outputs: + out = output + else: + s[output].compute_inline() + out = s.outputs[0] + + _, co, h, w, vc = s[out].op.axis + tile_and_bind3d(s, out, w, h, vc, 4, 8, 8) + + +def conv_arg_to_workload(data, kernel, strides, padding, layout, out_dtype): + """convert argument to workload""" + if len(kernel.shape) == 4: + raw_kernel = kernel + else: # the input kernel is transformed by alter_op_layout + shape = get_const_tuple(kernel.shape) + raw_kernel = tvm.placeholder((shape[0] * shape[4], shape[1], shape[2], shape[3]), + dtype=kernel.dtype) + return ('conv2d', ) + autotvm.task.args_to_workload( + [data, raw_kernel, strides, padding, layout, out_dtype]) + +@autotvm.register_topi_compute(conv2d, 'intel_graphics', 'direct') +def decl_conv2d(cfg, data, kernel, stride, padding, dilation, layout='NCHW', out_dtype='float32'): """Conv2D operator for Intel Graphics backend. Parameters @@ -344,18 +520,10 @@ def decl_conv2d(data, kernel, stride, padding, dilation, layout='NCHW', out_dtyp assert data.shape[0].value == 1, "only support batch size=1 convolution on intel gpu" assert data.dtype == kernel.dtype, "Do not support inputs with different data types now." - out_dtype = data.dtype - HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) - kernel_shape = util.get_const_tuple(kernel.shape) - if isinstance(stride, (tuple, list)): - HSTR, WSTR = stride - else: - HSTR, WSTR = stride, stride - - return _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype) + return _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype) -@generic.schedule_conv2d_nchw.register(["intel_graphics"]) -def schedule_conv2d_nchw(outs): +@autotvm.task.register_topi_schedule(generic.schedule_conv2d_nchw, 'intel_graphics', ['direct']) +def schedule_conv2d_nchw(cfg, outs): """Schedule for conv2d_nchw for Intel Graphics Parameters @@ -378,17 +546,17 @@ def traverse(op): if op not in s.outputs: s[op].compute_inline() for tensor in op.input_tensors: - if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) if 'conv2d' in op.tag: - _schedule_cl_spatialpack(s, op) + _schedule_cl_spatialpack(cfg, s, op) scheduled_ops.append(op) traverse(outs[0].op) return s -def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float16'): +def _decl_cl_spatialpack(cfg, data, kernel, stride, padding, layout, out_dtype='float16'): batch, in_channel, in_height, in_width = [util.get_const_int(x) for x in data.shape] num_filter, channel, kernel_h, kernel_w = [util.get_const_int(x) for x in kernel.shape] pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, kernel) @@ -429,23 +597,22 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float else: block_h = 1 block_w = 16 - attrs = {'block_h': block_h, 'block_w' : block_w} c_h = out_height c_w = out_width - if not out_width % block_w == 0: - c_w = (out_width // block_w + 1) * block_w - - if not out_height % block_h == 0: + if out_height % block_h != 0: c_h = (out_height // block_h + 1) * block_h + if out_width % block_w != 0: + c_w = (out_width // block_w + 1) * block_w + pad_before = [0, 0, pad_top, pad_left] pad_after = [0, 0, pad_down + c_h - block_h, pad_right + c_w - block_w] temp = pad(data, pad_before, pad_after, name="pad_temp") nv = 16 - if not num_filter % nv == 0: + if num_filter % nv != 0: num_filter = (num_filter // nv + 1) * nv out_channel = num_filter @@ -459,21 +626,23 @@ def _decl_cl_spatialpack(data, kernel, stride, padding, layout, out_dtype='float conv = tvm.compute( cshape, - lambda nn, ff, yy, xx, vc:\ - tvm.sum( - temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * - kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype), - axis=[rc, ry, rx]), name='conv', attrs=attrs) + lambda nn, ff, yy, xx, vc: \ + tvm.sum( + temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx].astype(out_dtype) * + kernel_vec[ff, rc, ry, rx, vc].astype(out_dtype), + axis=[rc, ry, rx]), name='conv', attrs=attrs) output = tvm.compute( oshape, lambda nn, ff, yy, xx: conv[nn][ff//nv][yy][xx][ff%nv], - name='output_unpack', tag='conv2d') + name='output_unpack', tag='conv2d', + attrs={'workload': conv_arg_to_workload(data, kernel, stride, padding, + layout, out_dtype)}) return output -def _schedule_cl_spatialpack(s, op): +def _schedule_cl_spatialpack(cfg, s, op): output = op.output(0) _, _, out_height, out_width = [util.get_const_int(x) for x in output.shape] diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py new file mode 100644 index 0000000000000..424cb3c5f925c --- /dev/null +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -0,0 +1,342 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Schedule for depthwise_conv2d with auto fusion""" +import tvm +from tvm import autotvm +from ..util import traverse_inline +from .. import tag +from .. import generic, nn +from ..nn.depthwise_conv2d import depthwise_conv2d_infer_layout + +# register original implementation of depthwise_conv2d_nchw since we don't need to change this part +autotvm.register_topi_compute(nn.depthwise_conv2d_nchw, ['intel_graphics'], 'direct', + nn.depthwise_conv2d_nchw.fdefault) + +@autotvm.register_topi_schedule(generic.schedule_depthwise_conv2d_nchw, \ + ['intel_graphics'], 'direct') +def schedule_depthwise_conv2d_nchw_intel_graphics(cfg, outs): + """Schedule for depthwise_conv2d nchw forward. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nchw. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'depthwise_conv2d_nchw': + pad_data = op.input_tensors[0] + kernel = op.input_tensors[1] + conv = op.output(0) + + ##### space definition begin ##### + n, f, y, x = s[conv].op.axis + cfg.define_split("tile_f", f, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_knob("auto_unroll_max_step", [0, 256, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + # fallback support + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'depthwise_conv2d_nchw', 'direct') + cfg.fallback_with_reference_log(ref_log) + cfg['unroll_explicit'].val = 0 + ##### space definition end ##### + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + AA = s.cache_read(pad_data, 'shared', [OL]) + WW = s.cache_read(kernel, 'shared', [OL]) + AL = s.cache_read(AA, 'local', [OL]) + WL = s.cache_read(WW, 'local', [OL]) + + # tile and bind spatial axes + n, f, y, x = s[output].op.axis + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + kernel_scope, n = s[output].split(n, nparts=1) + bf = s[output].fuse(n, bf) + s[output].bind(bf, tvm.thread_axis("blockIdx.z")) + s[output].bind(by, tvm.thread_axis("blockIdx.y")) + s[output].bind(bx, tvm.thread_axis("blockIdx.x")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vy, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + s[output].bind(tf, tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi) + s[OL].compute_at(s[output], tx) + + # cooperative fetching + s[AA].compute_at(s[output], bx) + s[WW].compute_at(s[output], bx) + s[AL].compute_at(s[output], tx) + s[WL].compute_at(s[output], tx) + + for load in [AA, WW]: + fused = s[load].fuse(*list(s[load].op.axis)) + fused, tx = s[load].split(fused, cfg["tile_x"].size[2]) + fused, ty = s[load].split(fused, cfg["tile_y"].size[2]) + fused, tz = s[load].split(fused, cfg["tile_f"].size[2]) + s[load].bind(tz, tvm.thread_axis("threadIdx.z")) + s[load].bind(ty, tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + traverse_inline(s, outs[0].op, _callback) + return s + +@generic.schedule_depthwise_conv2d_nhwc.register(["intel_graphics"]) +def schedule_depthwise_conv2d_nhwc(outs): + """Schedule for depthwise_conv2d nhwc forward. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d nhwc. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _schedule(temp, Filter, DepthwiseConv2d): + s[temp].compute_inline() + FS = s.cache_read(Filter, "shared", [DepthwiseConv2d]) + if DepthwiseConv2d.op in s.outputs: + Output = DepthwiseConv2d + CL = s.cache_write(DepthwiseConv2d, "local") + else: + Output = outs[0].op.output(0) + s[DepthwiseConv2d].set_scope("local") + + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + + b, h, w, c = s[Output].op.axis + + # num_thread here could be 728, it is larger than cuda.max_num_threads + num_thread = tvm.ir_pass.Simplify(temp.shape[3]).value + target = tvm.target.current_target() + if target and (target.target_name not in ["cuda", "nvptx"]): + num_thread = target.max_num_threads + xoc, xic = s[Output].split(c, factor=num_thread) + s[Output].reorder(xoc, b, h, w, xic) + xo, yo, _, _ = s[Output].tile(h, w, x_factor=2, y_factor=2) + fused = s[Output].fuse(yo, xo) + fused = s[Output].fuse(fused, b) + fused = s[Output].fuse(fused, xoc) + + s[Output].bind(fused, block_x) + s[Output].bind(xic, thread_x) + + if DepthwiseConv2d.op in s.outputs: + s[CL].compute_at(s[Output], xic) + else: + s[DepthwiseConv2d].compute_at(s[Output], xic) + + _, _, ci, fi = s[FS].op.axis + s[FS].compute_at(s[Output], fused) + fused = s[FS].fuse(fi, ci) + s[FS].bind(fused, thread_x) + + scheduled_ops = [] + + def traverse(OP): + """Internal travserse function""" + # inline all one-to-one-mapping operators except the last stage (output) + if tag.is_broadcast(OP.tag): + if OP not in s.outputs: + s[OP].compute_inline() + for tensor in OP.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + # schedule depthwise_conv2d + if OP.tag == 'depthwise_conv2d_nhwc': + PaddedInput = OP.input_tensors[0] + Filter = OP.input_tensors[1] + if isinstance(Filter.op, tvm.tensor.ComputeOp) and 'dilate' in Filter.op.tag: + s[Filter].compute_inline() + DepthwiseConv2d = OP.output(0) + _schedule(PaddedInput, Filter, DepthwiseConv2d) + + scheduled_ops.append(OP) + + traverse(outs[0].op) + return s + + +def schedule_depthwise_conv2d_backward_input_nhwc(outs): + """Schedule for depthwise_conv2d nhwc backward wrt input. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + backward wrt input in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d backward + wrt input with layout nhwc. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _schedule(Padded_out_grad, In_grad): + s[Padded_out_grad].compute_inline() + + block_x = tvm.thread_axis("blockIdx.x") + thread_x = tvm.thread_axis("threadIdx.x") + _, h, w, c = In_grad.op.axis + + fused_hwc = s[In_grad].fuse(h, w, c) + xoc, xic = s[In_grad].split(fused_hwc, factor=128) + + s[In_grad].bind(xoc, block_x) + s[In_grad].bind(xic, thread_x) + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if OP.tag == 'depthwise_conv2d_backward_input_nhwc': + Padded_out_grad = OP.input_tensors[0] + Dilated_out_grad = Padded_out_grad.op.input_tensors[0] + s[Dilated_out_grad].compute_inline() + In_grad = OP.output(0) + _schedule(Padded_out_grad, In_grad) + else: + raise ValueError("Depthwise conv backward wrt input for non-NHWC is not supported.") + + traverse(outs[0].op) + return s + +def schedule_depthwise_conv2d_backward_weight_nhwc(outs): + """Schedule for depthwise_conv2d nhwc backward wrt weight. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of depthwise_conv2d + backward wrt weight in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for depthwise_conv2d backward + wrt weight with layout nhwc. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _schedule(Weight_grad): + block_x = tvm.thread_axis("blockIdx.x") + thread_y = tvm.thread_axis("threadIdx.y") + thread_x = tvm.thread_axis("threadIdx.x") + + db, dh, dw = Weight_grad.op.reduce_axis + + fused_dbdhdw = s[Weight_grad].fuse(db, dh, dw) + _, ki = s[Weight_grad].split(fused_dbdhdw, factor=8) + BF = s.rfactor(Weight_grad, ki) + + fused_fwcm = s[Weight_grad].fuse(*s[Weight_grad].op.axis) + + xo, xi = s[Weight_grad].split(fused_fwcm, factor=32) + + s[Weight_grad].bind(xi, thread_x) + s[Weight_grad].bind(xo, block_x) + + s[Weight_grad].bind(s[Weight_grad].op.reduce_axis[0], thread_y) + s[BF].compute_at(s[Weight_grad], s[Weight_grad].op.reduce_axis[0]) + + def traverse(OP): + # inline all one-to-one-mapping operators except the last stage (output) + if OP.tag == 'depthwise_conv2d_backward_weight_nhwc': + Padded_in = OP.input_tensors[1] + s[Padded_in].compute_inline() + Weight_grad = OP.output(0) + _schedule(Weight_grad) + else: + raise ValueError("Depthwise conv backward wrt weight for non-NHWC is not supported.") + + traverse(outs[0].op) + return s + +@depthwise_conv2d_infer_layout.register("intel_graphics") +def _depthwise_conv2d_infer_layout(workload, _): + """Infer input/output shapes and layouts from a workload and cfg. + + Parameters + ---------- + workload : tuple + conv2d workload + + cfg : tuple + tvm.autotvm config + + Returns + ------- + Output : [tuple of tuple and str, tuple of tuple and str] + Input shapes and layouts, and output shapes and layouts + """ + _, data, kernel, strides, padding, _, _ = workload + batch_size, in_channel, in_height, in_width = data[:-1] + filter_channel, channel_multiplier, k_height, k_width = kernel[:-1] + out_channel = filter_channel * channel_multiplier + out_height = (in_height + 2 * padding[0] - k_height) // strides[0] + 1 + out_width = (in_width + 2 * padding[1] - k_width) // strides[1] + 1 + in_shape = (batch_size, in_channel, in_height, in_width) + out_shape = (batch_size, out_channel, out_height, out_width) + in_layout = out_layout = "NCHW" + return ((in_shape, in_layout),), ((out_shape, out_layout),)