Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make things work for MPS. #666

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

hameerabbasi
Copy link

No description provided.

Verified

This commit was signed with the committer’s verified signature.
hameerabbasi Hameer Abbasi
@kohya-ss
Copy link
Owner

Thank you for this! It seems like a lot of work to get it to work properly in mps😅

@hameerabbasi
Copy link
Author

Not much really, just had to get the device and use it instead of CUDA.

@kohya-ss
Copy link
Owner

If the assignment like unet = unet.to(accelerator.device, dtype=weight_dtype) is needed when changing the device of the model, it is really hard. Is the assignment not needed for mps?

@hameerabbasi
Copy link
Author

The assignment is needed for all backends. This was a bug I fixed, otherwise it moves the device and discards the result instead of using it.

@kohya-ss
Copy link
Owner

According to the document of PyTorch, the assignment doesn't seem to be needed for nn.Module.
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.to

However, Tensor requires an assignment:
https://pytorch.org/docs/stable/generated/torch.Tensor.to.html#torch.Tensor.to

I think unet is nn.Module.

@hameerabbasi
Copy link
Author

I will adjust this PR soon. ;)

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants