Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functionalTensor is not supported by ipex custom kernel when using Torch.compile after ipex.llm.optimize #760

Open
lostkingdom4 opened this issue Dec 30, 2024 · 14 comments
Assignees

Comments

@lostkingdom4
Copy link

Describe the issue

I was attempting to use the torch.compile after doing the ipex.llm.optimize on language model on a Max 1100 GPU. My goal is to improve the torch.compile by recognizing the fx graph pattern and directly using the custom kernel such as torch_ipex.xetla_sdp_dropout. However, as I was testing torch.compile on the NewIPEXBertSelfAttention as shown in the following code,

import torch
import torch.nn as nn
torch.set_default_dtype(torch.float16)
from transformers.models.bert.modeling_bert import BertConfig, BertSelfAttention
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.transformers.models.xpu.optimize_transformers.modules.bert import NewIPEXBertSelfAttention

config = BertConfig()

attention_layer = BertSelfAttention(config)

new_attention = NewIPEXBertSelfAttention(attention_layer, config).to('xpu')

batch_size = 2
seq_length = 10
hidden_size = config.hidden_size

hidden_states = torch.rand(batch_size, seq_length, hidden_size).to('xpu')  # Random input tensor
attention_mask = torch.ones(batch_size, 1, 1, seq_length).to('xpu')  # No masking for simplicity

# Forward pass
outputs = new_attention(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
)

# Print the outputs
print("Output shape:", outputs[0].shape)  # Should be (batch_size, seq_length, hidden_size)
print(outputs)
if len(outputs) > 1:
    print("Past key value shape:", [pkv.shape for pkv in outputs[1]])


# Compile the new_attention module using torch.compile
compiled_attention = torch.compile(new_attention)

# Forward pass with compiled module
compiled_outputs = compiled_attention(
    hidden_states=hidden_states,
    attention_mask=attention_mask,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
)

# Print the outputs from the compiled module
print("Compiled output shape:", compiled_outputs[0].shape)  # Should be (batch_size, seq_length, hidden_size)
print(compiled_outputs)
if len(compiled_outputs) > 1:
    print("Compiled past key value shape:", [pkv.shape for pkv in compiled_outputs[1]])

I got error as the following:

