Skip to content

Commit

Permalink
Revert "pin nightly to 2.5.0.dev20240709+cu121 (#505)"
Browse files Browse the repository at this point in the history
This reverts commit cc871c5.
  • Loading branch information
jerryzh168 committed Jul 16, 2024
1 parent 6e7cf71 commit f03b194
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.5.0.dev20240709+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"
- name: CPU 2.2.2
Expand All @@ -48,7 +48,7 @@ jobs:
gpu-arch-version: ""
- name: CPU Nightly
runs-on: linux.4xlarge
torch-spec: '--pre torch==2.5.0.dev20240709+cpu --index-url https://download.pytorch.org/whl/nightly/cpu'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cpu'
gpu-arch-type: "cpu"
gpu-arch-version: ""

Expand Down
4 changes: 4 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ def test_dequantize_int8_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand All @@ -641,6 +642,7 @@ def test_dequantize_int4_weight_only_quant_subclass(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_dequantize_int4_weight_only_quant_subclass_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Currently only supports bfloat16.")
Expand Down Expand Up @@ -821,6 +823,7 @@ def test_int8_weight_only_quant_with_freeze(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand All @@ -835,6 +838,7 @@ def test_int4_weight_only_quant_subclass_api(self, device, dtype):

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "int4 requires torch nightly.")
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 skipping 2.5+ for now")
def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest(f"Fails for {dtype}")
Expand Down
50 changes: 26 additions & 24 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK):
TEST_CONFIGS_DEQUANT = list(itertools.product(SHAPES, INNERKTILES, QGROUP_SIZES))

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK, ids=str)
def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):
N, K = shape
Expand All @@ -107,14 +108,15 @@ def test_unpack_tensor_core_tiled_layout_correctness(shape, inner_k_tiles):

# TODO: Fix "test_aot_dispatch_dynamic" test failure
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(TORCH_VERSION_AFTER_2_5, reason="weight packing is updated in 2.5+")
@pytest.mark.parametrize("shape, inner_k_tiles", TEST_CONFIGS_UNPACK , ids=str)
def test_unpack_tensor_core_tiled_layout_op(shape, inner_k_tiles):
test_utils = [
"test_schema",
"test_autograd_registration",
"test_faketensor",
]

# TODO: Figure out why test fails unless torch >= 2.5
if TORCH_VERSION_AFTER_2_5:
test_utils.append("test_aot_dispatch_dynamic")
Expand All @@ -137,10 +139,10 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
assert scales.shape == zeros.shape

midpoint = 2 ** (nbits - 1)

#Convert fron u4 -> s4 and upcast to bfloat16
q = q.sub(midpoint).to(dtype)

# Dequantize
q = q.reshape(-1, group_size)
dq = q * scales.reshape(-1, 1) + zeros.reshape(-1, 1)
Expand All @@ -152,18 +154,18 @@ def dequant_ref(q, scales, zeros, group_size, nbits=4, dtype=torch.bfloat16):
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
dtype = torch.bfloat16

device = "cuda"

t = torch.randn(n, k, dtype=dtype, device=device)
scales, zeros = get_groupwise_affine_qparams(t, n_bit=4, groupsize=group_size, dtype=dtype)

# Quantize
q = groupwise_affine_quantize_tensor_from_qparams(
t, scales, zeros, n_bit=4, groupsize=group_size
)

# Pack to tensor core layout
packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)
Expand All @@ -174,7 +176,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
q, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
Expand All @@ -183,23 +185,23 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
# Test that the `dequant` kernel gives same results as identity matrix-based dequant

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

Expand All @@ -210,7 +212,7 @@ def test_dequantize_tensor_core_tiled_layout_correctness_quant_dequant(shape, in
@pytest.mark.parametrize("shape, inner_k_tiles, group_size", TEST_CONFIGS_DEQUANT, ids=str)
def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shape, inner_k_tiles, group_size):
n, k = shape
dtype = torch.bfloat16
dtype = torch.bfloat16
device = "cuda"

# Quantize and pack
Expand All @@ -222,13 +224,13 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap

packed = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

# Unpack and dequantize
unpacked = torchao.ops.unpack_tensor_core_tiled_layout(packed, inner_k_tiles)
dq_ao = groupwise_affine_dequantize_tensor_from_qparams(
unpacked, scales, zeros, n_bit=4, groupsize=group_size
)

# Dequantize by passing in an identity matrix as the activation
a_eye = torch.eye(k, device=device, dtype=dtype)
dq_id = torch.ops.aten._weight_int4pack_mm(
Expand All @@ -237,23 +239,23 @@ def test_dequantize_tensor_core_tiled_layout_correctness_unpack_and_dequant(shap
group_size,
scales_and_zeros,
).t()

# Actual operation to test
dq_op = torchao.ops.dequantize_tensor_core_tiled_layout(packed, scales_and_zeros, group_size, inner_k_tiles)

# Compare results
diff_ao_id = (dq_id - dq_ao).abs().max()
diff_op_id = (dq_op - dq_id).abs().max()
diff_op_ao = (dq_op - dq_ao).abs().max()

# There are slight numerical differences when dequantizing with an identity matrix when compared to `groupwise_affine_dequantize`
# Since the `dequantize_tensor_core_layout` kernel relies on the same underlying bit twiddling tricks for fast
# conversion from u4 -> s4 -> bf16, the identity matrix dequant hack and `dequantize_tensor_core_layout` are
# expected to give same results, while both will have similar numerical differences to `groupwise_affine_dequantize`.
# Test that the `dequant` kernel gives same results as identity matrix-based dequant

# Test that the `dequant` kernel gives same results as identity matrix-based dequant
assert diff_op_id == 0

# Test that the `dequant` kernel gives same numerical diffs as the `groupwise_affine_dequantize` when compared against the identity matrix
assert diff_op_ao == diff_ao_id

Expand All @@ -271,7 +273,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
scales = torch.randn(n, q_groups, dtype=torch.bfloat16, device=device)
zeros = torch.randn_like(scales)
scales_and_zeros = pack_tinygemm_scales_and_zeros(scales, zeros)

test_utils = [
"test_schema",
"test_autograd_registration",
Expand All @@ -287,4 +289,4 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
)

if __name__ == "__main__":
run_tests()
run_tests()

0 comments on commit f03b194

Please sign in to comment.