Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] no suitable user-defined conversion from "int2" to "longlong2" exists #13379

Closed
comaniac opened this issue Nov 14, 2022 · 0 comments · Fixed by #13382
Closed

[Bug] no suitable user-defined conversion from "int2" to "longlong2" exists #13379

comaniac opened this issue Nov 14, 2022 · 0 comments · Fixed by #13382
Assignees
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@comaniac
Copy link
Contributor

PR #13317 introduces a bug when generating CUDA kernel for ops with index (e.g., gather).
With some investigations, it depends on the input shape. Taking the unit tests shown below, error only happens when the last dimension is an even number (e.g., 2).

cc @vinx13 @junrushao

Expected behavior

No error.

Actual behavior

Testing shape  (4, 7, 5)
data type float16
idx type int64
Running on target: llvm
Running on target: cuda
Running on target: nvptx
Successed
Testing shape  (4, 2, 3)
data type float16
idx type int64
Running on target: llvm
Running on target: cuda
Running on target: nvptx
Successed
Testing shape  (4, 7, 3)
data type float16
idx type int64
Running on target: llvm
Running on target: cuda
Running on target: nvptx
Successed
Testing shape  (4, 7, 2)
data type float16
idx type int64
Running on target: llvm
Running on target: cuda
Failed: Traceback (most recent call last):
  23: TVMFuncCall
        at /home/ubuntu/raf/3rdparty/tvm/src/runtime/c_runtime_api.cc:477
  22: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1217
  21: Call
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1213
  20: operator()
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1731
  19: unpack_call<tvm::runtime::Module, 2, tvm::<lambda(const tvm::runtime::Map<tvm::Target, tvm::IRModule>&, tvm::Target)> >
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1671
  18: run<>
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1631
  17: run<tvm::runtime::TVMMovableArgValueWithContext_>
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1631
  16: run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1646
  15: operator()
        at /home/ubuntu/raf/3rdparty/tvm/src/driver/driver_api.cc:501
  14: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
        at /home/ubuntu/raf/3rdparty/tvm/src/driver/driver_api.cc:483
  13: tvm::codegen::Build(tvm::IRModule, tvm::Target)
        at /home/ubuntu/raf/3rdparty/tvm/src/target/codegen.cc:59
  12: tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<tvm::IRModule&, tvm::Target&>(tvm::IRModule&, tvm::Target&) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1618
  11: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1217
  10: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1213
  9: tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1731
  8: void tvm::runtime::detail::unpack_call<tvm::runtime::Module, 2, tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, tvm::runtime::Module (* const&)(tvm::IRModule, tvm::Target), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1671
  7: void tvm::runtime::detail::unpack_call_dispatcher<tvm::runtime::Module, 2, 0, tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>::run<>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > (*)(), tvm::runtime::Module (* const&)(tvm::IRModule, tvm::Target), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1631
  6: void tvm::runtime::detail::unpack_call_dispatcher<tvm::runtime::Module, 1, 1, tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>::run<tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > (*)(), tvm::runtime::Module (* const&)(tvm::IRModule, tvm::Target), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMMovableArgValueWithContext_&&)
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1631
  5: void tvm::runtime::detail::unpack_call_dispatcher<tvm::runtime::Module, 0, 2, tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>::run<tvm::runtime::TVMMovableArgValueWithContext_, tvm::runtime::TVMMovableArgValueWithContext_>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > (*)(), tvm::runtime::Module (* const&)(tvm::IRModule, tvm::Target), tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*, tvm::runtime::TVMMovableArgValueWithContext_&&, tvm::runtime::TVMMovableArgValueWithContext_&&)
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1646
  4: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
        at /home/ubuntu/raf/3rdparty/tvm/src/target/opt/build_cuda_on.cc:153
  3: tvm::runtime::TVMRetValue tvm::runtime::PackedFunc::operator()<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1618
  2: tvm::runtime::PackedFuncObj::CallPacked(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) const
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1217
  1: Call
        at /home/ubuntu/raf/3rdparty/tvm/include/tvm/runtime/packed_func.h:1213
  0: operator()
        at /home/ubuntu/raf/3rdparty/tvm/src/runtime/c_runtime_api.cc:534
  File "/home/ubuntu/raf/3rdparty/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/ubuntu/raf/3rdparty/tvm/python/tvm/contrib/nvcc.py", line 189, in tvm_callback_cuda_compile
    ptx = compile_cuda(code, target_format="fatbin")
  File "/home/ubuntu/raf/3rdparty/tvm/python/tvm/contrib/nvcc.py", line 113, in compile_cuda
    raise RuntimeError(msg)
RuntimeError: #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
#include <cuda_fp16.h>
__device__ half max(half a, half b)
{
  return __hgt(__half(a), __half(b)) ? a : b;
}
__device__ half min(half a, half b)
{
  return __hlt(__half(a), __half(b)) ? a : b;
}
#else

typedef unsigned short uint16_t;
typedef unsigned char uint8_t;
typedef signed char int8_t;
typedef int int32_t;
typedef unsigned long long uint64_t;
typedef unsigned int uint32_t;

#define TVM_FORCE_INLINE inline __attribute__((always_inline))
#define TVM_XINLINE TVM_FORCE_INLINE __device__ __host__
#define TVM_ALIGNED(x) __attribute__ ((aligned(x)))
#define TVM_HALF_OPERATOR(RTYPE, OP)                              \
  TVM_XINLINE RTYPE operator OP (half a, half b) {                \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (half a, T b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE RTYPE operator OP (T a, half b) {                   \
    return RTYPE(float(a) OP float(b));                           \
  }

#define TVM_HALF_ASSIGNOP(AOP, OP)                                \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const T& a) {                    \
    return *this = half(float(*this) OP float(a));                \
  }                                                               \
  template<typename T>                                            \
  TVM_XINLINE half operator AOP (const volatile T& a) volatile {  \
    return *this = half(float(*this) OP float(a));                \
  }

class TVM_ALIGNED(2) half {
 public:
  uint16_t half_;

  static TVM_XINLINE half Binary(uint16_t value) {
    half res;
    res.half_ = value;
    return res;
  }

  TVM_XINLINE half() {}

  TVM_XINLINE half(const float& value) { constructor(value); }
  TVM_XINLINE explicit half(const double& value) { constructor(value); }
  TVM_XINLINE explicit half(const int8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint8_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const int32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint32_t& value) { constructor(value); }
  TVM_XINLINE explicit half(const long long& value) { constructor(value); }
  TVM_XINLINE explicit half(const uint64_t& value) { constructor(value); }

  TVM_XINLINE operator float() const {                          \
    return float(half2float(half_));                            \
  }                                                             \
  TVM_XINLINE operator float() const volatile {                 \
    return float(half2float(half_));                            \
  }


  TVM_HALF_ASSIGNOP(+=, +)
  TVM_HALF_ASSIGNOP(-=, -)
  TVM_HALF_ASSIGNOP(*=, *)
  TVM_HALF_ASSIGNOP(/=, /)

  TVM_XINLINE half operator+() {
    return *this;
  }

  TVM_XINLINE half operator-() {
    return half(-float(*this));
  }

  TVM_XINLINE half operator=(const half& a) {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) {
    return *this = half(a);
  }

  TVM_XINLINE half operator=(const half& a) volatile {
    half_ = a.half_;
    return a;
  }

  template<typename T>
  TVM_XINLINE half operator=(const T& a) volatile {
    return *this = half(a);
  }

 private:
  union Bits {
    float f;
    int32_t si;
    uint32_t ui;
  };

  static int const fp16FractionBits = 10;
  static int const fp32FractionBits = 23;
  static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);   // == 0x7fffff
  static int32_t const fp32HiddenBit = 1 << fp32FractionBits;   // == 0x800000
  static int const shift = fp32FractionBits - fp16FractionBits;   // == 13
  static int const shiftSign = 16;
  static int32_t const expAdjust = 127 - 15;   // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)

  static int32_t const infN = 0x7F800000;   // flt32 infinity
  static int32_t const maxN = 0x477FFFFF;   // max flt32 that's a flt16 normal after >> by shift
  static int32_t const minN = 0x38800000;   // min flt16 normal as a flt32
  static int32_t const maxZ = 0x33000000;   // max fp32 number that's still rounded to zero in fp16
  static int32_t const signN = 0x80000000;  // flt32 sign bit

  static int32_t const infC = infN >> shift;
  static int32_t const nanN = (infC + 1) << shift;   // minimum flt16 nan as a flt32
  static int32_t const maxC = maxN >> shift;
  static int32_t const minC = minN >> shift;
  static int32_t const signC = signN >> shiftSign;  // flt16 sign bit

  static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
  static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))

  static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
  static int32_t const norC = 0x00400;  // min flt32 normal down shifted

  static int32_t const maxD = infC - maxC - 1;
  static int32_t const minD = minC - subC - 1;

  TVM_XINLINE uint16_t float2half(const float& value) const {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  // Same as above routine, except for addition of volatile keyword
  TVM_XINLINE uint16_t float2half(
    const volatile float& value) const volatile {
    Bits v;
    v.f = value;
    uint32_t sign = v.si & signN;    // grab sign bit
    v.si ^= sign;                    // clear sign bit from v
    sign >>= shiftSign;              // logical shift sign to fp16 position

    if (v.si <= maxZ) {
      // Handle eventual zeros here to ensure
      // vshift will not exceed 32 below.
      v.ui = 0;
    } else if (v.si < minN) {
      // Handle denorms
      uint32_t exp32 = v.ui >> fp32FractionBits;
      int32_t exp16 = exp32 - expAdjust;
      // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
      // Smaller (so negative) exp16 values should result in greater right shifts.
      uint32_t vshift = 1 - exp16;
      uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
      v.ui = significand >> vshift;
      v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
    } else if (v.si <= maxN) {
      // Handle norms
      v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
      v.ui -= expAdjust << fp32FractionBits;
    } else if (v.si <= infN) {
      v.si = infN;
    } else if (v.si < nanN) {
      v.si = nanN;
    }

    v.ui >>= shift;
    return sign | (v.ui & 0x7fff);
  }

  TVM_XINLINE float half2float(const uint16_t& value) const {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  TVM_XINLINE float half2float(
    const volatile uint16_t& value) const volatile {
    Bits v;
    v.ui = value;
    int32_t sign = v.si & signC;
    v.si ^= sign;
    sign <<= shiftSign;
    v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
    v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
    Bits s;
    s.si = mulC;
    s.f *= v.si;
    int32_t mask = -(norC > v.si);
    v.si <<= shift;
    v.si ^= (s.si ^ v.si) & mask;
    v.si |= sign;
    return v.f;
  }

  template<typename T>
  TVM_XINLINE void constructor(const T& value) {
    half_ = float2half(float(value));
  }
};

TVM_HALF_OPERATOR(half, +)
TVM_HALF_OPERATOR(half, -)
TVM_HALF_OPERATOR(half, *)
TVM_HALF_OPERATOR(half, /)
TVM_HALF_OPERATOR(bool, >)
TVM_HALF_OPERATOR(bool, <)
TVM_HALF_OPERATOR(bool, >=)
TVM_HALF_OPERATOR(bool, <=)

TVM_XINLINE half __float2half_rn(const float a) {
  return half(a);
}
#endif


// Pack two half values.
static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
  unsigned v0 = *((unsigned short *)&x);
  unsigned v1 = *((unsigned short *)&y);
  return (v1 << 16) | v0;
}

