diff --git a/test/test_ops.py b/test/test_ops.py index 1420e1df8..26f1b8414 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -271,6 +271,8 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size device = "cuda" q = torch.randint(0, 16, shape, dtype=torch.int, device=device) + if TORCH_VERSION_AFTER_2_5: + q = (q[::, ::2] << 4 | q[::, 1::2]).to(torch.uint8) packed_w = torch._convert_weight_to_int4pack(q, inner_k_tiles) q_groups = k // group_size scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)