Skip to content

Commit

Permalink
Add pipeline parallel plan to PretrainedConfig and PreTrainedModel (
Browse files Browse the repository at this point in the history
huggingface#36091)

* Add `base_model_pp_plan` to `PretrainedConfig`

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add `_pp_plan` to `PreTrainedModel`

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add both to Llama for testing

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Fix type error

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update to suggested schema

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* `_pp_plan` keys are not patterns

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Simplify schema

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Fix typing error

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update input name for Llama

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Aria

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Bamba

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Cohere 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to diffllama and emu3

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Gemma 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to GLM and GPT NeoX

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Granite and Helium

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Mistral and Mixtral

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to OLMo 1 & 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan to Phi and Phi 3

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan for Qwen 2, 2 MoE, 2 VL and 2.5 VL

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add pp plan for Starcoder 2

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Add enum for accessing inputs and outputs

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Update type hints to use tuples

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

* Change outer list to tuple

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>

---------

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
  • Loading branch information
hmellor authored and sbucaille committed Feb 14, 2025
1 parent 2b24a44 commit 48c93bd
Show file tree
Hide file tree
Showing 50 changed files with 188 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ class PretrainedConfig(PushToHubMixin):
naming of attributes.
- **base_model_tp_plan** (`Dict[str, Any]`) -- A dict that maps sub-modules FQNs of a base model to a tensor
parallel plan applied to the sub-module when `model.tensor_parallel` is called.
- **base_model_pp_plan** (`Dict[str, Tuple[List[str]]]`) -- A dict that maps child-modules of a base model to a
pipeline parallel plan that enables users to place the child-module on the appropriate device.
Common attributes (present in all subclasses):
Expand Down Expand Up @@ -198,6 +200,7 @@ class PretrainedConfig(PushToHubMixin):
is_composition: bool = False
attribute_map: Dict[str, str] = {}
base_model_tp_plan: Optional[Dict[str, Any]] = None
base_model_pp_plan: Optional[Dict[str, Tuple[List[str]]]] = None
_auto_class: Optional[str] = None

def __setattr__(self, key, value):
Expand Down Expand Up @@ -860,6 +863,9 @@ def to_diff_dict(self) -> Dict[str, Any]:
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in serializable_config_dict:
del serializable_config_dict["base_model_pp_plan"]

return serializable_config_dict

Expand All @@ -882,6 +888,9 @@ def to_dict(self) -> Dict[str, Any]:
# Do not serialize `base_model_tp_plan` for now
if "base_model_tp_plan" in output:
del output["base_model_tp_plan"]
# Do not serialize `base_model_pp_plan` for now
if "base_model_pp_plan" in output:
del output["base_model_pp_plan"]

# Transformers version when serializing the model
output["transformers_version"] = __version__
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
Expand Down Expand Up @@ -923,6 +924,11 @@ def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
return weights_name


class PipelineParallel(Enum):
inputs: 0
outputs: 1


class ModuleUtilsMixin:
"""
A few utilities for `torch.nn.Modules`, to be used as a mixin.
Expand Down Expand Up @@ -1312,6 +1318,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# `config.base_model_tp_plan` during `post_init`.
_tp_plan = None

# A pipeline parallel plan specifying the layers which may not be present
# on all ranks when PP is enabled. For top-level models, this attribute is
# currently defined in respective model code. For base models, this
# attribute comes from `config.base_model_pp_plan` during `post_init`.
#
# The variable names for the inputs and outputs of the specified layers can
# be indexed using the `PipelineParallel` enum as follows:
# - `_pp_plan["layers"][PipelineParallel.inputs]`
# - `_pp_plan["layers"][PipelineParallel.outputs]`
_pp_plan = None

# This flag signal that the model can be used as an efficient backend in TGI and vLLM
# In practice, it means that they support attention interface functions, fully pass the kwargs
# through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
Expand Down Expand Up @@ -1374,6 +1391,9 @@ def post_init(self):
# If current model is a base model, attach `base_model_tp_plan` from config
if self.base_model is self:
self._tp_plan = self.config.base_model_tp_plan
# If current model is a base model, attach `base_model_pp_plan` from config
if self.base_model is self:
self._pp_plan = self.config.base_model_pp_plan

def dequantize(self):
"""
Expand Down Expand Up @@ -5196,6 +5216,15 @@ def tplize(mod: torch.nn.Module) -> None:
# function to every submodule.
self.apply(tplize)

@property
def supports_pp_plan(self):
if self._pp_plan is not None:
return True
# Check if base model has PP plan
if getattr(self.base_model, "_pp_plan", None) is not None:
return True
return False

@property
def loss_function(self):
if hasattr(self, "_loss_function"):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/aria/configuration_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ class AriaTextConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}
base_config_key = "text_config"

def __init__(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,7 @@ class AriaTextForCausalLM(AriaTextPreTrainedModel, GenerationMixin):

_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = AriaTextConfig

def __init__(self, config: AriaTextConfig):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,7 @@ def _update_mamba_mask(self, attention_mask, cache_position):
class BambaForCausalLM(BambaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere/configuration_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class CohereConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class CohereForCausalLM(CoherePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere2/configuration_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Cohere2ForCausalLM(Cohere2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config: Cohere2Config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ class Cohere2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/diffllama/modeling_diffllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class DiffLlamaForCausalLM(DiffLlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/emu3/modeling_emu3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,6 +1598,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class Emu3ForCausalLM(Emu3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
config_class = Emu3TextConfig

def __init__(self, config):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma/configuration_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GemmaForCausalLM(GemmaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ class GemmaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
class Gemma2ForCausalLM(Gemma2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,11 @@ class Gemma2Config(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/glm/configuration_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ class GlmConfig(PretrainedConfig):
"layers.*.mlp.gate_up_proj": "colwise_rep", # we need to replicate here due to the `chunk` operation
"layers.*.mlp.down_proj": "rowwise_rep", # we need to replicate here due to the `chunk` operation
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glm/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/models/gpt_neox/configuration_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ class GPTNeoXConfig(PretrainedConfig):
"layers.*.mlp.dense_h_to_4h": "colwise",
"layers.*.mlp.dense_4h_to_h": "rowwise",
}
base_model_pp_plan = {
"embed_in": (["input_ids"], ["inputs_embeds"]),
"emb_dropout": (["inputs_embeds"], ["hidden_states"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"final_layer_norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gpt_neox/modular_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["embed_out.weight"]
_tp_plan = {"embed_out": "colwise_rep"}
_pp_plan = {"embed_out": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/granite/configuration_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ class GraniteConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/granite/modeling_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class GraniteForCausalLM(GranitePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/helium/configuration_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class HeliumConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/helium/modeling_helium.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class HeliumForCausalLM(HeliumPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config: HeliumConfig):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ class LlamaConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/mistral/configuration_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ class MistralConfig(PretrainedConfig):
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

def __init__(self, config):
super().__init__(config)
Expand Down
Loading

0 comments on commit 48c93bd

Please sign in to comment.