Skip to content

Commit

Permalink
pass test_quant_args
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed May 3, 2024
1 parent e929df2 commit 1229c5a
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/compressed_tensors/quantization/quant_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class QuantizationArgs(BaseModel):
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
strategy: Optional[QuantizationStrategy] = None
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
observer: str = Field(
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_observer(self):

return Observer.load_from_registry(self.observer, quantization_args=self)

@validator("strategy", pre=True)
@validator("strategy", pre=True, always=True)
def validate_strategy(cls, value, values):
group_size = values.get("group_size")

Expand All @@ -114,8 +114,7 @@ def validate_strategy(cls, value, values):
"group_size > 0 for strategy='group' and "
"group_size = -1 for 'channel'"
)
# breakpoint()
group_size = 128

if value == QuantizationStrategy.GROUP:
if group_size is None:
raise ValueError(f"strategy {value} requires group_size to be set.")
Expand Down

0 comments on commit 1229c5a

Please sign in to comment.