Skip to content

Tilelang doesn't vectorize A[i] = (fp8_e4_t)B[i] #871

@LJC00118

Description

@LJC00118

A

Current code:

// A: fp8_e4m3, B: float
for (int i = 0; i < 16; i++)
    A[i] = (fp8_e4_t)B[i];

Expected vectorized version:

// A: fp8_e4m3, B: float
__device__ unsigned short float2_cast_to_fp8x2(const float2 x) {
    unsigned short storage;
    asm("{cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;}\n"
        : "=h"(storage)
        : "f"(x.x), "f"(x.y));
    return storage;
}

auto C = reinterpret_cast<unsigned short*>(A);
for (int i = 0; i < 8; i++) {
    float2 value;
    value.x = B[i * 2];
    value.y = B[i * 2 + 1];
    C[i] = float2_cast_to_fp8x2(value);
}

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions