Skip to content
This repository was archived by the owner on Aug 1, 2025. It is now read-only.
This repository was archived by the owner on Aug 1, 2025. It is now read-only.

Segmentation fault with inductor for huggingface/diffusers #1681

@mreso

Description

@mreso

Hi,

running the stable_diffusion example from huggingface/diffusers results in segmentation fault.

Repro:
pip install --upgrade diffusers transformers

Then run:

import torchdynamo

from diffusers import StableDiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") #ACCESS TOKEN REQUIRED
pipe = pipe.to("cuda")

@torchdynamo.optimize("inductor")
def apply(x):
    return pipe(x)

prompt = "a photo of an astronaut riding a horse on mars"
image = apply(prompt).images[0]  

Error message:

ftfy or spacy is not installed using BERT BasicTokenizer instead of ftfy.
WARNING:torchinductor.lowering:make_fallback(aten.unfold): a decomposition exists, we should switch to it
WARNING:torchinductor.lowering:make_fallback(aten.unfold_backward): a decomposition exists, we should switch to it
[2022-10-15 01:43:28,669] torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(StableDiffusionPipeline) [ConstantVariable(str)] {} from user code at   File "/home/ubuntu/torchdynamo/test.py", line 14, in apply
    return pipe(x)

[2022-10-15 01:43:28,770] torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedObjectVariable(CLIPTokenizer) [ConstantVariable(str)] {'padding': ConstantVariable(str), 'max_length': ConstantVariable(int), 'return_tensors': Cons
tantVariable(str)} from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 214, in __call__
    text_inputs = self.tokenizer(

[2022-10-15 01:43:34,247] torchdynamo.symbolic_convert: [WARNING] Graph break: call_function BuiltinVariable(dict) [UserDefinedObjectVariable(FrozenDict)] {} from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-package
s/diffusers/pipeline_utils.py", line 205, in device
    module_names, _ = self.extract_init_dict(dict(self.config))

[2022-10-15 01:43:37,990] torchinductor.graph: [WARNING] Creating implicit fallback for:
  target: aten.triu_.default
  args[0]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.float32, size=[1, 77, 77], stride=[5929, 77, 1]), data=Pointwise(
      'cpu',
      torch.float32,
      constant(-3.4028234663852886e+38, torch.float32),
      ranges=[1, 77, 77],
      origins={fill_}
    ))
  ))
  args[1]: 1
[2022-10-15 01:43:37,992] torchinductor.ir: [WARNING] Using FallbackKernel: torch.ops.aten.triu_.default
[2022-10-15 01:43:37,992] torchinductor.ir: [WARNING] DeviceCopy
[2022-10-15 01:43:43,374] torchinductor.compile_fx: [WARNING] skipping cudagraphs due to multiple devices
[2022-10-15 01:43:55,938] torchdynamo.symbolic_convert: [WARNING] Graph break: call_method GetAttrVariable(UserDefinedObjectVariable(Version), _regex) search [ConstantVariable(str)] {} from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/
lib/python3.10/site-packages/diffusers/schedulers/scheduling_pndm.py", line 161, in set_timesteps
    deprecated_offset = deprecate(
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/utils/deprecation_utils.py", line 17, in deprecate
    if version.parse(version.parse(__version__).base_version) >= version.parse(version_name):
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/packaging/version.py", line 49, in parse
    return Version(version)
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/packaging/version.py", line 264, in __init__
    match = self._regex.search(version)

[2022-10-15 01:43:55,947] torchdynamo.symbolic_convert: [WARNING] Graph break: numpy from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/schedulers/scheduling_pndm.py", line 170, in <graph break in
set_timesteps>
    self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()

[2022-10-15 01:43:56,134] torchdynamo.symbolic_convert: [WARNING] Graph break: call_function in skip_files /home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/tqdm/asyncio.py from user code at   File "/home/ubuntu/miniconda3/envs/torc
hdynamo/lib/python3.10/site-packages/diffusers/pipeline_utils.py", line 559, in progress_bar
    return tqdm(iterable, **self._progress_bar_config)

  0%|
                                                                                       | 0/51 [00:00<?, ?it/s][2022-10-15 01:43:59,438] torchdynamo.symbolic_convert: [WARNING] Graph break: call_function UserDefinedClassVariable() [] {'sample': TensorVa
riable()} from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/models/unet_2d_condition.py", line 341, in forward
    return UNet2DConditionOutput(sample=sample)

[2022-10-15 01:47:47,542] torchdynamo.symbolic_convert: [WARNING] Graph break: inline __setitem__ from user code at   File "<string>", line 4, in <graph break in __init__>
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/utils/outputs.py", line 72, in __post_init__
    self[field.name] = v

[2022-10-15 01:47:47,550] torchdynamo.symbolic_convert: [WARNING] Graph break: non-function or method super: <slot wrapper '__setitem__' of 'collections.OrderedDict' objects> from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python
3.10/site-packages/diffusers/utils/outputs.py", line 104, in __setitem__
    super().__setitem__(key, value)

[2022-10-15 01:47:47,553] torchdynamo.symbolic_convert: [WARNING] Graph break: non-function or method super: <slot wrapper '__setattr__' of 'object' objects> from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-package
s/diffusers/utils/outputs.py", line 106, in <graph break in __setitem__>
    super().__setattr__(key, value)

[2022-10-15 01:47:47,565] torchdynamo.symbolic_convert: [WARNING] Graph break: data dependent operator: aten._local_scalar_dense.default from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/scheduler
s/scheduling_pndm.py", line 225, in step
    return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/schedulers/scheduling_pndm.py", line 340, in step_plms
    prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
  File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/schedulers/scheduling_pndm.py", line 374, in _get_prev_sample
    alpha_prod_t = self.alphas_cumprod[timestep]

[2022-10-15 01:47:48,531] torchdynamo.symbolic_convert: [WARNING] Graph break: data dependent operator: aten._local_scalar_dense.default from user code at   File "/home/ubuntu/miniconda3/envs/torchdynamo/lib/python3.10/site-packages/diffusers/scheduler
s/scheduling_pndm.py", line 375, in <graph break in _get_prev_sample>
    alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

[2022-10-15 01:47:49,100] torchinductor.compile_fx: [WARNING] skipping cudagraphs due to multiple devices
Segmentation fault (core dumped)

Expected result:
Example works with eager + aot_eager and successfully generates an image. Crashed with inductor.

Env info:

Collecting environment information...
PyTorch version: 1.14.0.dev20221014+cu116
Is debug build: False
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
Clang version: Could not collect
CMake version: version 3.24.1
Libc version: glibc-2.31

Python version: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:36:39) [GCC 10.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-1020-aws-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.6.124
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A10G
Nvidia driver version: 510.73.08
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.23.4
[pip3] torch==1.14.0.dev20221014+cu116
[pip3] torchdynamo==1.14.0.dev0
[conda] numpy                     1.23.4                   pypi_0    pypi
[conda] torch                     1.14.0.dev20221014+cu116          pypi_0    pypi
[conda] torchdynamo               1.14.0.dev0              pypi_0    pypi

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions