Skip to content

cumsum op: pytorch failed to run GPT-2 model in M1's MPS device #79112

Closed
@liulhdarks

Description

@liulhdarks

🐛 Describe the bug

My transformers inference script is running successfully in device CPU, but when using device MPS in MacOS M1 Pro, it will report 'aten::cumsum.out' op is missing, so I set environment variable 'PYTORCH_ENABLE_MPS_FALLBACK', but it will report the next error for huggingface transformers GPT-2 model:

/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py:999: UserWarning: The operator 'aten::cumsum.out' is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications. (Triggered internally at  /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/mps/MPSFallback.mm:11.)
  position_ids = attention_mask.long().cumsum(-1) - 1
Traceback (most recent call last):
  File "/Users/lihua.llh/Documents/codes/lab/python/gpt2_demo/inferences/demo/beam_generation_demo.py", line 40, in <module>
    main()
  File "/Users/lihua.llh/Documents/codes/lab/python/gpt2_demo/inferences/demo/beam_generation_demo.py", line 31, in main
    outputs = model.generate(input_ids=input_ids, num_beams=2, max_length=500, num_return_sequences=2,
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/generation_utils.py", line 1344, in generate
    return self.beam_search(
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/generation_utils.py", line 2192, in beam_search
    outputs = self(
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 1046, in forward
    transformer_outputs = self.transformer(
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 889, in forward
    outputs = block(
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 390, in forward
    attn_outputs = self.attn(
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/models/gpt2/modeling_gpt2.py", line 312, in forward
    query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/lihua.llh/miniconda3/envs/torch-m1/lib/python3.8/site-packages/transformers/pytorch_utils.py", line 107, in forward
    x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
RuntimeError: tensors must be 2-D

Script

using huggingface transformers version 4.19.2

import torch

from transformers import (
    GPT2LMHeadModel,
    GPT2Tokenizer,
)

MODEL_CLASSES = {
    "distilgpt2": (GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2-large": (GPT2LMHeadModel, GPT2Tokenizer),
    "gpt2": (GPT2LMHeadModel, GPT2Tokenizer),
}


def main():
    model_type = "gpt2"
    model_class, tokenizer_class = MODEL_CLASSES[model_type]

    prompt_text = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered."""
    tokenizer = tokenizer_class.from_pretrained(model_type)
    model = model_class.from_pretrained(model_type)

    input_ids = tokenizer(prompt_text, return_tensors="pt").input_ids
    model.eval()
    device = torch.device("mps")
    model = model.to(device)
    input_ids = input_ids.to(device)
    outputs = model.generate(input_ids=input_ids, num_beams=2, max_length=500, num_return_sequences=2,
                             repetition_penalty=1.2, length_penalty=1.2, no_repeat_ngram_size=5, top_p=1.0,
                             early_stopping=True)
    ret = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    for item in ret:
        print(item)


if __name__ == "__main__":
    main()

Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220601
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 12.4 (arm64)
GCC version: Could not collect
Clang version: 13.0.0 (clang-1300.0.29.30)
CMake version: Could not collect
Libc version: N/A

Python version: 3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16) [Clang 12.0.1 ] (64-bit runtime)
Python platform: macOS-12.4-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.0rc2
[pip3] torch==1.13.0.dev20220601
[pip3] torchaudio==0.14.0.dev20220601
[pip3] torchvision==0.14.0a0+f9f721d
[conda] numpy 1.23.0rc2 pypi_0 pypi
[conda] torch 1.13.0.dev20220601 pypi_0 pypi
[conda] torchaudio 0.14.0.dev20220601 pypi_0 pypi
[conda] torchvision 0.14.0a0+f9f721d pypi_0 pypi

cc @kulinseth @albanD

Metadata

Metadata

Assignees

Labels

module: mpsRelated to Apple Metal Performance Shaders frameworktriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions