Skip to content

Commit 71bfccb

Browse files
authored
SpinQuant rotate bias (#2913)
Summary: Added bias rotation. This is needed to apply SpinQuant R2 to models which have bias such as Qwen models. Differential Revision: D81352249
1 parent 266f749 commit 71bfccb

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

torchao/prototype/spinquant/hadamard_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,16 +237,24 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
237237
assert is_pow2(had_dim), "Hadamard dimension must be a power of 2!"
238238

239239
W = module.weight.data
240+
if module.bias is not None:
241+
B = module.bias.data
242+
bias_dtype_orig = B.dtype
243+
B = B.float()
240244
dtype_orig = W.dtype
241245
W = W.float()
242246

243247
if had_dim == -1:
244248
if output:
245249
had_K, K = get_hadK(out_features)
246250
W = matmul_hadU(W.t(), had_K.to(W.device), K).t()
251+
if module.bias is not None:
252+
B = matmul_hadU(B, had_K.to(B.device), K)
247253
else:
248254
had_K, K = get_hadK(in_features)
249255
W = matmul_hadU(W, had_K.to(W.device), K)
256+
if module.bias is not None:
257+
B = matmul_hadU(B, had_K.to(B.device), K)
250258
else:
251259
if R2 is not None:
252260
hadK = R2.to(torch.float64)
@@ -260,8 +268,15 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
260268
temp = W.reshape(-1, shape[-1] // had_dim, had_dim)
261269
temp = temp.to(torch.float64) @ hadK
262270
W = temp.reshape(shape)
271+
if module.bias is not None:
272+
shape = B.shape
273+
temp = B.reshape(-1, had_dim)
274+
temp = temp.to(torch.float64) @ hadK
275+
B = temp.reshape(shape)
263276

264277
if output:
265278
W = W.t()
266279

267280
module.weight.data = W.to(dtype=dtype_orig)
281+
if module.bias is not None:
282+
module.bias.data = B.to(dtype=bias_dtype_orig)

0 commit comments

Comments
 (0)