Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/relax/backend/cuda/cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype):
(lhs_dtype == "float16" and rhs_dtype == "float16")
or (lhs_dtype == "float32" and rhs_dtype == "float32")
or (lhs_dtype == "int8" and rhs_dtype == "int8")
or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16")
)


Expand Down
7 changes: 6 additions & 1 deletion src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s
if (int_support && TypeMatch(out_dtype, kDLInt, 32)) {
return TypeMatch(in_dtype, kDLInt, 8);
} else if (TypeMatch(out_dtype, kDLFloat, 32)) {
return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16);
return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) ||
TypeMatch(in_dtype, kDLBfloat, 16);
} else {
return false;
}
Expand Down Expand Up @@ -162,6 +163,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,

if (TypeMatch(A->dtype, kDLFloat, 16)) {
ab_type = CUDA_R_16F;
} else if (TypeMatch(A->dtype, kDLBfloat, 16)) {
ab_type = CUDA_R_16BF;
} else if (TypeMatch(A->dtype, kDLInt, 8)) {
ab_type = CUDA_R_8I;
} else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
Expand All @@ -171,6 +174,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,

if (TypeMatch(C->dtype, kDLFloat, 16)) {
c_type = CUDA_R_16F;
} else if (TypeMatch(C->dtype, kDLBfloat, 16)) {
c_type = CUDA_R_16BF;
} else if (TypeMatch(C->dtype, kDLInt, 32)) {
c_type = CUDA_R_32I;
compute_type = CUBLAS_COMPUTE_32I;
Expand Down
5 changes: 5 additions & 0 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) {
case 64:
return CUDA_R_64F;
}
} else if (type.code == kDLBfloat) {
switch (type.bits) {
case 16:
return CUDA_R_16BF;
}
}
LOG(FATAL) << "Unsupported cuda type";
}
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relax/test_codegen_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload():
tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3)


@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed")
@pytest.mark.parametrize(
"x_shape, y_shape, transpose_y, out_dtype",
[
((10, 32), (64, 32), True, "float32"),
((32, 16), (32, 16), True, "float32"),
((2, 10, 32), (2, 64, 32), True, "float32"),
],
)
def test_matmul_bfloat16_offload(
x_shape,
y_shape,
transpose_y,
out_dtype,
):
in_dtype = "bfloat16"
mod = get_relax_matmul_module(
x_shape,
y_shape,
in_dtype,
out_dtype,
bias_shape=None,
transposed_y=transpose_y,
activation=None,
)
# Generate input data in float32 and then convert to bfloat16 using ml_dtypes.
x_float32 = np.random.uniform(low=0, high=5, size=x_shape).astype("float32")
y_float32 = np.random.uniform(low=0, high=5, size=y_shape).astype("float32")
x_bf16 = ml_dtypes.bfloat16(x_float32)
y_bf16 = ml_dtypes.bfloat16(y_float32)

# For the reference result, adjust y (if needed) in float32.
z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32
args = (x_bf16, y_bf16)

out = get_result_with_relax_cublas_offload(mod, args)
ref_out = np.matmul(x_float32, z).astype(out_dtype)

tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize(
"M, N, K, out_dtype, transposed_y, partition_done",
[
Expand Down