Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][CUDA] Fix: cuda codegen vectorize cast #7561

Merged
merged 4 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 113 additions & 20 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ std::string CodeGenCUDA::Finish() {
decl_stream << " #define uint unsigned int\n";
decl_stream << " #define uchar unsigned char\n";
decl_stream << " #define ushort unsigned short\n";
decl_stream << " #define int64_t long\n";
decl_stream << " #define uint64_t ulong\n";
decl_stream << " #define int64_t long long\n";
decl_stream << " #define uint64_t unsigned long long\n";
decl_stream << "#endif\n";

return CodeGenC::Finish();
Expand Down Expand Up @@ -141,7 +141,21 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
}
break;
case 32:
os << "float";
if (lanes <= 4) {
os << "float";
} else if (lanes <= 8) {
// Emit CUDA code to access fp32 vector elements for 4 < lanes <= 8.
//
// float8 is stored as ulonglong4
//
// f8.v1 is emitted as *(float2*)(&(ul4.x)).x
// f8.v2 is emitted as *(float2*)(&(ul4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for float type with lanes > 4";
os << "ulonglong" << lanes / 2;
} else {
fail = true;
}
break;
case 64:
os << "double";
Expand All @@ -151,6 +165,7 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
break;
}
if (!fail && (t.is_scalar() || t.bits() == 16)) return;
if (!fail && (lanes > 4 && lanes <= 8 && t.bits() == 32)) return;
if (!fail && (lanes >= 2 && lanes <= 4)) {
os << lanes;
return;
Expand Down Expand Up @@ -238,12 +253,54 @@ void CodeGenCUDA::PrintType(DataType t, std::ostream& os) { // NOLINT(*)
break;
}
}
case 16:
os << "short";
case 16: {
if (t.is_scalar()) {
os << "short";
} else if (t.lanes() <= 4) {
os << "short" << lanes;
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int16 vector elements.
//
// short4 is stored as int2
//
// s4.x is emitted as *(short2*)(&(i2.x)).x
// s4.y is emitted as *(short2*)(&(i2.x)).y
// s4.z is emitted as *(short2*)(&(i2.y)).x
// s4.w is emitted as *(short2*)(&(i2.y)).y
//
ICHECK_EQ(t.lanes() % 2, 0) << "only support even lane for shorT type with lanes > 4";
os << "int" << t.lanes() / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing break here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

break;
case 32:
os << "int";
}
case 32: {
if (t.is_scalar()) {
os << "int";
} else if (t.lanes() <= 4) {
os << "int" << t.lanes();
} else if (t.lanes() <= 8) {
// Emit CUDA code to access int32 vector elements for 4 < lanes <= 8.
//
// int8 is stored as longlong4
//
// i8.v1 is emitted as *(int2*)(&(l4.x)).x
// i8.v2 is emitted as *(int2*)(&(l4.x)).y
//
ICHECK_EQ(lanes % 2, 0) << "only support even lane for int32 type with lanes > 4";
os << "longlong" << lanes / 2;
} else {
fail = true;
}
if (!fail) {
return;
}
break;
}
case 64: {
if (t.is_scalar()) {
os << "int64_t";
Expand Down Expand Up @@ -314,21 +371,36 @@ void CodeGenCUDA::PrintVecElemLoad(const std::string& vec, DataType t, int i,
}

static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) {
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((char)(" << vec << " >> " << i * 8 << "))";
}
} else if ((t.is_uint()) && t.bits() == 8) {
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()];
} else {
os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
}
} else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
os << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2];
} else {
os << vec << "." << access[i];
}
Expand All @@ -338,22 +410,43 @@ void CodeGenCUDA::PrintVecElemStore(const std::string& vec, DataType t, int i,
const std::string& value) {
this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 : (t.bits() == 16 || t.bits() == 32) ? 8 : 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n";
} else {
stream << vec << "=";
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << vec << " & ~(0x000000ff << " << i * 8 << ") |";
stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
}
stream << "(" << value << " << " << i * 8 << ");\n";
stream << "(" << value << " << " << i % 4 * 8 << ");\n";
}
} else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2] << " = "
<< value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
if (t.bits() == 16) {
if (t.is_int()) {
type_name = "short";
} else if (t.is_uint()) {
type_name = "ushort";
}
} else if (t.bits() == 32) {
if (t.is_int()) {
type_name = "int";
} else if (t.is_uint()) {
type_name = "uint";
} else if (t.is_float()) {
type_name = "float";
}
}
ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << vec << "." << access[i] << " = " << value << ";\n";
}
Expand Down
30 changes: 23 additions & 7 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def test_cuda_floormod_with_vectorization():
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_vectorized_casts():
def check(t0, t1):
def check(t0, t1, factor):
if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version):
print("Skip because gpu does not have fp16 support")
return
Expand All @@ -511,9 +511,8 @@ def check(t0, t1):

# schedule
s = tvm.te.create_schedule(C.op)
ob, ib = s[C].split(s[C].op.axis[0], nparts=32)
_, iib = s[C].split(ib, factor=4)
s[C].vectorize(iib)
ob, ib = s[C].split(s[C].op.axis[0], factor=factor)
s[C].vectorize(ib)
s[C].bind(ob, tx)
func = tvm.build(s, [A, B, C], "cuda")

Expand All @@ -538,9 +537,26 @@ def skip(t0, t1):
return True
return False

types = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
for t0, t1 in [(x, y) for x in types for y in types if not skip(x, y)]:
check(t0, t1)
types_4 = [
"float16",
"float32",
"int8",
"uint8",
"int16",
"uint16",
"int32",
"uint32",
"float64",
"int64",
"uint64",
]
types_8 = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"]
for t0, t1 in [(x, y) for x in types_4 for y in types_4 if not skip(x, y)]:
check(t0, t1, 4)
for t0, t1 in [(x, y) for x in types_8 for y in types_8 if not skip(x, y)]:
check(t0, t1, 8)
check("int8", "uint8", 16)
check("uint8", "int8", 16)


def sched(B):
Expand Down