Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial IPEX support for Intel Arc GPU #14171

Merged
merged 4 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--use-ipex", action="store_true", help="use Intel XPU as torch device")
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
Expand Down
13 changes: 13 additions & 0 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@
if sys.platform == "darwin":
from modules import mac_specific

if shared.cmd_opts.use_ipex:
from modules import xpu_specific


def has_xpu() -> bool:
return shared.cmd_opts.use_ipex and xpu_specific.has_xpu


def has_mps() -> bool:
if sys.platform != "darwin":
Expand All @@ -30,6 +37,9 @@ def get_optimal_device_name():
if has_mps():
return "mps"

if has_xpu():
return xpu_specific.get_xpu_device_string()

return "cpu"


Expand All @@ -54,6 +64,9 @@ def torch_gc():
if has_mps():
mac_specific.torch_mps_gc()

if has_xpu():
xpu_specific.torch_xpu_gc()


def enable_tf32():
if torch.cuda.is_available():
Expand Down
22 changes: 22 additions & 0 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,26 @@ def requirements_met(requirements_file):
def prepare_environment():
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}")
if args.use_ipex:
if platform.system() == "Windows":
# The "Nuullll/intel-extension-for-pytorch" wheels were built from IPEX source for Intel Arc GPU: https://github.com/intel/intel-extension-for-pytorch/tree/xpu-main
# This is NOT an Intel official release so please use it at your own risk!!
# See https://github.com/Nuullll/intel-extension-for-pytorch/releases/tag/v2.0.110%2Bxpu-master%2Bdll-bundle for details.
#
# Strengths (over official IPEX 2.0.110 windows release):
# - AOT build (for Arc GPU only) to eliminate JIT compilation overhead: https://github.com/intel/intel-extension-for-pytorch/issues/399
# - Bundles minimal oneAPI 2023.2 dependencies into the python wheels, so users don't need to install oneAPI for the whole system.
# - Provides a compatible torchvision wheel: https://github.com/intel/intel-extension-for-pytorch/issues/465
# Limitation:
# - Only works for python 3.10
url_prefix = "https://github.com/Nuullll/intel-extension-for-pytorch/releases/download/v2.0.110%2Bxpu-master%2Bdll-bundle"
torch_command = os.environ.get('TORCH_COMMAND', f"pip install {url_prefix}/torch-2.0.0a0+gite9ebda2-cp310-cp310-win_amd64.whl {url_prefix}/torchvision-0.15.2a0+fa99a53-cp310-cp310-win_amd64.whl {url_prefix}/intel_extension_for_pytorch-2.0.110+gitc6ea20b-cp310-cp310-win_amd64.whl")
else:
# Using official IPEX release for linux since it's already an AOT build.
# However, users still have to install oneAPI toolkit and activate oneAPI environment manually.
# See https://intel.github.io/intel-extension-for-pytorch/index.html#installation for details.
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://pytorch-extension.intel.com/release-whl/stable/xpu/us/")
torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.0a0 intel-extension-for-pytorch==2.0.110+gitba7f6c1 --extra-index-url {torch_index_url}")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")

xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20')
Expand Down Expand Up @@ -352,6 +372,8 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")

if args.use_ipex:
args.skip_torch_cuda_test = True
Copy link

@qiacheng qiacheng Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to include a torch version check, if users have other torch packages installed in the env then run pip install to install required ipex, torch, torchvision packages

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if args.use_ipex:
    if is_installed("torch"):
        import torch
        if torch.__version__ != "2.0.0a0+git9ebda2" or not is_installed("intel_extension_for_pytorch"):
            run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
            startup_timer.record("install torch")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could check_run_python("import torch; import intel_extension_for_pytorch; assert torch.xpu.is_available()") to perform a sanity test, so that we don't assume a specific torch version -- Intel may release newer versions and user could build from source with a custom version.

if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError(
'Torch is not able to use GPU; '
Expand Down
4 changes: 2 additions & 2 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

Expand Down Expand Up @@ -43,7 +43,7 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32)
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' and x.device.type != 'xpu' else torch.float32)
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)

extra_args = {} if extra_args is None else extra_args
Expand Down
50 changes: 50 additions & 0 deletions modules/xpu_specific.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from modules import shared
from modules.sd_hijack_utils import CondFunc

has_ipex = False
try:
import torch
import intel_extension_for_pytorch as ipex # noqa: F401
has_ipex = True
except Exception:
pass


def check_for_xpu():
return has_ipex and hasattr(torch, 'xpu') and torch.xpu.is_available()


def get_xpu_device_string():
if shared.cmd_opts.device_id is not None:
return f"xpu:{shared.cmd_opts.device_id}"
return "xpu"


def torch_xpu_gc():
with torch.xpu.device(get_xpu_device_string()):
torch.xpu.empty_cache()


has_xpu = check_for_xpu()

if has_xpu:
# W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
CondFunc('torch.Generator',
lambda orig_func, device=None: torch.xpu.Generator(device),
lambda orig_func, device=None: device is not None and device.type == "xpu")

# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
orig_func(input.to(weight.data.dtype), normalized_shape, weight, *args, **kwargs),
lambda orig_func, input, normalized_shape=None, weight=None, *args, **kwargs:
weight is not None and input.dtype != weight.data.dtype)
CondFunc('torch.nn.modules.GroupNorm.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.linear.Linear.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)
CondFunc('torch.nn.modules.conv.Conv2d.forward',
lambda orig_func, self, input: orig_func(self, input.to(self.weight.data.dtype)),
lambda orig_func, self, input: input.dtype != self.weight.data.dtype)