Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,6 @@ tilelang/lib

# tox
.tox/

# cython
tilelang/jit/adapter/cython/.cycache
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ func = matmul(1024, 1024, 1024, 128, 128, 32)
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch
Expand Down
2 changes: 1 addition & 1 deletion docs/deeplearning_operators/matmul.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
func = matmul(1024, 1024, 1024, 128, 128, 32)

# 2. JIT-compile the kernel for NVIDIA GPU
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

import torch

Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def main(
# out_idx specifies the index of the output buffer in the argument list
# if out_idx is specified, the tensor will be created during runtime
# target currently can be "cuda" or "hip" or "cpu".
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

# 3. Test the kernel in Python with PyTorch data
import torch
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ attrs
cloudpickle
ml_dtypes
psutil
torch
torch>=2.2.0
19 changes: 17 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def download_and_extract_llvm(version, is_aarch64=False, extract_path="3rdparty"


package_data = {
"tilelang": ["py.typed"],
"tilelang": ["py.typed", "*pyx"],
}

LLVM_VERSION = "10.0.1"
Expand Down Expand Up @@ -227,7 +227,22 @@ def run(self):
ext_output_dir = os.path.dirname(extdir)
print(f"Extension output directory (parent): {ext_output_dir}")
print(f"Build temp directory: {build_temp_dir}")

# copy cython files
CYTHON_SRC = [
"tilelang/jit/adapter/cython/cython_wrapper.pyx",
]
for item in CYTHON_SRC:
source_dir = os.path.join(ROOT_DIR, item)
target_dir = os.path.join(self.build_lib, item)
if os.path.isdir(source_dir):
self.mkpath(target_dir)
distutils.dir_util.copy_tree(source_dir, target_dir)
else:
target_dir = os.path.dirname(target_dir)
if not os.path.exists(target_dir):
os.makedirs(target_dir)
shutil.copy2(source_dir, target_dir)
# copy the tl_templates
TILELANG_SRC = [
"src/tl_templates",
]
Expand Down
10 changes: 5 additions & 5 deletions testing/python/debug/test_tilelang_debug_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def program(Q: T.Buffer((M, N), dtype)):
shared_buf = T.alloc_shared([M, N], dtype)
T.print(shared_buf)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -34,7 +34,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if bx == 0 and by == 0 and bz == 0:
T.print(shared_buf)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -53,7 +53,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if tid == 0:
T.print(bx + by + bz)

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -72,7 +72,7 @@ def program(Q: T.Buffer((M, N), dtype)):
for i, j in T.Parallel(M, N):
T.print(register_buf[i, j])

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand All @@ -91,7 +91,7 @@ def program(Q: T.Buffer((M, N), dtype)):
if tid == 0:
T.print(bx + by + bz, msg="hello world")

jit_kernel = tilelang.JITKernel(program, target="cuda")
jit_kernel = tilelang.compile(program, target="cuda")
profiler = jit_kernel.get_profiler()
profiler.run_once()

Expand Down
2 changes: 1 addition & 1 deletion testing/python/issue/test_tilelang_issue_96.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main(

def run_gemm_pipeline_test(N, block_M=128, block_N=128, block_K=32):
func = matmul(N, N, N, block_M, block_N, block_K)
jit_kernel = tilelang.JITKernel(func, out_idx=[2], target="cuda")
jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda")

torch.manual_seed(0)
a = torch.randn(N, N, device="cuda", dtype=torch.float16)
Expand Down
4 changes: 2 additions & 2 deletions testing/python/jit/test_tilelang_jit_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

kernel_source = matmul_kernel.get_kernel_source()

Expand Down Expand Up @@ -196,7 +196,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down
2 changes: 1 addition & 1 deletion testing/python/jit/test_tilelang_jit_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="dlpack")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="dlpack")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down
10 changes: 5 additions & 5 deletions testing/python/jit/test_tilelang_jit_gemm_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def tilelang_callback_cuda_postproc(code, _):
code = f"// {stramp}\n" + code
return code

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")

kernel_source = matmul_kernel.get_kernel_source()

Expand Down Expand Up @@ -195,7 +195,7 @@ def run_gemm_jit_kernel(
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, out_idx=-1, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, out_idx=-1, execution_backend="ctypes")

A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
B = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down Expand Up @@ -263,7 +263,7 @@ def run_ctypes_kernel_do_bench(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

profiler = matmul_kernel.get_profiler()

Expand Down Expand Up @@ -312,7 +312,7 @@ def run_ctypes_kernel_multi_stream(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

tensor_a = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)).cuda()
tensor_b = torch.randn(K, N, dtype=torch.__getattribute__(in_dtype)).cuda()
Expand Down Expand Up @@ -364,7 +364,7 @@ def run_ctypes_dynamic_shape(M,
num_threads,
)

matmul_kernel = tilelang.JITKernel(program, execution_backend="ctypes")
matmul_kernel = tilelang.compile(program, execution_backend="ctypes")
if isinstance(M, T.Var):
M = 1024
if isinstance(N, T.Var):
Expand Down
Loading
Loading