Skip to content

Commit 9d38297

Browse files
authored
[CI] Refactor import paths in dequantization examples to use dequantize_utils (#914)
* Update requirements and refactor benchmark script for deepseek_nsa example - Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository. - Refactored import paths in benchmark_nsa_fwd.py for better organization. - Added a new function to generate configurations for autotuning. - Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility. - Changed allocation of shared memory for accumulators to optimize performance. * Refactor import paths in dequantization examples to use dequantize_utils - Updated import statements in multiple dequantization example scripts to replace references to the removed utils.py file with the new dequantize_utils module. - Ensured consistency across example scripts for better organization and maintainability.
1 parent 1b4cd38 commit 9d38297

File tree

5 files changed

+4
-4
lines changed

5 files changed

+4
-4
lines changed
File renamed without changes.

examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tvm import DataType
55
from tvm import tir
66
import torch
7-
from utils import torch_convert_bit_twiddling, torch_convert
7+
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
88

99

1010
def get_configs():

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tvm import DataType
55
from tvm import tir
66
import torch
7-
from utils import torch_convert_bit_twiddling, torch_convert
7+
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
88

99

1010
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tvm import DataType
55
from tvm import tir
66
import torch
7-
from utils import torch_convert_bit_twiddling, torch_convert
7+
from dequantize_utils import torch_convert_bit_twiddling, torch_convert
88

99

1010
def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from tilelang import tvm as tvm
55
from tvm import DataType
66
import torch
7-
from utils import torch_convert_bit_twiddling, assert_similar
7+
from dequantize_utils import torch_convert_bit_twiddling, assert_similar
88
from tilelang.autotuner import set_autotune_inputs
99

1010

0 commit comments

Comments
 (0)