@@ -75,20 +75,64 @@ def get_min_capability(cls) -> int:
7575 def get_config_filenames (cls ) -> list [str ]:
7676 return ["hf_quant_config.json" ]
7777
78+ @classmethod
79+ def override_quantization_method (
80+ cls , hf_quant_cfg , user_quant ) -> Optional [QuantizationMethods ]:
81+ """Detect if this ModelOpt config should be used based on
82+ quantization config."""
83+
84+ if hf_quant_cfg is None :
85+ return None
86+
87+ # Use the community standard 'quant_method'
88+ quant_method = hf_quant_cfg .get ("quant_method" , "" ).lower ()
89+
90+ # Only proceed if the method is explicitly "modelopt"
91+ if quant_method != "modelopt" :
92+ return None
93+
94+ # Look for ModelOpt-specific config structure
95+ if "quantization" in hf_quant_cfg :
96+ quant_config = hf_quant_cfg ["quantization" ]
97+ if isinstance (quant_config , dict ):
98+ quant_algo = quant_config .get ("quant_algo" , "" )
99+ if "FP8" in quant_algo :
100+ return "modelopt"
101+ else :
102+ # Check for compressed-tensors style config with specific quant_algo
103+ quant_algo = hf_quant_cfg .get ("quant_algo" , "" )
104+ if isinstance (quant_algo , str ) and "FP8" in quant_algo :
105+ return "modelopt"
106+
107+ return None
108+
78109 @classmethod
79110 def from_config (cls , config : dict [str , Any ]) -> "ModelOptFp8Config" :
80- quant_config = cls .get_from_keys (config , ["quantization" ])
81- quant_method = quant_config ["quant_algo" ]
82- kv_cache_quant_method = cls .get_from_keys (
83- config , ["quantization" ]).get ("kv_cache_quant_algo" )
84- exclude_modules = cls .get_from_keys (
85- config , ["quantization" ]).get ("exclude_modules" )
111+ # Handle both ModelOpt format and compressed-tensors style format
112+ if "quantization" in config :
113+ # ModelOpt format: {"quantization": {"quant_algo": "..."}}
114+ quant_config = cls .get_from_keys (config , ["quantization" ])
115+ if not isinstance (quant_config , dict ):
116+ raise ValueError (
117+ "Expected 'quantization' to be a dictionary in config" )
118+ quant_method = quant_config .get ("quant_algo" , "" )
119+ if not quant_method :
120+ raise ValueError ("Missing 'quant_algo' in quantization config" )
121+ kv_cache_quant_method = quant_config .get ("kv_cache_quant_algo" )
122+ exclude_modules = quant_config .get ("exclude_modules" )
123+ else :
124+ # Compressed-tensors style format:
125+ # {"quant_algo": "...", "quant_method": "modelopt"}
126+ quant_method = config .get ("quant_algo" , "" )
127+ kv_cache_quant_method = config .get ("kv_cache_quant_algo" )
128+ exclude_modules = config .get ("exclude_modules" )
86129
87130 if quant_method not in QUANT_ALGOS :
88- raise ValueError (f"ModelOpt currently only supports: { QUANT_ALGOS } "
89- " quantizations in vLLM. Please check the "
90- "`hf_quant_config.json` file for your model's "
91- "quant configuration." )
131+ raise ValueError (
132+ f"ModelOpt currently only supports: { QUANT_ALGOS } "
133+ "quantizations in vLLM. Please check the "
134+ "`hf_quant_config.json` file for your model's "
135+ "quant configuration." )
92136 is_checkpoint_fp8_serialized = ("FP8" in quant_method )
93137
94138 return cls (is_checkpoint_fp8_serialized , kv_cache_quant_method ,
@@ -434,7 +478,7 @@ class ModelOptNvFp4Config(QuantizationConfig):
434478 def __init__ (
435479 self ,
436480 is_checkpoint_nvfp4_serialized : bool ,
437- kv_cache_quant_algo : str ,
481+ kv_cache_quant_algo : Optional [ str ] ,
438482 exclude_modules : list [str ],
439483 group_size : int = 16 ,
440484 ) -> None :
@@ -465,24 +509,138 @@ def get_min_capability(cls) -> int:
465509 def get_config_filenames (cls ) -> list [str ]:
466510 return ["hf_quant_config.json" ]
467511
512+ @classmethod
513+ def override_quantization_method (
514+ cls , hf_quant_cfg , user_quant ) -> Optional [QuantizationMethods ]:
515+ """Detect if this ModelOpt FP4 config should be used based on
516+ quantization config."""
517+ if hf_quant_cfg is None :
518+ return None
519+
520+ # Use the community standard 'quant_method'
521+ quant_method = hf_quant_cfg .get ("quant_method" , "" ).lower ()
522+
523+ # Only proceed if the method is explicitly "modelopt"
524+ if quant_method != "modelopt" :
525+ return None
526+
527+ # Look for ModelOpt-specific config structure
528+ if "quantization" in hf_quant_cfg :
529+ quant_config = hf_quant_cfg ["quantization" ]
530+ if isinstance (quant_config , dict ):
531+ quant_algo = quant_config .get ("quant_algo" , "" )
532+ if "NVFP4" in quant_algo :
533+ return "modelopt_fp4"
534+ else :
535+ # Check for compressed-tensors style config with specific
536+ # quant_algo field
537+ quant_algo = hf_quant_cfg .get ("quant_algo" , "" )
538+ if isinstance (quant_algo , str ) and "FP4" in quant_algo .upper ():
539+ return "modelopt_fp4"
540+
541+ return None
542+
468543 @classmethod
469544 def from_config (cls , config : dict [str , Any ]) -> "ModelOptNvFp4Config" :
470- quant_config = cls .get_from_keys (config , ["quantization" ])
471- quant_method = quant_config ["quant_algo" ]
545+ # Handle both traditional ModelOpt format and compressed-tensors
546+ # style format
547+ if "quantization" in config :
548+ # Traditional ModelOpt format:
549+ # {"quantization": {"quant_algo": "..."}}
550+ quant_config = cls .get_from_keys (config , ["quantization" ])
551+ if not isinstance (quant_config , dict ):
552+ raise ValueError (
553+ "Expected 'quantization' to be a dictionary in config" )
554+
555+ quant_method = quant_config .get ("quant_algo" , "" )
556+ if not quant_method :
557+ raise ValueError ("Missing 'quant_algo' in quantization config" )
558+
559+ # Handle kv_cache_quant_algo with proper type validation
560+ kv_cache_quant_algo_raw = quant_config .get ("kv_cache_quant_algo" )
561+ if kv_cache_quant_algo_raw is None :
562+ # No KV cache quantization by default
563+ kv_cache_quant_algo = None
564+ elif isinstance (kv_cache_quant_algo_raw , str ):
565+ kv_cache_quant_algo = kv_cache_quant_algo_raw
566+ else :
567+ raise ValueError (f"kv_cache_quant_algo must be a string, got "
568+ f"{ type (kv_cache_quant_algo_raw )} " )
569+
570+ # Handle group_size with proper type validation
571+ group_size_raw = quant_config .get ("group_size" )
572+ if group_size_raw is None :
573+ group_size = 16 # Default value
574+ elif isinstance (group_size_raw , int ):
575+ group_size = group_size_raw
576+ else :
577+ try :
578+ group_size = int (group_size_raw )
579+ except (ValueError , TypeError ):
580+ raise ValueError (f"group_size must be an integer, got "
581+ f"{ type (group_size_raw )} " ) from None
582+
583+ exclude_modules = quant_config .get ("exclude_modules" , [])
584+ if not isinstance (exclude_modules , list ):
585+ raise ValueError (f"exclude_modules must be a list, got "
586+ f"{ type (exclude_modules )} " )
587+ else :
588+ # Compressed-tensors style format:
589+ # {"quant_algo": "...", "quant_method": "modelopt"}
590+ quant_method = config .get ("quant_algo" , "" )
591+
592+ # Handle kv_cache_quant_algo with proper type validation
593+ kv_cache_quant_algo_raw = config .get ("kv_cache_quant_algo" )
594+ if kv_cache_quant_algo_raw is None :
595+ # No KV cache quantization by default
596+ kv_cache_quant_algo = None
597+ elif isinstance (kv_cache_quant_algo_raw , str ):
598+ kv_cache_quant_algo = kv_cache_quant_algo_raw
599+ else :
600+ raise ValueError (f"kv_cache_quant_algo must be a string, got "
601+ f"{ type (kv_cache_quant_algo_raw )} " )
602+
603+ # Handle group_size with proper type validation
604+ group_size_raw = config .get ("group_size" )
605+ if group_size_raw is None :
606+ group_size = 16 # Default value
607+ elif isinstance (group_size_raw , int ):
608+ group_size = group_size_raw
609+ else :
610+ try :
611+ group_size = int (group_size_raw )
612+ except (ValueError , TypeError ):
613+ raise ValueError (f"group_size must be an integer, got "
614+ f"{ type (group_size_raw )} " ) from None
615+
616+ exclude_modules = config .get ("exclude_modules" , [])
617+ if not isinstance (exclude_modules , list ):
618+ raise ValueError (f"exclude_modules must be a list, got "
619+ f"{ type (exclude_modules )} " )
620+
472621 if quant_method not in QUANT_ALGOS :
473- raise ValueError (f"ModelOpt currently only supports: { QUANT_ALGOS } "
474- " quantizations in vLLM. Please check the "
475- "`hf_quant_config.json` file for your model's "
476- "quant configuration." )
622+ raise ValueError (
623+ f"ModelOpt currently only supports: { QUANT_ALGOS } "
624+ "quantizations in vLLM. Please check the "
625+ "`hf_quant_config.json` file for your model's "
626+ "quant configuration." )
477627 is_checkpoint_nvfp4_serialized = ("NVFP4" in quant_method )
478- if ("group_size" and "kv_cache_quant_algo"
479- and "exclude_modules" ) not in quant_config :
480- raise ValueError ("NVFP4 quantization requires group size and "
481- "kv_cache_quant_algo specified in "
482- "hf_quant_config.json" )
483- kv_cache_quant_algo = quant_config ["kv_cache_quant_algo" ]
484- group_size = quant_config ["group_size" ]
485- exclude_modules = quant_config ["exclude_modules" ]
628+
629+ # For FP4, these fields are required
630+ if is_checkpoint_nvfp4_serialized and "quantization" in config :
631+ # Check if required fields are present in the quantization config
632+ quant_config = config ["quantization" ]
633+ required_fields = [
634+ "group_size" , "kv_cache_quant_algo" , "exclude_modules"
635+ ]
636+ missing_fields = [
637+ field for field in required_fields if field not in quant_config
638+ ]
639+ if missing_fields :
640+ raise ValueError (
641+ f"NVFP4 quantization requires the following fields in "
642+ f"hf_quant_config.json: { missing_fields } " )
643+
486644 return cls (is_checkpoint_nvfp4_serialized , kv_cache_quant_algo ,
487645 exclude_modules , group_size )
488646
0 commit comments