Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -943,22 +943,43 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(
// NOTE: important to print expr first
// in case each expr have their own nested expressions
// print each elements
for (const PrimExpr& vec : op->vectors) {
std::string vec_value = this->PrintExpr(vec);
if (vec.dtype().lanes() == 1) {
if (op->vectors.size() > 1) {
for (const PrimExpr& vec : op->vectors) {
std::string vec_value = this->PrintExpr(vec);
if (vec.dtype().lanes() == 1) {
concat_vec.push_back(vec_value);
} else {
// print out each element
for (int i = 0; i < vec.dtype().lanes(); ++i) {
// access i-th element of each vector
std::ostringstream vec_elem_strm;
vec_elem_strm << vec_value << "[" << i << "]";
concat_vec.push_back(vec_elem_strm.str());
}
}
}
} else {
// Extract elements from a single vector-type value.
std::string vec_value = "(" + this->PrintExpr(op->vectors[0]) + ")";
if (op->vectors[0].dtype().lanes() == 1) {
concat_vec.push_back(vec_value);
} else {
// print out each element
for (int i = 0; i < vec.dtype().lanes(); ++i) {
for (int i = 0; i < op->vectors[0].dtype().lanes(); ++i) {
// access i-th element of each vector
std::ostringstream vec_elem_strm;
vec_elem_strm << vec_value << "[" << i << "]";
PrintVecElemLoad(vec_value, op->vectors[0].dtype(), i, vec_elem_strm);
concat_vec.push_back(vec_elem_strm.str());
}
}
}
if (op->indices.size() == 1) {
// This is an extract element
CHECK(op->indices[0]->IsInstance<IntImmNode>())
<< "The ShuffleNode indices are expected to be constants at codegen time. However, "
<< "a non-constant index is " << op->indices[0]
<< ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or "
<< "vectorize.";
int64_t idx = Downcast<IntImm>(op->indices[0])->value;
ICHECK_LT(idx, concat_vec.size());
os << concat_vec[idx];
Expand All @@ -969,6 +990,11 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(
os << '(';
for (size_t i = 0; i < op->indices.size(); ++i) {
if (i != 0) os << ", ";
CHECK(op->indices[i]->IsInstance<IntImmNode>())
<< "The ShuffleNode indices are expected to be constants at codegen time. However, "
<< "a non-constant index is " << op->indices[i]
<< ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or "
<< "vectorize.";
os << concat_vec[Downcast<IntImm>(op->indices[i])->value];
}
os << ')';
Expand Down
44 changes: 24 additions & 20 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -454,26 +454,6 @@ struct __align__(8) half4_bfloat164 {
(static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16));
return result;
}
__device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) {
__nv_fp8x2_e5m2 result;
result.__x = (x) | (y << 8);
return result;
}
__device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
__nv_fp8x4_e5m2 result;
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
return result;
}
__device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) {
__nv_fp8x2_e4m3 result;
result.__x = (x) | (y << 8);
return result;
}
__device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) {
__nv_fp8x4_e4m3 result;
result.__x = (a) | (b << 8) | (c << 16) | (d << 24);
return result;
}
)";
}
if (enable_fp4) {
Expand Down Expand Up @@ -542,6 +522,30 @@ __host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8
)";
}
}
if (enable_fp8) {
stream << R"(
__device__ __nv_fp8x2_e5m2 make___nv_fp8x2_e5m2(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) {
__nv_fp8x2_e5m2 result;
result.__x = (x.__x) | (y.__x << 8);
return result;
}
__device__ __nv_fp8x4_e5m2 make___nv_fp8x4_e5m2(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b, __nv_fp8_e5m2 c, __nv_fp8_e5m2 d) {
__nv_fp8x4_e5m2 result;
result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
return result;
}
__device__ __nv_fp8x2_e4m3 make___nv_fp8x2_e4m3(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) {
__nv_fp8x2_e4m3 result;
result.__x = (x.__x) | (y.__x << 8);
return result;
}
__device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b, __nv_fp8_e4m3 c, __nv_fp8_e4m3 d) {
__nv_fp8x4_e4m3 result;
result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24);
return result;
}
)";
}
if (enable_fp4) {
stream << R"(
__device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) {
Expand Down
5 changes: 3 additions & 2 deletions src/tir/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,11 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) {
auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); };
auto vectors = op->vectors.Map(fexpr);
if (vectors.same_as(op->vectors)) {
auto indices = op->indices.Map(fexpr);
if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
return Shuffle(vectors, op->indices);
return Shuffle(vectors, indices);
}
}

Expand Down
68 changes: 67 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,11 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
if (value.dtype().is_scalable_vector()) {
return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, {value});
} else {
return Call(op->dtype.with_lanes(lanes), op->op, {value});
int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() &&
op->args[0].dtype() != DataType::NVFloat4E2M1FN())
? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits()
: value.dtype().lanes();
return Call(op->dtype.with_lanes(new_lanes), op->op, {value});
}
}
}
Expand Down Expand Up @@ -624,6 +628,68 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
}
}
PrimExpr VisitExpr_(const ShuffleNode* op) final {
CHECK(op->vectors.size() == 1 && op->indices.size() == 1)
<< "Cannot vectorize ShuffleNode with multiple vectors or indices: the vector size is "
<< op->vectors.size() << " and the index size is " << op->indices.size();
int lane_vectors = 0;
int lane_indices = 0;
Array<PrimExpr> vectors = MutateArray(op->vectors, &lane_vectors);
Array<PrimExpr> indices = MutateArray(op->indices, &lane_indices);
if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
}

