Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workaround for MPS gather crash #15

Merged
merged 1 commit into from
Apr 3, 2023

Conversation

brkirch
Copy link
Contributor

@brkirch brkirch commented Apr 2, 2023

This is a workaround for gather on MPS that uses unsqueeze before gather and squeeze afterwards so that gather doesn't cause a crash. Fixes #4.

@dbolya
Copy link
Owner

dbolya commented Apr 2, 2023

Oh sweet, I didn't realize the fix was this simple. I don't have an MPS device to test this on, so can someone else (e.g., @KohakuBlueleaf or @tvdtran) confirm that it works for them?

Edit: Also this might fix this directml issues (#13). If so, the workaround could be applied if the device type is "mps" or "dml".

@GreenLandisaLie
Copy link

Oh sweet, I didn't realize the fix was this simple. I don't have an MPS device to test this on, so can someone else (e.g., @KohakuBlueleaf or @tvdtran) confirm that it works for them?

Edit: Also this might fix this directml issues (#13). If so, the workaround could be applied if the device type is "mps" or "dml".

I tried it with with directml: https://github.com/lshqqytiger/stable-diffusion-webui-directml and still got an error in:
return torch.gather(input.unsqueeze(-1), dim - 1 if dim < 0 else dim, index.unsqueeze(-1)).squeeze(-1)
triggered after this line:
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx)
The error was something akin to 'invalid parameter'. Also the metric.device.type was not 'dml' but something else I can't remember - might be specific to the UI I mentioned.

@dbolya
Copy link
Owner

dbolya commented Apr 2, 2023

I tried it with with directml: https://github.com/lshqqytiger/stable-diffusion-webui-directml and still got an error in: return torch.gather(input.unsqueeze(-1), dim - 1 if dim < 0 else dim, index.unsqueeze(-1)).squeeze(-1) triggered after this line: dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) The error was something akin to 'invalid parameter'. Also the metric.device.type was not 'dml' but something else I can't remember - might be specific to the UI I mentioned.

Thanks for testing this. Seems like the directml issue is separate then, unfortunately. Then we can keep this as a fix just for MPS and look into a separate fix for directml.

@jrittvo
Copy link

jrittvo commented Apr 3, 2023

Using this patched tomesd version in a simple python diffusers pipeline threw an error saying I needed to set an environment variable: export PYTORCH_ENABLE_MPS_FALLBACK=1

Worked after that. A 768x768 straight generation went from 3.75 s/it without tomesd to 2.13 s/it using tomesd at .5 so I assume the "fallback" isn't a big detriment, if at all. I don't know how to do Hires fix with python commands yet, so I can't test this all the way.

@dbolya
Copy link
Owner

dbolya commented Apr 3, 2023

Using this patched tomesd version in a simple python diffusers pipeline threw an error saying I needed to set an environment variable: export PYTORCH_ENABLE_MPS_FALLBACK=1

Worked after that. A 768x768 straight generation went from 3.75 s/it without tomesd to 2.13 s/it using tomesd at .5 so I assume the "fallback" isn't a big detriment, if at all. I don't know how to do Hires fix with python commands yet, so I can't test this all the way.

Sounds good to me. Requiring PYTORCH_ENABLE_MPS_FALLBACK=1 is better than just crashing, and that speed-up seems not too far off expected (1.76x speed-up using MPS vs. ~1.87x speed-up using CUDA).
Thank you for testing this! Merging now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Failed to run on M1Mac
4 participants