diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4454dd319768..defc94efa28f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1615,13 +1615,17 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) // Type code is kBFloat if (op->dtype.is_bfloat16()) { os << "__float2bfloat16_rn"; - os << '(' << std::scientific << op->value << 'f' << ')'; + os << '(' << std::hexfloat << op->value << 'f'; + os << "/*" << std::scientific << op->value << "*/"; + os << ')'; return; } // Type code is kFloat8_e5m2 or kE4M4Float if (op->dtype.is_float8() || op->dtype.is_float4()) { p->PrintType(op->dtype, os); - os << '(' << std::scientific << op->value << 'f' << ')'; + os << '(' << std::hexfloat << op->value << 'f'; + os << "/*" << std::scientific << op->value << "*/"; + os << ')'; return; } // Type code is kFloat @@ -1656,7 +1660,8 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) temp << "CUDART_NAN_F"; p->need_math_constants_h_ = true; } else { - temp << std::scientific << op->value << 'f'; + temp << std::hexfloat << op->value << 'f'; + temp << "/*" << std::scientific << op->value << "*/"; } p->MarkConst(temp.str()); os << temp.str(); diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index db49f56045ad..0841d0f54562 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -801,6 +801,25 @@ def main( assert 'extern "C" __device__ float add(float a, float b) {\n return (a + b);\n}' in cuda_code +@tvm.testing.requires_cuda +def test_cuda_float_const_hex_format(): + """Test that float constants are emitted in hexadecimal format for precision""" + + @I.ir_module + class Module: + @T.prim_func + def main( + A: T.Buffer((1024, 1024), "float32"), + ): + for bx in T.thread_binding(1024, "blockIdx.x"): + for tx in T.thread_binding(1024, "threadIdx.x"): + A[bx, tx] = T.float32(1 / 27) + + lib = tvm.compile(Module, target="cuda") + cuda_code = lib.mod.imports[0].inspect_source() + assert "0x1.2f684bda12f68p-5f" in cuda_code + + @tvm.testing.requires_cuda def test_device_host_call_same_func(): @I.ir_module diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index 67598b0ba04f..aa4f5138a17f 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -264,8 +264,8 @@ def test_inject_async_copy_barrier(): extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64]; - A_shared[((int)threadIdx.x)] = 0.000000e+00f; - B_shared[((int)threadIdx.x)] = 0.000000e+00f; + A_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/; + B_shared[((int)threadIdx.x)] = 0x0p+0f/*0.000000e+00*/; __asm__ __volatile__("cp.async.commit_group;");