Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5ForConditionalGeneration: After calling adapters.init() the data_collator input misses attention_mask #737

Open
lenglaender opened this issue Aug 21, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@lenglaender
Copy link
Member

Environment info

  • adapters version: 1.0
  • Platform: Linux-5.15.0-118-generic-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • PyTorch version: 2.3.1
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Information

Model I am using (Bert, XLNet ...): T5ForConditionalGeneration

Language I am using the model on (English, Chinese ...): any

Adapter setup I am using (if any): Isn't dependent on the adapter setup.

Expected behavior

The input that the data collator gets should be the same, independent of if I use

  1. default Hugging Face T5ForConditionalGeneration
  2. T5ForConditionalGeneration with adapters.init(model) or
  3. AutoAdapterModel with sequence-to-sequence head.

However, for T5ForConditionalGeneration this isn't true: When using the default HF T5ForConditionalGeneration or the AutoAdapterModel with seq2seq head, then the input for the data collator is dict_keys(['input_ids', 'attention_mask']). However, when using

model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
adapters.init(model)

then the data collator receives only dict_keys(['input_ids']), i.e. the "attention_mask" is missing!

I tested it for other models & tasks:

  • BART also for conditional generation
  • BERT for MLM

Both these models don't show this bug. Output of the script below:

Standard T5 conditional generation: Default HuggingFace model & trainer
keys of first feature: dict_keys(['input_ids', 'attention_mask'])

Setup 1: T5 conditional generation with adapters.init
keys of first feature: dict_keys(['input_ids'])  # THIS SHOULD CONTAIN "attention_mask"

Setup 2: T5 conditional generation with AutoAdapterModel
keys of first feature: dict_keys(['input_ids', 'attention_mask'])


Standard: BART conditional generation: Default HuggingFace model & trainer
keys of first feature: dict_keys(['input_ids', 'attention_mask'])

Setup 3: BART conditional generation with adapters.init
keys of first feature: dict_keys(['input_ids', 'attention_mask'])

Setup 4: BART conditional generation with AutoAdapterModel
keys of first feature: dict_keys(['input_ids', 'attention_mask'])


Standard: BERT masked language modeling: Default HuggingFace model & trainer
keys of first feature: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

Setup 5: BERT masked language modeling with adapters.init
keys of first feature: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

Setup 6: BERT masked language modeling with AutoAdapterModel
keys of first feature: dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])

To reproduce

The output pasted below is coming from this script:

from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    BertForMaskedLM,
    TrainingArguments,
    BartForConditionalGeneration,
    logging,
    Trainer,
)
from datasets import load_dataset
import adapters
from adapters import AdapterTrainer, AutoAdapterModel
from typing import List, Dict, Any

logging.set_verbosity_error()


def simple_data_collator(features: List[Dict[str, Any]]) -> Dict[str, Any]:
    print(f"keys of first feature: {features[0].keys()}")
    raise ValueError("Aborting")  # We only need to see the keys of the first feature


def run_auto_adapter_model(model_name, dataset, head_type):
    model = AutoAdapterModel.from_pretrained(model_name)
    model.add_adapter("my_adapter")
    if head_type == "seq2seq":
        model.add_seq2seq_lm_head("my_adapter")
    elif head_type == "mlm":
        model.add_masked_lm_head("my_adapter")

    model.train_adapter("my_adapter")
    model.set_active_adapters("my_adapter")

    try:
        AdapterTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            data_collator=simple_data_collator,
        ).train()
    except Exception:
        pass  # As expected, caught exception.


def run_adapters_init_model(model_class, model_name, dataset):
    model = model_class.from_pretrained(model_name)
    adapters.init(model)
    model.add_adapter("my_adapter")
    model.train_adapter("my_adapter")
    model.set_active_adapters("my_adapter")

    try:
        AdapterTrainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            data_collator=simple_data_collator,
        ).train()
    except Exception:
        pass  # As expected, caught exception.


def run_default_huggingface_model(model_class, model_name, dataset):
    model = model_class.from_pretrained(model_name)
    try:
        Trainer(
            model=model,
            args=training_args,
            train_dataset=dataset,
            data_collator=simple_data_collator,
        ).train()
    except Exception:
        pass  # As expected, caught exception.


# Prepare datasets and tokenizers
conditional_generation_ds = load_dataset("glue", "sst2", split="train[:100]")
mlm_ds = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:100]")
tokenizer_t5 = AutoTokenizer.from_pretrained("t5-small")
tokenizer_bart = AutoTokenizer.from_pretrained("facebook/bart-base")
tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-uncased")


tokenized_conditional_generation_ds_t5 = conditional_generation_ds.map(
    lambda x: tokenizer_t5(x["sentence"], return_special_tokens_mask=True), batched=True
).remove_columns(["label"])

tokenized_conditional_generation_ds_bart = conditional_generation_ds.map(
    lambda x: tokenizer_bart(x["sentence"], return_special_tokens_mask=True), batched=True
).remove_columns(["label"])
tokenized_mlm_ds_bert = mlm_ds.map(lambda x: tokenizer_bert(x["text"], return_special_tokens_mask=True), batched=True)


training_args = TrainingArguments(
    output_dir="./output",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    logging_steps=10,
    report_to="none",
)

# Run experiments
print("\nStandard T5 conditional generation: Default HuggingFace model & trainer")
run_default_huggingface_model(
    T5ForConditionalGeneration, "google/flan-t5-small", tokenized_conditional_generation_ds_t5
)

print("\nSetup 1: T5 conditional generation with adapters.init")
run_adapters_init_model(T5ForConditionalGeneration, "google/flan-t5-small", tokenized_conditional_generation_ds_t5)

print("\nSetup 2: T5 conditional generation with AutoAdapterModel")
run_auto_adapter_model("google/flan-t5-small", tokenized_conditional_generation_ds_t5, head_type="seq2seq")

# For all other setups, keys are always: dict_keys(['label', 'input_ids', 'attention_mask', 'labels'])
# Doesn't matter if it's for seq2seq or mlm
print("\n\nStandard: BART conditional generation: Default HuggingFace model & trainer")
run_default_huggingface_model(
    BartForConditionalGeneration, "facebook/bart-base", tokenized_conditional_generation_ds_bart
)

print("\nSetup 3: BART conditional generation with adapters.init")
run_adapters_init_model(BartForConditionalGeneration, "facebook/bart-base", tokenized_conditional_generation_ds_bart)

print("\nSetup 4: BART conditional generation with AutoAdapterModel")
run_auto_adapter_model("facebook/bart-base", tokenized_conditional_generation_ds_bart, head_type="seq2seq")


print("\n\nStandard: BERT masked language modeling: Default HuggingFace model & trainer")
run_default_huggingface_model(BertForMaskedLM, "bert-base-uncased", tokenized_mlm_ds_bert)

print("\nSetup 5: BERT masked language modeling with adapters.init")
run_adapters_init_model(BertForMaskedLM, "bert-base-uncased", tokenized_mlm_ds_bert)

print("\nSetup 6: BERT masked language modeling with AutoAdapterModel")
run_auto_adapter_model("bert-base-uncased", tokenized_mlm_ds_bert, head_type="mlm")
@lenglaender lenglaender added the bug Something isn't working label Aug 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant