Skip to content

Commit fc6619d

Browse files
committed
[Lint]
1 parent 9e98307 commit fc6619d

File tree

3 files changed

+4
-5
lines changed

3 files changed

+4
-5
lines changed

examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,8 @@ def main(
201201
202202
Side effects:
203203
Writes results into C. Calls external device decode functions to expand B from its packed representation into shared memory before computation.
204-
"""
205-
with T.Kernel(
204+
"""
205+
with T.Kernel(
206206
T.ceildiv(N, block_N),
207207
T.ceildiv(M, block_M),
208208
threads=threads,

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def matmul(M,
9696
num_stages=2,
9797
threads=256,
9898
split=1):
99-
10099
"""
101100
Construct and return a tiled matrix-multiply TIR kernel that multiplies A (shape MxK) by a quantized B (shape Nx(QK)) and writes an MxN output in out_dtype.
102101

examples/dequantize_gemm/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33

44
def torch_convert_bit_twiddling(tensor):
5-
65
"""
76
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
87
@@ -17,6 +16,7 @@ def torch_convert_bit_twiddling(tensor):
1716
Raises:
1817
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
1918
"""
19+
2020
def _convert(val0, val1, pos) -> torch.bfloat16:
2121
assert val0.dtype == torch.uint8
2222
assert val1.dtype == torch.uint8
@@ -51,7 +51,6 @@ def _convert(val0, val1, pos) -> torch.bfloat16:
5151

5252

5353
def torch_convert(tensor, scale_size=None, Scale=None):
54-
5554
"""
5655
Decode a 2D uint8 tensor into a 2D bfloat16 tensor by expanding each byte into two bf16 values using a 4-bit (nibble) encoding.
5756
@@ -65,6 +64,7 @@ def torch_convert(tensor, scale_size=None, Scale=None):
6564
Returns:
6665
torch.Tensor: A new tensor of shape (N, K*2) and dtype torch.bfloat16 containing the decoded bf16 values.
6766
"""
67+
6868
def _convert(val, pos, scale=None):
6969
assert val.dtype == torch.uint8
7070
# val = val.view(torch.int8)

0 commit comments

Comments
 (0)