@@ -237,7 +237,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
237237 assert is_pow2 (had_dim ), "Hadamard dimension must be a power of 2!"
238238
239239 W = module .weight .data
240- if module .bias is not None :
240+ if output and module .bias is not None :
241241 B = module .bias .data
242242 bias_dtype_orig = B .dtype
243243 B = B .float ()
@@ -248,12 +248,12 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
248248 if output :
249249 had_K , K = get_hadK (out_features )
250250 W = matmul_hadU (W .t (), had_K .to (W .device ), K ).t ()
251- if module .bias is not None :
251+ if output and module .bias is not None :
252252 B = matmul_hadU (B , had_K .to (B .device ), K )
253253 else :
254254 had_K , K = get_hadK (in_features )
255255 W = matmul_hadU (W , had_K .to (W .device ), K )
256- if module .bias is not None :
256+ if output and module .bias is not None :
257257 B = matmul_hadU (B , had_K .to (B .device ), K )
258258 else :
259259 if R2 is not None :
@@ -268,7 +268,7 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
268268 temp = W .reshape (- 1 , shape [- 1 ] // had_dim , had_dim )
269269 temp = temp .to (torch .float64 ) @ hadK
270270 W = temp .reshape (shape )
271- if module .bias is not None :
271+ if output and module .bias is not None :
272272 shape = B .shape
273273 temp = B .reshape (- 1 , had_dim )
274274 temp = temp .to (torch .float64 ) @ hadK
@@ -278,5 +278,5 @@ def apply_exact_had_to_linear(module, had_dim=-1, output=False, R2=None):
278278 W = W .t ()
279279
280280 module .weight .data = W .to (dtype = dtype_orig )
281- if module .bias is not None :
281+ if output and module .bias is not None :
282282 module .bias .data = B .to (dtype = bias_dtype_orig )
0 commit comments