Skip to content

Commit

Permalink
before tests, correctness verified
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Apr 12, 2024
1 parent 0804be3 commit 971d140
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 103 deletions.
60 changes: 0 additions & 60 deletions bin/quant.py

This file was deleted.

33 changes: 0 additions & 33 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@

import torch
from sparsetensors.quantization.lifecycle.status import QuantizationStatus

# from sparsetensors.quantization.utils.quantization_scheme import (
# QuantizationArgs,
# QuantizationScheme,
# )
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module
Expand Down Expand Up @@ -59,35 +54,7 @@ def fake_quantize(
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1)
columns = x.shape[1]
Q = torch.zeros_like(x)
# for i1 in range(0, columns, args.block_size):
# i2 = min(i1 + args.block_size, columns)
# count = i2 - i1

# W1 = x[:, i1:i2].clone()
# Q1 = torch.zeros_like(W1)

# for i in range(count):
# w = W1[:, i]
# breakpoint()
# if args.group_size != -1:
# if (i1 + i) % args.group_size == 0:
# xmin, xmax = get_qparams(
# x[:, (i1 + i) : (i1 + i + args.group_size)], args.symmetric
# )
# scale, zero = get_scale_zero_point(
# x[:, (i1 + i) : (i1 + i + args.group_size)],
# max_q,
# xmax,
# xmin,
# args.symmetric,
# args.group_size,
# )

# q = quantize(w.unsqueeze(1), scale, zero, max_q).flatten()
# Q1[:, i] = q
# Q[:, i1:i2] = Q1
Q = quantize(x, scale, zero_point, max_q)
return dequantize(Q, scale, zero_point)

Expand Down
5 changes: 0 additions & 5 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,6 @@
import torch
from sparsetensors.quantization.lifecycle.forward import wrap_module_forward_quantized
from sparsetensors.quantization.lifecycle.status import QuantizationStatus

# from sparsetensors.quantization.utils.quantization_scheme import (
# QuantizationArgs,
# QuantizationScheme,
# )
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsetensors.quantization.quant_scheme import QuantizationScheme
from torch.nn import Module, Parameter
Expand Down
1 change: 0 additions & 1 deletion src/sparsetensors/quantization/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from typing import Optional, Tuple

# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs
from sparsetensors.quantization.quant_args import QuantizationArgs
from sparsezoo.utils.registry import RegistryMixin
from torch import FloatTensor, IntTensor, Tensor
Expand Down
3 changes: 0 additions & 3 deletions src/sparsetensors/quantization/observers/memoryless.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
from torch import FloatTensor, IntTensor, Tensor


# from sparsetensors.quantization.utils.quantization_scheme import QuantizationArgs


__all__ = ["MemorylessObserver"]


Expand Down
3 changes: 2 additions & 1 deletion src/sparsetensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class QuantizationArgs(BaseModel):
observer: str = Field(
default="minmax",
description=(
"The class to use to compute the quantization params - scale and zero-point'"
"The class to use to compute the quantization param - "
"scale and zero-point'"
),
)
observer_kwargs: Dict[str, Any] = Field(
Expand Down

0 comments on commit 971d140

Please sign in to comment.