From 130434748bf73c8684ec355786f75c2d3cce928d Mon Sep 17 00:00:00 2001 From: Wei Pan Date: Wed, 1 Apr 2020 10:51:29 -0700 Subject: [PATCH] [CodeGen][CUDA] Fix bugs - Support vectorized casts - It is incorrect to extract elements from int8x4 with 0x000000ff & (x >> i * 8) as this value is of type int in C/C++. If this expression is used for sign extensions, the sign bit will be wrong. Simply use C style casts instead and sign bits will just work. Signed-off-by: Wei Pan --- src/target/source/codegen_cuda.cc | 39 +++++++++++++-- src/target/source/codegen_cuda.h | 1 + .../unittest/test_target_codegen_cuda.py | 50 +++++++++++++++++++ 3 files changed, 87 insertions(+), 3 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index f8bc8731d4ed..9c4fc69a9d78 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -273,8 +273,10 @@ void CodeGenCUDA::PrintVecElemLoad( const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*) static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); - if (t.is_int() && t.bits() == 8) { - os << "(0x000000ff & (" << vec << " >> " << i * 8 << "))"; + if ((t.is_int()) && t.bits() == 8) { + os << "((char)(" << vec << " >> " << i * 8 << "))"; + } else if ((t.is_uint()) && t.bits() == 8) { + os << "((unsigned char)(" << vec << " >> " << i * 8 << "))"; } else if (t.is_float16()) { os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" << access[i % 2]; @@ -288,7 +290,7 @@ void CodeGenCUDA::PrintVecElemStore( this->PrintIndent(); static const char access[] = {'x', 'y', 'z', 'w'}; CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4)); - if (t.is_int() && t.bits() == 8) { + if (t.bits() == 8 && (t.is_int() || t.is_uint())) { stream << vec << "="; // Do not read the first undef lane. if (i != 0) { @@ -352,6 +354,37 @@ void CodeGenCUDA::PrintStorageScope( } } +void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { + DataType from_ty = op->value.dtype(); + DataType target_ty = op->dtype; + CHECK_EQ(target_ty.lanes(), from_ty.lanes()); + + // Emit simple C-style type conversion. + if (from_ty.is_scalar()) + return CodeGenC::VisitExpr_(op, os); + + // We could emit make_float4 like calls, but the emitted code looks + // too compact to read. Emit this as vectorized unary ops. + std::string sret = GetUniqueName("_"); + this->PrintIndent(); + this->PrintType(target_ty, stream); + stream << ' ' << sret << ";\n"; + { + EnterScopeRAII scope(this); + std::string src = SSAGetID(PrintExpr(op->value), from_ty); + for (int i = 0, lanes = from_ty.lanes(); i < lanes; ++i) { + std::ostringstream val; + val << "("; + PrintType(target_ty.element_of(), val); + val << ")("; + PrintVecElemLoad(src, from_ty, i, val); + val << ")"; + PrintVecElemStore(sret, target_ty, i, val.str()); + } + } + os << sret; +} + void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) { if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) { need_mma_h_ = true; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index c31bdf5f2d59..6ba748755d5b 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -62,6 +62,7 @@ class CodeGenCUDA final : public CodeGenC { void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode *op, std::ostream& os) final; void VisitExpr_(const CallNode *op, std::ostream& os) final; + void VisitExpr_(const CastNode* op, std::ostream& os) final; void VisitStmt_(const EvaluateNode *op) final; void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AttrStmtNode *op) final; diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index e8c6cd1925a8..75d6c1425732 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -348,6 +348,55 @@ def test_cuda_floordiv_with_vectorization(): func(a_nd, b_nd) tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3) +def test_vectorized_casts(): + if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): + print("skip because cuda is not enabled..") + return + + def check(t0, t1): + if (t0 == "float16" or t1 == "float16") and not have_fp16(tvm.gpu(0).compute_version): + print("Skip because gpu does not have fp16 support") + return + + # compute + n = 128 + A = te.placeholder((n,), dtype=t0, name='A') + B = te.placeholder((n,), dtype=t1, name='B') + C = te.compute((n,), lambda i: A[i] + topi.cast(B[i], A.dtype), name='C') + + # schedule + s = tvm.te.create_schedule(C.op) + ob, ib = s[C].split(s[C].op.axis[0], nparts=32) + _, iib = s[C].split(ib, factor=4) + s[C].vectorize(iib) + s[C].bind(ob, tx) + func = tvm.build(s, [A, B, C], "cuda") + + # correctness + ctx = tvm.gpu(0) + low, high = (0, 20) if t0.startswith('u') or t1.startswith('u') else (-10, 10) + a_np = np.random.randint(low, high, size=n).astype(A.dtype) + b_np = np.random.randint(low, high, size=n).astype(B.dtype) + c_np = (a_np + b_np).astype(A.dtype) + a_nd = tvm.nd.array(a_np, ctx) + b_nd = tvm.nd.array(b_np, ctx) + c_nd = tvm.nd.array(np.zeros(c_np.shape, dtype=c_np.dtype), ctx) + func(a_nd, b_nd, c_nd) + tvm.testing.assert_allclose(c_nd.asnumpy(), c_np, rtol=1e-3) + + def skip(t0, t1): + if t0 == t1: + return True + # CUDA does support cast between {u}int8 and fp16. + skip_set = {"float16", "uint8", "int8"} + if t0 in skip_set and t1 in skip_set: + return True + return False + + types = ["float16", "float32", "int8", "uint8", "int16", "uint16", "int32", "uint32"] + for t0, t1 in [(x, y) for x in types for y in types if not skip(x, y)]: + check(t0, t1) + def sched(B): s = te.create_schedule(B.op) io, ii = s[B].split(s[B].op.axis[0], nparts=1) @@ -474,6 +523,7 @@ def run_test(dtype): test_cuda_make_int8x4() test_cuda_inf_nan() test_cuda_shuffle() + test_vectorized_casts() test_cuda_reducition_binding() test_rfactor_predicates() test_cuda_const_float_to_half()