Skip to content

[Bug] Precision issue when working with floating point constants (FloatImm) #17276

@SerodioJ

Description

@SerodioJ

I have been working with TVM+Ansor (auto-scheduler) to generate code for a set of operators for both CPU (LLVM backend) and GPU (CUDA backend).
The operators use trigonometric functions in some steps, and I set the value of PI with pi_const = te.const(np.pi, X.dtype).
One thing I noticed was that CPU and GPU results were diverging.

I started to check what could be the source of this issue in my code and I found out that COS and SIN were yielding different values, which led me to believe it was a problem in the scheduling or code generation steps.

To check if the schedule exploration with Ansor was in some way causing this, I tested similar operators with AutoTVM, and the same problem was evident.

The only thing left to check was the code generation pipeline so I started to check the codegen source code and found out what I believe is the root cause for this behavior.
When generating CUDA code, FloatImm are treated as shown in the code snipped extracted from codegen_cuda.cc

// Type code is kFloat
  switch (op->dtype.bits()) {
    case 64:
    case 32: {
      std::ostringstream temp;
      if (std::isinf(op->value)) {
        if (op->value < 0) {
          temp << "-";
        }
        temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF");
        p->need_math_constants_h_ = true;
      } else if (std::isnan(op->value)) {
        temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN");
        p->need_math_constants_h_ = true;
      } else {
        temp << std::scientific << op->value;
        if (op->dtype.bits() == 32) temp << 'f';
      }
      p->MarkConst(temp.str());
      os << temp.str();
      break;
    }

So Float32 and Float64 (double) are being treated the same and when generating the source code a value such as 3.141592653589793 is being reduced to 3.141593e+00. And this precision loss due to string conversion when generating the CUDA source code leads to the problem I am having. I tried changing the case 64 rule to have something like

temp << std::fixed << std::setprecision(15) << op->value;

and the results start to converge.

I believe this issue also happens with other backends such as C, but when using the LLVM backend there is no issue.

Expected behavior

Close COS and SIN values for CPU (LLVM) and GPU (CUDA).

Actual behavior

Divergent values - only CPU results match results obtained with the NumPy ground truth.
The output below can be obtained using the code listed in Steps to reproduce

####################################################### SIN #######################################################
LLVM - Ground Truth: MAX = 2.220446049250313e-16 MEAN = 6.754579534584693e-17 STD = 5.829086717290594e-17
CUDA - Ground Truth: MAX = 3.405868034614401e-07 MEAN = 1.0753732889140066e-07 STD = 1.0267789908758974e-07
LLVM - CUDA        : MAX = 3.405868034545012e-07 MEAN = 1.0753732889064171e-07 STD = 1.0267789909304487e-07
SIN of PI/2: LLVM = 1.0 | CUDA = 0.999999999999985 | Ground Truth = 1.0
####################################################### COS #######################################################
LLVM - Ground Truth: MAX = 2.220446049250313e-16 MEAN = 6.754579534584693e-17 STD = 6.676141599144182e-17
CUDA - Ground Truth: MAX = 2.0061242439473048e-07 MEAN = 1.102436397725252e-07 STD = 6.933040331236782e-08
LLVM - CUDA        : MAX = 2.0061242439473048e-07 MEAN = 1.1024363976526104e-07 STD = 6.933040330991245e-08
COS of PI/2: LLVM = 6.123233995736766e-17 | CUDA = -1.7320510330969933e-07 | Ground Truth = 6.123233995736766e-17

The main issue here for my use case is the COS of PI/2 which is resulting in a negative number. This value matches np.cos(3.141593/2) which is the value to which the float constant is being rounded to when printing it in scientific notation (3.141593e+00).

Steps to reproduce

Here is a couple of simplified modules that reproduce this issue

import tvm
import numpy as np
from tvm.script import ir as I
from tvm.script import tir as T

# Modules definition
@I.ir_module
class CPU_Module:
    @T.prim_func
    def main(X: T.Buffer((64,), "float64"), sin: T.Buffer((64,), "float64"), cos: T.Buffer((64,), "float64")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        X_1 = T.Buffer((64,), "float64", data=X.data)
        for i in T.parallel(64):
            sin_1 = T.Buffer((64,), "float64", data=sin.data)
            sin_1[i] = T.sin(X_1[i] * T.float64(np.pi) * T.float64(0.015625))
            cos_1 = T.Buffer((64,), "float64", data=cos.data)
            cos_1[i] = T.cos(X_1[i] * T.float64(np.pi) * T.float64(0.015625))
            
@I.ir_module
class GPU_Module:
    @T.prim_func
    def main(X: T.Buffer((64,), "float64"), sin: T.Buffer((64,), "float64"), cos: T.Buffer((64,), "float64")):
        T.func_attr({"from_legacy_te_schedule": T.bool(True), "tir.noalias": T.bool(True)})
        with T.launch_thread("blockIdx.x", 2) as blockIdx_x:
            threadIdx_x = T.launch_thread("threadIdx.x", 32)
            X_1 = T.Buffer((64,), "float64", data=X.data)
            sin_1 = T.Buffer((64,), "float64", data=sin.data)
            sin_1[blockIdx_x*32 + threadIdx_x] = T.sin(X_1[blockIdx_x*32 + threadIdx_x] * T.float64(np.pi) * T.float64(0.015625))
            cos_1 = T.Buffer((64,), "float64", data=cos.data)
            cos_1[blockIdx_x*32 + threadIdx_x] = T.cos(X_1[blockIdx_x*32 + threadIdx_x] * T.float64(np.pi) * T.float64(0.015625))

# Module build
llvm_module = tvm.build(CPU_Module, target='llvm')
cuda_module = tvm.build(GPU_Module, target='cuda')

# Input/Output definition
data = np.arange(64).astype("float64")

cpu = tvm.cpu()
gpu = tvm.cuda()

cpu_input = tvm.nd.array(data, device=cpu)
gpu_input = tvm.nd.array(data, device=gpu)

llvm_output = [tvm.nd.empty(data.shape, dtype=data.dtype, device=cpu) for _ in range(2)]
cuda_output = [tvm.nd.empty(data.shape, dtype=data.dtype, device=gpu) for _ in range(2)]

# Modules execution
llvm_module(cpu_input, *llvm_output)
cuda_module(gpu_input, *cuda_output)

# Comparison
expected = [np.sin((data*np.pi)/64), np.cos((data*np.pi)/64)]
for i, op in enumerate(["SIN", "COS"]):
    llvm = llvm_output[i].numpy()
    cuda = cuda_output[i].numpy()
    gt = expected[i]
    llvm_gt = np.abs(llvm - gt)
    cuda_gt = np.abs(cuda - gt)
    llvm_cuda = np.abs(llvm - cuda)
    print(f"{55*'#'} {op} {55*'#'}")
    print(f"LLVM - Ground Truth: MAX = {llvm_gt.max()} MEAN = {llvm_gt.mean()} STD = {llvm_gt.std()}")
    print(f"CUDA - Ground Truth: MAX = {cuda_gt.max()} MEAN = {cuda_gt.mean()} STD = {cuda_gt.std()}")
    print(f"LLVM - CUDA        : MAX = {llvm_cuda.max()} MEAN = {llvm_cuda.mean()} STD = {llvm_cuda.std()}")
    print(f"{op} of PI/2: LLVM = {llvm[32]} | CUDA = {cuda[32]} | Ground Truth = {gt[32]}")

Also the CUDA code produced is listed below, which shows that 3.141592653589793 (np.pi) is being changed to 3.141593e+00.

#if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \
     (__CUDACC_VER_MAJOR__ > 11))
#define TVM_ENABLE_L2_PREFETCH 1
#else
#define TVM_ENABLE_L2_PREFETCH 0
#endif

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(32) main_kernel(double* __restrict__ X_1, double* __restrict__ cos_1, double* __restrict__ sin_1) {
  sin_1[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = sin(((X_1[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] * 3.141593e+00) * 1.562500e-02));
  cos_1[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] = cos(((X_1[((((int)blockIdx.x) * 32) + ((int)threadIdx.x))] * 3.141593e+00) * 1.562500e-02));
}

Triage

  • needs-triage
  • backend: c
  • backend: cuda

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions