Skip to content

Commit

Permalink
Update micro benchmarking code for AQT (#673)
Browse files Browse the repository at this point in the history
Summary:
Just benchmark a single linear module with (m * k) * (k * n) problem size

Test Plan:
python benchmarks/benchmark_aq.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored Aug 15, 2024
1 parent 18e38f1 commit 7f0621d
Showing 1 changed file with 74 additions and 21 deletions.
95 changes: 74 additions & 21 deletions benchmarks/benchmark_aq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,60 @@
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_4,
TORCH_VERSION_AT_LEAST_2_5,
)
from torchao.quantization.quant_api import (
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
import copy

def _int8wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_weight_only(**kwargs), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_woqtensors(mod, **kwargs)

def _int8da_int8w_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
quantize_(mod, int8_dynamic_activation_int8_weight(**kwargs), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int8_dqtensors(mod, **kwargs)

def _int4wo_api(mod, **kwargs):
if TORCH_VERSION_AT_LEAST_2_4:
kwargs_copy = kwargs.copy()
if "groupsize" in kwargs_copy:
kwargs_copy["group_size"] = kwargs_copy["groupsize"]
del kwargs_copy["groupsize"]
quantize_(mod, int4_weight_only(**kwargs_copy), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)
else:
change_linear_weights_to_int4_woqtensors(mod, **kwargs)

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
"""Single linear for m * k * n problem size
"""
def __init__(self, m=64, n=32, k=64, has_bias=False, dtype=torch.float, device="cuda"):
super().__init__()
self.linear1 = torch.nn.Linear(m, n, bias=False).to(torch.float)
self.linear2 = torch.nn.Linear(n, k, bias=False).to(torch.float)
self.m = m
self.dtype = dtype
self.device = device
self.linear = torch.nn.Linear(k, n, bias=has_bias).to(dtype=self.dtype, device=self.device)

def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"):
return (torch.randn(batch_size, self.linear1.in_features, dtype=dtype, device=device),)
def example_inputs(self):
return (torch.randn(self.m, self.linear.in_features, dtype=self.dtype, device=self.device),)

def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear(x)
return x

def _ref_change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs):
Expand Down Expand Up @@ -69,14 +105,17 @@ def _ref_change_linear_weights_to_woqtensors(model, filter_fn=None, **kwargs):
_ref_change_linear_weights_to_int4_woqtensors = _get_ref_change_linear_weights_to_woqtensors(Int4WeightOnlyQuantizedLinearWeight)


def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
torch._dynamo.config.cache_size_limit = 50000

@torch.no_grad
def _bench_quantized_tensor_subclass_perf(api, ref_api, M, N, K, kwargs=None):
if kwargs is None:
kwargs = {}

m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda")
m = ToyLinearModel(M, N, K, has_bias=True, dtype=torch.bfloat16, device="cuda").eval()
m_bf16 = copy.deepcopy(m)
m_ref = copy.deepcopy(m)
# setting batch_size to 20 to be compatible with the kernel
example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda")
example_inputs = m.example_inputs()

api(m, **kwargs)

Expand All @@ -91,27 +130,41 @@ def _bench_quantized_tensor_subclass_perf(api, ref_api, kwargs=None):
# perf comparison
from torchao.utils import benchmark_model
# warmup
WARMUP = 5
WARMUP = 20
RUNS = 100
m = torch.compile(m, mode='max-autotune', fullgraph=True)

benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)

m_ref = torch.compile(m_ref, mode='max-autotune', fullgraph=True)
benchmark_model(m_ref, WARMUP, example_inputs)
ref_elapsed_time = benchmark_model(m_ref, RUNS, example_inputs)

print(f"elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}")
assert elapsed_time < 1.05 * ref_elapsed_time
m = torch.compile(m, mode='max-autotune', fullgraph=True)
benchmark_model(m, WARMUP, example_inputs)
elapsed_time = benchmark_model(m, RUNS, example_inputs)


m_bf16 = torch.compile(m_bf16, mode='max-autotune', fullgraph=True)
benchmark_model(m_bf16, WARMUP, example_inputs)
bf16_elapsed_time = benchmark_model(m_bf16, RUNS, example_inputs)

print(f"{(M, N, K)}: elapsed time: {elapsed_time}, ref elapsed time: {ref_elapsed_time}, bf16 elapsed time: {bf16_elapsed_time}")

if __name__ == "__main__" and TORCH_VERSION_AT_LEAST_2_4 and torch.cuda.is_available():
all_shapes = [
(20, 2048, 2048),
]

print("_int8da_int8w_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_dqtensors, _ref_change_linear_weights_to_int8_dqtensors)
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int8da_int8w_api, _ref_change_linear_weights_to_int8_dqtensors, M, N, K)

print("_int8wo_api")
from torchao.quantization.quant_api import change_linear_weights_to_int8_woqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int8_woqtensors, _ref_change_linear_weights_to_int8_woqtensors)
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int8wo_api, _ref_change_linear_weights_to_int8_woqtensors, M, N, K)

print("_int4wo_api")
kwargs = {"groupsize": 32}
from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors
_bench_quantized_tensor_subclass_perf(change_linear_weights_to_int4_woqtensors, _ref_change_linear_weights_to_int4_woqtensors, kwargs)
for M, N, K in all_shapes:
_bench_quantized_tensor_subclass_perf(_int4wo_api, _ref_change_linear_weights_to_int4_woqtensors, M, N, K, kwargs)

0 comments on commit 7f0621d

Please sign in to comment.