Skip to content

Commit

Permalink
Revert "Refactor quant_llm to work with affine quantized tensor (#696)"
Browse files Browse the repository at this point in the history
This reverts commit 0fed444.
  • Loading branch information
jerryzh168 committed Aug 28, 2024
1 parent 0fed444 commit 2add1b5
Show file tree
Hide file tree
Showing 17 changed files with 400 additions and 600 deletions.
10 changes: 5 additions & 5 deletions benchmarks/benchmark_fp6.py → benchmarks/benchmark_fp6_llm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import torch
import pandas as pd
import torch.nn.functional as F
# from torchao.prototype.quant_llm import QuantLlmLinearWeight
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
from torchao.prototype.quant_llm import QuantLlmLinearWeight
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)

fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
7 changes: 2 additions & 5 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
int8_dynamic_activation_int8_weight,
quantize_,
autoquant,
fpx_weight_only,
)
from torchao.sparsity import (
sparsify_,
Expand Down Expand Up @@ -60,8 +59,6 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
elif quantization == "int4wo":
# note cannot quantize this model on cpu and run it on cuda at this time
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "fp6":
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))

Expand All @@ -82,7 +79,7 @@ def all_linear(mod, name):
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

if sparsity and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

Expand Down Expand Up @@ -114,7 +111,7 @@ def all_linear(mod, name):
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp6", "None"], help='Which quantization technique to apply')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
Expand Down
75 changes: 46 additions & 29 deletions test/dtypes/test_fpx.py → test/prototype/test_quant_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,23 @@
parametrize,
run_tests,
)
from torchao.dtypes.fpx import (
FpxTensorCoreAQTLayout,
FpxTensorCoreLayoutType,
from torchao.prototype.quant_llm import (
QuantLlmLinearWeight,
quant_llm_fpx_weight_only,
fp6_llm_weight_only,
to_scaled_tc_fpx,
from_scaled_tc_fpx,
)
from torchao.dtypes.fpx.fpx import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6
from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32
from torchao.quantization import (
quantize_,
fpx_weight_only,
)

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.quantization.quant_api import quantize_


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
_FPx_DTYPES = [(3, 2), (2, 2)]


class TestFpxTensorCoreAQTLayout(TestCase):
class TestQuantLlmLinearWeight(TestCase):
@parametrize("device", _DEVICES)
def test_pack_tc_fp6_correctness(self, device):
x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device)
Expand Down Expand Up @@ -73,40 +69,61 @@ def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
def test_to_copy_device(self, ebits, mbits):
from torchao.quantization.quant_primitives import (
choose_qparams_affine_fpx,
quantize_affine_fpx,
)

x = torch.randn(256, 64)
scale = choose_qparams_affine_fpx(x, ebits, mbits)
x = quantize_affine_fpx(x, scale, ebits, mbits)
layout_type = FpxTensorCoreLayoutType(ebits, mbits)
fpx_layout_tensor = FpxTensorCoreAQTLayout.from_plain(x, scale, None, layout_type).cuda()
assert fpx_layout_tensor.device.type == "cuda"
fpx_layout_tensor = fpx_layout_tensor.cpu()
assert fpx_layout_tensor.device.type == "cpu"
fpx = QuantLlmLinearWeight.from_float(x, ebits, mbits).cuda()
assert fpx.device.type == "cuda"
fpx = fpx.cpu()
assert fpx.device.type == "cpu"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("leading_dims", [(4,), (2, 4)])
@parametrize("bias", [False, True])
def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims):
OC, IC = 256, 64
device = "cuda"

fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half)
fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None

fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits)

x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half)
out = torch.nn.functional.linear(x, fpx_weight, fp16_bias)
assert out.shape == leading_dims + (OC,)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="quantization only works with torch.compile for 2.5+")
@parametrize("ebits,mbits", _FPx_DTYPES)
@parametrize("bias", [False, True])
def test_fpx_weight_only(self, ebits, mbits, bias):
def test_quant_llm_quantize(self, ebits, mbits, bias):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits))

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_fp6_llm_quantize(self):
N, OC, IC = 4, 256, 64
device = "cuda"

linear = torch.nn.Linear(IC, OC, bias=bias, device=device, dtype=torch.half)
linear = torch.nn.Linear(IC, OC, device=device)
fpx_linear = copy.deepcopy(linear)
quantize_(fpx_linear, fpx_weight_only(ebits, mbits))
quantize_(fpx_linear, fp6_llm_weight_only())

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fpx_linear(x)
actual = torch.compile(fpx_linear, fullgraph=True)(x)
# somehow compile now changes the result a bit
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFpxTensorCoreAQTLayout)
instantiate_parametrized_tests(TestQuantLlmLinearWeight)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)
from torch.testing._internal.optests import opcheck
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.fpx import from_scaled_tc_fpx
from torchao.prototype.quant_llm import from_scaled_tc_fpx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
import pytest

Expand Down Expand Up @@ -318,7 +318,7 @@ def test_dequantize_tensor_core_tiled_layout_op(shape, inner_k_tiles, group_size
MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(itertools.product(
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_K_CHUNKS, MARLIN_24_N_CHUNKS, MARLIN_24_SUPPORTED_NUM_BITS,
MARLIN_24_SUPPORTED_GROUP_SIZES, MNK_FACTORS
))

Expand Down Expand Up @@ -399,7 +399,7 @@ def test_marlin_24(k_chunk, n_chunk, num_bits, group_size, mnk_factors):
workspace_24 = marlin_24_workspace(size_n)

fn_inputs = (
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
a_input, marlin_24_q_w_comp, meta, marlin_24_scale, workspace_24,
num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1],
)
output = torchao.ops.marlin_24_gemm(*fn_inputs)
Expand Down
Loading

0 comments on commit 2add1b5

Please sign in to comment.