Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Pointer argument (at 0) doesn't reference XPU device memory (cpu tensor?) #2188

Closed
dvrogozh opened this issue Sep 10, 2024 · 1 comment · Fixed by #2192
Closed

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Sep 10, 2024

@vlad-penkin : that's a reproducer associated with PyTorch PR #135567 which changes how pytorch returns tensor's data_ptr. The following test in Huggingface accelerate fails after this PR:

git clone https://github.com/huggingface/accelerate.git && cd accelerate
python3 -m pytest --pspec tests/test_utils.py::UtilsTester::test_dynamo

See details in pytorch/pytorch#135567 (review).

# AOT ID: ['0_forward']
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._inductor.codegen.multi_kernel import MultiKernelCall
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels, start_graph, end_graph
from torch._C import _xpu_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()


# kernel path: /tmp/torchinductor_dvrogozh/hg/chgfabyslvbgrlhijrr6hjw4sjcxtg6f2q7tfxxuhivdvcf2rwij.py
# Topologically Sorted Source Nodes: [mul, add], Original ATen: [aten.mul, aten.add]
# Source node to ATen node mapping:
#   add => add
#   mul => mul
# Graph fragment:
#   %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%primals_2, %primals_1), kwargs = {})
#   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %primals_3), kwargs = {})
triton_poi_fused_add_mul_0 = async_compile.triton('triton_', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties

@triton_heuristics.pointwise(
    size_hints=[64], 
    filename=__file__,
    triton_meta={'signature': {0: '*fp32', 1: 'fp32', 2: 'fp32', 3: '*fp32', 4: 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.30049', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None), 'constants': {}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 3), equal_to_1=())]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_mul_0', 'mutated_arg_names': [], 'no_x_dim': False, 'num_load': 3, 'num_reduction': 0, 'backend_hash': '2D2AD0B55A7DC439E8979A8831462B3B9FC5AD26CB3B7F3AFDEFF5FB7681E7F1', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
    min_elem_per_thread=0
)
@triton.jit
def triton_(in_ptr0, in_ptr1, in_ptr2, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 40
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = in_ptr1
    tmp3 = in_ptr2
    tmp2 = tmp0 * tmp1
    tmp4 = tmp2 + tmp3
    tl.store(out_ptr0 + (x0), tmp4, xmask)
''', device_str='xpu')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2, primals_3 = args
    args.clear()
    assert_size_stride(primals_1, (), ())
    assert_size_stride(primals_2, (4, 10), (10, 1))
    assert_size_stride(primals_3, (), ())
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        buf0 = empty_strided_xpu((4, 10), (10, 1), torch.float32)
        # Topologically Sorted Source Nodes: [mul, add], Original ATen: [aten.mul, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_mul_0.run(primals_2, primals_1.item(), primals_3.item(), buf0, 40, grid=grid(40), stream=stream0)
        del primals_1
        del primals_3
    return (buf0, primals_2, )


def benchmark_compiled_module(times=10, repeat=10):
    from torch._dynamo.testing import rand_strided
    from torch._inductor.utils import print_performance
    primals_1 = rand_strided((), (), device='cpu', dtype=torch.float32)
    primals_2 = rand_strided((4, 10), (10, 1), device='xpu:0', dtype=torch.float32)
    primals_3 = rand_strided((), (), device='cpu', dtype=torch.float32)
    fn = lambda: call([primals_1, primals_2, primals_3])
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    compiled_module_main('None', benchmark_compiled_module)
@vlad-penkin vlad-penkin self-assigned this Sep 10, 2024
dvrogozh added a commit to dvrogozh/intel-xpu-backend-for-triton that referenced this issue Sep 11, 2024
Fixes: intel#2188

Calling PyLong_AsLongLong causes overflow and return of a wrong pointer.

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@dvrogozh
Copy link
Contributor Author

@vlad-penkin : I've debugged this on my side and posted PR #2192. The issue was that triton calls PyLong_AsLongLong() to cast device memory pointers instead of PyLong_AsVoidPtr(). This fix did work for my failing test both with and without @guangyey change pytorch/pytorch#135567 on pytorch side. I hope this won't regress any tests on your side. Let's see...

victor-eds pushed a commit that referenced this issue Sep 12, 2024
Fixes: #2188

Calling PyLong_AsLongLong causes overflow and return of a wrong pointer.

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants