Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer class on Mac uses accelerate to incorrectly set MPS device #24697

Closed
2 of 4 tasks
alex2awesome opened this issue Jul 6, 2023 · 10 comments
Closed
2 of 4 tasks

Comments

@alex2awesome
Copy link

alex2awesome commented Jul 6, 2023

System Info

transformers==4.30.2
Mac 2019, Ventura 13.4

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

ISSUE: I am running a generic model training using Trainer on my mac, locally. My model is being moved to MPS, but my tensors are staying on CPU.

I can provide more details about my script, but I kinda expect that this is a general library problem. Here's the lines of code I discovered:

When the accelerator is instantiated in the Trainer class, it doesn't get passed any user-specific arguments, like this from TrainingArgs for e.g to give the user control over which device to use. As a result, when running locally on Mac, Accelerate does a lot of inference about which device we want to use, and moves the model to self.device in the non-distributed setting. I'm not sure yet how self.device is instantiated in Accelerate, however, Trainer doesn't natively move my data to mps, so my script is crashing.

Expected behavior

Ideally, I have a flag I can pass into Trainer to help me not MPS if I don't want to, and just stick with CPU.

@alex2awesome
Copy link
Author

EDIT:

Adding the flag --no_cuda in TrainingArgs takes care of this issue.

I suggest making it something like --use_cpu or --no_cuda_or_mps, because i totally didn't realize it could be used for this purpose and had to dive to the very bottom of the code-base to see.

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 7, 2023

I am not really an expert on this topic, but do you think #24660 will help?

@ydshieh
Copy link
Collaborator

ydshieh commented Jul 7, 2023

If not, a reproducible script is indeed necessary, please 🙏

@tcapelle
Copy link

I have a similar issue as the Trainer was automatically using the MPS backend and couldn't figure out a way of running on CPU. (The MPS backend is missing some operations, so no all models runs!).
Using no_cuda=True in the TrainerArgs solved the issue! pretty unintuitive!

@sgugger
Copy link
Collaborator

sgugger commented Jul 17, 2023

cc @SunMarc Maybe we could deprecate the no_cuda flag to replace it with use_cpu, which would be more intuitive?

@SunMarc
Copy link
Member

SunMarc commented Jul 17, 2023

Yes, we should do that since we will automatically set the device to cuda or mps if available. Furthermore, use_mps_device in TrainingArgs is also deprecated. I will open a PR for that. The other issue is that we don't dispatch the data in the right device. @muellerzr, I see that we don't move the dataloader to a specific device in get_train_dataloader. Is this something we want to add ? I can open a PR for it if needed.

@muellerzr
Copy link
Contributor

muellerzr commented Jul 17, 2023

@SunMarc accelerate does this automatically in its dataloader/with the Accelerator, so this should be already happening. If not, it's something we need to fix in accelerate

@tcapelle
Copy link

tcapelle commented Jul 17, 2023

There is also another issue that the default device is mps but the data is not moved to mps, so the Trainer raises an error, minimal code:

from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import AutoModelForCausalLM
from transformers import Trainer, TrainingArguments

model_checkpoint = "roneneldan/TinyStories-33M"
ds = load_dataset('MohamedRashad/characters_backstories')["train"]

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(example):
    merged = example["text"] + " " + example["target"]
    batch = tokenizer(merged, padding='max_length', truncation=True, max_length=128)
    batch["labels"] = batch["input_ids"].copy()
    return batch

tokenized_dataset = ds.map(tokenize_function, remove_columns=["text", "target"])

model = AutoModelForCausalLM.from_pretrained(model_checkpoint);

training_args = TrainingArguments(
    num_train_epochs=1,
    output_dir=".",
    # use_mps_device=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

print(trainer.accelerator.device)
# device("mps")

# Let's train!
trainer.train()

You can solve the issue by explicitly using use_mps_device=True or no_cuda=True on the TrainingArgs

PD: I am on latest of transformers, datasets and accelerate (pip install -U ....)

@SunMarc
Copy link
Member

SunMarc commented Jul 17, 2023

Hey @tcapelle , thanks for the snippet. It helps a lot to solve the issue. I was able to reproduce the bug on the latest version of transformers. This bug is fixed on the main branch of transformers that you can download with pip install https://github.com/huggingface/transformers.git. Let me know if it works on your side.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants