diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index d1a195fbad6f..ffd6b6d09533 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -244,6 +244,7 @@ def apply( LOAD_V_SHARED, LOAD_V_VEC, UNROLL, + SUPPORT_WARP_SHUFFLE, ): # rfactor: reduce to tx * vec_c _, s, r, c = sch.get_loops(block=gemv) @@ -273,10 +274,17 @@ def apply( shared_mem_usage = 0 for buf in vector_input_buffers: - buf_size = reduce( - lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) - ) * get_bytes(buf.dtype) + dtype_bytes = get_bytes(buf.dtype) + buf_size = ( + reduce(lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1)) + * dtype_bytes + ) shared_mem_usage += buf_size + if not SUPPORT_WARP_SHUFFLE: + # When warp shuffle is not able, cross-thread allreduce + # is implemented with shared memory. + shared_mem_usage += TS * TR * dtype_bytes + LOAD_V_SHARED = ( LOAD_V_SHARED and isinstance(shared_mem_usage, tir.IntImm) @@ -421,11 +429,13 @@ def apply( len_R = len_r * len_c TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" + SUPPORT_WARP_SHUFFLE = False if target.kind.name == "cuda": VEC_C = 4 LOAD_V_SHARED = True LOAD_V_VEC = 8 UNROLL = 256 + SUPPORT_WARP_SHUFFLE = True if isinstance(len_S, int): if len_S > len_R: TS, TR = 4, 64 @@ -438,6 +448,7 @@ def apply( LOAD_V_SHARED = False LOAD_V_VEC = -1 UNROLL = 256 + SUPPORT_WARP_SHUFFLE = True if isinstance(len_S, int): if len_S > len_R: TS, TR = 4, 16 @@ -515,6 +526,7 @@ def apply( LOAD_V_SHARED=LOAD_V_SHARED, LOAD_V_VEC=LOAD_V_VEC, UNROLL=UNROLL, + SUPPORT_WARP_SHUFFLE=SUPPORT_WARP_SHUFFLE, ) def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument