Skip to content

Commit

Permalink
Fixes device-tma kernel (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg authored Oct 8, 2024
1 parent b4a789c commit 34e51d0
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 35 deletions.
2 changes: 1 addition & 1 deletion benchmarks/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def get_configs_varying_k(M: int = 8192, N: int = 8192) -> List[ExperimentConfig
FP8Kernel.SCALED_MM,
# FP8Kernel.PERSISTENT,
FP8Kernel.PERSISTENT_TMA,
# FP8Kernel.DEVICE_TMA,
FP8Kernel.DEVICE_TMA,
]

for (M, K, N), strategy, compile, kernel in itertools.product(
Expand Down
34 changes: 0 additions & 34 deletions transformer_nuggets/fp8/fp8_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,6 @@ def matmul_tma_persistent(
@triton.jit(launch_metadata=_matmul_launch_metadata)
def matmul_kernel_device_tma_persistent(
workspace_ptr,
tiles_per_update: tl.constexpr,
a_ptr,
a_scale_ptr,
b_ptr,
Expand Down Expand Up @@ -532,7 +531,6 @@ def matmul_kernel_device_tma_persistent(

tile_id = start_pid - NUM_SMS
ki = -1
ni = -1

pid_m = 0
pid_n = 0
Expand All @@ -547,36 +545,6 @@ def matmul_kernel_device_tma_persistent(
for _ in range(0, k_tiles * tiles_per_SM):
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
if ki == 0:
ni += 1

# Simulate a grouped gemm
if ni == tiles_per_update:
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=a_desc_ptr,
global_address=a_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_K],
global_size=[M, K],
element_ty=a_ptr.dtype.element_ty,
)
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=b_desc_ptr,
global_address=b_ptr,
load_size=[BLOCK_SIZE_N, BLOCK_SIZE_K],
global_size=[N, K],
element_ty=b_ptr.dtype.element_ty,
)
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[M, N],
element_ty=c_ptr.dtype.element_ty,
)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
ni = 0

tile_id += NUM_SMS
group_id = tile_id // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
Expand Down Expand Up @@ -627,7 +595,6 @@ def matmul_device_tma_persistent(
b: torch.Tensor,
b_scale: torch.Tensor,
output_dtype: torch.dtype,
tiles_per_update: int = 1,
) -> torch.Tensor:
assert is_row_major(a.stride()), "a must be row major"
assert is_col_major(b.stride()), "b must be col major"
Expand All @@ -649,7 +616,6 @@ def matmul_device_tma_persistent(
)
matmul_kernel_device_tma_persistent[grid](
workspace,
tiles_per_update,
a,
a_scale,
b,
Expand Down

0 comments on commit 34e51d0

Please sign in to comment.