-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial IPEX support for Intel Arc GPU #14171
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from modules import shared | ||
from modules.sd_hijack_utils import CondFunc | ||
|
||
has_ipex = False | ||
try: | ||
import torch | ||
import intel_extension_for_pytorch as ipex # noqa: F401 | ||
has_ipex = True | ||
except Exception: | ||
pass | ||
|
||
|
||
def check_for_xpu(): | ||
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available() | ||
|
||
|
||
def get_xpu_device_string(): | ||
if shared.cmd_opts.device_id is not None: | ||
return f"xpu:{shared.cmd_opts.device_id}" | ||
return "xpu" | ||
|
||
|
||
def torch_xpu_gc(): | ||
with torch.xpu.device(get_xpu_device_string()): | ||
torch.xpu.empty_cache() | ||
|
||
|
||
has_xpu = check_for_xpu() | ||
|
||
if has_xpu: | ||
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device | ||
CondFunc('torch.Generator', | ||
lambda orig_func, device=None: torch.xpu.Generator(device), | ||
lambda orig_func, device=None: device is not None and device.type == "xpu") | ||
|
||
# W/A for some OPs that could not handle different input dtypes | ||
CondFunc('torch.nn.functional.layer_norm', | ||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: | ||
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs), | ||
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs: | ||
weight is not None and input.dtype != weight.data.dtype) | ||
CondFunc('torch.nn.modules.GroupNorm.forward', | ||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), | ||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) | ||
CondFunc('torch.nn.modules.linear.Linear.forward', | ||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), | ||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) | ||
CondFunc('torch.nn.modules.conv.Conv2d.forward', | ||
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)), | ||
lambda orig_func, self, input: input.dtype != self.weight.data.dtype) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to include a torch version check, if users have other torch packages installed in the env then run pip install to install required ipex, torch, torchvision packages
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we could
check_run_python("import torch; import intel_extension_for_pytorch; assert torch.xpu.is_available()")
to perform a sanity test, so that we don't assume a specific torch version -- Intel may release newer versions and user could build from source with a custom version.