-
Notifications
You must be signed in to change notification settings - Fork 333
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Source Code
import torch
import torch.nn.functional as F
import tilelang
import tilelang.language as T
@tilelang.jit(out_idx=[-1])
def swiglu_trig_kernel(
M, N, K,
block_M, block_N, block_K,
dtype="float16",
accum_dtype="float"
):
"""
TileLang kernel for a SwiGLU-like activation with a trigonometric transformation.
Computes: tanh(X @ W) * cat(cos(up_part1), sin(up_part2))
where up = X @ V
"""
@T.prim_func
def swiglu_trig(
X: T.Tensor((M, K), dtype),
W: T.Tensor((K, N), dtype),
V: T.Tensor((K, N), dtype),
Out: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
# 1. Memory Allocation
X_shared = T.alloc_shared((block_M, block_K), dtype)
W_shared = T.alloc_shared((block_K, block_N), dtype)
V_shared = T.alloc_shared((block_K, block_N), dtype)
# Accumulators in float32
gate_local_f32 = T.alloc_fragment((block_M, block_N), accum_dtype)
up_local_f32 = T.alloc_fragment((block_M, block_N), accum_dtype)
# 2. Clear Accumulators
T.clear(gate_local_f32)
T.clear(up_local_f32)
# 3. Pipelined GEMM Computation
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(X[by * block_M, k * block_K], X_shared)
T.copy(W[k * block_K, bx * block_N], W_shared)
T.copy(V[k * block_K, bx * block_N], V_shared)
T.gemm(X_shared, W_shared, gate_local_f32)
T.gemm(X_shared, V_shared, up_local_f32)
# 4. Fused Activation Function with Trigonometric Transform
# Explicitly cast intermediate results to float16
gate_local_f16 = T.alloc_fragment((block_M, block_N), dtype)
up_local_f16 = T.alloc_fragment((block_M, block_N), dtype)
T.copy(gate_local_f32, gate_local_f16)
T.copy(up_local_f32, up_local_f16)
for i, j in T.Parallel(block_M, block_N):
# Calculate tanh for the gate part
gate_tanh = T.tanh(gate_local_f16[i, j])
# --- NEW LOGIC START ---
# Apply cos to the first half and sin to the second half of the 'up' projection
transformed_up_val = T.if_then_else(
j < block_N // 2,
T.cos(up_local_f16[i, j]),
T.sin(up_local_f16[i, j])
)
# --- NEW LOGIC END ---
# Final element-wise multiplication
gate_local_f16[i, j] = gate_tanh * transformed_up_val
# 5. Store result back to global memory
T.copy(gate_local_f16, Out[by * block_M, bx * block_N])
return swiglu_trig
def main():
M, K, N = 4096, 2048, 1024
dtype = torch.float16
device = "cuda"
X = torch.randn(M, K, dtype=dtype, device=device)
W = torch.randn(K, N, dtype=dtype, device=device)
V = torch.randn(K, N, dtype=dtype, device=device)
block_M = 64
block_N = 64
block_K = 32
kernel = swiglu_trig_kernel(M, N, K, block_M, block_N, block_K, dtype="float16")
output_tilelang = kernel(X, W, V)
# --- Verification using the UPDATED PyTorch implementation ---
gate_out = torch.tanh(X @ W)
up_out = X @ V
# Apply the new trigonometric transformation logic
up_cos, up_sin = up_out.chunk(2, dim=-1)
up_transformed = torch.cat((torch.cos(up_cos), torch.sin(up_sin)), dim=-1)
output_ref = gate_out * up_transformed
print("Verifying correctness for tanh(X@W) * cat(cos(up_1), sin(up_2))...")
# Using a realistic tolerance for float16 operations
torch.testing.assert_close(output_tilelang, output_ref, rtol=1e-2, atol=1e-2)
print("✅ Correctness check passed!")
print("\nProfiling performance...")
profiler = kernel.get_profiler()
def torch_ref_func():
gate = torch.tanh(X @ W)
up = X @ V
up_cos, up_sin = up.chunk(2, dim=-1)
up_transformed = torch.cat((torch.cos(up_cos), torch.sin(up_sin)), dim=-1)
return gate * up_transformed
torch_latency = profiler.do_bench(torch_ref_func)
print(f"PyTorch (non-fused) Latency: {torch_latency:.4f} ms")
tilelang_latency = profiler.do_bench()
print(f"TileLang (fused) Latency: {tilelang_latency:.4f} ms")
speedup = torch_latency / tilelang_latency
print(f"\nSpeedup: {speedup:.2f}x")
if __name__ == "__main__":
main()Error Log
(venv) ➜ /tmp python3 tilelang_swiglu.py
/tmp/tmpzigyi78y.cu(127): error: no instance of overloaded function "hcos" matches the argument list
argument types are: (cutlass::half_t)
condval = hcos(up_local_f16[i_16]);
^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(3127): note #3326-D: function "hcos(__nv_bfloat16)" does not match because argument #1 does not match parameter
__inline__ __nv_bfloat16 hcos(const __nv_bfloat16 a) {
^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(2835): note #3326-D: function "hcos(__half)" does not match because argument #1 does not match parameter
__inline__ __half hcos(const __half a) {
^
/tmp/tmpzigyi78y.cu(129): error: no instance of overloaded function "hsin" matches the argument list
argument types are: (cutlass::half_t)
condval = hsin(up_local_f16[i_16]);
^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_bf16.hpp(3114): note #3326-D: function "hsin(__nv_bfloat16)" does not match because argument #1 does not match parameter
__inline__ __nv_bfloat16 hsin(const __nv_bfloat16 a) {
^
/opt/cuda/bin/../targets/x86_64-linux/include/cuda_fp16.hpp(2800): note #3326-D: function "hsin(__half)" does not match because argument #1 does not match parameter
__inline__ __half hsin(const __half a) {
^
2 errors detected in the compilation of "/tmp/tmpzigyi78y.cu".
Traceback (most recent call last):
File "/tmp/tilelang_swiglu.py", line 129, in <module>
main()
File "/tmp/tilelang_swiglu.py", line 90, in main
kernel = swiglu_trig_kernel(M, N, K, block_M, block_N, block_K, dtype="float16")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 201, in wrapper
kernel_result = compile(
^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/__init__.py", line 72, in compile
return cached(
^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/cache/__init__.py", line 30, in cached
return _kernel_cache_instance.cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/cache/kernel_cache.py", line 167, in cached
kernel = JITKernel(
^^^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 117, in __init__
adapter = self._compile_and_create_adapter(func, out_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/kernel.py", line 235, in _compile_and_create_adapter
adapter = CythonKernelAdapter(
^^^^^^^^^^^^^^^^^^^^
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/adapter/cython/adapter.py", line 257, in __init__
self.lib_generator.compile_lib()
File "/home/@root/venv/lib/python3.11/site-packages/tilelang/jit/adapter/libgen.py", line 110, in compile_lib
raise RuntimeError(f"Compilation Failed! {command}"
RuntimeError: Compilation Failed! ['/opt/cuda/bin/nvcc', '-std=c++17', '-w', '-Xcudafe', '--diag_suppress=177', '--compiler-options', "'-fPIC'", '-lineinfo', '--shared', '/tmp/tmpzigyi78y.cu', '-lcuda', '-gencode', 'arch=compute_89,code=sm_89', '-I/home/@root/venv/lib/python3.11/site-packages/tilelang/3rdparty/cutlass/include', '-I/home/@root/venv/lib/python3.11/site-packages/tilelang/src', '-o', '/tmp/tmpzigyi78y.so']
#include <tl_templates/cuda/gemm.h>
#include <tl_templates/cuda/copy.h>
#include <tl_templates/cuda/reduce.h>
#include <tl_templates/cuda/ldsm.h>
#include <tl_templates/cuda/threadblock_swizzle.h>
#include <tl_templates/cuda/debug.h>
extern "C" __global__ void swiglu_trig_kernel(half_t* __restrict__ Out, half_t* __restrict__ V, half_t* __restrict__ W, half_t* __restrict__ X);
extern "C" __global__ void __launch_bounds__(128, 1) swiglu_trig_kernel(half_t* __restrict__ Out, half_t* __restrict__ V, half_t* __restrict__ W, half_t* __restrict__ X) {
extern __shared__ __align__(1024) uchar buf_dyn_shmem[];
float gate_local_f32[32];
float up_local_f32[32];
half_t gate_local_f16[32];
half_t up_local_f16[32];
#pragma unroll
for (int i = 0; i < 16; ++i) {
*(float2*)(gate_local_f32 + (i * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
}
#pragma unroll
for (int i_1 = 0; i_1 < 16; ++i_1) {
*(float2*)(up_local_f32 + (i_1 * 2)) = make_float2(0.000000e+00f, 0.000000e+00f);
}
#pragma unroll
for (int i_2 = 0; i_2 < 2; ++i_2) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_2 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), W+((((i_2 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)));
}
tl::cp_async_commit();
#pragma unroll
for (int i_3 = 0; i_3 < 2; ++i_3) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_3 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), X+((((((int)blockIdx.y) * 131072) + (i_3 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)));
}
#pragma unroll
for (int i_4 = 0; i_4 < 2; ++i_4) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_4 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), V+((((i_4 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)));
}
tl::cp_async_commit();
#pragma unroll
for (int i_5 = 0; i_5 < 2; ++i_5) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_5 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 4096), W+(((((i_5 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 32768));
}
tl::cp_async_commit();
#pragma unroll
for (int i_6 = 0; i_6 < 2; ++i_6) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_6 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 16384), X+(((((((int)blockIdx.y) * 131072) + (i_6 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)) + 32));
}
#pragma unroll
for (int i_7 = 0; i_7 < 2; ++i_7) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_7 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 28672), V+(((((i_7 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 32768));
}
tl::cp_async_commit();
#pragma unroll
for (int i_8 = 0; i_8 < 2; ++i_8) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_8 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 8192), W+(((((i_8 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 65536));
}
tl::cp_async_commit();
#pragma unroll
for (int i_9 = 0; i_9 < 2; ++i_9) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((i_9 * 2048) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 20480), X+(((((((int)blockIdx.y) * 131072) + (i_9 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + ((((int)threadIdx.x) & 3) * 8)) + 64));
}
#pragma unroll
for (int i_10 = 0; i_10 < 2; ++i_10) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((i_10 * 2048) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 32768), V+(((((i_10 * 16384) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 65536));
}
tl::cp_async_commit();
for (int k = 0; k < 61; ++k) {
tl::cp_async_wait<4>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 6144)])), (&(((half_t*)buf_dyn_shmem)[((k % 3) * 2048)])), (&(gate_local_f32[0])));
__syncthreads();
#pragma unroll
for (int i_11 = 0; i_11 < 2; ++i_11) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k % 3) * 4096) + (i_11 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)), W+((((((k * 32768) + (i_11 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 98304));
}
tl::cp_async_commit();
tl::cp_async_wait<5>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 6144)])), (&(((half_t*)buf_dyn_shmem)[(((k % 3) * 2048) + 12288)])), (&(up_local_f32[0])));
__syncthreads();
#pragma unroll
for (int i_12 = 0; i_12 < 2; ++i_12) {
tl::cp_async_gs<16>(buf_dyn_shmem+(((((((k % 3) * 4096) + (i_12 * 2048)) + ((((int)threadIdx.x) >> 2) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 12288), X+((((((((int)blockIdx.y) * 131072) + (i_12 * 65536)) + ((((int)threadIdx.x) >> 2) * 2048)) + (k * 32)) + ((((int)threadIdx.x) & 3) * 8)) + 96));
}
#pragma unroll
for (int i_13 = 0; i_13 < 2; ++i_13) {
tl::cp_async_gs<16>(buf_dyn_shmem+((((((((k % 3) * 4096) + (i_13 * 2048)) + ((((int)threadIdx.x) >> 3) * 128)) + (((((((int)threadIdx.x) & 63) >> 5) + ((((int)threadIdx.x) & 7) >> 2)) & 1) * 64)) + (((((((int)threadIdx.x) & 31) >> 4) + ((((int)threadIdx.x) & 3) >> 1)) & 1) * 32)) + (((((((int)threadIdx.x) & 15) >> 3) + (((int)threadIdx.x) & 1)) & 1) * 16)) + 24576), V+((((((k * 32768) + (i_13 * 16384)) + ((((int)threadIdx.x) >> 3) * 1024)) + (((int)blockIdx.x) * 64)) + ((((int)threadIdx.x) & 7) * 8)) + 98304));
}
tl::cp_async_commit();
}
tl::cp_async_wait<4>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[8192])), (&(((half_t*)buf_dyn_shmem)[2048])), (&(gate_local_f32[0])));
tl::cp_async_wait<4>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[8192])), (&(((half_t*)buf_dyn_shmem)[14336])), (&(up_local_f32[0])));
tl::cp_async_wait<2>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[10240])), (&(((half_t*)buf_dyn_shmem)[4096])), (&(gate_local_f32[0])));
tl::cp_async_wait<2>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[10240])), (&(((half_t*)buf_dyn_shmem)[16384])), (&(up_local_f32[0])));
tl::cp_async_wait<0>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[6144])), (&(((half_t*)buf_dyn_shmem)[0])), (&(gate_local_f32[0])));
tl::cp_async_wait<0>();
__syncthreads();
tl::gemm_ss<64, 64, 32, 2, 2, 0, 0, 0>((&(((half_t*)buf_dyn_shmem)[6144])), (&(((half_t*)buf_dyn_shmem)[12288])), (&(up_local_f32[0])));
#pragma unroll
for (int i_14 = 0; i_14 < 16; ++i_14) {
uint1 __1;
float2 v_ = *(float2*)(gate_local_f32 + (i_14 * 2));
((half2*)(&(__1.x)))->x = (half_t)(v_.x);
((half2*)(&(__1.x)))->y = (half_t)(v_.y);
*(uint1*)(gate_local_f16 + (i_14 * 2)) = __1;
}
#pragma unroll
for (int i_15 = 0; i_15 < 16; ++i_15) {
uint1 __2;
float2 v__1 = *(float2*)(up_local_f32 + (i_15 * 2));
((half2*)(&(__2.x)))->x = (half_t)(v__1.x);
((half2*)(&(__2.x)))->y = (half_t)(v__1.y);
*(uint1*)(up_local_f16 + (i_15 * 2)) = __2;
}
#pragma unroll
for (int i_16 = 0; i_16 < 32; ++i_16) {
half_t condval;
if ((i_16 < 16)) {
condval = hcos(up_local_f16[i_16]);
} else {
condval = hsin(up_local_f16[i_16]);
}
gate_local_f16[i_16] = (htanh(gate_local_f16[i_16]) * condval);
}
#pragma unroll
for (int i_17 = 0; i_17 < 16; ++i_17) {
*(uint1*)(Out + (((((((((((int)blockIdx.y) * 65536) + (((i_17 & 3) >> 1) * 32768)) + (((((int)threadIdx.x) & 63) >> 5) * 16384)) + ((i_17 & 1) * 8192)) + (((((int)threadIdx.x) & 31) >> 2) * 1024)) + (((int)blockIdx.x) * 64)) + ((i_17 >> 2) * 16)) + ((((int)threadIdx.x) >> 6) * 8)) + ((((int)threadIdx.x) & 3) * 2))) = *(uint1*)(gate_local_f16 + (i_17 * 2));
}
}
#define ERROR_BUF_SIZE 1024
static char error_buf[ERROR_BUF_SIZE];
extern "C" const char* get_last_error() {
return error_buf;
}
extern "C" int init() {
error_buf[0] = '\0';
cudaError_t result_swiglu_trig_kernel = cudaFuncSetAttribute(swiglu_trig_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 36864);
if (result_swiglu_trig_kernel != CUDA_SUCCESS) {
snprintf(error_buf, ERROR_BUF_SIZE, "Failed to set the allowed dynamic shared memory size to %d with error: %s", 36864, cudaGetErrorString(result_swiglu_trig_kernel));
return -1;
}
return 0;
}
extern "C" int call(half_t* __restrict__ X, half_t* __restrict__ W, half_t* __restrict__ V, half_t* __restrict__ Out, cudaStream_t stream=cudaStreamDefault) {
swiglu_trig_kernel<<<dim3(16, 64, 1), dim3(128, 1, 1), 36864, stream>>>(Out, V, W, X);
TILELANG_CHECK_LAST_ERROR("swiglu_trig_kernel");
return 0;
}is this are normal behavior, that we need manually implement function cos and sin?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working