From c20cf973669a4d5df32f14851f6b609c473ce874 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 21 Mar 2024 01:02:20 +0000 Subject: [PATCH] Move some cudagraphs checks into C++ (#122251) Based off of https://github.com/pytorch/pytorch/pull/111094 This + cpp guards improves TIMM geomean optimizer performance by about 20% Pull Request resolved: https://github.com/pytorch/pytorch/pull/122251 Approved by: https://github.com/eellison --- test/inductor/test_cudagraph_trees.py | 17 +++++++ torch/_C/__init__.pyi.in | 1 + torch/_inductor/cudagraph_trees.py | 68 ++++++++++++++++++--------- torch/csrc/cuda/Module.cpp | 14 ++++++ 4 files changed, 77 insertions(+), 23 deletions(-) diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 3ba47db57cf3b..01400c55eeffc 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -1230,6 +1230,23 @@ def foo(mod, inp): node = self.get_manager().current_node self.assertEqual(len(list(node.path_live_weakrefs())), 1) + def test_unstable_ptr(self): + import torch + + @torch.compile(mode="reduce-overhead") + def foo(m, inp): + return m(inp) + + def f(): + l = [] + m = torch.nn.Linear(20, 20).cuda() + for _ in range(4): + inp = torch.rand([20, 20], device="cuda") + foo(m, inp) + m.weight.data = torch.rand([20, 20], device="cuda") + + self.assertRaises(RuntimeError, f) + @requires_multigpu() def test_manager_per_device(self): def test(): diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 002788131019d..e23c928f8db5e 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1695,6 +1695,7 @@ def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_ def _set_cached_tensors_enabled(enabled: _bool) -> None: ... def _add_cached_tensor(t: Tensor) -> None: ... def _remove_cached_tensor(t: Tensor) -> None: ... +def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ... def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ... def _storage_Use_Count(storage_ptr: _int) -> _int: ... def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 9cb19bb72ad71..12b855f53dbb4 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -58,7 +58,6 @@ Iterator, List, Optional, - Sequence, Set, Tuple, Union, @@ -128,7 +127,7 @@ class WrappedFunction: """ model: Callable[..., Any] - static_input_idxs: Sequence[int] + static_input_idxs: List[int] id: FunctionID constants: Tuple[torch.Tensor, ...] @@ -787,6 +786,16 @@ def __init__( set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs) ) + self.non_static_input_idx: LevelList[int] = [ + i for i in range(len(inputs)) if i not in self.static_input_idxs + ] + + self.non_managed_static_input_idxs: LevelList[int] = [ + i + for i in wrapped_function.static_input_idxs + if i not in self.cudagraph_managed_idxs + ] + self.static_input_data_ptrs: InputList[Optional[int]] = [ ( inputs[i].data_ptr() @@ -924,6 +933,23 @@ def _copy_input(self, idx, dst, src): # TODO - one jit kernel across multiple inputs dst.copy_(src) + def check_static_inputs_are_stable(self, new_inputs): + # avoid checking managed tensor static points since we already checked those in check_invariants + if not torch._C._tensors_data_ptrs_at_indices_equal( + new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs + ): + # this should error + static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs] + data_ptrs = [ + self.static_input_data_ptrs[i] + for i in self.non_managed_static_input_idxs + ] + for t, data_ptr in zip(static_tensors, data_ptrs): + torch._check( + t.data_ptr() == data_ptr, + lambda: f"static input data pointer changed from {data_ptr} to {t.data_ptr()}", + ) + def run_first_inputs(self, new_inputs): if config.triton.fast_path_cudagraph_asserts: self.debug_check_invariants_before_invocation() @@ -936,30 +962,23 @@ def run_first_inputs(self, new_inputs): return outputs def run(self, new_inputs): - if config.triton.fast_path_cudagraph_asserts: - self.debug_check_invariants_before_invocation() + self.check_static_inputs_are_stable(new_inputs) - assert len(self.static_input_data_ptrs) == len(new_inputs) - # NB: this ranges over non-static inputs too - for idx, data_ptr in enumerate(self.static_input_data_ptrs): - if idx in self.cudagraph_managed_idxs: - continue + for idx in self.non_static_input_idx: if not isinstance(new_inputs[idx], torch.Tensor): - pass - elif data_ptr is not None: - # static input, e.g., parameter - assert data_ptr == new_inputs[idx].data_ptr() - else: - # non-static input, need to copy it into CUDA graph - dst = self.reconstructed_inputs[idx] - src = new_inputs[idx] - self._copy_input(idx, dst, src) + continue + + # non-static input, need to copy it into CUDA graph + self._copy_input(idx, self.reconstructed_inputs[idx], new_inputs[idx]) new_inputs.clear() + self.run_graph() outputs = self.reconstruct_outputs() - self.debug_check_invariants_after_invocation() + + if config.triton.fast_path_cudagraph_asserts: + self.debug_check_invariants_after_invocation() return outputs @@ -1513,9 +1532,12 @@ def check_invariants(self, inputs: List[Tensor]) -> bool: """ # previously managed data pointers remain stable - for idx in self.cudagraph_managed_idxs: - if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]: - return False + # this is on the hot path so moved to C++. equivalent to: + # return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs)) + if not torch._C._tensors_data_ptrs_at_indices_equal( + inputs, self.static_input_data_ptrs, self.cudagraph_managed_idxs + ): + return False if not self._check_liveness( self.expected_dead_indices_before_graph, self.path_weakrefs @@ -1931,7 +1953,7 @@ def add_function( self.ids_to_stack_traces[id] = stack_traces self.ids_to_funcs[id] = WrappedFunction( model, - static_input_idxs, + list(static_input_idxs), id, tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda), ) diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 0ce243c300c3a..e622c254a5003 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -1191,6 +1191,20 @@ static void registerCudaPluggableAllocator(PyObject* module) { return c10::raw::weak_intrusive_ptr::use_count(storage_impl); }); + m.def( + "_tensors_data_ptrs_at_indices_equal", + [](py::list& tensors, py::list& data_ptrs, py::list& indices) { + for (size_t i = 0, end = indices.size(); i < end; ++i) { + auto index = indices[i].cast(); + auto t = tensors[index].cast(); + auto data_ptr = data_ptrs[index].cast(); + if (reinterpret_cast(t.data_ptr()) != data_ptr) { + return false; + } + } + return true; + }); + m.def( "_construct_CUDA_Tensor_From_Storage_And_Metadata", [](py::dict& metadata, c10::Storage s) {