int new_vec_length = Downcast<IntImm>(var_lanes_)->value / op->vectors[0].dtype().lanes();
PrimExpr updated_index = indices[0];
// Check that the indices satisfy the specific patterns.
auto f_check_index = [this, op](const PrimExpr& index) {
// Allowing Ramp(0, 1, var_lanes_)
if (const auto* ramp = index.as<RampNode>()) {
if (ramp->base->IsInstance<IntImmNode>() && Downcast<IntImm>(ramp->base)->value == 0 &&
ramp->stride->IsInstance<IntImmNode>() && Downcast<IntImm>(ramp->stride)->value == 1 &&
ramp->lanes->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->lanes)->value == Downcast<IntImm>(var_lanes_)->value) {
return true;
}
}
// Allowing FloorMod(Ramp(0, 1, var_lanes_), Broadcast(op->vectors[0]->lanes, var_lanes_))
if (const auto* floordiv = index.as<FloorModNode>()) {
if (const auto* ramp = floordiv->a.as<RampNode>()) {
if (const auto* broadcast = floordiv->b.as<BroadcastNode>()) {
if (ramp->base->IsInstance<IntImmNode>() && Downcast<IntImm>(ramp->base)->value == 0 &&
ramp->stride->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->stride)->value == 1 &&
ramp->lanes->IsInstance<IntImmNode>() &&
Downcast<IntImm>(ramp->lanes)->value == Downcast<IntImm>(var_lanes_)->value &&
broadcast->value->IsInstance<IntImmNode>() &&
Downcast<IntImm>(broadcast->value)->value == op->vectors[0]->dtype.lanes() &&
broadcast->lanes->IsInstance<IntImmNode>() &&
Downcast<IntImm>(broadcast->lanes)->value == Downcast<IntImm>(var_lanes_)->value) {
return true;
}
}
}
}

return false;
};
CHECK(f_check_index(updated_index));

if (new_vec_length == 1) {
return tir::Substitute(op->vectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}});
} else {
PrimExpr prev_ramp = ramp_;
PrimExpr prev_var_lanes = var_lanes_;
ramp_ = Ramp(IntImm(var_->dtype, 0), IntImm(var_->dtype, 2), new_vec_length);
var_lanes_ = tvm::IntImm(var_lanes_.dtype(), new_vec_length);
lane_vectors = 0;
vectors = MutateArray(op->vectors, &lane_vectors);
ramp_ = prev_ramp;
var_lanes_ = prev_var_lanes;
return vectors[0];
}
}
// BufferStore
Stmt VisitStmt_(const BufferStoreNode* op) final {
auto store = GetRef<BufferStore>(op);
Expand Down
79 changes: 79 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda_fp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,5 +211,84 @@ def reinterpret(
)


@tvm.testing.requires_cuda_compute_version(10)
def test_e2m1_dequantize():
n = 128

dev = tvm.device("cuda", 0)
target = tvm.target.Target.from_device(dev)
num_elem_per_storage = 32 // 4

def get_reinterpret_mod(func_type, vector_length):
@T.prim_func
def shuffle_reinterpret(
A: T.Buffer((n // num_elem_per_storage,), "uint32"),
B: T.Buffer((n,), "float16"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i in range(n):
with T.block("C"):
v_i = T.axis.spatial(n, i)
T.reads(A[v_i])
T.writes(B[v_i])
B[v_i] = T.Shuffle(
[
T.reinterpret(
"float4_e2m1fnx2",
T.bitwise_and(
T.shift_right(
A[v_i // num_elem_per_storage],
((v_i % num_elem_per_storage) // 2 * 4 * 2).astype(
"uint32"
),
),
T.uint32((1 << (4 * 2)) - 1),
).astype("uint8"),
).astype("float16x2")
],
indices=[v_i % 2],
)

@T.prim_func
def scalar_reinterpret(
A: T.Buffer((n // num_elem_per_storage,), "uint32"),
B: T.Buffer((n,), "float16"),
):
T.func_attr({"tir.noalias": T.bool(True)})
for i in range(n):
with T.block("C"):
v_i = T.axis.spatial(n, i)
T.reads(A[v_i])
T.writes(B[v_i])
B[v_i] = T.reinterpret(
"float4_e2m1fn",
T.bitwise_and(
T.shift_right(
A[v_i // num_elem_per_storage],
(v_i % num_elem_per_storage * 4).astype("uint32"),
),
T.uint32((1 << 4) - 1),
).astype("uint8"),
).astype("float16")

func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret
sch = tvm.tir.Schedule(func)
block = sch.get_block("C")
b = sch.get_loops(block)
bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length])
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)
return sch.mod

# We only test the whether the code can be compiled.
for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]):
if func_type == "shuffle" and vector_length == 1:
# Vectorize is necessary for shuffle.
continue
mod = get_reinterpret_mod(func_type, vector_length)
tvm.compile(mod, target=target)


if __name__ == "__main__":
tvm.testing.main()
Loading