Skip to content

Commit

Permalink
FIX: New bloom changes breaking prompt learning (#1969)
Browse files Browse the repository at this point in the history
Bloom had two dimensions of the attention layer transposed (compared to
all other transformers models), which was fixed by:

huggingface/transformers#31445

Therefore, for future transformers versions, skip the special handling
in PEFT.

There is also an issue that prompt injection did not take place when
past_key_values was a Cache object that is empty. This should now
hopefully work as expected.
  • Loading branch information
BenjaminBossan authored Jul 29, 2024
1 parent 273acf0 commit 27833a2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
28 changes: 18 additions & 10 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,10 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
uses_transformers_4_38 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.38.0")
uses_transformers_4_36 = packaging.version.parse(transformers.__version__) >= packaging.version.parse("4.36.0")
transformers_new_cache_archs = ["llama", "mistral", "persimmon", "phi"]
if packaging.version.parse(transformers.__version__) > packaging.version.parse("4.43.3"):
# https://github.com/huggingface/transformers/pull/31445
transformers_new_cache_archs.append("bloom")

uses_cache = uses_transformers_4_38 or (
uses_transformers_4_36 and self.base_model.config.model_type in transformers_new_cache_archs
)
Expand Down Expand Up @@ -1690,16 +1694,20 @@ def prepare_inputs_for_generation(self, *args, task_ids: Optional[torch.Tensor]
)
kwargs["token_type_ids"] = None

if model_kwargs["past_key_values"] is None and peft_config.peft_type == PeftType.PREFIX_TUNING:
past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
model_kwargs["past_key_values"] = past_key_values
else:
if model_kwargs["past_key_values"] is None:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None
# no past_key_values or past_key_values empty cache
requires_prompt_injection = (model_kwargs["past_key_values"] is None) or (
isinstance(model_kwargs["past_key_values"], transformers.Cache) and not model_kwargs["past_key_values"]
)

if requires_prompt_injection and peft_config.peft_type == PeftType.PREFIX_TUNING:
new_past_key_values = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
model_kwargs["past_key_values"] = new_past_key_values
elif requires_prompt_injection:
inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0], task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
model_kwargs["input_ids"] = None

# For transformers>=4.38.0 - for some architectures such as Llama, `cache_position` is
# passed in the forward pass to keep track of the position ids of the cache. We have to
Expand Down
9 changes: 8 additions & 1 deletion src/peft/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import packaging.version
import torch
import transformers


# needed for prefix-tuning of bloom model
Expand Down Expand Up @@ -40,10 +43,14 @@ def starcoder_model_postprocess_past_key_value(past_key_values):


TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
"bloom": bloom_model_postprocess_past_key_value,
"gpt_bigcode": starcoder_model_postprocess_past_key_value,
}

if packaging.version.parse(transformers.__version__) <= packaging.version.parse("4.43.3"):
# special handling for bloom architecture was fixed in:
# https://github.com/huggingface/transformers/pull/31445
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING["bloom"] = bloom_model_postprocess_past_key_value

TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING = {
"llama": ["input_layernorm", "post_attention_layernorm", "norm"],
"bloom": ["input_layernorm", "post_attention_layernorm", "ln_f"],
Expand Down

0 comments on commit 27833a2

Please sign in to comment.