Skip to content

[Bug] Looks like T.tanh is usable but, T.cos and T.sin were not going to pass the compile. #862

@aaababaaz

Description

@aaababaaz
Source Code
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T

@tilelang.jit(out_idx=[-1])
def swiglu_trig_kernel(
    M, N, K,
    block_M, block_N, block_K,
    dtype="float16",
    accum_dtype="float"
):
    """
    TileLang kernel for a SwiGLU-like activation with a trigonometric transformation.
    Computes: tanh(X @ W) * cat(cos(up_part1), sin(up_part2))
    where up = X @ V
    """
    @T.prim_func
    def swiglu_trig(
        X: T.Tensor((M, K), dtype),
        W: T.Tensor((K, N), dtype),
        V: T.Tensor((K, N), dtype),
        Out: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            # 1. Memory Allocation
            X_shared = T.alloc_shared((block_M, block_K), dtype)
            W_shared = T.alloc_shared((block_K, block_N), dtype)
            V_shared = T.alloc_shared((block_K, block_N), dtype)

            # Accumulators in float32
            gate_local_f32 = T.alloc_fragment((block_M, block_N), accum_dtype)
            up_local_f32 = T.alloc_fragment((block_M, block_N), accum_dtype)

            # 2. Clear Accumulators
            T.clear(gate_local_f32)
            T.clear(up_local_f32)

            # 3. Pipelined GEMM Computation
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
                T.copy(X[by * block_M, k * block_K], X_shared)
                T.copy(W[k * block_K, bx * block_N], W_shared)
                T.copy(V[k * block_K, bx * block_N], V_shared)

                T.gemm(X_shared, W_shared, gate_local_f32)
                T.gemm(X_shared, V_shared, up_local_f32)

            # 4. Fused Activation Function with Trigonometric Transform
            # Explicitly cast intermediate results to float16
            gate_local_f16 = T.alloc_fragment((block_M, block_N), dtype)
            up_local_f16 = T.alloc_fragment((block_M, block_N), dtype)

            T.copy(gate_local_f32, gate_local_f16)
            T.copy(up_local_f32, up_local_f16)

            for i, j in T.Parallel(block_M, block_N):
                # Calculate tanh for the gate part
                gate_tanh = T.tanh(gate_local_f16[i, j])

                # --- NEW LOGIC START ---
                # Apply cos to the first half and sin to the second half of the 'up' projection
                transformed_up_val = T.if_then_else(
                    j < block_N // 2,
                    T.cos(up_local_f16[i, j]),
                    T.sin(up_local_f16[i, j])
                )
                # --- NEW LOGIC END ---

                # Final element-wise multiplication
                gate_local_f16[i, j] = gate_tanh * transformed_up_val

            # 5. Store result back to global memory
            T.copy(gate_local_f16, Out[by * block_M, bx * block_N])

    return swiglu_trig

def main():
    M, K, N = 4096, 2048, 1024
    dtype = torch.float16
    device = "cuda"

    X = torch.randn(M, K, dtype=dtype, device=device)
    W = torch.randn(K, N, dtype=dtype, device=device)
    V = torch.randn(K, N, dtype=dtype, device=device)

    block_M = 64
    block_N = 64
    block_K = 32

    kernel = swiglu_trig_kernel(M, N, K, block_M, block_N, block_K, dtype="float16")

    output_tilelang = kernel(X, W, V)

    # --- Verification using the UPDATED PyTorch implementation ---
    gate_out = torch.tanh(X @ W)
    up_out = X @ V
    
    # Apply the new trigonometric transformation logic
    up_cos, up_sin = up_out.chunk(2, dim=-1)
    up_transformed = torch.cat((torch.cos(up_cos), torch.sin(up_sin)), dim=-1)
    
    output_ref = gate_out * up_transformed

    print("Verifying correctness for tanh(X@W) * cat(cos(up_1), sin(up_2))...")
    # Using a realistic tolerance for float16 operations
    torch.testing.assert_close(output_tilelang, output_ref, rtol=1e-2, atol=1e-2)
    print("✅ Correctness check passed!")

    print("\nProfiling performance...")
    profiler = kernel.get_profiler()

    def torch_ref_func():
        gate = torch.tanh(X @ W)
        up = X @ V
        up_cos, up_sin = up.chunk(2, dim=-1)
        up_transformed = torch.cat((torch.cos(up_cos), torch.sin(up_sin)), dim=-1)
        return gate * up_transformed

    torch_latency = profiler.do_bench(torch_ref_func)
    print(f"PyTorch (non-fused) Latency: {torch_latency:.4f} ms")

    tilelang_latency = profiler.do_bench()
    print(f"TileLang (fused) Latency:   {tilelang_latency:.4f} ms")

    speedup = torch_latency / tilelang_latency
    print(f"\nSpeedup: {speedup:.2f}x")

if __name__ == "__main__":
    main()
Error Log
(venv) ➜  /tmp python3 tilelang_swiglu.py 
	/tmp/tmpzigyi78y.cu(127): error: no instance of overloaded function "hcos" matches the argument list
            argument types are: (cutlass::half_t)
        condval = hcos(up_local_f16[i_16]);
                  ^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(3127): note #3326-D: function "hcos(__nv_bfloat16)" does not match because argument #1 does not match parameter
  __inline__ __nv_bfloat16 hcos(const __nv_bfloat16 a) {
                           ^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(2835): note #3326-D: function "hcos(__half)" does not match because argument #1 does not match parameter
  __inline__ __half hcos(const __half a) {
                    ^

/tmp/tmpzigyi78y.cu(129): error: no instance of overloaded function "hsin" matches the argument list
            argument types are: (cutlass::half_t)
        condval = hsin(up_local_f16[i_16]);
                  ^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(3114): note #3326-D: function "hsin(__nv_bfloat16)" does not match because argument #1 does not match parameter
  __inline__ __nv_bfloat16 hsin(const __nv_bfloat16 a) {
                           ^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(2800): note #3326-D: function "hsin(__half)" does not match because argument #1 does not match parameter
  __inline__ __half hsin(const __half a) {
                    ^

2 errors detected in the compilation of "/tmp/tmpzigyi78y.cu".
Traceback (most recent call last):
  File "/tmp/tilelang_swiglu.py", line 129, in <module>
    main()
  File "/tmp/tilelang_swiglu.py", line 90, in main
    kernel = swiglu_trig_kernel(M, N, K, block_M, block_N, block_K, dtype="float16")
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 201, in wrapper
    kernel_result = compile(
                    ^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 72, in compile
    return cached(
           ^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/cache/__init__.py", line 30, in cached
    return _kernel_cache_instance.cached(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/cache/kernel_cache.py", line 167, in cached
    kernel = JITKernel(
             ^^^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 117, in __init__
    adapter = self._compile_and_create_adapter(func, out_idx)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 235, in _compile_and_create_adapter
    adapter = CythonKernelAdapter(
              ^^^^^^^^^^^^^^^^^^^^
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/adapter/cython/adapter.py", line 257, in __init__
    self.lib_generator.compile_lib()
  File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/adapter/libgen.py", line 110, in compile_lib
    raise RuntimeError(f"Compilation Failed! {command}"
RuntimeError: Compilation Failed! ['/opt/cuda/bin/nvcc', '-std=c++17', '-w', '-Xcudafe', '--diag_suppress=177', '--compiler-options', "'-fPIC'", '-lineinfo', '--shared', '/tmp/tmpzigyi78y.cu', '-lcuda', '-gencode', 'arch=compute_89,code=sm_89', '-I/home/@root/venv/lib/python3.11/site-packages/tilelang/3rdparty/cutlass/include', '-I/home/@root/venv/lib/python3.11/site-packages/tilelang/src', '-o', '/tmp/tmpzigyi78y.so']
 #include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>

extern "C" __global__ void swiglu_trig_kernel(half_t* __restrict__ Out, half_t* __restrict__ V, half_t* __restrict__ W, half_t* __restrict__ X);
extern "C" __global__ void __launch_bounds__(128, 1) swiglu_trig_kernel(half_t* __restrict__ Out, half_t* __restrict__ V, half_t* __restrict__ W, half_t* __restrict__ X) {
  extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
  float gate_local_f32[32];
  float up_local_f32[32];
  half_t gate_local_f16[32];
  half_t up_local_f16[32];
  #pragma unroll
  for (int i = 0; i < 16; ++i) {
    *(float2*)(gate_local_f32 + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  #pragma unroll
  for (int i_1 = 0; i_1 < 16; ++i_1) {
    *(float2*)(up_local_f32 + (i_1 * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
  }
  #pragma unroll
  for (int i_2 = 0; i_2 < 2; ++i_2) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_2 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), W+((((i_2 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_3 = 0; i_3 < 2; ++i_3) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), X+((((((int)blockIdx.y) * 131072) + (i_3 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)));
  }
  #pragma unroll
  for (int i_4 = 0; i_4 < 2; ++i_4) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_4 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), V+((((i_4 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_5 = 0; i_5 < 2; ++i_5) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_5 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 4096), W+(((((i_5 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 32768));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_6 = 0; i_6 < 2; ++i_6) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_6 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), X+(((((((int)blockIdx.y) * 131072) + (i_6 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)) + 32));
  }
  #pragma unroll
  for (int i_7 = 0; i_7 < 2; ++i_7) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_7 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 28672), V+(((((i_7 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 32768));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_8 = 0; i_8 < 2; ++i_8) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_8 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192), W+(((((i_8 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 65536));
  }
  tl::cp_async_commit();
  #pragma unroll
  for (int i_9 = 0; i_9 < 2; ++i_9) {
    tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_9 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 20480), X+(((((((int)blockIdx.y) * 131072) + (i_9 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)) + 64));
  }
  #pragma unroll
  for (int i_10 = 0; i_10 < 2; ++i_10) {
    tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_10 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 32768), V+(((((i_10 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 65536));
  }
  tl::cp_async_commit();
  for (int k = 0; k < 61; ++k) {
    tl::cp_async_wait<4>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 6144)])), (&(((half_t*)buf_dyn_shmem)[((k % 3) * 2048)])), (&(gate_local_f32[0])));
    __syncthreads();
    #pragma unroll
    for (int i_11 = 0; i_11 < 2; ++i_11) {
      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k % 3) * 4096) + (i_11 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), W+((((((k * 32768) + (i_11 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 98304));
    }
    tl::cp_async_commit();
    tl::cp_async_wait<5>();
    __syncthreads();
    tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 6144)])), (&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 12288)])), (&(up_local_f32[0])));
    __syncthreads();
    #pragma unroll
    for (int i_12 = 0; i_12 < 2; ++i_12) {
      tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k % 3) * 4096) + (i_12 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), X+((((((((int)blockIdx.y) * 131072) + (i_12 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 96));
    }
    #pragma unroll
    for (int i_13 = 0; i_13 < 2; ++i_13) {
      tl::cp_async_gs<16>(buf_dyn_shmem+((((((((k % 3) * 4096) + (i_13 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), V+((((((k * 32768) + (i_13 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 98304));
    }
    tl::cp_async_commit();
  }
  tl::cp_async_wait<4>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[8192])), (&(((half_t*)buf_dyn_shmem)[2048])), (&(gate_local_f32[0])));
  tl::cp_async_wait<4>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[8192])), (&(((half_t*)buf_dyn_shmem)[14336])), (&(up_local_f32[0])));
  tl::cp_async_wait<2>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[10240])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(gate_local_f32[0])));
  tl::cp_async_wait<2>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[10240])), (&(((half_t*)buf_dyn_shmem)[16384])), (&(up_local_f32[0])));
  tl::cp_async_wait<0>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[6144])), (&(((half_t*)buf_dyn_shmem)[0])), (&(gate_local_f32[0])));
  tl::cp_async_wait<0>();
  __syncthreads();
  tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[6144])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(up_local_f32[0])));
  #pragma unroll
  for (int i_14 = 0; i_14 < 16; ++i_14) {
    uint1 __1;
    float2 v_ = *(float2*)(gate_local_f32 + (i_14 * 2));
    ((half2*)(&(__1.x)))->x = (half_t)(v_.x);
    ((half2*)(&(__1.x)))->y = (half_t)(v_.y);
    *(uint1*)(gate_local_f16 + (i_14 * 2)) = __1;
  }
  #pragma unroll
  for (int i_15 = 0; i_15 < 16; ++i_15) {
    uint1 __2;
    float2 v__1 = *(float2*)(up_local_f32 + (i_15 * 2));
    ((half2*)(&(__2.x)))->x = (half_t)(v__1.x);
    ((half2*)(&(__2.x)))->y = (half_t)(v__1.y);
    *(uint1*)(up_local_f16 + (i_15 * 2)) = __2;
  }
  #pragma unroll
  for (int i_16 = 0; i_16 < 32; ++i_16) {
    half_t condval;
    if ((i_16 < 16)) {
      condval = hcos(up_local_f16[i_16]);
    } else {
      condval = hsin(up_local_f16[i_16]);
    }
    gate_local_f16[i_16] = (htanh(gate_local_f16[i_16]) * condval);
  }
  #pragma unroll
  for (int i_17 = 0; i_17 < 16; ++i_17) {
    *(uint1*)(Out + (((((((((((int)blockIdx.y) * 65536) + (((i_17 & 3) >> 1) * 32768)) + (((((int)threadIdx.x) & 63) >> 5) * 16384)) + ((i_17 & 1) * 8192)) + (((((int)threadIdx.x) & 31) >> 2) * 1024)) + (((int)blockIdx.x) * 64)) + ((i_17 >> 2) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(uint1*)(gate_local_f16 + (i_17 * 2));
  }
}


#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];

extern "C" const char* get_last_error() {
    return error_buf;
}

extern "C" int init() {
    error_buf[0] = '\0';
    
    cudaError_t result_swiglu_trig_kernel = cudaFuncSetAttribute(swiglu_trig_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 36864);
    if (result_swiglu_trig_kernel != CUDA_SUCCESS) {
        snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 36864, cudaGetErrorString(result_swiglu_trig_kernel));
        return -1;
    }

    return 0;
}

extern "C" int call(half_t* __restrict__ X, half_t* __restrict__ W, half_t* __restrict__ V, half_t* __restrict__ Out, cudaStream_t stream=cudaStreamDefault) {
	swiglu_trig_kernel<<<dim3(16, 64, 1), dim3(128, 1, 1), 36864, stream>>>(Out, V, W, X);
	TILELANG_CHECK_LAST_ERROR("swiglu_trig_kernel");

	return 0;
}

is this are normal behavior, that we need manually implement function cos and sin?

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions