Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed May 2, 2024
1 parent e62de87 commit e929df2
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions src/compressed_tensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,15 @@ def fake_quantize(

columns = x.shape[1]
if columns >= group_size:
assert columns % group_size == 0
if columns % group_size != 0:
raise ValueError(
"tesnor column shape must be divisble "
f"by the given group_size {group_size}"
)
for i in range(ceil(columns / group_size)):

# scale.shape should be [nchan, ndim]
# sc.shape should be [nchan, 1] after unsqueeze

sc = scale[:, i].unsqueeze(1)
zp = zero_point[:, i].unsqueeze(1)

Expand All @@ -122,8 +126,10 @@ def fake_quantize(

# per-token
elif args.strategy == QuantizationStrategy.TOKEN:
# before: scale shape = [channel_size]
# after: scale shape = [channel_size, 1]
# before: scale shape = [num_tokens]
# after: scale shape = [num_tokens, 1]
# x.shape = 1, num_tokens, 1]
# scale gets broadcasted as expected withput having [1, num_tokens, 1] shape

scale = scale.unsqueeze(1)
zero_point = zero_point.unsqueeze(1)
Expand Down

0 comments on commit e929df2

Please sign in to comment.