Traceback (most recent call last):
  File "/workspace/newipexatten.py", line 57, in <module>
    compiled_outputs = compiled_attention(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1064, in __call__
    result = self._inner_convert(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1692, in CALL_FUNCTION_KW
    self.call_function(fn, args, kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1602, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/torch.py", line 897, in call_function
    tensor_variable = wrap_fx_proxy(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/variables/builder.py", line 2124, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2082, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2017, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
    return fn()
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2018, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2150, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 2132, in run_node
    return node.target(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
    return self._op(*args, **(kwargs or {}))
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torch_ipex.xetla_sdp_dropout(*(FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
           grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
           grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 12, 10, 64),
           grad_fn=<PermuteBackward0>), FakeTensor(..., device='xpu:0', size=(2, 1, 1, 10)), 0.1, False, None), **{}):
Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

from user code:
   File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/intel_extension_for_pytorch/transformers/models/xpu/optimize_transformers/modules/bert.py", line 103, in forward
    context_layer = torch.xpu.IpexSDP_dropout(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/intel_extension_for_pytorch/xpu/intrinsic/__init__.py", line 163, in IpexSDP_dropout
    return torch.ops.torch_ipex.xetla_sdp_dropout(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

I followed Adding torch.compile support for an operator to added FakeTensor kernels for torch.ops.torch_ipex.xetla_sdp_dropout and use torch.library.opcheck for detail failed reason as shown below.

import torch
from torch import Tensor   
import intel_extension_for_pytorch as ipex


@torch.library.register_fake("torch_ipex::xetla_sdp_dropout")
def _(query, key, value, attn_mask, dropout_p, is_causal, scale):
    # Add Python-side checks for inputs
    torch._check(query.shape == key.shape == value.shape, "All three must have the same embedding dimension")
    torch._check(query.dtype == key.shape == value.shape == torch.float16, "Query must be float16")
    torch._check(query.device == key.device == value.device == attn_mask.device, "All inputs must be on the same device")
    
    # Provide a fake output shape or mock implementation for testing
    return torch.empty_like(query)

sample_inputs = [
    (torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
     torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
     torch.randn(2, 12, 10, 64, dtype=torch.float16, device="xpu"),
     torch.ones(2, 1, 1, 10, dtype=torch.float16, device="xpu"),
     0.0, False, None)
]


for args in sample_inputs:
    torch.library.opcheck(torch.ops.torch_ipex.xetla_sdp_dropout, args, test_utils='test_aot_dispatch_static')

The code will fail on test_aot_dispatch_static and test_aot_dispatch_dynamic.
I had a close look at the source code and found the problem is at OpOverload.
It will return None and cause the failure. It seems like the problem is because the functionalTensor is not supported even I have already declared the fakeTensor kernel.
The error is as follows:

Traceback (most recent call last):
  File "/root/miniforge3/envs/py310/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/root/miniforge3/envs/py310/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/__main__.py", line 71, in <module>
    cli.main()
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 501, in main
    run()
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/adapter/../../debugpy/launcher/../../debugpy/../debugpy/server/cli.py", line 351, in run_file
    runpy.run_path(target, run_name="__main__")
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 310, in run_path
    return _run_module_code(code, init_globals, run_name, pkg_name=pkg_name, script_name=fname)
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 127, in _run_module_code
    _run_code(code, mod_globals, init_globals, mod_name, mod_spec, pkg_name, script_name)
  File "/root/.vscode-server/extensions/ms-python.debugpy-2024.14.0-linux-x64/bundled/libs/debugpy/_vendored/pydevd/_pydevd_bundle/pydevd_runpy.py", line 118, in _run_code
    exec(code, run_globals)
  File "/workspace/opcheck.py", line 48, in <module>
    torch.library.opcheck(torch.ops.torch_ipex.xetla_sdp_dropout, args, test_utils='test_aot_dispatch_static')
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/library.py", line 1322, in opcheck
    return optests.opcheck(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 657, in opcheck
    tester(op, args, kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 114, in safe_aot_autograd_check
    return aot_autograd_check(func, args, kwargs, dynamic, check_gradients="auto")
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/aot_autograd.py", line 75, in aot_autograd_check
    compiled_out = wrapper_set_seed(compiled_f, args)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_utils.py", line 18, in wrapper_set_seed
    output = op(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 868, in returned_function
    compiled_fn, _ = create_aot_dispatcher_function(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 623, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 173, in inner
    flat_f_outs = f(*flat_f_args)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 182, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/aot_autograd.py", line 64, in func_no_tensors
    return func(*c_args, **c_kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/testing/_internal/optests/generate_tests.py", line 110, in func
    return op(*args, **kwargs)
  File "/root/miniforge3/envs/py310/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
    return self._op(*args, **kwargs)
RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

Any tips on how to solve this problem will be really helpful!

My system configuration is as follows:

PyTorch version: 2.5.1+cxx11.abi
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-118-generic-x86_64-with-glibc2.35
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        52 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               35
On-line CPU(s) list:                  0-34
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8480+
CPU family:                           6
Model:                                143
Thread(s) per core:                   1
Core(s) per socket:                   35
Socket(s):                            1
Stepping:                             8
BogoMIPS:                             4000.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            1.1 MiB (35 instances)
L1i cache:                            1.1 MiB (35 instances)
L2 cache:                             140 MiB (35 instances)
L3 cache:                             16 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-34
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] intel_extension_for_pytorch==2.5.10+xpu
[pip3] numpy==1.26.4
[pip3] pytorch-triton-xpu==3.1.0+91b14bf559
[pip3] torch==2.5.1+cxx11.abi
[pip3] triton==3.2.0+git6ee08cd2
[conda] intel-extension-for-pytorch 2.5.10+xpu               pypi_0    pypi
[conda] mkl                       2025.0.1                 pypi_0    pypi
[conda] mkl-dpcpp                 2025.0.1                 pypi_0    pypi
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] onemkl-sycl-blas          2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-datafitting   2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-dft           2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-lapack        2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-rng           2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-sparse        2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-stats         2025.0.1                 pypi_0    pypi
[conda] onemkl-sycl-vm            2025.0.1                 pypi_0    pypi
[conda] pytorch-triton-xpu        3.1.0+91b14bf559          pypi_0    pypi
[conda] torch                     2.5.1+cxx11.abi          pypi_0    pypi
[conda] triton                    3.2.0+git6ee08cd2          pypi_0    pypi
IPEX version: 2.5.10+xpu
IPEX commit: Unknown
@ZhaoqiongZ ZhaoqiongZ self-assigned this Dec 30, 2024
@EikanWang
Copy link
Contributor

@ZhaoqiongZ , any update?

@ZhaoqiongZ
Copy link
Contributor

Hi @EikanWang , I pass the issue to Su, Tong, since it is too detail with torch.compile feature.

@Stonepia
Copy link
Contributor

Stonepia commented Jan 9, 2025

Hi @lostkingdom4 , thanks for the detailed reproducer!
I did a check, the problem is because the registered meta function's schema does not find the exact custom op. Actually, if you change your registration like the below:

@torch.library.register_fake("torch_ipex::xetla_sdp_dropout")
def xetla_sdp_dropout(query, key, value, attn_mask, dropout_p, is_causal, scale):
    print("run into fake")
    assert False

You will find that this still does not throw the assert False. This means that the fake tensor does not actually run into this op.

After further investigation, I found that when registering the custom op, it has the dispatch key of c10::DispatchKey::AutogradXPU.

This makes the schema can't be correctly found.

A quick and temporary solution is to change c10::DispatchKey::AutogradXPU to c10::DispatchKey::XPU and rebuild IPEX. Then your fake tensor registration would work.

This solution is not perfect and just a temporary solution, but normally it won't affect much performance/accuracy. We will try to fix that later. Thanks again for your patience!

@lostkingdom4
Copy link
Author

Hi @Stonepia Thanks for the feedback. It works for me!

@lostkingdom4
Copy link
Author

Hi @Stonepia, I was trying to do the same thing for torch.ops.torch_ipex.mm_qkv_out(input, self.weight, self.bias, q, k, v).
With a fake register as an example:

@torch.library.register_fake("torch_ipex::mm_qkv_out.xpu")
def _(query, key, value, attn_mask, dropout_p, is_causal, scale):
    print("run into fake")
    assert False

I got

RuntimeError: register_fake(...): the operator torch_ipex::mm_qkv_out.xpu already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an fake impl; instead, the operator will decompose into its constituents and those can have fake impls defined on them.

I have already modified the dispatch key but I think this might not be the problem. Is there a quick fix for this type of operator? I went through the C++ source code but I'm still not entirely sure where it is registered as CompositeImplicitAutograd.

@lostkingdom4
Copy link
Author

I try to solve the problem by modifying code in csrc

  IPEX_OP_REGISTER("mm_qkv_out.xpu", at::AtenIpexTypeXPU::mm_qkv_out);
  IPEX_OP_REGISTER_DISPATCH(
      "mm_qkv_out.xpu",
      at::AtenIpexTypeXPU::mm_qkv_out_autocast,
      c10::DispatchKey::AutocastXPU);

to

// IPEX_OP_REGISTER("mm_qkv_out", at::AtenIpexTypeXPU::mm_qkv_out);
  IPEX_OP_REGISTER_DISPATCH(
      "mm_qkv_out",
      at::AtenIpexTypeXPU::mm_qkv_out,
      c10::DispatchKey::XPU);

The problem with the fake tensor seems to be solved. However, when running with torch.compile, the dynamo will cause a graph break on this operator. Could you double-check that this will work for operators with a pre-existing registration to DispatchKey::CompositeImplicitAutograd?

Meanwhile, what is the difference between IPEX_OP_REGISTER and IPEX_OP_REGISTER_DISPATCH? Why do we want to put .xpu at the end of the mm_qkv_out? I know it is for overloading to XPU. But is this necessary?

Thanks

@ZhaoqiongZ ZhaoqiongZ assigned Stonepia and unassigned ZhaoqiongZ Feb 6, 2025
@lostkingdom4
Copy link
Author

@Stonepia
Is there any update on this?

Thanks

@Stonepia
Copy link
Contributor

Stonepia commented Mar 7, 2025

Hi @lostkingdom4 ,
I suspect this should be a bug by PyTorch, not your implementation issue. We haven't found it because no one has tried this path (custom ops registered in another dispatch key). So the next step is to write a reproducer and submit it to the PyTorch issue.

Apologize for that I didn't take the chance to get some bandwidth on this. I will update you once I have some new findings.

@lostkingdom4
Copy link
Author

@Stonepia
Thanks for replying, and thanks for your effort on this matter.
Essentially, I want the torch.compile to understand the operator with a simple fake tensor registration. So, if the operator is significantly faster than what is generated by Triton, we can use it instead.

I also tried to understand the problem by going through the source codes of IPEX_OP_REGISTER, IPEX_OP_REGISTER_DISPATCH, and TORCH_LIBRARY_IMPL. I think the problem is that once the operator registered using IPEX_OP_REGISTER. The dispatch key will be registered as CompositeImplicitAutograd. I've tried to use IPEX_OP_REGISTER_DISPATCH to register it to a certain dispatch key and rebuild it. However the build is always unsuccessful.

@lostkingdom4
Copy link
Author

@Stonepia
Hi, I accidentally marked this thread as closed. Can you reopen it?

@Stonepia Stonepia reopened this Mar 7, 2025
@Stonepia
Copy link
Contributor

Stonepia commented Mar 7, 2025

Hi, @lostkingdom4
Yes, that's why I said that the problem is not your implementation, it should be PyTorch side issue (not even IPEX side). So I would assume to write a reproducer from pure PyTorch side code for reproducing, but didn't get a chance😿
If you are interested, there are two manuals that would help:

  1. torch.compile : https://docs.google.com/document/d/1y5CRfMLdwEoF1nTk9q8qEu1mgMUuUtvhklPKJ2emLU8/edit?tab=t.0
  2. custom operators: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0

@lostkingdom4
Copy link
Author

@Stonepia
Thanks for replying, and thanks for your effort on this matter.
Essentially, I want the torch.compile to understand the operator with a simple fake tensor registration. So, if the operator is significantly faster than what is generated by Triton, we can use it instead.

I also tried to understand the problem by going through the source codes of IPEX_OP_REGISTER, IPEX_OP_REGISTER_DISPATCH, and TORCH_LIBRARY_IMPL. I think the problem is that once the operator registered using IPEX_OP_REGISTER. The dispatch key will be registered as CompositeImplicitAutograd. I've tried to use IPEX_OP_REGISTER_DISPATCH to register it to a certain dispatch key and rebuild it. However, the build is always unsuccessful.

As I basically need to rebuild the entire pytorch and ipex every time I tried with new registration, it would be really helpful if you could give me some tips on just building the operator without rebuilding the entire system. So I can try to solve this problem more conveniently. I might have a way to solve the problem, but building the ipex from scratch is killing me.

@Stonepia
Copy link
Contributor

Stonepia commented Mar 7, 2025

I don't think you need to rebuild everytime, you could try first with the Python Op registration, it should be the same with C++ side. You don't need to build PyTorch as well.

I suggest starting from a simpler custom op with Python registration, to see if everything goes well. Then move to the harder one (that fused everything on IPEX).

@lostkingdom4
Copy link
Author

I see what you mean. I've already tried the custom operator registration. It works. The only thing I'm not so sure about is these headers. For example:

#include "XeGemm.h"
#include <ATen/ATen.h>
#include <ATen/CPUApplyUtils.h>
#include <ATen/record_function.h>
#include <runtime/Utils.h>
#include <iostream>
#include "Blas.h"
#include "Linear.h"
#include "comm/ATDispatch.h"
#include "utils/CustomOperatorRegistration.h"
#if defined(USE_XETLA) && defined(USE_XETLA_XE_HPC) // XeGemm only supports PVC
#include "xetla/hgemm.h"
#endif
#include <ATen/autocast_mode.h>

But I think I will try to build it within the repository so I don't need to worry about the path. I will get back to you after I have some findings.

Again, thanks for the information.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants