Skip to content

Exploring Stable Diffusion in Torch Compile Path #2144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
gs-olive opened this issue Jul 26, 2023 · 4 comments
Closed

Exploring Stable Diffusion in Torch Compile Path #2144

gs-olive opened this issue Jul 26, 2023 · 4 comments
Assignees
Labels
component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths No Activity Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path

Comments

@gs-olive
Copy link
Collaborator

gs-olive commented Jul 26, 2023

When benchmarking Torch-TRT on Stable Diffusion via TorchBench (see tutorial here), and using the following command:

HUGGING_FACE_HUB_TOKEN={YOUR HF TOKEN} python run.py stable_diffusion --backend torch_trt --precision fp32 -d cuda -t eval --ir torch_compile

set_module Issue

The above fails first due to this line, since the set_module attribute for Stable Diffusion seems to not be working.

Missing ATen Operators

If we intercept the compilation above that line, the compilation succeeds, but is very slow with over 100 TRT engines generated. Below are the missing operators listed. Some of these may have since been implemented.

- torch.ops.aten.var_mean.correction
- torch.ops.aten._unsafe_view.default
- torch.ops.aten.arange.start_step
- torch.ops.aten.bmm.default
- torch.ops.aten.amax.default
- torch.ops.aten.erf.default
- torch.ops.aten.full.default
- torch.ops.aten._to_copy.default
- torch.ops.aten.permute.default
- torch.ops.aten.sum.dim_IntList
- torch.ops.aten.exp.default
- torch.ops.aten.slice.Tensor
- _operator.getitem
- torch.ops.aten.argmax.default
- torch.ops.aten.unsqueeze.default
- torch.ops.aten.index.Tensor
- torch.ops.aten.sqrt.default
- torch.ops.aten.mm.default
- torch.ops.aten.embedding.default
- torch.ops.aten.reciprocal.default
- torch.ops.aten.clone.default

Accuracy Issue

Additionally, there is an accuracy issue where the outputs of the trt_model have a very low cosine similarity score (~0.14 on one test), relative to their PyTorch counterparts. This may be contributed to by the graph segmentation.

### Tasks
- [x] _operator.getitem
- [ ] torch.ops.aten.var_mean.correction
- [ ] torch.ops.aten._unsafe_view.default
- [ ] torch.ops.aten.amax.default
- [ ] torch.ops.aten.erf.default
- [ ] torch.ops.aten.full.default
- [x] torch.ops.aten._to_copy.default
- [x] torch.ops.aten.permute.default
- [ ] https://github.com/pytorch/TensorRT/issues/2244
- [ ] torch.ops.aten.exp.default
- [x] torch.ops.aten.slice.Tensor
- [ ] https://github.com/pytorch/TensorRT/issues/2245
- [x] torch.ops.aten.unsqueeze.default
- [ ] https://github.com/pytorch/TensorRT/issues/2231
- [ ] torch.ops.aten.sqrt.default
- [x] torch.ops.aten.mm.default
- [x] torch.ops.aten.embedding.default
- [ ] torch.ops.aten.reciprocal.default
- [x] torch.ops.aten.clone.default
- [ ] https://github.com/pytorch/TensorRT/issues/2236
- [ ] https://github.com/pytorch/TensorRT/issues/2243
- [ ] https://github.com/pytorch/TensorRT/issues/1795
- [x] https://github.com/pytorch/TensorRT/issues/2346
@gs-olive gs-olive added component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path labels Jul 26, 2023
@gs-olive
Copy link
Collaborator Author

gs-olive commented Aug 3, 2023

Different parts of the Stable Diffusion pipeline have different missing operators. For instance, the UNet is missing the following set:
(Jul 3 Torch Nightly)

- torch.ops.aten.arange.start_step + Operator Count: 7
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.sin.default + Operator Count: 1
- torch.ops.aten.cos.default + Operator Count: 1
- torch.ops.aten.var_mean.correction + Operator Count: 109
- torch.ops.aten.sqrt.default + Operator Count: 109
- torch.ops.aten._scaled_dot_product_efficient_attention.default + Operator Count: 32
- torch.ops.aten.split.Tensor + Operator Count: 16
- torch.ops.aten.erf.default + Operator Count: 16
- torch.ops.aten._to_copy.default + Operator Count: 6
- torch.ops.aten._unsafe_index.Tensor + Operator Count: 3

(Aug 3 Torch Nightly)

- torch.ops.aten.arange.start_step + Operator Count: 7
- torch.ops.aten.exp.default + Operator Count: 1
- torch.ops.aten.sin.default + Operator Count: 1
- torch.ops.aten.cos.default + Operator Count: 1
- torch.ops.aten.native_group_norm.default + Operator Count: 61
- torch.ops.aten.native_layer_norm.default + Operator Count: 48
- torch.ops.aten._scaled_dot_product_efficient_attention.default + Operator Count: 32
- torch.ops.aten.split.Tensor + Operator Count: 16
- torch.ops.aten._to_copy.default + Operator Count: 6
- torch.ops.aten._unsafe_index.Tensor + Operator Count: 3

Find decompositions for group norm and layer norm to transform it back into component ops.

@gs-olive gs-olive changed the title Accuracy Issue with Stable Diffusion in Torch Compile Path Exploring Stable Diffusion in Torch Compile Path Aug 14, 2023
@gs-olive
Copy link
Collaborator Author

gs-olive commented Aug 15, 2023

Next up:

- torch.ops.aten._scaled_dot_product_efficient_attention.default + Operator Count: 32
- _operator.getitem + Operator Count: 64 [Covered by split below]
- torch.ops.aten.split.Tensor + Operator Count: 16
- torch.ops.aten.index.Tensor + Operator Count: 3

@gs-olive
Copy link
Collaborator Author

Missing operators from text encoder

- torch.ops.aten.bmm.default + Operator Count: 24
- torch.ops.aten.amax.default + Operator Count: 12
- torch.ops.aten.exp.default + Operator Count: 12
- torch.ops.aten.sum.dim_IntList + Operator Count: 12
- torch.ops.aten.argmax.default + Operator Count: 1
- torch.ops.aten.index.Tensor + Operator Count: 1

Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths No Activity Story: Dynamo Compile Improvements Issues relating to improvement of the Dynamo compile path
Projects
None yet
Development

No branches or pull requests

3 participants