From 5b3f2bbfee68ac016a06fb0b8324026443d1efda Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 18 Mar 2025 03:30:48 -0400 Subject: [PATCH 1/7] [Flashinfer] Added jit flow for sampling kernel --- python/tvm/relax/backend/cuda/flashinfer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 8aa4817a302d..545923a76b91 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -415,3 +415,18 @@ def gen_flashinfer_mla_module( object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) modules = _load_flashinfer_modules(object_files) return modules + +def gen_sampling_module(target: Target, num_threads: int = 8): + try: + from flashinfer.jit import ( # pylint: disable=import-outside-toplevel + gen_sampling_tvm_binding, + ) + except ImportError: + raise ImportError( + "FlashInfer is not installed. Please follow instructions " + "in https://docs.flashinfer.ai to install FlashInfer." + ) + uri, source_paths = gen_sampling_tvm_binding() + object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) + modules = _load_flashinfer_modules(object_files) + return modules From e2862b7cfff9b3371764d1e47b3a3eb22d3271fa Mon Sep 17 00:00:00 2001 From: Annanya Date: Fri, 21 Mar 2025 20:16:14 -0400 Subject: [PATCH 2/7] Small fixes --- python/tvm/relax/backend/cuda/flashinfer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 545923a76b91..e5e36142c1c1 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -426,7 +426,7 @@ def gen_sampling_module(target: Target, num_threads: int = 8): "FlashInfer is not installed. Please follow instructions " "in https://docs.flashinfer.ai to install FlashInfer." ) - uri, source_paths = gen_sampling_tvm_binding() + uri, source_paths = gen_sampling_tvm_binding(uri="sampling") object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads) modules = _load_flashinfer_modules(object_files) return modules From ac858a8da99526d71d8a0ebcb26fd08736e2b671 Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 02:51:56 -0400 Subject: [PATCH 3/7] Added test --- .../relax/test_runtime_sampling_flashinfer.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 tests/python/relax/test_runtime_sampling_flashinfer.py diff --git a/tests/python/relax/test_runtime_sampling_flashinfer.py b/tests/python/relax/test_runtime_sampling_flashinfer.py new file mode 100644 index 000000000000..b9e9f9fc3a28 --- /dev/null +++ b/tests/python/relax/test_runtime_sampling_flashinfer.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import random +import numpy as np +import tvm +import tvm.testing +import pytest +from tvm import relax +from tvm.contrib import utils +from typing import List + + + + +@pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup") +def test_sampling(): + + def load_module(name: str, static_modules: List[tvm.runtime.Module]): + assert len(static_modules) > 0 + if len(static_modules) == 1: + return static_modules[0] + static_mod = static_modules[0] + for mod in static_modules[1:]: + static_mod.import_module(mod) + temp = utils.tempdir() + mod_path = temp.relpath(f"{name}.so") + static_mod.export_library(mod_path) + return tvm.runtime.load_module(mod_path) + + # Test configuration + batch_size = 10 + vocab_size = 5 + num_iterations = 1000 + tol_atol = 0.02 + tol_rtol = 0.05 # relative tolerance + + # Probability tensor (each row sums to 1) + probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32") + + dev = tvm.cuda(0) + probs_tvm = tvm.nd.array(probs_np, device=dev) + output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev) + + device = tvm.cuda() + target = tvm.target.Target.from_device(device) + sampling_mod = load_module( + "flashinfer_sampling", + relax.backend.cuda.flashinfer.gen_sampling_module( + target=target, + ), + ) + sampling_func = sampling_mod["sampling_from_probs"] + + counts = np.zeros((batch_size, vocab_size), dtype="int32") + + for _ in range(num_iterations): + deterministic = False + # Generate seed and a random offset. + philox_seed = np.uint64(random.getrandbits(63)) + philox_offset = np.uint64(random.getrandbits(63) % 1000) + + # the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream) + sampling_func(probs_tvm, output_tvm, None, deterministic, + philox_seed, philox_offset, 0) + + out = output_tvm.asnumpy() + for i in range(batch_size): + sampled_token = out[i] + counts[i, sampled_token] += 1 + + # Convert counts to frequencies. + frequencies = counts / float(num_iterations) + + # For each row, check that the empirical frequency is close to the input probability. + for row in range(batch_size): + tvm.testing.assert_allclose( + frequencies[row], + probs_np[row], + rtol=tol_rtol, + atol=tol_atol + ) + +if __name__ == "__main__": + # Run the test standalone (if not using pytest) + test_sampling() + + From b1615e369aad3c3d3f4765627adbdfd34466b42a Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 15:05:18 -0400 Subject: [PATCH 4/7] Some small change --- tests/python/relax/test_runtime_sampling_flashinfer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_runtime_sampling_flashinfer.py b/tests/python/relax/test_runtime_sampling_flashinfer.py index b9e9f9fc3a28..5ca020b7bd6f 100644 --- a/tests/python/relax/test_runtime_sampling_flashinfer.py +++ b/tests/python/relax/test_runtime_sampling_flashinfer.py @@ -54,7 +54,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32") dev = tvm.cuda(0) - probs_tvm = tvm.nd.array(probs_np, device=dev) + prob_tvm = tvm.nd.array(probs_np, device=dev) output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev) device = tvm.cuda() @@ -76,7 +76,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): philox_offset = np.uint64(random.getrandbits(63) % 1000) # the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream) - sampling_func(probs_tvm, output_tvm, None, deterministic, + sampling_func(prob_tvm, output_tvm, None, deterministic, philox_seed, philox_offset, 0) out = output_tvm.asnumpy() From 32334aa6342432247745a3dc85acdae9a9a55fec Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 15:29:07 -0400 Subject: [PATCH 5/7] Some small change --- python/tvm/relax/backend/cuda/flashinfer.py | 1 + .../relax/test_runtime_sampling_flashinfer.py | 27 +++++++------------ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index e5e36142c1c1..865ad5122700 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -416,6 +416,7 @@ def gen_flashinfer_mla_module( modules = _load_flashinfer_modules(object_files) return modules + def gen_sampling_module(target: Target, num_threads: int = 8): try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel diff --git a/tests/python/relax/test_runtime_sampling_flashinfer.py b/tests/python/relax/test_runtime_sampling_flashinfer.py index 5ca020b7bd6f..608f59eb1567 100644 --- a/tests/python/relax/test_runtime_sampling_flashinfer.py +++ b/tests/python/relax/test_runtime_sampling_flashinfer.py @@ -26,8 +26,6 @@ from typing import List - - @pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup") def test_sampling(): @@ -46,17 +44,17 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): # Test configuration batch_size = 10 vocab_size = 5 - num_iterations = 1000 + num_iterations = 1000 tol_atol = 0.02 tol_rtol = 0.05 # relative tolerance # Probability tensor (each row sums to 1) probs_np = np.array([[0.1, 0.2, 0.3, 0.2, 0.2] for _ in range(batch_size)], dtype="float32") - + dev = tvm.cuda(0) prob_tvm = tvm.nd.array(probs_np, device=dev) output_tvm = tvm.nd.empty((batch_size,), "int32", device=dev) - + device = tvm.cuda() target = tvm.target.Target.from_device(device) sampling_mod = load_module( @@ -68,7 +66,7 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): sampling_func = sampling_mod["sampling_from_probs"] counts = np.zeros((batch_size, vocab_size), dtype="int32") - + for _ in range(num_iterations): deterministic = False # Generate seed and a random offset. @@ -76,9 +74,8 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): philox_offset = np.uint64(random.getrandbits(63) % 1000) # the kernel expects (probs, output, maybe_indices, deterministic, philox_seed, philox_offset, cuda_stream) - sampling_func(prob_tvm, output_tvm, None, deterministic, - philox_seed, philox_offset, 0) - + sampling_func(prob_tvm, output_tvm, None, deterministic, philox_seed, philox_offset, 0) + out = output_tvm.asnumpy() for i in range(batch_size): sampled_token = out[i] @@ -86,18 +83,12 @@ def load_module(name: str, static_modules: List[tvm.runtime.Module]): # Convert counts to frequencies. frequencies = counts / float(num_iterations) - + # For each row, check that the empirical frequency is close to the input probability. for row in range(batch_size): - tvm.testing.assert_allclose( - frequencies[row], - probs_np[row], - rtol=tol_rtol, - atol=tol_atol - ) + tvm.testing.assert_allclose(frequencies[row], probs_np[row], rtol=tol_rtol, atol=tol_atol) + if __name__ == "__main__": # Run the test standalone (if not using pytest) test_sampling() - - From 14d20a401dc08330bd6c459e8e98b9c7385d076c Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 17:42:21 -0400 Subject: [PATCH 6/7] Some small change --- tests/python/relax/test_runtime_sampling_flashinfer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_runtime_sampling_flashinfer.py b/tests/python/relax/test_runtime_sampling_flashinfer.py index 608f59eb1567..7b6deb6f0292 100644 --- a/tests/python/relax/test_runtime_sampling_flashinfer.py +++ b/tests/python/relax/test_runtime_sampling_flashinfer.py @@ -28,7 +28,6 @@ @pytest.mark.skip(reason="Requires FlashInfer enabled and proper setup") def test_sampling(): - def load_module(name: str, static_modules: List[tvm.runtime.Module]): assert len(static_modules) > 0 if len(static_modules) == 1: From df5a628c070936320e1d919788a111d1f06079df Mon Sep 17 00:00:00 2001 From: Annanya Date: Tue, 1 Apr 2025 18:30:37 -0400 Subject: [PATCH 7/7] Added doctstring --- python/tvm/relax/backend/cuda/flashinfer.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tvm/relax/backend/cuda/flashinfer.py b/python/tvm/relax/backend/cuda/flashinfer.py index 865ad5122700..687987e4d66b 100644 --- a/python/tvm/relax/backend/cuda/flashinfer.py +++ b/python/tvm/relax/backend/cuda/flashinfer.py @@ -418,6 +418,21 @@ def gen_flashinfer_mla_module( def gen_sampling_module(target: Target, num_threads: int = 8): + """ + Generate a FlashInfer module for sampling kernels. + + Parameters + ---------- + target : Target + The target device for which the module will be compiled. + num_threads : int, optional + The number of threads to use during compilation (default is 8). + + Returns + ------- + List[tvm.runtime.Module] + A list of compiled static library modules for the FlashInfer sampling kernels. + """ try: from flashinfer.jit import ( # pylint: disable=import-outside-toplevel gen_sampling_tvm_binding,