Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions src/transformers/integrations/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,12 @@ def should_convert_module(current_key_name, patterns):
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
from ..integrations.tensor_parallel import shard_and_distribute_module

model = kwargs.get("model", None)
empty_param = kwargs.get("empty_param", None)
casting_dtype = kwargs.get("casting_dtype", None)
to_contiguous = kwargs.get("to_contiguous", None)
rank = kwargs.get("rank", None)
device_mesh = kwargs.get("device_mesh", None)
model = kwargs.get("model")
empty_param = kwargs.get("empty_param")
casting_dtype = kwargs.get("casting_dtype")
to_contiguous = kwargs.get("to_contiguous")
rank = kwargs.get("rank")
device_mesh = kwargs.get("device_mesh")

for proj in ["gate_up_proj", "down_proj"]:
if proj in param_name:
Expand Down Expand Up @@ -357,12 +357,12 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwa
)
from ..integrations.tensor_parallel import shard_and_distribute_module

model = kwargs.get("model", None)
empty_param = kwargs.get("empty_param", None)
casting_dtype = kwargs.get("casting_dtype", None)
to_contiguous = kwargs.get("to_contiguous", None)
rank = kwargs.get("rank", None)
device_mesh = kwargs.get("device_mesh", None)
model = kwargs.get("model")
empty_param = kwargs.get("empty_param")
casting_dtype = kwargs.get("casting_dtype")
to_contiguous = kwargs.get("to_contiguous")
rank = kwargs.get("rank")
device_mesh = kwargs.get("device_mesh")

for proj in ["gate_up_proj", "down_proj"]:
if proj in param_name:
Expand Down
12 changes: 6 additions & 6 deletions src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def validate_environment(self, *args, **kwargs):
global triton_kernels_hub
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")

device_map = kwargs.get("device_map", None)
device_map = kwargs.get("device_map")
if device_map is None:
logger.warning_once(
"You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set "
Expand Down Expand Up @@ -210,11 +210,11 @@ def create_quantized_param(
# we take this path if already quantized but not in a compatible way
# The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
else:
empty_param = kwargs.get("empty_param", None)
casting_dtype = kwargs.get("casting_dtype", None)
to_contiguous = kwargs.get("to_contiguous", None)
rank = kwargs.get("rank", None)
device_mesh = kwargs.get("device_mesh", None)
empty_param = kwargs.get("empty_param")
casting_dtype = kwargs.get("casting_dtype")
to_contiguous = kwargs.get("to_contiguous")
rank = kwargs.get("rank")
device_mesh = kwargs.get("device_mesh")
if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
# blocks and scales have the same length that's this works for both
module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
Expand Down