@@ -505,30 +505,47 @@ def tensor_parallel(self):
505505 Apply the model's tensor parallelization plan.
506506 Currently only supports linear layers.
507507 """
508- tp_plan = getattr (self .model .config , "base_model_tp_plan" , None ) or {}
508+ # Look for tp plans in all of the PreTrainedModels found in self.model
509+ is_pretrained_model = lambda m : isinstance (m , PreTrainedModel )
510+ supports_tp_plan = lambda m : m .config .base_model_tp_plan is not None
511+ pretrained_models = filter (is_pretrained_model , self .model .modules ())
512+ models_with_tp_plan = filter (supports_tp_plan , pretrained_models )
509513
510- if not tp_plan and self .tp_size > 1 :
514+ if not any ( models_with_tp_plan ) and self .tp_size > 1 :
511515 raise ValueError (
512516 f"{ type (self .model )} does not support tensor parallel yet!" )
513517
514- # Some weight loaders expect linear layers to inherit from vLLM's
515- # LinearBase class, so we set a default style which causes any
516- # unspecified linear layers to be replaced with ReplicatedLinear
517- tp_plan [".*" ] = "replicate"
518-
519- def _tensor_parallel (module : nn .Module , prefix : str = "" ):
518+ def _tensor_parallel (module : nn .Module ,
519+ prefix : str = "" ,
520+ tp_plan = None ):
521+ tp_plan = tp_plan or {}
522+
523+ # If the current module is a PreTrainedModel, set the tp_plan for
524+ # all of its children
525+ if isinstance (module , PreTrainedModel ):
526+ tp_plan = module .config .base_model_tp_plan or {}
527+ tp_plan = {
528+ maybe_prefix (prefix , k ): v
529+ for k , v in tp_plan .items ()
530+ }
531+
532+ # Some weight loaders expect linear layers to inherit from vLLM's
533+ # LinearBase class, so we set a default style which causes any
534+ # unspecified linear layers to be replaced with ReplicatedLinear
520535 for child_name , child_module in module .named_children ():
521536 qual_name = maybe_prefix (prefix , child_name )
522- for pattern , style in tp_plan . items ( ):
523- if re .match (pattern , qual_name ) and isinstance (
524- child_module , nn . Linear ):
525- new_module = replace_linear_class (
526- child_module , style , self . quant_config )
527- setattr ( module , child_name , new_module )
528- log_replacement ( qual_name , child_module , new_module )
529- break
537+ if isinstance ( child_module , nn . Linear ):
538+ generator = ( p for p in tp_plan if re .match (p , qual_name ))
539+ pattern = next ( generator , None )
540+ style = tp_plan . get ( pattern , "replicate" )
541+ new_module = replace_linear_class ( child_module , style ,
542+ self . quant_config )
543+ setattr ( module , child_name , new_module )
544+ log_replacement ( qual_name , child_module , new_module )
530545 else :
531- _tensor_parallel (child_module , prefix = qual_name )
546+ _tensor_parallel (child_module ,
547+ prefix = qual_name ,
548+ tp_plan = tp_plan )
532549
533550 _tensor_parallel (self .model )
534551
0 commit comments