Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the crash when symint is part of the output with dynamic torch.compile #7735

Merged
merged 3 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/dynamo/test_dynamo_dynamic_shape.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import sys
import os
example_folder = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(
sys.argv[0])))) + "/examples"
sys.path.append(example_folder)
from decoder_only_model import DecoderOnlyConfig, DecoderOnlyModel

import unittest
import sys

Expand Down Expand Up @@ -183,6 +191,25 @@ def test_dynamic_shape_mix_with_non_dynamic(self):
self.assertEqual(met.metric_data('CompileTime')[0], 1)
self.assertEqual(met.metric_data('ExecuteTime')[0], 1)

def test_dynamic_shape_symint_as_return(self):
device = torch_xla.device()
config = DecoderOnlyConfig()
config.num_hidden_layers = 2
config.hidden_size = 512
seq_len = 512
decoder_model = DecoderOnlyModel(config).to(device)
compiled_decoder_model = torch.compile(
decoder_model, backend="openxla", dynamic=True)
xm.wait_device_ops()
met.clear_all()
for batch_size in [1, 2, 3, 4, 5]:
input = torch.zeros(batch_size, seq_len, dtype=torch.int64).to(device)
res = compiled_decoder_model(input)
# For some reason `batch_size == 1` is a special case where output does not
# have additional ints. We will have on compile for batch size 1 and one compile
# for other batch sizes.
self.assertEqual(met.counter_value('DynamoExtractCompiledGraph'), 2)

def test_dynamic_shape_no_retracing(self):
device = torch_xla.device()
# model setup
Expand Down
72 changes: 51 additions & 21 deletions torch_xla/core/dynamo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,11 @@ def recover(self, deduped_list):
return [deduped_list[i] for i in self.permute_for_orig]


class DumbReturnHandler:
class SpecialReturnHandler:
"""
Define dumb return as an output that is also an input.
In this class we handle 2 types of special outputs
1. dump return
We Define dumb return as an output that is also an input.
Torch xla does not return such tensors as its graph output. That breaks the
API contract with the caller of the graph. Also AOTAutograd
may generate such a graph quite often.
Expand All @@ -221,12 +223,18 @@ class DumbReturnHandler:
(this is a graph generated for a model with a single BatchNorm2d)
XLA will dedup those duplicate items, but we need recover the duplications to maintain
the contract with the caller.

2. Int output from dynamic compile
In the case of the `torch.compile(Dynamic=True)` there might be some int outputs related
to the dynmiac dimension of the tensor. These ints are static for a given input shape
combinations so we can cache them and inject to the final result directly.
"""

def __init__(self, trace_inputs, trace_outputs,
trace_inputs_inplace_update_bool):
trace_inputs_inplace_update_bool, int_outputs_and_indexes):
self.trace_inputs = trace_inputs
self.trace_outputs = trace_outputs
self.int_outputs_and_indexes = int_outputs_and_indexes

# dedup the traced outputs first
self.deduper = Deduper()
Expand All @@ -251,6 +259,12 @@ def addDumbReturn(self, real_inputs, real_outputs):
real_outputs.insert(out_pos, real_inputs[in_pos])

ret = self.deduper.recover(real_outputs)

if len(self.int_outputs_and_indexes) != 0:
# insert the int outputs back to the res
for index, value in self.int_outputs_and_indexes:
ret.insert(index, value)

return ret


Expand Down Expand Up @@ -371,22 +385,36 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
xla_args_need_update.append(tensor)

args_and_out = tuple(xla_args_need_update) + tuple(xla_out)
# args_and_out should be tensor only, in the dynamic cases there might be
# symint return as the result. In that case we want to extract them and separate
# them from the device computation.
int_outputs_and_indexes = []
args_and_out_tensor_only = []
for i in range(len(args_and_out)):
arg = args_and_out[i]
if not isinstance(arg, torch.Tensor):
assert type(arg) == int
int_outputs_and_indexes.append((i, arg))
else:
args_and_out_tensor_only.append(arg)

# calculate graph hash
dumb_return_handler = DumbReturnHandler(xla_args, args_and_out,
xla_args_need_update_bool)
special_return_handler = SpecialReturnHandler(xla_args,
args_and_out_tensor_only,
xla_args_need_update_bool,
int_outputs_and_indexes)

