diff --git a/python/tvm/relax/backend/cuda/cublas.py b/python/tvm/relax/backend/cuda/cublas.py index 6828381e68e1..f8621d9b5621 100644 --- a/python/tvm/relax/backend/cuda/cublas.py +++ b/python/tvm/relax/backend/cuda/cublas.py @@ -43,6 +43,7 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype, out_dtype): (lhs_dtype == "float16" and rhs_dtype == "float16") or (lhs_dtype == "float32" and rhs_dtype == "float32") or (lhs_dtype == "int8" and rhs_dtype == "int8") + or (lhs_dtype == "bfloat16" and rhs_dtype == "bfloat16") ) diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index ba01f791d98a..3fbda3ac945d 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -125,7 +125,8 @@ bool CheckMixPrecisionType(DLDataType in_dtype, DLDataType out_dtype, bool int_s if (int_support && TypeMatch(out_dtype, kDLInt, 32)) { return TypeMatch(in_dtype, kDLInt, 8); } else if (TypeMatch(out_dtype, kDLFloat, 32)) { - return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16); + return TypeMatch(in_dtype, kDLInt, 8) || TypeMatch(in_dtype, kDLFloat, 16) || + TypeMatch(in_dtype, kDLBfloat, 16); } else { return false; } @@ -162,6 +163,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(A->dtype, kDLFloat, 16)) { ab_type = CUDA_R_16F; + } else if (TypeMatch(A->dtype, kDLBfloat, 16)) { + ab_type = CUDA_R_16BF; } else if (TypeMatch(A->dtype, kDLInt, 8)) { ab_type = CUDA_R_8I; } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) { @@ -171,6 +174,8 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream, if (TypeMatch(C->dtype, kDLFloat, 16)) { c_type = CUDA_R_16F; + } else if (TypeMatch(C->dtype, kDLBfloat, 16)) { + c_type = CUDA_R_16BF; } else if (TypeMatch(C->dtype, kDLInt, 32)) { c_type = CUDA_R_32I; compute_type = CUBLAS_COMPUTE_32I; diff --git a/src/runtime/contrib/cublas/cublas_utils.h b/src/runtime/contrib/cublas/cublas_utils.h index 387065093eaa..3e9ded08deb1 100644 --- a/src/runtime/contrib/cublas/cublas_utils.h +++ b/src/runtime/contrib/cublas/cublas_utils.h @@ -116,6 +116,11 @@ inline cudaDataType_t GetCudaDataType(DLDataType type) { case 64: return CUDA_R_64F; } + } else if (type.code == kDLBfloat) { + switch (type.bits) { + case 16: + return CUDA_R_16BF; + } } LOG(FATAL) << "Unsupported cuda type"; } diff --git a/tests/python/relax/test_codegen_cublas.py b/tests/python/relax/test_codegen_cublas.py index dbcb25b69d52..152f04fc3ce7 100644 --- a/tests/python/relax/test_codegen_cublas.py +++ b/tests/python/relax/test_codegen_cublas.py @@ -393,6 +393,47 @@ def test_matmul_fp8_multiply_offload(): tvm.testing.assert_allclose(out, ref, rtol=1e-3, atol=1e-3) +@pytest.mark.skipif(ml_dtypes is None, reason="requires ml_dtypes to be installed") +@pytest.mark.parametrize( + "x_shape, y_shape, transpose_y, out_dtype", + [ + ((10, 32), (64, 32), True, "float32"), + ((32, 16), (32, 16), True, "float32"), + ((2, 10, 32), (2, 64, 32), True, "float32"), + ], +) +def test_matmul_bfloat16_offload( + x_shape, + y_shape, + transpose_y, + out_dtype, +): + in_dtype = "bfloat16" + mod = get_relax_matmul_module( + x_shape, + y_shape, + in_dtype, + out_dtype, + bias_shape=None, + transposed_y=transpose_y, + activation=None, + ) + # Generate input data in float32 and then convert to bfloat16 using ml_dtypes. + x_float32 = np.random.uniform(low=0, high=5, size=x_shape).astype("float32") + y_float32 = np.random.uniform(low=0, high=5, size=y_shape).astype("float32") + x_bf16 = ml_dtypes.bfloat16(x_float32) + y_bf16 = ml_dtypes.bfloat16(y_float32) + + # For the reference result, adjust y (if needed) in float32. + z = np.swapaxes(y_float32, -2, -1) if transpose_y else y_float32 + args = (x_bf16, y_bf16) + + out = get_result_with_relax_cublas_offload(mod, args) + ref_out = np.matmul(x_float32, z).astype(out_dtype) + + tvm.testing.assert_allclose(out, ref_out, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize( "M, N, K, out_dtype, transposed_y, partition_done", [