// Some fp16 math functions are not supported in cuda_fp16.h,
// so we define them here to make sure the generated CUDA code
// is valid.
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
#define CUDA_UNSUPPORTED_HALF_MATH_BINARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x, half y) {   \
  float tmp_x = __half2float(x);                                          \
  float tmp_y = __half2float(y);                                          \
  float result = FP32_MATH_NAME(tmp_x, tmp_y);                            \
  return __float2half(result);                                            \
}

#define CUDA_UNSUPPORTED_HALF_MATH_UNARY(HALF_MATH_NAME, FP32_MATH_NAME) \
static inline __device__ __host__ half HALF_MATH_NAME(half x) {          \
  float tmp_x = __half2float(x);                                         \
  float result = FP32_MATH_NAME(tmp_x);                                  \
  return __float2half(result);                                           \
}

CUDA_UNSUPPORTED_HALF_MATH_BINARY(hpow, powf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htanh, tanhf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(htan, tanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(hatan, atanf)
CUDA_UNSUPPORTED_HALF_MATH_UNARY(herf, erf)

#undef CUDA_UNSUPPORTED_HALF_MATH_BINARY
#undef CUDA_UNSUPPORTED_HALF_MATH_UNARY

#endif

#ifdef _WIN32
  using uint = unsigned int;
  using uchar = unsigned char;
  using ushort = unsigned short;
  using int64_t = long long;
  using uint64_t = unsigned long long;
#else
  #define uint unsigned int
  #define uchar unsigned char
  #define ushort unsigned short
  #define int64_t long long
  #define uint64_t unsigned long long
#endif
extern "C" __global__ void __launch_bounds__(56) gather_kernel0(half* __restrict__ T_gather, half* __restrict__ data, int64_t* __restrict__ indices) {
  if (((int)threadIdx.x) < 28) {
    longlong2 __1;
      longlong2 __2;
        longlong2 __3;
          longlong2 __4 = *(longlong2*)(indices + (((int)threadIdx.x) * 2));
          longlong2 __5 = make_longlong2((int64_t)14, (int64_t)14);
          __3.x = (__4.x*__5.x);
          __3.y = (__4.y*__5.y);
        longlong2 __6 = make_longlong2(((((int64_t)((int)threadIdx.x)) % (int64_t)7) * (int64_t)2), ((((int64_t)((int)threadIdx.x)) % (int64_t)7) * (int64_t)2));
        __2.x = (__3.x+__6.x);
        __2.y = (__3.y+__6.y);
      longlong2 __7 = make_int2(((int64_t)0)+((int64_t)1*0), ((int64_t)0)+((int64_t)1*1));
      __1.x = (__2.x+__7.x);
      __1.y = (__2.y+__7.y);
    *(uint1*)(T_gather + (((int)threadIdx.x) * 2)) = make_uint1(__pack_half2(data[__1.x],data[__1.y]));
  }
}


Compilation error:
/tmp/tmpmqvrsimv/my_kernel.cu(337): error: no suitable user-defined conversion from "int2" to "longlong2" exists

/tmp/tmpmqvrsimv/my_kernel.cu(301): warning: function "hpow" was declared but never referenced

/tmp/tmpmqvrsimv/my_kernel.cu(302): warning: function "htanh" was declared but never referenced

/tmp/tmpmqvrsimv/my_kernel.cu(303): warning: function "htan" was declared but never referenced

/tmp/tmpmqvrsimv/my_kernel.cu(304): warning: function "hatan" was declared but never referenced

/tmp/tmpmqvrsimv/my_kernel.cu(305): warning: function "herf" was declared but never referenced

1 error detected in the compilation of "/tmp/tmpmqvrsimv/my_kernel.cu".

Environment

CUDA 11.3 on NVIDIA T4.

Steps to reproduce

  1. Checkout commit 244bceb (from [TIR] Allow folding cast with broadcast and ramp #13317) or later.
  2. Run the following script:
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
from tvm import relay
import tvm.topi.testing
from tvm.contrib.nvcc import have_fp16

import tvm.testing


# From https://github.com/apache/tvm/blob/3224817d0835909c2673184a6c20bac3b7672632/tests/python/topi/python/test_topi_transform.py
def verify_gather(data, axis, indices):
    data = np.asarray(data)
    indices = np.asarray(indices)

    print(f"data type {data.dtype.name}")
    print(f"idx type {indices.dtype.name}")
    var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data")
    var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices")
    out_tensor = topi.gather(var_data, axis, var_indices)

    def check_device(target, dev):
        print("Running on target: %s" % target)
        with tvm.target.Target(target):
            s = tvm.topi.testing.get_injective_schedule(target)(out_tensor)

        func = tvm.build(s, [var_data, var_indices, out_tensor], target, name="gather")
        out_npys = tvm.topi.testing.gather_python(data, axis, indices)

        data_nd = tvm.nd.array(data, dev)
        indices_nd = tvm.nd.array(indices, dev)
        out_nd = tvm.nd.empty(out_npys.shape, device=dev, dtype=data.dtype.name)
        func(data_nd, indices_nd, out_nd)
        tvm.testing.assert_allclose(out_nd.numpy(), out_npys)

    for target, dev in tvm.testing.enabled_targets():
        check_device(target, dev)


def test():
    axis = 0
    for shape in [(4, 7, 5), (4, 2, 3), (4, 7, 3), (4, 7, 2)]:
        print(f"Testing shape  {shape}")
        try:
            data = np.random.randn(*shape).astype("float16")
            idx = np.random.randint(size=shape, low=0, high=shape[0])
            verify_gather(data, axis, idx)
            print("Successed")
        except Exception as err:
            print(f"Failed: {err}")


test()

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage
@comaniac comaniac added type: bug needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Nov 14, 2022
@vinx13 vinx13 self-assigned this Nov 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants