# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from re import I import logging import os import sys import tvm from tvm import te import numpy as np # from tvm import relay from tvm import autotvm from tvm import topi from tvm.contrib import cudnn from tvm.contrib import nvcc from tvm.tir.op import indexdiv from tvm.topi.nn.utils import get_pad_tuple from tvm.topi.utils import get_const_tuple from tvm.topi.nn.pad import pad from tvm.topi import transform from tvm.topi import nn from tvm.topi.testing import conv2d_nhwc_python from tvm.topi.cuda.tensor_intrin import ( intrin_wmma_load_matrix_A, intrin_wmma_load_matrix_W, intrin_wmma_store_matrix, intrin_wmma_gemm ) from tvm.runtime import const add_bias = True def implicit_gemm_conv_tensorecore( Input, Filter, stride, padding, dilation, out_dtype="int32", ): """Convolution operator in NHWC layout. Parameters ---------- Input : tvm.te.Tensor 4-D with shape [batch, in_height, in_width, in_channel] Filter : tvm.te.Tensor 4-D with shape [num_filter, filter_height, filter_width, in_channel] stride : int or a list/tuple of two ints Stride size, or [stride_height, stride_width] padding : int or a list/tuple of 2 or 4 ints padding size, or [pad_height, pad_width] for 2 ints, or [pad_top, pad_left, pad_bottom, pad_right] for 4 ints dilation: int or a list/tuple of two ints dilation size, or [dilation_height, dilation_width] out_dtype: str = "float32", The type of output tensor Returns ------- output : tvm.te.Tensor 4-D with shape [batch, out_height, out_width, out_channel] """ assert isinstance(stride, int) or len(stride) == 2 assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(stride, int): stride_h = stride_w = stride else: stride_h, stride_w = stride if isinstance(dilation, int): dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation # Implicit GEMM from cutlass, refer to: # - https://github.com/NVIDIA/cutlass/blob/master/media/docs/implicit_gemm_convolution.md B, IH, IW, IC = Input.shape OC, KH, KW, IC = Filter.shape # compute the output shape dilated_kernel_h = (KH - 1) * dilation_h + 1 dilated_kernel_w = (KW - 1) * dilation_w + 1 pad_top, pad_left, pad_down, pad_right = get_pad_tuple( padding, (dilated_kernel_h, dilated_kernel_w) ) OH = (IH + pad_top + pad_down - dilated_kernel_h) // stride_h + 1 OW = (IW + pad_left + pad_right - dilated_kernel_w) // stride_w + 1 if OH == IH and OW == IW and KH == 1 and KW == 1: # No Padding, NO IM2COl, NO RESHAPE, Just a Gemm A_shape = (B, OH, OW, IC) B_shape = (OC, IC) M = B * OH * OW N = OC K = KH * KW * IC assert ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) ), "The shape of Im2col (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" GEMM_A_Input = Input GEMM_B_Input = te.compute( (OC, IC), lambda oc, ic: Filter[oc, 0, 0, ic], name="FilterReshape" ) k = te.reduce_axis((0, IC), "gemm_k") C = te.compute( (B, OH, OW, OC), lambda b, oh, ow, oc: te.sum( GEMM_A_Input[b, oh, ow, k].astype(out_dtype) * GEMM_B_Input[oc, k].astype(out_dtype), axis=[k] ), name="implicit_gemm_conv" ) return C # 1. Padding # if pad_top or pad_left: # Pad_Input = nn.pad( # Input, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0], # pad_value=0, # name="PadInput") # else: # Pad_Input = Input pad_before = [0, pad_top, pad_left, 0] pad_after = [0, pad_down, pad_right, 0] filter_pad_before = [0, 0, 0, 0] filter_pad_after = [0, 0, 0, 0] pad_channel = 0 if IC == 24: pad_channel = 8 pad_after = [0, pad_down, pad_right, pad_channel] filter_pad_after = [0, 0, 0, pad_channel] IC = IC + 8 PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") PaddedFilter = pad(Filter, filter_pad_before, filter_pad_after, name="PaddedFilter") # 2. Im2col M = B * OH * OW N = OC K = KH * KW * IC assert ( (M % 8 == 0 and K % 16 == 0 and N % 32 == 0) or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0) or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0) ), "The shape of Im2col (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)" A_shape = (B * OH * OW, KH * KW * IC) B_shape = (OC, KH * KW * IC) Im2Col = te.compute( (B, OH, OW, KH, KW, IC), lambda b, oh, ow, kh, kw, ic: PaddedInput[b, oh * stride_h + kh * dilation_h, ow * stride_w + kw * dilation_w, ic], name="im2col") GEMM_A_Input = transform.reshape(Im2Col, A_shape) GEMM_B_Input = transform.reshape(PaddedFilter, B_shape) # 3. GEMM k = te.reduce_axis((0, KH * KW * IC), 'k') C = te.compute( (B * OH * OW, OC), lambda i, j: te.sum( GEMM_A_Input[i, k].astype(out_dtype) * GEMM_B_Input[j, k].astype(out_dtype), axis=[k] ), name="implicit_gemm_conv", ) RC = te.compute( (B, OH, OW, OC), lambda b, oh, ow, oc: C[b * OH * OW + oh * OW + ow, oc], name="out_reshape", ) return RC @autotvm.template("implicit_gemm_conv_tensorecore") def implicit_gemm_conv_tensorecore_schedule( input_shape, kernel_shape, stride, padding, dilation, in_dtype, out_dtype ): Input = te.placeholder(input_shape, name='input', dtype=in_dtype) Filter = te.placeholder(kernel_shape, name='weight', dtype=in_dtype) Bias = te.placeholder( (1, 1, 1, kernel_shape[0]), dtype=in_dtype, name="bias") RC = implicit_gemm_conv_tensorecore( Input, Filter, stride, padding, dilation, out_dtype) BC = topi.add(RC, Bias) sch = te.create_schedule(BC.op) C = sch[RC].op.input_tensors[0] A, B = sch[C].op.input_tensors Im2Col = sch[A].op.input_tensors[0] PaddX = sch[Im2Col].op.input_tensors[0] PaddF = sch[B].op.input_tensors[0] # if PaddX != Input: sch[PaddF].compute_inline() sch[PaddX].compute_inline() # print(tvm.lower(sch, [Input, Filter, RC], simple_mode=True)) sch[Im2Col].compute_inline() sch[A].compute_inline() sch[B].compute_inline() sch[C].compute_inline() AS = sch.cache_read(A, "shared", [C]) BS = sch.cache_read(B, "shared", [C]) # AF = sch.cache_read(AS, "local", [C]) # BF = sch.cache_read(BS, "local", [C]) # CF = sch.cache_write(C, "local") AF = sch.cache_read(AS, "wmma.matrix_a", [C]) BF = sch.cache_read(BS, "wmma.matrix_b", [C]) CF = sch.cache_write(C, "wmma.accumulator") CS = sch.cache_read(CF, "shared", [C]) # Deal with op fusion, such as bias/relu and slice after padding sch[RC].compute_inline() RC = sch.outputs[0].output(0) cfg = autotvm.get_config() cfg.define_knob("block_row_warps", [2, 1, 2, 4]) cfg.define_knob("block_col_warps", [1, 2, 4]) cfg.define_knob("warp_row_tiles", [4, 1, 2, 4]) cfg.define_knob("warp_col_tiles", [1, 2, 4]) cfg.define_knob("chunk", [1, 2, 4, 8]) cfg.define_knob("offset", [0, 8, 16]) cfg.define_knob("offsetCS", [0, 8, 16]) cfg.define_knob("vecc", [4, 2, 1]) cfg.define_knob("veci", [8, 16, 4, 2, 1]) cfg.define_knob("auto_unroll_max_step", [0, 16, 64]) cfg.define_knob("unroll_explicit", [0, 1]) M_dim, K_dim = get_const_tuple(A.shape) N_dim, _ = get_const_tuple(B.shape) # print("M_dim: ", M_dim, " N_dim: ", N_dim, " K_dim: ", K_dim) # Ensure that the default parameters are applicable when autotvm is not in use if M_dim % 32 == 0 and N_dim % 8 == 0: cfg.define_knob("wmma_m", [32, 16, 8]) elif M_dim % 16 == 0 and N_dim % 16 == 0: cfg.define_knob("wmma_m", [16, 8, 32]) elif M_dim % 8 == 0 and N_dim % 32 == 0: cfg.define_knob("wmma_m", [8, 16, 32]) warp_size = 32 wmma_k = 16 block_row_warps = cfg["block_row_warps"].val block_col_warps = cfg["block_col_warps"].val warp_row_tiles = cfg["warp_row_tiles"].val warp_col_tiles = cfg["warp_col_tiles"].val chunk = cfg["chunk"].val offset = cfg["offset"].val offsetCS = cfg["offsetCS"].val wmma_m = cfg["wmma_m"].val vecc = cfg['vecc'].val veci = cfg['veci'].val if wmma_m == 16: wmma_n = 16 elif wmma_m == 8: wmma_n = 32 elif wmma_m == 32: wmma_n = 8 AS_align = chunk * wmma_k + offset BS_align = chunk * wmma_k + offset CS_align = warp_col_tiles * block_col_warps * wmma_n + offsetCS AS_stride = [AS_align, 1] BS_stride = [BS_align, 1] AF_stride = [wmma_k, 1] BF_stride = [wmma_k, 1] CF_stride = [warp_col_tiles * wmma_n, 1] CS_stride = [CS_align, 1] block_x = te.thread_axis("blockIdx.x") block_y = te.thread_axis("blockIdx.y") block_z = te.thread_axis("blockIdx.z") thread_x = te.thread_axis("threadIdx.x") thread_y = te.thread_axis("threadIdx.y") thread_z = te.thread_axis("threadIdx.z") # Schedule for dense computation block_factor_m = wmma_m * warp_row_tiles * block_row_warps block_factor_n = wmma_n * warp_col_tiles * block_col_warps rc_b, rc_oh, rc_ow, rc_c = RC.op.axis rc_m = sch[RC].fuse(rc_b, rc_oh, rc_ow) rc_n = rc_c rc_by, rc_byi = sch[RC].split(rc_m, factor=block_factor_m) rc_bx, rc_bxi = sch[RC].split(rc_n, factor=block_factor_n) sch[RC].reorder(rc_by, rc_bx, rc_byi, rc_bxi) rc_t = sch[RC].fuse(rc_byi, rc_bxi) rc_t, rc_vi = sch[RC].split(rc_t, factor=vecc) rc_t, rc_tx = sch[RC].split(rc_t, factor=warp_size) rc_t, rc_ty = sch[RC].split(rc_t, factor=block_row_warps) rc_t, rc_tz = sch[RC].split(rc_t, factor=block_col_warps) sch[RC].bind(rc_by, block_y) sch[RC].bind(rc_bx, block_x) sch[RC].bind(rc_tz, thread_z) sch[RC].bind(rc_ty, thread_y) sch[RC].bind(rc_tx, thread_x) sch[RC].vectorize(rc_vi) # print(tvm.lower(sch, [Input, Filter, Bias, BC])) # Schedule for wmma store sch[CS].compute_at(sch[RC], rc_bx) cs_m, cs_n = CS.op.axis sch[CS].storage_align(cs_m, CS_align - 1, CS_align) cs_m, cs_mtc = sch[CS].split(cs_m, factor=wmma_m) cs_n, cs_ntc = sch[CS].split(cs_n, factor=wmma_n) cs_m, cs_mwi = sch[CS].split(cs_m, factor=warp_row_tiles) cs_n, cs_nwi = sch[CS].split(cs_n, factor=warp_col_tiles) sch[CS].reorder(cs_m, cs_n, cs_mwi, cs_nwi, cs_mtc, cs_ntc) # Schedule for wmma computation sch[CF].compute_at(sch[CS], cs_n) warp_i, warp_j = CF.op.axis warp_i, _ii = sch[CF].split(warp_i, factor=wmma_m) warp_j, _jj = sch[CF].split(warp_j, factor=wmma_n) (k,) = CF.op.reduce_axis k, _k = sch[CF].split(k, factor=wmma_k) ko, ki = sch[CF].split(k, factor=chunk) # koo, ko = sch[CF].split(ko, factor=3) sch[CF].reorder(ko, ki, warp_i, warp_j, _ii, _jj, _k) # Schedule for wmma_matrix_a load sch[AF].compute_at(sch[CF], ki) af_m, af_k = AF.op.axis af_m, af_mi = sch[AF].split(af_m, factor=wmma_m) af_k, af_ki = sch[AF].split(af_k, factor=wmma_k) sch[AF].reorder(af_m, af_k, af_mi, af_ki) # Schedule for wmma_matrix_b load sch[BF].compute_at(sch[CF], ki) bf_n, bf_k = BF.op.axis bf_n, bf_ni = sch[BF].split(bf_n, factor=wmma_n) bf_k, bf_ki = sch[BF].split(bf_k, factor=wmma_k) sch[BF].reorder(bf_n, bf_k, bf_ni, bf_ki) # double buffer sch[AS].double_buffer() sch[BS].double_buffer() # Schedule for A's(B's) shared memory load def shared_shedule(stage, strides): sch[stage].compute_at(sch[CF], ko) xo, yo = stage.op.axis sch[stage].storage_align(xo, strides - 1, strides) yo, vi = sch[stage].split(yo, factor=veci) t = sch[stage].fuse(xo, yo) t, tx = sch[stage].split(t, factor=warp_size) t, ty = sch[stage].split(t, factor=block_row_warps) _, tz = sch[stage].split(t, factor=block_col_warps) sch[stage].bind(ty, thread_y) sch[stage].bind(tz, thread_z) sch[stage].bind(tx, thread_x) sch[stage].vectorize(vi) shared_shedule(AS, AS_align) shared_shedule(BS, BS_align) shape = (wmma_m, wmma_n, wmma_k) # TODO: add checking here, datatype casting may cause precision loss AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype) BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype) k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm") CL_compute = te.compute( (wmma_m, wmma_n), lambda ii, jj: te.sum( AL_gemm[ii, k_gemm].astype( out_dtype) * BL_gemm[jj, k_gemm].astype(out_dtype), axis=k_gemm, ), name="CL_compute", ) # print(tvm.lower(sch, [Input, Filter, Bias, BC])) # lower the computation loops down to TensorCore hardware intrinsics # by mapping the dense tensorcore to tensor intrinsics sch[AF].tensorize( af_mi, intrin_wmma_load_matrix_A( AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), in_dtype=in_dtype, ), ) sch[BF].tensorize( bf_ni, intrin_wmma_load_matrix_W( BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), in_dtype=in_dtype, ), ) sch[CF].tensorize( _ii, intrin_wmma_gemm(AL_gemm, BL_gemm, CL_compute, AF_stride, BF_stride, CF_stride, shape), ) sch[CS].tensorize( cs_mtc, intrin_wmma_store_matrix( CS_stride, CF_stride, shape, out_dtype, ( wmma_m, wmma_n), (wmma_m, wmma_n) ), ) return sch, [Input, Filter, Bias, BC] #########################################Run Testing################################# # [batch, ih, iw, ic, oc, kh, kw, ph, pw, sh, sw] trails = [ (1, 16, 16, 128, 128, 3, 3, 1, 1, 1, 1), ] for (b, ih, iw, ic, oc, kh, kw, ph, pw, sh, sw) in trails: input_shape = (b, ih, iw, ic) kernel_shape = (oc, kh, kw, ic) bias_shape = (1, 1, 1, oc) print("params: ", (b, ih, iw, ic, oc, kh, kw, ph, pw, sh, sw)) print("input_shape: ", input_shape) print("kernel_shape: ", kernel_shape) in_dtype = "int8" out_dtype = "int32" x_np = np.random.uniform(-10, 20, input_shape).astype(in_dtype) w_np = np.random.uniform(-10, 20, kernel_shape).astype(in_dtype) b_np = np.random.uniform(-10, 20, bias_shape).astype(in_dtype) dev = tvm.cuda(0) target = tvm.target.cuda() target_host = 'llvm' x_tc = tvm.nd.array(x_np, dev) w_tc = tvm.nd.array(w_np, dev) b_tc = tvm.nd.array(b_np, dev) ################# use ours #################### logging.getLogger("autotvm").setLevel(logging.DEBUG) logging.getLogger("autotvm").addHandler(logging.StreamHandler(sys.stdout)) task = autotvm.task.create( "implicit_gemm_conv_tensorecore", args=(input_shape, kernel_shape, (sh, sw), (ph, pw), 1, in_dtype, out_dtype), target="cuda", target_host=target_host ) print(task.config_space) measure_option = autotvm.measure_option( builder=autotvm.LocalBuilder(), runner=autotvm.LocalRunner( number=1, repeat=10, min_repeat_ms=0 ), ) str_value = (b, ih, iw, ic, oc, kh, kw, ph, pw, sh, sw) log_name = "conv2d" # + str(index) for val in str_value: log_name += "_" + str(val) log_name += '.log' log_path = os.path.join("./", log_name) # tuner = autotvm.tuner.XGBTuner(task) tuner = autotvm.tuner.GATuner(task) if os.path.isfile(log_path): tuner.load_history(autotvm.record.load_from_file(log_path)) tuner.tune( n_trial=min(10, len(task.config_space)), early_stopping=None, measure_option=measure_option, callbacks=[autotvm.callback.log_to_file(log_path)], ) # inspect the best config dispatch_context = autotvm.apply_history_best(log_path) best_config = dispatch_context.query(task.target, task.workload) print("\nBest config:") print(best_config) dev = tvm.cuda(0) with autotvm.apply_history_best(log_path): with tvm.target.Target("cuda"): sch, arg_bufs = implicit_gemm_conv_tensorecore_schedule( input_shape, kernel_shape, (sh, sw), (ph, pw), 1, in_dtype, out_dtype) print(tvm.lower(sch, arg_bufs, simple_mode=True)) tensorcore_f = tvm.build(sch, arg_bufs) print(tensorcore_f.imported_modules[0].get_source()) _, _, _, RC = arg_bufs outshape = [i.value for i in RC.shape] y_tc = tvm.nd.array(np.zeros(outshape, dtype=RC.dtype), dev) tensorcore_f(x_tc, w_tc, b_tc, y_tc) evaluator = tensorcore_f.time_evaluator( tensorcore_f.entry_name, dev, number=50) print("Time cost of this operator: %fus" % (evaluator(x_tc, w_tc, b_tc, y_tc).mean * 1000000)) ################ use python###################### wt = w_np.transpose((1, 2, 3, 0)) # OHWI ==> HWIO c_np = conv2d_nhwc_python(x_np, wt, sh, ph).astype(RC.dtype) c_np += b_np tvm.testing.assert_allclose(y_tc.asnumpy(), c_np, atol=1e-2, rtol=1e-2) print("Pass!")