Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][CUDA] Fix make_int4x cuda codegen vectorize #8137

Merged
merged 1 commit into from
May 26, 2021

Conversation

wyc-ruiker
Copy link
Contributor

Added support for int4x32 int4x16 int4x4 in BroadcastNode.

In the int4x4 testcase, the IR is:

primfn(compute_1: handle) -> ()
  attr = {"global_symbol": "main", "tir.noalias": True}
  buffers = {compute: Buffer(compute_2: Pointer(int4), int4, [64, 4], [])}
  buffer_map = {compute_1: compute} {
  attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")] "thread_extent" = 64;
  compute_2[ramp((blockIdx.x*4), 1, 4)] = broadcast(1i4, 4)
}

Before the fix in codegen_c.cc, the codegen cuda is:

extern "C" __global__ void make_int4x4_kernel0(int* __restrict__ compute) {
  ((int16_t*)(compute + ((((int)blockIdx.x) * 4)) / 8))[0] = (int16_t)4369;
}

For int16_t, this index (((int)blockIdx.x) * 4)) / 8 is a bug.
After the fix in codegen_c.cc, the codegen cuda is:

extern "C" __global__ void make_int4x4_kernel0(int* __restrict__ compute) {
  ((int16_t*)(compute) + ((((int)blockIdx.x) * 4)) / 4)[0] = (int16_t)4369;
}

Could you please help review this fix? @vinx13 @Hzfengsy

@tqchen
Copy link
Member

tqchen commented May 26, 2021

@vinx13 please help to manage this PR

@vinx13 vinx13 merged commit f4dce24 into apache:main May 26, 2021
@wyc-ruiker wyc-ruiker deleted the fix-int4 branch May 27, 2021 02:27
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Jun 17, 2021
Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jun 17, 2021
Co-authored-by: wangyucheng <wangyucheng@sensetime.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants