Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG][IPEX/XPU] init_ipex_linear taking very long time, >10 minutes, with a small 1B model on XPU #977

Open
notsyncing opened this issue Dec 28, 2024 · 9 comments
Labels
bug Something isn't working

Comments

@notsyncing
Copy link

Describe the bug

Hello, I'm trying out gptqmodel on an Intel A770 16G with Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4 model using the following script:

import torch
import intel_extension_for_pytorch as ipex
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import gptqmodel.integration
from datetime import datetime

print(torch.__version__)
print(ipex.__version__)
[print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())]

gptqmodel.integration.patch_hf()

model_4bit = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4"
).to("xpu")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4")

generator = pipeline("text-generation", model=model_4bit, tokenizer=tokenizer)
print(f"{datetime.now()}: Generating...")
print(generator("def helloWorld() {"))
print(f"{datetime.now()}: Generate again...")
print(generator("Hello!"))
print(f"{datetime.now()}: End!")

And it takes almost 10 minutes after Generating... to get first generation output, with 100% CPU usage (one core) and about 20% GPU usage. The second generation takes about 4 seconds. Is this expected or something is wrong?

CPU is Intel Core i9-10940X.

Full output:

[W1228 17:56:59.304756311 OperatorEntry.cpp:155] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
    registered at /build/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at /build/pytorch/build/aten/src/ATen/RegisterCPU.cpp:30476
       new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:2971 (function operator())
2024-12-28 17:57:01,463 - datasets - INFO - PyTorch version 2.5.1+cxx11.abi available.
2.5.1+cxx11.abi
2.5.10+xpu
[0]: _XpuDeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.31294+21', total_memory=15473MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=32, max_work_group_size=1024, max_num_sub_groups=128, sub_group_sizes=[8 16 32], has_fp16=1, has_fp64=0, has_atomic64=1)
2024-12-28 17:57:03,206 - gptqmodel.integration.src.transformers.utils.quantization_config - INFO - You have activated exllama backend. Note that you can get better inference speed using exllamav2 kernel by setting `exllama_config`.
`low_cpu_mem_usage` was None, now default to True since model is quantized.
INFO - Auto pick kernel based on compatibility: <class 'gptqmodel.nn_modules.qlinear.ipex.IPEXQuantLinear'>
/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/modeling_utils.py:5006: FutureWarning: `_is_quantized_training_enabled` is going to be deprecated in transformers 4.39.0. Please use `model.hf_quantizer.is_trainable` instead
  warnings.warn(
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
2024-12-28 17:57:04,299 - gptqmodel.integration.src.optimum.gptq.quantizer - WARNING - Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`
2024-12-28 17:57:05.686379: Generating...
/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/generation/utils.py:1375: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
[{'generated_text': 'def helloWorld() {\n    println "Hello World!"\n}\n\nhelloWorld()  # 输出:'}]
2024-12-28 18:06:40.181055: Generate again...
[{'generated_text': "Hello! I'm a language model created by Anthropic. How can I assist you today?"}]
2024-12-28 18:06:44.581584: End!

GPU Info

Show output of:

+-----------+--------------------------------------------------------------------------------------+
| Device ID | Device Information                                                                   |
+-----------+--------------------------------------------------------------------------------------+
| 0         | Device Name: Intel(R) Arc(TM) A770 Graphics                                          |
|           | Vendor Name: Intel(R) Corporation                                                    |
|           | SOC UUID: 00000000-0000-0067-0000-000856a08086                                       |
|           | PCI BDF Address: 0000:67:00.0                                                        |
|           | DRM Device: /dev/dri/card1                                                           |
|           | Function Type: physical                                                              |
+-----------+--------------------------------------------------------------------------------------+

Software Info

Fedora 41, running distrobox from intel/oneapi-basekit:2025.0.1-0-devel-ubuntu24.04, Python 3.12.3

Show output of:

WARNING: Package(s) not found: triton
Name: gptqmodel
Version: 1.5.0+cpu
Summary: A LLM quantization package with user-friendly apis. Based on GPTQ algorithm.
Home-page: https://github.com/ModelCloud/GPTQModel
Author: ModelCloud
Author-email: qubitium@modelcloud.ai
License: 
Location: /var/home/pods/test/venv/lib/python3.12/site-packages
Requires: accelerate, datasets, device-smi, numpy, packaging, pillow, protobuf, safetensors, sentencepiece, threadpoolctl, torch, transformers
Required-by: 
---
Name: torch
Version: 2.5.1+cxx11.abi
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /var/home/pods/test/venv/lib/python3.12/site-packages
Requires: filelock, fsspec, jinja2, networkx, setuptools, sympy, typing-extensions
Required-by: accelerate, azarrot, bitsandbytes, effdet, gptqmodel, optimum, optimum-intel, sentence-transformers, timm, torchaudio, torchvision, unstructured-inference
---
Name: transformers
Version: 4.46.3
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /var/home/pods/test/venv/lib/python3.12/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: azarrot, gptqmodel, optimum, optimum-intel, sentence-transformers, unstructured-inference
---
Name: accelerate
Version: 1.2.1
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: zach.mueller@huggingface.co
License: Apache
Location: /var/home/pods/test/venv/lib/python3.12/site-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: azarrot, gptqmodel

If you are reporting an inference bug of a post-quantized model, please post the content of config.json and quantize_config.json.

To Reproduce

Run the script above.

Expected behavior

First generation takes less than 1 minute.

Model/Datasets

Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4

@notsyncing notsyncing added the bug Something isn't working label Dec 28, 2024
@Qubitium
Copy link
Collaborator

@notsyncing First run is slower due to model loading from disk but 10 minutes vs 4 second is not normal.

Let's directly load using gptqmodel internal code without hf integration:

Use below code and re-test. Do not move to model to xpu since gptqmodel will already load it on xpu.

model_4bit = GPTQModel.load(
    "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4", 
    device="xpu"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4")

generator = pipeline("text-generation", model=model_4bit, tokenizer=tokenizer)
print(f"{datetime.now()}: Generating...")
print(generator("def helloWorld() {"))
print(f"{datetime.now()}: Generate again...")
print(generator("Hello!"))
print(f"{datetime.now()}: End!")

@Qubitium Qubitium changed the title [BUG] Very long first generation time (~10 minutes) with a small (1.5B) model on XPU [XPU] Very long first generation time (~10 minutes) with a small (1.5B) model on XPU Dec 29, 2024
@Qubitium
Copy link
Collaborator

Qubitium commented Dec 29, 2024

Also test not using hf pipeline. Not sure if pipeline is doing extra torch.compile steps.

model = GPTQModel.load(
    "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4", 
    device="xpu"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4")

result = model.generate(
  **tokenizer(
      "def helloWorld() {", return_tensors="pt"
  ).to("xpu")
)[0]

I don't trust any api that wraps too many layers deep. I have not looked at pipeline to what it actually does in HF.

@notsyncing
Copy link
Author

Also test not using hf pipeline. Not sure if pipeline is doing extra torch.compile steps.

model = GPTQModel.load(
    "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4", 
    device="xpu"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4")

result = model.generate(
  **tokenizer(
      "def helloWorld() {", return_tensors="pt"
  ).to("xpu")
)[0]

I don't trust any api that wraps too many layers deep. I have not looked at pipeline to what it actually does in HF.

I tested again with your code, it first complains:

FileNotFoundError: [Errno 2] No such file or directory: '/var/home/pods/.cache/huggingface/hub/models--Qwen--Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4/snapshots/d45c7545dc428f013534f8bfd0441b3afffc0006/quantize_config.json'

Then I manually created that file with the quantization_config section in config.json of that model, now it outputs:

[W1229 11:56:31.753561680 OperatorEntry.cpp:155] Warning: Warning only once for all operators,  other operators may also be overridden.
  Overriding a previously registered kernel for the same operator and the same dispatch key
  operator: aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()
    registered at /build/pytorch/build/aten/src/ATen/RegisterSchema.cpp:6
  dispatch key: XPU
  previous kernel: registered at /build/pytorch/build/aten/src/ATen/RegisterCPU.cpp:30476
       new kernel: registered at /build/intel-pytorch-extension/build/Release/csrc/gpu/csrc/aten/generated/ATen/RegisterXPU.cpp:2971 (function operator())
2024-12-29 11:56:33,320 - datasets - INFO - PyTorch version 2.5.1+cxx11.abi available.
2.5.1+cxx11.abi
2.5.10+xpu
[0]: _XpuDeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) oneAPI Unified Runtime over Level-Zero', type='gpu', driver_version='1.6.31294+21', total_memory=15473MB, max_compute_units=512, gpu_eu_count=512, gpu_subslice_count=32, max_work_group_size=1024, max_num_sub_groups=128, sub_group_sizes=[8 16 32], has_fp16=1, has_fp64=0, has_atomic64=1)
WARNING - No cuda found, use IPEX backend
Fetching 10 files: 100%|███████████████████████████████████████████████████| 10/10 [00:00<00:00, 54050.31it/s]
INFO - Ignoring unknown parameter in the quantization configuration: batch_size.
INFO - Ignoring unknown parameter in the quantization configuration: block_name_to_quantize.
INFO - Ignoring unknown parameter in the quantization configuration: cache_block_outputs.
INFO - Ignoring unknown parameter in the quantization configuration: dataset.
INFO - Ignoring unknown parameter in the quantization configuration: exllama_config.
INFO - Ignoring unknown parameter in the quantization configuration: max_input_length.
INFO - Ignoring unknown parameter in the quantization configuration: model_seqlen.
INFO - Ignoring unknown parameter in the quantization configuration: module_name_preceding_first_block.
INFO - Ignoring unknown parameter in the quantization configuration: modules_in_block_to_quantize.
INFO - Ignoring unknown parameter in the quantization configuration: pad_token_id.
INFO - Ignoring unknown parameter in the quantization configuration: tokenizer.
INFO - Ignoring unknown parameter in the quantization configuration: use_cuda_fp16.
INFO - Ignoring unknown parameter in the quantization configuration: use_exllama.
INFO - `checkpoint_format` is missing from the quantization configuration and is automatically inferred to gptq
2024-12-29 11:56:38.322534: Generating...
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/generation/utils.py:1375: UserWarning: Using the model-agnostic default `max_length` (=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(
tensor([  750, 23811, 10134,   368,   314,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
       device='xpu:0')
2024-12-29 12:05:02.594639: Generating again...
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
tensor([9707,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0], device='xpu:0')
2024-12-29 12:05:07.417358: End!

Still the same generation time, with same CPU and GPU usage.

Full code:

from datetime import datetime
import torch
import intel_extension_for_pytorch as ipex
from transformers import AutoTokenizer
from gptqmodel import GPTQModel
import gptqmodel.integration

print(torch.__version__)
print(ipex.__version__)
[print(f'[{i}]: {torch.xpu.get_device_properties(i)}') for i in range(torch.xpu.device_count())]

gptqmodel.integration.patch_hf()

model = GPTQModel.load(
    "Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4",
    device="xpu"
)

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-1.5B-Instruct-GPTQ-Int4")

print(f"{datetime.now()}: Generating...")

result = model.generate(
  **tokenizer(
      "def helloWorld() {", return_tensors="pt"
  ).to("xpu")
)[0]
print(result)

print(f"{datetime.now()}: Generating again...")

result = model.generate(
  **tokenizer(
      "Hello!", return_tensors="pt"
  ).to("xpu")
)[0]
print(result)

print(f"{datetime.now()}: End!")

btw, I forgot to mention my disk: an NVMe 4T SSD on a PCIe 3.0 x4 slot. So the disk loading cannot be the bottleneck.

If I interrupt the first generation with Ctrl+C at about the 5th minute, it breaks at:

^CTraceback (most recent call last):
  File "/var/home/pods/test/test.py", line 34, in <module>
    result = model.generate(
             ^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/gptqmodel/models/base.py", line 761, in generate
    return self.model.generate(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 2215, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/generation/utils.py", line 3206, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 1164, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 895, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 623, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/transformers/models/qwen2/modeling_qwen2.py", line 502, in forward
    key_states = self.k_proj(hidden_states)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/gptqmodel/nn_modules/qlinear/ipex.py", line 206, in forward
    self.init_ipex_linear(x)
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/gptqmodel/nn_modules/qlinear/ipex.py", line 149, in init_ipex_linear
    self.ipex_linear = IPEXWeightOnlyQuantizedLinear.from_weight(self.qweight, self.scales, self.qzeros,
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/intel_extension_for_pytorch/llm/quantization/woq_linear.py", line 64, in from_weight
    woq_linear_impl = woq_linear_impl_cls.from_weight(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/intel_extension_for_pytorch/nn/utils/_quantize_convert.py", line 417, in from_weight
    qweight, g_idx = shuffler(qweight, g_idx)
                     ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/intel_extension_for_pytorch/nn/utils/_quantize_convert.py", line 76, in forward
    g_idx4kernel = self.convert_idx(g_idx, k).to(qweight_int32.device)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/home/pods/test/venv/lib/python3.12/site-packages/intel_extension_for_pytorch/nn/utils/_quantize_convert.py", line 29, in convert_idx
    g_counter[g_idx[i]] += 1
    ~~~~~~~~~^^^^^^^^^^
KeyboardInterrupt
^C

@Qubitium
Copy link
Collaborator

@notsyncing We will check this on our B580 test device on Monday and report back if norm with ipex/xpu or gptqmodel specific.

@Qubitium
Copy link
Collaborator

@jiqing-feng Can you check this? We isolated the issue to ipex_init_linear() that is performed once for the first time each ipex quant linear forward is called. We are not sure why this operation is taking so long for xpu on a 1B model. We also tested llama 1B and same slowness on ipex_init_linear as Qwen. Is this normal?

@Qubitium Qubitium changed the title [XPU] Very long first generation time (~10 minutes) with a small (1.5B) model on XPU [BUG][IPEX/XPU] init_ipex_linear taking very long time, >10 minutes, with a small 1B model on XPU Dec 30, 2024
@jiqing-feng
Copy link
Contributor

Only 1B model take so much time or all models like 3b and 7b?

@notsyncing
Copy link
Author

@jiqing-feng Tested the same script with Qwen/Qwen2.5-Coder-7B-Instruct-GPTQ-Int4 model, it takes even longer: the first generation takes 25 minutes, while the second generation takes 8 seconds. No change in CPU and GPU usage.

@Qubitium
Copy link
Collaborator

Qubitium commented Jan 8, 2025

@notsyncing With help from @jiqing-feng We have tracked down the issue to following:

  • All kernels (gpu feature code) must be precompiled with combination of : gpu * feature * shape. Each kernel must be compiled multiple times per gpu arch, the kernel, and the varies shapes the kernel takes.
  • B580 is fairly new and IPEX did not precompile the kernels into the IPEX library
  • Because there are no pre-compiled kernels, the first time you run the kernel, IPEX needs to runtime compile the kernel which is extremely slow. Also explain why the second run is so fast.

In light of this, I will open up an issue with the IPEX packaging team so they can compile kernels for the B580 arch in their next release.

@Qubitium
Copy link
Collaborator

Qubitium commented Jan 8, 2025

Tracking IPEX issue: intel/intel-extension-for-pytorch#767

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants