Skip to content

Commit 9962ec5

Browse files
committed
[misc] polish && add reference && apply review suggestionsi && format
1 parent de09434 commit 9962ec5

File tree

8 files changed

+26
-40
lines changed

8 files changed

+26
-40
lines changed

benchmark/matmul/benchmark_matmul_sp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ def main(
288288
print(f"Best config: {best_config}")
289289

290290
if args.bench_torch_sparse is not None:
291-
print(f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}")
291+
print(
292+
f"Torch sparse ({args.bench_torch_sparse}) TFlops: {total_flops / torch_sparse_latency * 1e-9:.3f}"
293+
)
292294

293295
print(f"Reference Dense TFlops: {total_flops / ref_latency * 1e-9:.3f}")

src/tl_templates/cpp/half.hpp

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -513,9 +513,7 @@ using std::true_type;
513513
template <typename T> struct is_float : std::is_floating_point<T> {};
514514
#else
515515
/// Conditional type.
516-
template <bool, typename T, typename> struct conditional {
517-
typedef T type;
518-
};
516+
template <bool, typename T, typename> struct conditional { typedef T type; };
519517
template <typename T, typename F> struct conditional<false, T, F> {
520518
typedef F type;
521519
};
@@ -536,9 +534,7 @@ template <> struct is_float<long double> : true_type {};
536534
#endif
537535

538536
/// Type traits for floating-point bits.
539-
template <typename T> struct bits {
540-
typedef unsigned char type;
541-
};
537+
template <typename T> struct bits { typedef unsigned char type; };
542538
template <typename T> struct bits<const T> : bits<T> {};
543539
template <typename T> struct bits<volatile T> : bits<T> {};
544540
template <typename T> struct bits<const volatile T> : bits<T> {};
@@ -554,14 +550,10 @@ typedef std::uint_fast32_t uint32;
554550
typedef std::int_fast32_t int32;
555551

556552
/// Unsigned integer of (at least) 32 bits width.
557-
template <> struct bits<float> {
558-
typedef std::uint_least32_t type;
559-
};
553+
template <> struct bits<float> { typedef std::uint_least32_t type; };
560554

561555
/// Unsigned integer of (at least) 64 bits width.
562-
template <> struct bits<double> {
563-
typedef std::uint_least64_t type;
564-
};
556+
template <> struct bits<double> { typedef std::uint_least64_t type; };
565557
#else
566558
/// Unsigned integer of (at least) 16 bits width.
567559
typedef unsigned short uint16;
@@ -586,9 +578,7 @@ struct bits<double>
586578
unsigned long, unsigned long long> {};
587579
#else
588580
/// Unsigned integer of (at least) 64 bits width.
589-
template <> struct bits<double> {
590-
typedef unsigned long type;
591-
};
581+
template <> struct bits<double> { typedef unsigned long type; };
592582
#endif
593583
#endif
594584

src/tl_templates/cuda/common.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,7 @@ TL_DEVICE unsigned int cast_smem_ptr_to_int(const void *const smem_ptr) {
136136
return smem_int;
137137
}
138138

139-
template <typename T> struct normalize_atomic_type {
140-
using type = T;
141-
};
139+
template <typename T> struct normalize_atomic_type { using type = T; };
142140

143141
template <> /**
144142
* Map the public half_t alias to the native `half` type for atomic

src/tl_templates/cuda/gemm_sp_sm80.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ template <typename Shape> struct ShapeCheck<uint8_t, Shape> {
2828
(Shape::kM % 16 == 0) && (Shape::kN % 16 == 0) && (Shape::kK % 64 == 0);
2929
};
3030

31+
// ref:
32+
// https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h
3133
template <typename T> struct DispatchInstructionShape {
3234
static_assert(!std::is_same_v<T, T>,
3335
"Unsupported type for DispatchInstructionShape");
@@ -119,13 +121,9 @@ template <> struct DispatchType<cutlass::bfloat16_t> {
119121
using Type = cutlass::bfloat16_t;
120122
};
121123

122-
template <> struct DispatchType<unsigned char> {
123-
using Type = uint8_t;
124-
};
124+
template <> struct DispatchType<unsigned char> { using Type = uint8_t; };
125125

126-
template <> struct DispatchType<signed char> {
127-
using Type = int8_t;
128-
};
126+
template <> struct DispatchType<signed char> { using Type = int8_t; };
129127

130128
template <typename Shape, int num_warp_m, int num_warp_n, bool trans_A,
131129
bool trans_B, bool clear_accum, typename A_type_raw,

testing/python/tilelibrary/test_tilelang_tilelibrary_gemm_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,8 @@ def test_gemm_sp_sm90():
364364
run_gemm_sp_sm90(512, 1024, 768, "int8", "int32", "int32", 64, 64, 64, 2, 128, False, True)
365365

366366

367-
# @tilelang.testing.requires_cuda
368-
# @tilelang.testing.requires_cuda_compute_version(8, 0)
367+
@tilelang.testing.requires_cuda
368+
@tilelang.testing.requires_cuda_compute_version(8, 0)
369369
def test_gemm_sp_sm80():
370370
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 32, 32, 32, 0, 32)
371371
run_gemm_sp_sm80(512, 1024, 768, "float16", "float32", "float32", 64, 64, 64, 0, 32)

tilelang/language/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,4 +355,4 @@ def sync_grid():
355355
def cp_async_barrier_noinc(barrier_id: Union[int, PrimExpr, tir.Call]):
356356
"""Perform a ptx async copy barrier using cp.async.mbarrier.arrive.noinc.
357357
"""
358-
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)
358+
return tir.call_intrin("handle", tir.op.Op.get("tl.ptx_cp_async_barrier_noinc"), barrier_id)

tilelang/layout/gemm_sp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Wrapping Layouts."""
22
# pylint: disable=invalid-name, unsupported-binary-operation
33

4-
from tilelang.autotuner.capture import Optional
4+
from typing import Optional
55
import tvm
66
import tilelang.language as T
77
import warnings
@@ -116,7 +116,7 @@ def _make_metadata_layout_sm8x_cutlass(buffer: tvm.tir.Buffer, mma_dtype: str):
116116
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/layout/matrix.h#L405
117117
# https://github.com/nvidia/cutlass/blob/ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e/include/cutlass/gemm/warp/mma_sparse_tensor_op.h#L172
118118

119-
if mma_dtype in ["float16, bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
119+
if mma_dtype in ["float16", "bfloat16"] and buffer.dtype not in ["uint16", "int16"]:
120120
raise ValueError(f"metadata should be 16 bit, got {buffer.dtype}")
121121

122122
if mma_dtype in ["float8", "int8", "uint8"] and buffer.dtype not in ["uint32", "int32"]:

tilelang/utils/sparse.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,15 @@ def compress_sm80(A: torch.Tensor, transposed: bool) -> tuple[torch.Tensor, torc
6262
except ImportError as err:
6363
raise ImportError("SparseSemiStructuredTensor is not available in this version of PyTorch. "
6464
"Please install a compatible version.") from err
65-
6665
orig_val = SparseSemiStructuredTensor._FORCE_CUTLASS
67-
SparseSemiStructuredTensor._FORCE_CUTLASS = True
68-
69-
if transposed is not False:
70-
raise NotImplementedError("transposed flag is deprecated by pytorch")
71-
72-
compressed = to_sparse_semi_structured(A)
73-
SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val
74-
75-
return compressed.packed, compressed.meta
66+
try:
67+
SparseSemiStructuredTensor._FORCE_CUTLASS = True
68+
if transposed is not False:
69+
raise NotImplementedError("transposed flag is deprecated by pytorch")
70+
compressed = to_sparse_semi_structured(A)
71+
return compressed.packed, compressed.meta
72+
finally:
73+
SparseSemiStructuredTensor._FORCE_CUTLASS = orig_val
7674

7775

7876
def compress(A: torch.Tensor,

0 commit comments

Comments
 (0)