Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wav2Vec2 CUDA memory usage doubled in v4.11.3 compared to v4.10.3 with the same batch size #14388

Closed
MarktHart opened this issue Nov 14, 2021 · 12 comments · Fixed by #14407 or #14408
Closed
Assignees

Comments

@MarktHart
Copy link

Environment info

  • transformers version: 4.11.3
  • Platform: Linux-5.11.0-40-generic-x86_64-with-glibc2.29
  • Python version: 3.8.10
  • PyTorch version (GPU?): 1.8.1+cu111 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: Yes, 3090
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten, @anton-l

Information

When using Wav2vec2 the memory usage roughly doubles when going from Huggingface v4.10.3 to v4.11.3
Whereas my 3090 (24GB memory) in v4.10.3 could handle a batchsize of ~32, in 4.11.3 this is reduced to ~10.

The problem arises when using:

  • my own modified scripts

The tasks I am working on is:

  • ASR

To reproduce

Steps to reproduce the behavior:

  1. Run script with v4.10 and v4.11 and watch CUDA memory usage

Reproduce script (relatively minimal):

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, TrainingArguments
from transformers.trainer import Trainer
from torch.utils.data.dataset import Dataset
import numpy as np

class ProcessedDataset(Dataset):
    def __init__(self, processor):
        self.processor = processor

    def __getitem__(self, i):
        x = np.ones(16000 * 10) # 10 seconds
        y = "this is a random sentence"
        with self.processor.as_target_processor():
            batch= {"labels": self.processor(y).input_ids}
        batch["input_values"] = self.processor(x, sampling_rate=16000).input_values
        return batch

    def __len__(self):
        return 10000

class DataCollator:
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, features):
        input_features = [{"input_values": feature["input_values"][0]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        batch = self.processor.pad(
            input_features,
            padding=True,
            max_length=None,
            pad_to_multiple_of=None,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=True,
                max_length=None,
                pad_to_multiple_of=None,
                return_tensors="pt",
            )
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
        batch["labels"] = labels
        return batch


proc = Wav2Vec2Processor.from_pretrained("wietsedv/wav2vec2-large-xlsr-53-dutch")
model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-nl-voxpopuli",
    attention_dropout=0,
    hidden_dropout=0,
    feat_proj_dropout=0,
    mask_time_prob=0,
    layerdrop=0,
    activation_dropout=0,
    gradient_checkpointing=True,
    ctc_loss_reduction="mean",
    pad_token_id=proc.tokenizer.pad_token_id,
    vocab_size=len(proc.tokenizer),
    ctc_zero_infinity=True
)
ds = ProcessedDataset(proc)
data_collator = DataCollator(processor=proc)
args = TrainingArguments(
    output_dir="/tmp/tmp_model",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    do_eval=False,
    num_train_epochs=1,
    fp16=True,
    group_by_length=False,
    save_steps=-1,
    eval_steps=1024,
    logging_steps=1024,
    warmup_steps=128,
    save_total_limit=1,
    dataloader_num_workers=1,
    seed=11
)

trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=data_collator)
trainer.train()

Expected behavior

Upgrading Huggingface Transformers from 4.10 to a later version should keep the memory usage in the same ballpark

@patrickvonplaten
Copy link
Contributor

Looking into it now

@patrickvonplaten
Copy link
Contributor

Benchmarking your script on current master gives me a peak GPU mem usage of 20068MiB .

@patrickvonplaten
Copy link
Contributor

And with 4.10 it gives me 10738MiB => so this seems like a pretty heavy bug! Thanks for the heads-up!

@patrickvonplaten
Copy link
Contributor

Will investigate now

@MarktHart
Copy link
Author

No problem at all! If there is anything I can do to assist I would be happy to help.

@patrickvonplaten
Copy link
Contributor

Ok I think I already found one problem. It seems like the gradient_checkpointing PR refactor wasn't 100% backward compatible.

@MarktHart - could you add

model.gradient_checkpointing_enable()

before this line:

trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=data_collator)

this should more or less solve the problem

@MarktHart
Copy link
Author

That does solve the issue. Thanks a bunch!

@MarktHart
Copy link
Author

@patrickvonplaten do you decide whether to close the issue or that backward compatibility should be restored?

@patrickvonplaten
Copy link
Contributor

@sgugger - this is a weird issue. For some reason from_pretrained(...) doesn't currently set gradient_checkpointing to True at the first init since the main models does not have the nn.Modules attached yet.

Will open a hacky PR to fix it

@voidful
Copy link
Contributor

voidful commented Dec 17, 2021

Ok I think I already found one problem. It seems like the gradient_checkpointing PR refactor wasn't 100% backward compatible.

@MarktHart - could you add

model.gradient_checkpointing_enable()

before this line:

trainer = Trainer(model=model, args=args, train_dataset=ds, data_collator=data_collator)

this should more or less solve the problem

I have this issue in 4.14.1 when i set group_by_length=True. Adding model.gradient_checkpointing_enable() can't solve this problem.

@patrickvonplaten
Copy link
Contributor

@voidful - can you provide a reproducible script here? :-) Thanks a lot!

@voidful
Copy link
Contributor

voidful commented Dec 17, 2021

@voidful - can you provide a reproducible script here? :-) Thanks a lot!

It turn out to be length issue on my custom dataset, simplify apply .filter can solve this problem~~~~
Sorry for misleading.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment