@@ -314,12 +314,12 @@ def should_convert_module(current_key_name, patterns):
314314def dequantize (module , param_name , param_value , target_device , dq_param_name , ** kwargs ):
315315 from ..integrations .tensor_parallel import shard_and_distribute_module
316316
317- model = kwargs .get ("model" , None )
318- empty_param = kwargs .get ("empty_param" , None )
319- casting_dtype = kwargs .get ("casting_dtype" , None )
320- to_contiguous = kwargs .get ("to_contiguous" , None )
321- rank = kwargs .get ("rank" , None )
322- device_mesh = kwargs .get ("device_mesh" , None )
317+ model = kwargs .get ("model" )
318+ empty_param = kwargs .get ("empty_param" )
319+ casting_dtype = kwargs .get ("casting_dtype" )
320+ to_contiguous = kwargs .get ("to_contiguous" )
321+ rank = kwargs .get ("rank" )
322+ device_mesh = kwargs .get ("device_mesh" )
323323
324324 for proj in ["gate_up_proj" , "down_proj" ]:
325325 if proj in param_name :
@@ -357,12 +357,12 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwa
357357 )
358358 from ..integrations .tensor_parallel import shard_and_distribute_module
359359
360- model = kwargs .get ("model" , None )
361- empty_param = kwargs .get ("empty_param" , None )
362- casting_dtype = kwargs .get ("casting_dtype" , None )
363- to_contiguous = kwargs .get ("to_contiguous" , None )
364- rank = kwargs .get ("rank" , None )
365- device_mesh = kwargs .get ("device_mesh" , None )
360+ model = kwargs .get ("model" )
361+ empty_param = kwargs .get ("empty_param" )
362+ casting_dtype = kwargs .get ("casting_dtype" )
363+ to_contiguous = kwargs .get ("to_contiguous" )
364+ rank = kwargs .get ("rank" )
365+ device_mesh = kwargs .get ("device_mesh" )
366366
367367 for proj in ["gate_up_proj" , "down_proj" ]:
368368 if proj in param_name :
0 commit comments