Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use int for int8x4 due to performance overhead of char4 #1569

Merged
merged 3 commits into from
Aug 9, 2018

Conversation

vinx13
Copy link
Member

@vinx13 vinx13 commented Aug 8, 2018

Loading four int8 elements as char4 is likely to produce more integer instructions. When we use int8 intrinsics (e.g. dp4a), we need packed 32-bit data, which need extra operations for packing int8 elements.

For example, below is a ptx code snippet of
__dp4a((( char4*)(( signed char*)A_shared_local + ((k_inner_outer_outer % 2) * 32)))[0], (( char4*)(( signed char*)B_shared_local + ((k_inner_outer_outer % 2) * 32)))[0], C_local[0]);

ld.shared.v4.u8 {%rs577, %rs578, %rs579, %rs580}, [%r5+24];
ld.shared.v4.u8 {%rs625, %rs626, %rs627, %rs628}, [%r6+24];
...
cvt.u32.u16 %r2873, %rs580;
mul.wide.u16 %r2874, %rs578, 256;
cvt.u32.u16 %r2875, %rs577;
cvt.u32.u16 %r2876, %rs579;
prmt.b32 %r2877, %r2876, %r2875, 28756;
prmt.b32 %r2878, %r2873, %r2877, 1620;
or.b32 %r2010, %r2878, %r2874;

cvt.u32.u16 %r2879, %rs628;
mul.wide.u16 %r2880, %rs626, 256;
cvt.u32.u16 %r2881, %rs625;
cvt.u32.u16 %r2882, %rs627;
prmt.b32 %r2883, %r2882, %r2881, 28756;
prmt.b32 %r2884, %r2879, %r2883, 1620;
or.b32 %r1815, %r2884, %r2880;

dp4a.s32.s32 %r1785, %r2010, %r1815, %r1529;

We would like to use ld.shared.u32 in this case so that 32-bit data can be directly loaded.

This disables support for vectorized int8 arithmetic operations. Since these operations are used in few cases, we prefer better performance here.

@vinx13
Copy link
Member Author

vinx13 commented Aug 8, 2018

@tqchen Please review.

@tqchen
Copy link
Member

tqchen commented Aug 8, 2018

cc @nishi-t

@@ -90,7 +90,7 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*)
if (t.lanes() == 4) {
// directly 4 8 bit int in integer.
enable_int8_ = true;
os << "char4"; return;
os << "int"; return;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a comment block here on why are we making this choice, so people won't change it back

@nishi-t
Copy link
Contributor

nishi-t commented Aug 9, 2018

vectorized_add test will not work for int8 anymore. Please remove this:

check_cuda("int8", 64, 4)

https://github.com/dmlc/tvm/blob/master/tests/python/unittest/test_codegen_cuda.py#L34

@vinx13 vinx13 force-pushed the feature/int_for_int8x4 branch from 27e14ab to 112c633 Compare August 9, 2018 03:11
@vinx13
Copy link
Member Author

vinx13 commented Aug 9, 2018

@tqchen Please review again.

@tqchen tqchen merged commit 41d4dd6 into apache:master Aug 9, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
status: need update need update based on feedbacks
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants