diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 575f52e2257a..a67cb80b917b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -943,22 +943,43 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( // NOTE: important to print expr first // in case each expr have their own nested expressions // print each elements - for (const PrimExpr& vec : op->vectors) { - std::string vec_value = this->PrintExpr(vec); - if (vec.dtype().lanes() == 1) { + if (op->vectors.size() > 1) { + for (const PrimExpr& vec : op->vectors) { + std::string vec_value = this->PrintExpr(vec); + if (vec.dtype().lanes() == 1) { + concat_vec.push_back(vec_value); + } else { + // print out each element + for (int i = 0; i < vec.dtype().lanes(); ++i) { + // access i-th element of each vector + std::ostringstream vec_elem_strm; + vec_elem_strm << vec_value << "[" << i << "]"; + concat_vec.push_back(vec_elem_strm.str()); + } + } + } + } else { + // Extract elements from a single vector-type value. + std::string vec_value = "(" + this->PrintExpr(op->vectors[0]) + ")"; + if (op->vectors[0].dtype().lanes() == 1) { concat_vec.push_back(vec_value); } else { // print out each element - for (int i = 0; i < vec.dtype().lanes(); ++i) { + for (int i = 0; i < op->vectors[0].dtype().lanes(); ++i) { // access i-th element of each vector std::ostringstream vec_elem_strm; - vec_elem_strm << vec_value << "[" << i << "]"; + PrintVecElemLoad(vec_value, op->vectors[0].dtype(), i, vec_elem_strm); concat_vec.push_back(vec_elem_strm.str()); } } } if (op->indices.size() == 1) { // This is an extract element + CHECK(op->indices[0]->IsInstance()) + << "The ShuffleNode indices are expected to be constants at codegen time. However, " + << "a non-constant index is " << op->indices[0] + << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " + << "vectorize."; int64_t idx = Downcast(op->indices[0])->value; ICHECK_LT(idx, concat_vec.size()); os << concat_vec[idx]; @@ -969,6 +990,11 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( os << '('; for (size_t i = 0; i < op->indices.size(); ++i) { if (i != 0) os << ", "; + CHECK(op->indices[i]->IsInstance()) + << "The ShuffleNode indices are expected to be constants at codegen time. However, " + << "a non-constant index is " << op->indices[i] + << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " + << "vectorize."; os << concat_vec[Downcast(op->indices[i])->value]; } os << ')'; diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index b095f5b8cf20..039d89b93feb 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -454,26 +454,6 @@ struct __align__(8) half4_bfloat164 { (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } - __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { - __nv_fp8x2_e5m2 result; - result.__x = (x) | (y << 8); - return result; - } - __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { - __nv_fp8x4_e5m2 result; - result.__x = (a) | (b << 8) | (c << 16) | (d << 24); - return result; - } - __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { - __nv_fp8x2_e4m3 result; - result.__x = (x) | (y << 8); - return result; - } - __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { - __nv_fp8x4_e4m3 result; - result.__x = (a) | (b << 8) | (c << 16) | (d << 24); - return result; - } )"; } if (enable_fp4) { @@ -542,6 +522,30 @@ __host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8 )"; } } + if (enable_fp8) { + stream << R"( +__device__ __nv_fp8x2_e5m2 make___nv_fp8x2_e5m2(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) { + __nv_fp8x2_e5m2 result; + result.__x = (x.__x) | (y.__x << 8); + return result; +} +__device__ __nv_fp8x4_e5m2 make___nv_fp8x4_e5m2(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b, __nv_fp8_e5m2 c, __nv_fp8_e5m2 d) { + __nv_fp8x4_e5m2 result; + result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); + return result; +} +__device__ __nv_fp8x2_e4m3 make___nv_fp8x2_e4m3(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) { + __nv_fp8x2_e4m3 result; + result.__x = (x.__x) | (y.__x << 8); + return result; +} +__device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b, __nv_fp8_e4m3 c, __nv_fp8_e4m3 d) { + __nv_fp8x4_e4m3 result; + result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); + return result; +} +)"; + } if (enable_fp4) { stream << R"( __device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 34b46583d5ad..3c117b58a7a3 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -279,10 +279,11 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; auto vectors = op->vectors.Map(fexpr); - if (vectors.same_as(op->vectors)) { + auto indices = op->indices.Map(fexpr); + if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { return GetRef(op); } else { - return Shuffle(vectors, op->indices); + return Shuffle(vectors, indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index ec290e48d457..58ce6d61742a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -503,7 +503,11 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype.with_scalable_vscale_factor(lanes), op->op, {value}); } else { - return Call(op->dtype.with_lanes(lanes), op->op, {value}); + int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() && + op->args[0].dtype() != DataType::NVFloat4E2M1FN()) + ? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits() + : value.dtype().lanes(); + return Call(op->dtype.with_lanes(new_lanes), op->op, {value}); } } } @@ -624,6 +628,68 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors.size() == 1 && op->indices.size() == 1) + << "Cannot vectorize ShuffleNode with multiple vectors or indices: the vector size is " + << op->vectors.size() << " and the index size is " << op->indices.size(); + int lane_vectors = 0; + int lane_indices = 0; + Array vectors = MutateArray(op->vectors, &lane_vectors); + Array indices = MutateArray(op->indices, &lane_indices); + if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { + return GetRef(op); + } + + int new_vec_length = Downcast(var_lanes_)->value / op->vectors[0].dtype().lanes(); + PrimExpr updated_index = indices[0]; + // Check that the indices satisfy the specific patterns. + auto f_check_index = [this, op](const PrimExpr& index) { + // Allowing Ramp(0, 1, var_lanes_) + if (const auto* ramp = index.as()) { + if (ramp->base->IsInstance() && Downcast(ramp->base)->value == 0 && + ramp->stride->IsInstance() && Downcast(ramp->stride)->value == 1 && + ramp->lanes->IsInstance() && + Downcast(ramp->lanes)->value == Downcast(var_lanes_)->value) { + return true; + } + } + // Allowing FloorMod(Ramp(0, 1, var_lanes_), Broadcast(op->vectors[0]->lanes, var_lanes_)) + if (const auto* floordiv = index.as()) { + if (const auto* ramp = floordiv->a.as()) { + if (const auto* broadcast = floordiv->b.as()) { + if (ramp->base->IsInstance() && Downcast(ramp->base)->value == 0 && + ramp->stride->IsInstance() && + Downcast(ramp->stride)->value == 1 && + ramp->lanes->IsInstance() && + Downcast(ramp->lanes)->value == Downcast(var_lanes_)->value && + broadcast->value->IsInstance() && + Downcast(broadcast->value)->value == op->vectors[0]->dtype.lanes() && + broadcast->lanes->IsInstance() && + Downcast(broadcast->lanes)->value == Downcast(var_lanes_)->value) { + return true; + } + } + } + } + + return false; + }; + CHECK(f_check_index(updated_index)); + + if (new_vec_length == 1) { + return tir::Substitute(op->vectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}}); + } else { + PrimExpr prev_ramp = ramp_; + PrimExpr prev_var_lanes = var_lanes_; + ramp_ = Ramp(IntImm(var_->dtype, 0), IntImm(var_->dtype, 2), new_vec_length); + var_lanes_ = tvm::IntImm(var_lanes_.dtype(), new_vec_length); + lane_vectors = 0; + vectors = MutateArray(op->vectors, &lane_vectors); + ramp_ = prev_ramp; + var_lanes_ = prev_var_lanes; + return vectors[0]; + } + } // BufferStore Stmt VisitStmt_(const BufferStoreNode* op) final { auto store = GetRef(op); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 0a170026c96b..14820ec34f09 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -211,5 +211,84 @@ def reinterpret( ) +@tvm.testing.requires_cuda_compute_version(10) +def test_e2m1_dequantize(): + n = 128 + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + num_elem_per_storage = 32 // 4 + + def get_reinterpret_mod(func_type, vector_length): + @T.prim_func + def shuffle_reinterpret( + A: T.Buffer((n // num_elem_per_storage,), "uint32"), + B: T.Buffer((n,), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.Shuffle( + [ + T.reinterpret( + "float4_e2m1fnx2", + T.bitwise_and( + T.shift_right( + A[v_i // num_elem_per_storage], + ((v_i % num_elem_per_storage) // 2 * 4 * 2).astype( + "uint32" + ), + ), + T.uint32((1 << (4 * 2)) - 1), + ).astype("uint8"), + ).astype("float16x2") + ], + indices=[v_i % 2], + ) + + @T.prim_func + def scalar_reinterpret( + A: T.Buffer((n // num_elem_per_storage,), "uint32"), + B: T.Buffer((n,), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret( + "float4_e2m1fn", + T.bitwise_and( + T.shift_right( + A[v_i // num_elem_per_storage], + (v_i % num_elem_per_storage * 4).astype("uint32"), + ), + T.uint32((1 << 4) - 1), + ).astype("uint8"), + ).astype("float16") + + func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret + sch = tvm.tir.Schedule(func) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + return sch.mod + + # We only test the whether the code can be compiled. + for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]): + if func_type == "shuffle" and vector_length == 1: + # Vectorize is necessary for shuffle. + continue + mod = get_reinterpret_mod(func_type, vector_length) + tvm.compile(mod, target=target) + + if __name__ == "__main__": tvm.testing.main()