# There is a `mark_step` in the beginning of this function call, we need to wait
# for that to finish before retriving the device data nodes.
xm.wait_device_ops()
# Collect all device data nodes that is needed to compute the args_and_out
# Collect all device data nodes that is needed to compute the args_and_out_tensor_only
# and wrap those device data nodes inside a at::tensor(graph_input_xla_values).
# Return the tensor_id that is corresponding to every device_data node as
# graph_input_tensor_ids.
(
graph_input_tensor_ids,
graph_input_xla_values,
) = torch_xla._XLAC._get_tensors_xla_device_data_node(args_and_out)
) = torch_xla._XLAC._get_tensors_xla_device_data_node(
args_and_out_tensor_only)
assert len(graph_input_tensor_ids) == len(
graph_input_xla_values
), f"{len(graph_input_tensor_ids)} v.s. {len(graph_input_xla_values)}"
Expand All @@ -402,18 +430,20 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,
buffer_donor_count += 1 if isinstance(
graph_input, torch.Tensor) and torch_xla._XLAC._get_buffer_donation(
graph_input) else 0
print(f"Number of HLO Output: {len(args_and_out)}")
print(f"Number of HLO Output: {len(args_and_out_tensor_only)}")
print(
f"Number of HLO Input can be aliased with Output: {buffer_donor_count}")
print(
f"XLA IR Text: \n{torch_xla._XLAC._get_xla_tensors_text(args_and_out)}")
f"XLA IR Text: \n{torch_xla._XLAC._get_xla_tensors_text(args_and_out_tensor_only)}"
)

with alias_with_buffer_donor_config() as saved_config:
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out)
# calculate graph hash
graph_hash = torch_xla._XLAC._get_graph_hash(args_and_out_tensor_only)
if dynamo_debug:
print("Graph Hash: ", graph_hash)
# compiles and cache graph rooted at tensors in 'args_and_out'
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
# compiles and cache graph rooted at tensors in 'args_and_out_tensor_only'
torch_xla._XLAC._xla_warm_up_cache(args_and_out_tensor_only, [])

# Restore the origional `xla_args`. Dynamo passed the real tensor as
# `xla_args`` and we performend the tracing on them. During the tracing,
Expand All @@ -434,7 +464,7 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule,

vars_to_return = (xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover,
graph_input_matcher, dumb_return_handler,
graph_input_matcher, special_return_handler,
xla_args_need_update)
# populate the cache if model is compiled with `dynamic=True`
if not torch._dynamo.config.assume_static_by_default:
Expand All @@ -461,12 +491,12 @@ def extract_internal(xla_model: torch.fx.GraphModule):
# Keys: tuple of input shapes
# Values: tuple of (xla_args_sharding_spec, args_and_out, graph_hash,
# arg_index_to_need_update_index, none_remover, graph_input_matcher,
# dumb_return_handler, xla_args_need_update).
# special_return_handler, xla_args_need_update).
shapes_to_graph_vars: Dict[Tuple[int, ...], Tuple[Any, ...]] = {}

(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
special_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model, shapes_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
'XLA_DYNAMO_INPUT_SHARDING_CHECK_THRESHOLD', int, 5)
Expand All @@ -479,7 +509,7 @@ def optimized_mod(*args: tuple):
nonlocal arg_index_to_need_update_index
nonlocal none_remover
nonlocal graph_input_matcher
nonlocal dumb_return_handler
nonlocal special_return_handler
nonlocal xla_args_need_update
nonlocal skip_checking_input_sharding_threashold
nonlocal shapes_to_graph_vars
Expand All @@ -491,12 +521,12 @@ def optimized_mod(*args: tuple):
if arg_input_shapes in shapes_to_graph_vars:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
special_return_handler,
xla_args_need_update) = shapes_to_graph_vars[arg_input_shapes]
else:
(xla_args_sharding_spec, args_and_out, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
special_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model,
shapes_to_graph_vars)

Expand Down Expand Up @@ -534,7 +564,7 @@ def optimized_mod(*args: tuple):
xla_model.xla_args = args
(xla_args_sharding_spec, args_and_out_copy, graph_hash,
arg_index_to_need_update_index, none_remover, graph_input_matcher,
dumb_return_handler,
special_return_handler,
xla_args_need_update) = extract_graph_helper(xla_model,
shapes_to_graph_vars)
skip_checking_input_sharding_threashold = xu.getenv_as(
Expand All @@ -549,7 +579,7 @@ def optimized_mod(*args: tuple):
graph_input = graph_input_matcher(args)
start_ts = time.time()
res = torch_xla._XLAC._run_cached_graph(graph_hash, graph_input)
res = dumb_return_handler.addDumbReturn(args, res)
res = special_return_handler.addDumbReturn(args, res)

assert len(res) == len(args_and_out), f"{len(res)} v.s. {len(args_and_out)}"
ncopy = 0
Expand Down
Loading