Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer evaluation loop can't be used with predict_with_generate #729

Open
TimoImhof opened this issue Aug 3, 2024 · 0 comments
Open

Trainer evaluation loop can't be used with predict_with_generate #729

TimoImhof opened this issue Aug 3, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@TimoImhof
Copy link
Contributor

Environment info

  • adapters version: 1.0.0.dev0 (latest main)
  • Platform: Windows
  • Python version: 3.12.3
  • PyTorch version (GPU?): 2.4.0+cu121 (True)
  • Using GPU in script?: Yes, NVIDIA GeForce GTX 1660 SUPER
  • Using distributed or parallel set-up in script?: No

Information

  • Model I am using (Bert, XLNet ...): openai/whisper-small
  • Language I am using the model on (English, Chinese ...): Hindi
  • Adapter setup I am using (if any): LoRAConfig(r=8, alpha=16)

To reproduce

Execute this example script:

from datasets import load_dataset, DatasetDict
from transformers import WhisperProcessor, Seq2SeqTrainingArguments
from datasets import Audio
from adapters import WhisperAdapterModel, LoRAConfig, Seq2SeqAdapterTrainer
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import evaluate


def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]
    # compute log-Mel input features from input audio array
    batch["input_features"] = \
        processor.feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]
    # encode target text to label ids
    batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids
    return batch


# Preprocessing
processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")
common_voice = DatasetDict()
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train",
                                     trust_remote_code=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", trust_remote_code=True)
common_voice = common_voice.remove_columns(
    ["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names[
    "train"])  # no multiprocessing used here as this results in errors

# Model setup
model = WhisperAdapterModel.from_pretrained("openai/whisper-small")
model.generation_config.language = "hindi"
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None
task_name = "whisper_hindi_lora"
config = LoRAConfig(r=8, alpha=16)
model.add_adapter(task_name, config=config)
model.add_seq2seq_lm_head(task_name)
model.train_adapter(task_name)


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]
        batch["labels"] = labels
        return batch


data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

metric = evaluate.load("wer")


def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids
    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    # we do not want to group tokens when computing the metrics
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)
    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"wer": wer}


training_args = Seq2SeqTrainingArguments(
    output_dir="./" + task_name,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=1e-5,
    fp16=True,
    evaluation_strategy="steps",
    predict_with_generate=True,  # This is the parameter causing the error
    generation_max_length=225,
    eval_steps=2,
    metric_for_best_model="wer",
    push_to_hub=False,
    overwrite_output_dir=True,
)

trainer = Seq2SeqAdapterTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

trainer.train()

Resulting error:

  File transformers\trainer.py", line 1932, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File transformers\trainer.py", line 2345, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
  File transformers\trainer.py", line 2793, in _maybe_log_save_evaluate
    metrics = self._evaluate(trial, ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File transformers\trainer.py", line 2750, in _evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File transformers\trainer_seq2seq.py", line 180, in evaluate
    return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File transformers\trainer.py", line 3641, in evaluate
    output = eval_loop(
             ^^^^^^^^^^
  File transformers\trainer.py", line 3826, in evaluation_loop
    losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File transformers\trainer_seq2seq.py", line 310, in prediction_step
    generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File transformers\generation\utils.py", line 1640, in generate
    self._validate_model_kwargs(model_kwargs.copy())
  File transformers\generation\utils.py", line 1238, in _validate_model_kwargs
    raise ValueError(
ValueError: The following `model_kwargs` are not used by the model: ['labels'] (note: typos in the generate arguments will also show up in this list)

Explanation: In generate(), the method _validate_model_kwargs() checks for any unused kwargs. Because we do not specify labels as parameter in the forward method of the AdapterModels and only give it in kwargs this line

        for key, value in model_kwargs.items():
            if value is not None and key not in model_args:
                unused_model_args.append(key)

in _validate_model_kwargs() will identify labels as unused, resulting in the thrown ValueError.

Expected behavior

Execution of the evaluation loop with use of generate() method without the error.
Current workaround: set predict_with_generate=False.

@TimoImhof TimoImhof added the bug Something isn't working label Aug 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant