Skip to content

Commit

Permalink
Move some cudagraphs checks into C++ (pytorch#122251)
Browse files Browse the repository at this point in the history
Based off of pytorch#111094
This + cpp guards improves TIMM geomean optimizer performance by about 20%

Pull Request resolved: pytorch#122251
Approved by: https://github.com/eellison
  • Loading branch information
mlazos authored and pytorchmergebot committed Mar 21, 2024
1 parent be5863d commit c20cf97
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 23 deletions.
17 changes: 17 additions & 0 deletions test/inductor/test_cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,23 @@ def foo(mod, inp):
node = self.get_manager().current_node
self.assertEqual(len(list(node.path_live_weakrefs())), 1)

def test_unstable_ptr(self):
import torch

@torch.compile(mode="reduce-overhead")
def foo(m, inp):
return m(inp)

def f():
l = []
m = torch.nn.Linear(20, 20).cuda()
for _ in range(4):
inp = torch.rand([20, 20], device="cuda")
foo(m, inp)
m.weight.data = torch.rand([20, 20], device="cuda")

self.assertRaises(RuntimeError, f)

@requires_multigpu()
def test_manager_per_device(self):
def test():
Expand Down
1 change: 1 addition & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,7 @@ def _cuda_getCheckpointState(device: _int, mempool: Tuple[_int, _int]) -> _cuda_
def _set_cached_tensors_enabled(enabled: _bool) -> None: ...
def _add_cached_tensor(t: Tensor) -> None: ...
def _remove_cached_tensor(t: Tensor) -> None: ...
def _tensors_data_ptrs_at_indices_equal(tensors: List[Tensor], ptrs: List[Optional[_int]], indices: List[_int]) -> _bool: ...
def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ...
def _storage_Use_Count(storage_ptr: _int) -> _int: ...
def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ...
Expand Down
68 changes: 45 additions & 23 deletions torch/_inductor/cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
Iterator,
List,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -128,7 +127,7 @@ class WrappedFunction:
"""

model: Callable[..., Any]
static_input_idxs: Sequence[int]
static_input_idxs: List[int]
id: FunctionID
constants: Tuple[torch.Tensor, ...]

Expand Down Expand Up @@ -787,6 +786,16 @@ def __init__(
set(wrapped_function.static_input_idxs) | set(self.cudagraph_managed_idxs)
)

self.non_static_input_idx: LevelList[int] = [
i for i in range(len(inputs)) if i not in self.static_input_idxs
]

self.non_managed_static_input_idxs: LevelList[int] = [
i
for i in wrapped_function.static_input_idxs
if i not in self.cudagraph_managed_idxs
]

self.static_input_data_ptrs: InputList[Optional[int]] = [
(
inputs[i].data_ptr()
Expand Down Expand Up @@ -924,6 +933,23 @@ def _copy_input(self, idx, dst, src):
# TODO - one jit kernel across multiple inputs
dst.copy_(src)

def check_static_inputs_are_stable(self, new_inputs):
# avoid checking managed tensor static points since we already checked those in check_invariants
if not torch._C._tensors_data_ptrs_at_indices_equal(
new_inputs, self.static_input_data_ptrs, self.non_managed_static_input_idxs
):
# this should error
static_tensors = [new_inputs[i] for i in self.non_managed_static_input_idxs]
data_ptrs = [
self.static_input_data_ptrs[i]
for i in self.non_managed_static_input_idxs
]
for t, data_ptr in zip(static_tensors, data_ptrs):
torch._check(
t.data_ptr() == data_ptr,
lambda: f"static input data pointer changed from {data_ptr} to {t.data_ptr()}",
)

def run_first_inputs(self, new_inputs):
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_before_invocation()
Expand All @@ -936,30 +962,23 @@ def run_first_inputs(self, new_inputs):
return outputs

def run(self, new_inputs):
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_before_invocation()
self.check_static_inputs_are_stable(new_inputs)

assert len(self.static_input_data_ptrs) == len(new_inputs)
# NB: this ranges over non-static inputs too
for idx, data_ptr in enumerate(self.static_input_data_ptrs):
if idx in self.cudagraph_managed_idxs:
continue
for idx in self.non_static_input_idx:
if not isinstance(new_inputs[idx], torch.Tensor):
pass
elif data_ptr is not None:
# static input, e.g., parameter
assert data_ptr == new_inputs[idx].data_ptr()
else:
# non-static input, need to copy it into CUDA graph
dst = self.reconstructed_inputs[idx]
src = new_inputs[idx]
self._copy_input(idx, dst, src)
continue

# non-static input, need to copy it into CUDA graph
self._copy_input(idx, self.reconstructed_inputs[idx], new_inputs[idx])

new_inputs.clear()

self.run_graph()

outputs = self.reconstruct_outputs()
self.debug_check_invariants_after_invocation()

if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_after_invocation()

return outputs

Expand Down Expand Up @@ -1513,9 +1532,12 @@ def check_invariants(self, inputs: List[Tensor]) -> bool:
"""

# previously managed data pointers remain stable
for idx in self.cudagraph_managed_idxs:
if inputs[idx].data_ptr() != self.static_input_data_ptrs[idx]:
return False
# this is on the hot path so moved to C++. equivalent to:
# return all(t.data_ptr() == data_ptr for (t, data_ptr) in zip(tensors, data_ptrs))
if not torch._C._tensors_data_ptrs_at_indices_equal(
inputs, self.static_input_data_ptrs, self.cudagraph_managed_idxs
):
return False

if not self._check_liveness(
self.expected_dead_indices_before_graph, self.path_weakrefs
Expand Down Expand Up @@ -1931,7 +1953,7 @@ def add_function(
self.ids_to_stack_traces[id] = stack_traces
self.ids_to_funcs[id] = WrappedFunction(
model,
static_input_idxs,
list(static_input_idxs),
id,
tuple(t for t in constants if isinstance(t, torch.Tensor) and t.is_cuda),
)
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/cuda/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1191,6 +1191,20 @@ static void registerCudaPluggableAllocator(PyObject* module) {
return c10::raw::weak_intrusive_ptr::use_count(storage_impl);
});

m.def(
"_tensors_data_ptrs_at_indices_equal",
[](py::list& tensors, py::list& data_ptrs, py::list& indices) {
for (size_t i = 0, end = indices.size(); i < end; ++i) {
auto index = indices[i].cast<int64_t>();
auto t = tensors[index].cast<at::Tensor>();
auto data_ptr = data_ptrs[index].cast<int64_t>();
if (reinterpret_cast<int64_t>(t.data_ptr()) != data_ptr) {
return false;
}
}
return true;
});

m.def(
"_construct_CUDA_Tensor_From_Storage_And_Metadata",
[](py::dict& metadata, c10::Storage s) {
Expand Down

0 comments on commit c20cf97

Please sign in to comment.