Skip to content

Commit bc52aa7

Browse files
authored
Added SpinQuant rotation unit test (#2925)
SpinQuant bias rotation fix; added test
1 parent 183068e commit bc52aa7

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

torchao/prototype/spinquant/hadamard_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
237237
assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!"
238238

239239
W = module.weight.data
240-
if module.bias is not None:
240+
if output and module.bias is not None:
241241
B = module.bias.data
242242
bias_dtype_orig = B.dtype
243243
B = B.float()
@@ -248,12 +248,12 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
248248
if output:
249249
had_K, K = get_hadK(out_features)
250250
W = matmul_hadU(W.t(), had_K.to(W.device), K).t()
251-
if module.bias is not None:
251+
if output and module.bias is not None:
252252
B = matmul_hadU(B, had_K.to(B.device), K)
253253
else:
254254
had_K, K = get_hadK(in_features)
255255
W = matmul_hadU(W, had_K.to(W.device), K)
256-
if module.bias is not None:
256+
if output and module.bias is not None:
257257
B = matmul_hadU(B, had_K.to(B.device), K)
258258
else:
259259
if R2 is not None:
@@ -268,7 +268,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
268268
temp = W.reshape(-1, shape[-1] // had_dim, had_dim)
269269
temp = temp.to(torch.float64) @ hadK
270270
W = temp.reshape(shape)
271-
if module.bias is not None:
271+
if output and module.bias is not None:
272272
shape = B.shape
273273
temp = B.reshape(-1, had_dim)
274274
temp = temp.to(torch.float64) @ hadK
@@ -278,5 +278,5 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
278278
W = W.t()
279279

280280
module.weight.data = W.to(dtype=dtype_orig)
281-
if module.bias is not None:
281+
if output and module.bias is not None:
282282
module.bias.data = B.to(dtype=bias_dtype_orig)
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torchao.prototype.spinquant.hadamard_utils import apply_exact_had_to_linear
13+
14+
15+
class TestSpinQuant(unittest.TestCase):
16+
def test_rotate_in_and_out(self):
17+
"""Perform rotation to output of linear layer and inverse rotation to input of next layer; test that the output is the same."""
18+
with torch.no_grad():
19+
layer1 = nn.Linear(256, 256, bias=True)
20+
layer2 = nn.Linear(256, 256, bias=True)
21+
model = nn.Sequential(layer1, layer2)
22+
input = torch.rand(256)
23+
output = model(input)
24+
apply_exact_had_to_linear(layer1, output=True)
25+
apply_exact_had_to_linear(layer2, output=False)
26+
new_output = model(input)
27+
torch.testing.assert_allclose(output, new_output)

0 commit comments

Comments
 (0)