Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Two-hop mutations are not supported. Found registrations from 5 to 5 to 5 #3369

Open
wujingyue opened this issue Nov 7, 2024 · 6 comments · May be fixed by #3379
Open

Two-hop mutations are not supported. Found registrations from 5 to 5 to 5 #3369

wujingyue opened this issue Nov 7, 2024 · 6 comments · May be fixed by #3379
Assignees
Labels

Comments

@wujingyue
Copy link
Collaborator

I noticed this from the latest CI run of @IvanYashchuk's Lightning-AI/lightning-thunder#1371. Apparently, it failed in some pre-segmenter pass.

Repro:

# CUDA devices:
#  0: NVIDIA GeForce RTX 3090
#  1: NVIDIA GeForce RTX 3090
# torch version: 2.6.0a0+gitd622b49
# cuda version: 12.1
# nvfuser version: 0.2.22+git6912435
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id28(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[5, 5], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[5, 5], contiguity=[True, True], dtype=DataType.Float, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[5], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
    T3 = fd.ops.linear(T0, T1, T2)
    fd.add_output(T3)

with FusionDefinition() as fd:
    nvfuser_fusion_id28(fd)

inputs = [
    torch.testing.make_tensor((5, 5), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5, 5), dtype=torch.float32, device='cuda:0'),
    torch.testing.make_tensor((5,), dtype=torch.float32, device='cuda:0'),
]
fd.execute(inputs)
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/mutator.cpp":45, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Two-hop mutations are not supported. Found registrations from 5 to 5 to 5
Exception raised from maybeMutated at /opt/pytorch/nvfuser/csrc/mutator.cpp:45 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xd6 (0x719d4dc07a76 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x62 (0x719d4dffac82 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #2: nvfuser::OptOutMutator::maybeMutated(nvfuser::Val*) const + 0x183 (0x719d4e256703 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #3: nvfuser::OptOutMutator::mutate(nvfuser::IterDomain*) + 0x4f (0x719d4e257ccf in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x70481c (0x719d4e1d481c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x7fe60e (0x719d4e2ce60e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #6: nvfuser::preseg_passes::PreSegmenter::runPass(nvfuser::Fusion*) + 0x10d0 (0x719d4e2e0cd0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x87f500 (0x719d4e34f500 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x873b2c (0x719d4e343b2c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #9: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xa9 (0x719d4e3442c9 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #10: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x15a (0x719d4e4fa74a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x1b61ee (0x719d4dc861ee in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x280c16 (0x719d4dd50c16 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #13: <unknown function> + 0x2de233 (0x719d4ddae233 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-312-x86_64-linux-gnu.so)
frame #14: python() [0x5821ef]
<omitting python frames>
frame #18: python() [0x608e12]
frame #19: python() [0x6b5253]
frame #24: <unknown function> + 0x2a1ca (0x719e688c61ca in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #25: __libc_start_main + 0x8b (0x719e688c628b in /usr/lib/x86_64-linux-gnu/libc.so.6)
@wujingyue
Copy link
Collaborator Author

gdb gives a more useful stack trace. Apparently, this error came from ExactMappedExtentSubstitutionPass. cc @liqiangxl who appears to be the author (#1642).

(gdb) bt
#0  0x00007ffff77c435a in __cxa_throw () from /usr/lib/x86_64-linux-gnu/libstdc++.so.6
#1  0x00007ffedc5d249e in nvfuser::nvfCheckFail (func=0x7ffedd167f76 "maybeMutated", file=0x7ffedd167f50 "/opt/pytorch/nvfuser/csrc/mutator.cpp", line=45, msg=" INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/mutator.cpp\":45, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Two-hop mutations are not supported. "...) at /opt/pytorch/nvfuser/csrc/exceptions.cpp:275
#2  0x00007ffedc5d2728 in nvfuser::nvfErrorFail (func=0x7ffedd167f76 "maybeMutated", file=0x7ffedd167f50 "/opt/pytorch/nvfuser/csrc/mutator.cpp", line=45, condMsg=0x7ffedd167ea8 " INTERNAL ASSERT FAILED at \"/opt/pytorch/nvfuser/csrc/mutator.cpp\":45, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. ", userMsg="Two-hop mutations are not supported. Found registrations from 5 to 5 to 5") at /opt/pytorch/nvfuser/csrc/exceptions.cpp:301
#3  0x00007ffedca2c1d0 in nvfuser::OptOutMutator::maybeMutated (this=0x7fffffffbaf0, val=0x7fff38b8b180) at /opt/pytorch/nvfuser/csrc/mutator.cpp:45
#4  0x00007ffedca2c8d8 in nvfuser::OptOutMutator::mutate (this=0x7fffffffbaf0, id=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/mutator.cpp:101
#5  0x00007ffedc4d4e97 in nvfuser::Val::mutatorDispatch<nvfuser::OptOutMutator*> (mutator=0x7fffffffbaf0, val=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:206
#6  0x00007ffedca2bfc1 in nvfuser::OptOutMutator::dispatchMutate (this=0x7fffffffbaf0, v=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/mutator.cpp:33
#7  0x00007ffedc935b84 in nvfuser::ir_utils::(anonymous namespace)::ValReplacementMutator::dispatchMutate (this=0x7fffffffbaf0, val=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:492
#8  0x00007ffedc4d4ae4 in nvfuser::Statement::mutatorDispatch<nvfuser::OptOutMutator*> (mutator=0x7fffffffbaf0, stmt=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/dispatch.cpp:230
#9  0x00007ffedca2bf97 in nvfuser::OptOutMutator::dispatchMutate (this=0x7fffffffbaf0, s=0x7ffec6923580) at /opt/pytorch/nvfuser/csrc/mutator.cpp:29
#10 0x00007ffedc9358f5 in nvfuser::ir_utils::(anonymous namespace)::ValReplacementMutator::ValReplacementMutator (this=0x7fffffffbaf0, fusion=0x7ffede572c00, replacement_map=std::unordered_map with 5 elements = {...}) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:476
#11 0x00007ffedc936105 in nvfuser::ir_utils::replaceValue (fusion=0x7ffede572c00, replacement_map=std::unordered_map with 5 elements = {...}) at /opt/pytorch/nvfuser/csrc/ir/utils.cpp:533
#12 0x00007ffedcaddddc in nvfuser::preseg_passes::(anonymous namespace)::exactMappedExtentSubstitution (fusion=0x7ffede572c00) at /opt/pytorch/nvfuser/csrc/preseg_passes/exact_mapped_extent_substitution.cpp:81
#13 0x00007ffedcade00c in nvfuser::preseg_passes::ExactMappedExtentSubstitutionPass::runPass (fusion=0x7ffede572c00) at /opt/pytorch/nvfuser/csrc/preseg_passes/exact_mapped_extent_substitution.cpp:95
#14 0x00007ffedcafc39e in nvfuser::preseg_passes::OptimizationPass<nvfuser::preseg_passes::ExactMappedExtentSubstitutionPass>::runPass (fusion=0x7ffede572c00) at /opt/pytorch/nvfuser/csrc/preseg_passes/optimization_pass.h:54
#15 0x00007ffedcafa8a4 in nvfuser::preseg_passes::PreSegmenter::runPass (fusion=0x7ffede572c00) at /opt/pytorch/nvfuser/csrc/preseg_passes/pre_segmenter.cpp:66
#16 0x00007ffedcbc09a5 in nvfuser::preseg_passes::OptimizationPass<nvfuser::preseg_passes::PreSegmenter>::runPass (fusion=0x7ffede572c00) at /opt/pytorch/nvfuser/csrc/preseg_passes/optimization_pass.h:54
#17 0x00007ffedcbb7950 in nvfuser::FusionKernelRuntime::FusionKernelRuntime (this=0x7fff38b8c080, fusion=std::unique_ptr<nvfuser::Fusion> = {...}, args=..., serde_buffer=0x0, forced_index_type=std::optional [no contained value], fusion_id=0, concrete_id=1, runtime_id=0, auto_schedule=true) at /opt/pytorch/nvfuser/csrc/runtime/fusion_kernel_runtime.cpp:75
#18 0x00007ffedcbaa0ad in std::make_unique<nvfuser::FusionKernelRuntime, std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >, nvfuser::KernelArgumentHolder const&, decltype(nullptr), std::optional<nvfuser::PrimDataType>&, long&, long&, unsigned long, bool const&>(std::unique_ptr<nvfuser::Fusion, std::default_delete<nvfuser::Fusion> >&&, nvfuser::KernelArgumentHolder const&, decltype(nullptr)&&, std::optional<nvfuser::PrimDataType>&, long&, long&, unsigned long&&, bool const&) () at /usr/include/c++/13/bits/unique_ptr.h:1070
#19 0x00007ffedcba4e2b in nvfuser::FusionExecutorCache::getKernelRuntimeFor (this=0x7ffeddbd3000, args=..., forced_index_type=std::optional [no contained value]) at /opt/pytorch/nvfuser/csrc/runtime/fusion_executor_cache.cpp:653
#20 0x00007ffedcba191f in nvfuser::FusionExecutorCache::runFusionWithInputs (this=0x7ffeddbd3000, inputs=..., forced_index_type=std::optional [no contained value], selected_device=std::optional [no contained value]) at /opt/pytorch/nvfuser/csrc/runtime/fusion_executor_cache.cpp:58
#21 0x00007ffedceb1e84 in nvfuser::python_frontend::FusionDefinition::execute (this=0x7ffedfdf0580, inputs=..., selected_device=std::optional [no contained value], override_user_schedule=false, capture_debug_output=false, profile=false) at /opt/pytorch/nvfuser/csrc/python_frontend/fusion_definition.cpp:414
#22 0x00007ffedbe9efc8 in operator() (__closure=0x7ffeddc3ce58, self=..., iter=..., device=std::optional [no contained value], override_user_schedule=false, capture_debug_output=false, profile=false) at /opt/pytorch/nvfuser/csrc/python_frontend/python_bindings.cpp:1044
#23 0x00007ffedbfd433e in pybind11::detail::argument_loader<nvfuser::python_frontend::FusionDefinition&, pybind11::iterable const&, std::optional<long>, bool, bool, bool>::call_impl<std::vector<at::Tensor>, nvfuser::python_frontend::initNvFuserPythonBindings(PyObject*)::<lambda(nvfuser::python_frontend::FusionDefinition&, const pybind11::iterable&, std::optional<long int>, bool, bool, bool)>&, 0, 1, 2, 3, 4, 5, pybind11::detail::void_type>(struct {...} &, std::index_sequence, pybind11::detail::void_type &&) (this=0x7fffffffce90, f=...) at /usr/local/lib/python3.12/dist-packages/torch/include/pybind11/cast.h:1631
#24 0x00007ffedbfc4a61 in pybind11::detail::argument_loader<nvfuser::python_frontend::FusionDefinition&, pybind11::iterable const&, std::optional<long>, bool, bool, bool>::call<std::vector<at::Tensor>, pybind11::detail::void_type, nvfuser::python_frontend::initNvFuserPythonBindings(PyObject*)::<lambda(nvfuser::python_frontend::FusionDefinition&, const pybind11::iterable&, std::optional<long int>, bool, bool, bool)>&>(struct {...} &) (this=0x7fffffffce90, f=...)
    at /usr/local/lib/python3.12/dist-packages/torch/include/pybind11/cast.h:1600
#25 0x00007ffedbf68ac6 in operator() (__closure=0x0, call=...) at /usr/local/lib/python3.12/dist-packages/torch/include/pybind11/pybind11.h:278
#26 0x00007ffedbf68bac in _FUN () at /usr/local/lib/python3.12/dist-packages/torch/include/pybind11/pybind11.h:249
#27 0x00007ffedc006906 in pybind11::cpp_function::dispatcher (self=0x7fff31d7f750, args_in=0x7ffff6c4e740, kwargs_in=0x7fff23d1e500) at /usr/local/lib/python3.12/dist-packages/torch/include/pybind11/pybind11.h:971
#28 0x00000000005821ef in ?? ()
#29 0x0000000000548f8e in _PyObject_MakeTpCall ()
#30 0x00000000005d7819 in _PyEval_EvalFrameDefault ()
#31 0x00000000005d5d2b in PyEval_EvalCode ()
#32 0x0000000000608e12 in ?? ()
#33 0x00000000006b5253 in ?? ()
#34 0x00000000006b4fba in _PyRun_SimpleFileObject ()
#35 0x00000000006b4def in _PyRun_AnyFileObject ()
#36 0x00000000006bce95 in Py_RunMain ()
#37 0x00000000006bc97d in Py_BytesMain ()
#38 0x00007ffff79b31ca in ?? () from /usr/lib/x86_64-linux-gnu/libc.so.6
#39 0x00007ffff79b328b in __libc_start_main () from /usr/lib/x86_64-linux-gnu/libc.so.6
#40 0x00000000006584a5 in _start ()

@wujingyue
Copy link
Collaborator Author

cc @Priya2698 as well because this is related to linear.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 7, 2024

@liqiangxl I think at this line we should do a loop over replacement_map and chase references so that the map's entries all point to the leaves.

// Replace non-const extents with const extents
ir_utils::replaceValue(fusion, replacement_map);

I wonder if we should just do this automatically in replaceValue so that we don't have to handle this in every use case.

@liqiangxl
Copy link
Collaborator

The issue is becuase the extents in all these 3 disjoint sets are 5. We don't need to add constant values in the replacement map becuase if they are in the same disjoint set, they must be of the same extent.

disjoint sets{
  { iS5{5}; iS0{5}; iS2{5} }
  { rS7{5}; iS1{5}; iS3{5} }
  { iS6{5}; iS4{5} }
}

@liqiangxl
Copy link
Collaborator

@naoyam also suggested

we may want to build a DisjointSets of extents based on the DisjointSets of IterDomains. The latter doesn't guarantee their extents are also disjoint.

@liqiangxl
Copy link
Collaborator

liqiangxl commented Nov 11, 2024

I need to use customized hash & equal functions when using DisjointSets to ensure const extents are hashed with its const value instead of pointer address. Results when using DisjointSets<Val*, ValPtrHash, ValPtrEqual> extent_sets

============id_sets==================
disjoint sets{
  { iS5{5}; iS0{5}; iS2{5} }
  { rS7{5}; iS1{5}; iS3{5} }
  { iS6{5}; iS4{5} }
}
==============================
============extent_set==================
disjoint sets{
  { 5 }
}

Otherwise, these extents are hashed differently. Results when using DisjointSets<Val*> extent_sets;

============extent_set==================
Extent sets: disjoint sets{
  { 5; 5; 5 }
  { 5; 5 }
}

This can still solve the original bug, since these vals have differnt address. So DisjointSets<Val*, ValPtrHash, ValPtrEqual> extent_sets is not required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
3 participants