@@ -237,16 +237,24 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
237237 assert is_pow2 (had_dim ), "Hadamard dimension must be a power of 2!"
238238
239239 W = module .weight .data
240+ if module .bias is not None :
241+ B = module .bias .data
242+ bias_dtype_orig = B .dtype
243+ B = B .float ()
240244 dtype_orig = W .dtype
241245 W = W .float ()
242246
243247 if had_dim == - 1 :
244248 if output :
245249 had_K , K = get_hadK (out_features )
246250 W = matmul_hadU (W .t (), had_K .to (W .device ), K ).t ()
251+ if module .bias is not None :
252+ B = matmul_hadU (B , had_K .to (B .device ), K )
247253 else :
248254 had_K , K = get_hadK (in_features )
249255 W = matmul_hadU (W , had_K .to (W .device ), K )
256+ if module .bias is not None :
257+ B = matmul_hadU (B , had_K .to (B .device ), K )
250258 else :
251259 if R2 is not None :
252260 hadK = R2 .to (torch .float64 )
@@ -260,8 +268,15 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
260268 temp = W .reshape (- 1 , shape [- 1 ] // had_dim , had_dim )
261269 temp = temp .to (torch .float64 ) @ hadK
262270 W = temp .reshape (shape )
271+ if module .bias is not None :
272+ shape = B .shape
273+ temp = B .reshape (- 1 , had_dim )
274+ temp = temp .to (torch .float64 ) @ hadK
275+ B = temp .reshape (shape )
263276
264277 if output :
265278 W = W .t ()
266279
267280 module .weight .data = W .to (dtype = dtype_orig )
281+ if module .bias is not None :
282+ module .bias .data = B .to (dtype = bias_dtype_orig )
0 commit comments