Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions tests/test_trtllm_cutlass_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def compute_with_experts(
1,
]
HIDDEN_SIZES = [
128,
256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. Could you please add another case that hidden_size is not divisible of 128 to trigger the padding logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wenscarl , I believe the padding logic and the per block quantization logic also not correct.

]
NUM_EXPERTS = [2]
TOP_K_VALUES = [2]
Expand Down Expand Up @@ -884,6 +884,7 @@ def dequantize_block(
scales: torch.Tensor,
dtype: torch.dtype,
original_shape: tuple,
block_size_n: int = 128,
) -> torch.Tensor:
"""
Dequantize a block-quantized tensor.
Expand All @@ -893,6 +894,7 @@ def dequantize_block(
scales: Block scaling factors
dtype: Target dtype for dequantization
original_shape: Original shape of the tensor before padding
block_size_n: Block size

Returns:
torch.Tensor: Dequantized tensor
Expand All @@ -904,18 +906,22 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor:
if dim != -1:
a = a.transpose(dim, -1)
# Broadcast and reshape
a_broadcasted = a.unsqueeze(-1).expand(*a.shape, 128)
a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * 128)
a_broadcasted = a.unsqueeze(-1).expand(*a.shape, block_size_n)
a_reshaped = a_broadcasted.reshape(*a.shape[:-1], a.shape[-1] * block_size_n)
# Move back if needed
if dim != -1:
a_reshaped = a_reshaped.transpose(dim, -1)
return a_reshaped

if x_quant.dim() == 2: # For activation tensors [batch_size, hidden_size]
batch_size, hidden_size = x_quant.shape
num_blocks = (hidden_size + 127) // 128
scales = scales.view(batch_size, num_blocks, 1).expand(-1, -1, 128)
scales = scales[:, :, : hidden_size % 128] if hidden_size % 128 != 0 else scales
num_blocks = ceil_div(hidden_size, block_size_n)
scales = (
scales.view(batch_size, num_blocks, 1)
.expand(-1, -1, block_size_n)
.reshape(batch_size, -1)
)
scales = scales[:, :hidden_size]
else: # For weight tensors [..., in_dim, out_dim]
*_dims, in_dim, out_dim = x_quant.shape

Expand All @@ -924,10 +930,10 @@ def transform_dim(a: torch.Tensor, dim: int = -1) -> torch.Tensor:
scales = transform_dim(scales, -2) # Second-to-last dim

# Handle padding
if in_dim % 128 != 0:
scales = scales[..., : in_dim % 128, :]
if out_dim % 128 != 0:
scales = scales[..., :, : out_dim % 128]
if in_dim % block_size_n != 0:
scales = scales[..., : in_dim % block_size_n, :]
if out_dim % block_size_n != 0:
scales = scales[..., :, : out_dim % block_size_n]
Comment on lines +933 to +936
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic for handling padding for weight tensors appears to be incorrect. After transform_dim is applied, the dimensions of scales are padded up to be multiples of block_size_n. To truncate them back to the original dimensions, you should slice to in_dim and out_dim respectively, not in_dim % block_size_n and out_dim % block_size_n. The use of the modulo operator here is a bug.

This can be simplified to a single line that handles both dimensions, which is more concise and correct for all cases, including when dimensions are already multiples of block_size_n.

        scales = scales[..., :in_dim, :out_dim]

Copy link
Contributor Author

@rainj-me rainj-me Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the in_dim and out_dim in the scales per block.


x_dequant = x_quant.to(dtype) * scales.to(dtype)
return x_dequant.view(original_shape)
Expand Down