Skip to content

Commit

Permalink
Always create ShapeEnv, always apply unspec logic (#103302)
Browse files Browse the repository at this point in the history
Originally, my goal for this PR was to remove the `dynamic_shapes` tests in torch/_dynamo/variables/builder.py. However, one thing lead to another, and it turns out that it was easiest to do all of the following in one go:

* Unconditionally allocate a ShapeEnv, no matter if dynamic_shapes is enabled or not (torch/_dynamo/output_graph.py). There is a small adjustment to export torch/_dynamo/eval_frame.py to account for the fact that a ShapeEnv always exists, even if you're not doing symbolic export.
* Remove dynamic_shapes test from unspec logic (torch/_dynamo/variables/builder.py), the original goal
* Specialize strides and storage offset if all sizes are dynamic (torch/fx/experimental/symbolic_shapes.py). This is required to deal with unconditional ShapeEnv: if a ShapeEnv exist, fake tensor-ification may choose to allocate symbols. The idea is that with `automatic_dynamic_shapes == False`, Dynamo should never request dynamic sizes, but this invariant was not upheld for nontrivial strides/offset.

The rest are just auxiliary fixups from the above:

* Workaround bug in FakeTensorProp where sometimes it doesn't return a FakeTensor (torch/fx/passes/fake_tensor_prop.py), see #103395 for follow up
* Make ShapeProp correctly handle int inputs (torch/fx/passes/shape_prop.py)
* Disable indexing strength reduction if `assume_static_by_default` is False (torch/_inductor/codegen/triton.py)
* Fix hf_T5_generate to NOT toggle `assume_static_by_default` if dynamic shapes is not enabled (benchmarks/dynamo/common.py); technically this is not necessary anymore but it's in for safety.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: #103302
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Jun 12, 2023
1 parent f4228e7 commit c3fdfca
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 65 deletions.
7 changes: 6 additions & 1 deletion benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2366,7 +2366,12 @@ def run(runner, args, original_dir=None):
torch.use_deterministic_algorithms(True)
if args.only in {"hf_T5_generate"}:
# See https://github.com/pytorch/pytorch/issues/102814
torch._dynamo.config.assume_static_by_default = False
if torch._dynamo.config.dynamic_shapes:
torch._dynamo.config.assume_static_by_default = False
if not torch._dynamo.config.automatic_dynamic_shapes:
log.warning(
"hf_T5_generate compiles extremely slowly without dynamic shapes; consider lowering cache_size_limit"
)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
Expand Down
8 changes: 4 additions & 4 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,10 +988,10 @@ def result_capturing_wrapper(*graph_inputs):
remove_from_cache(f)

if (
shape_env := getattr(fake_mode, "shape_env", None)
) is not None and not skipfiles.check(inspect.getsourcefile(call_to_inspect)):
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
(shape_env := getattr(fake_mode, "shape_env", None)) is not None
and (dim_constraints := shape_env.dim_constraints) is not None
and not skipfiles.check(inspect.getsourcefile(call_to_inspect))
):
dim_constraints.solve()
msg = dim_constraints.prettify_results(original_signature)
forced_specializations = dim_constraints.forced_specializations()
Expand Down
4 changes: 1 addition & 3 deletions torch/_dynamo/output_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,7 @@ def __init__(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
frame_id=frame_state["_id"],
)
if config.dynamic_shapes
else None,
),
# TODO (tmanlaibaatar) Remove this once we always lift params and buffers
allow_non_fake_inputs=True if self.export else False,
)
Expand Down
96 changes: 46 additions & 50 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def wrap_module(self, value: torch.nn.Module):
)

def wrap_literal(self, value):
unspec = not config.specialize_int and config.dynamic_shapes
unspec = not config.specialize_int
if unspec and type(value) is torch.Size:
return SizeVariable(
[
Expand Down Expand Up @@ -930,8 +930,7 @@ def wrap_unspecialized_primitive(self, value):
# but the general idea is that we generate kernels that can
# take unspecialized floats and use them in sizevar computation
if (
config.dynamic_shapes
and isinstance(value, int)
isinstance(value, int)
and not is_constant_source(self.get_source())
and not isinstance(self.get_source(), RandomValueSource)
):
Expand Down Expand Up @@ -1218,10 +1217,9 @@ def _clone_input(value):
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutable_local=MutableLocal(), **options)
else:
assert (
example_value.__class__.__module__ == "torch.return_types"
or hasattr(example_value, "_fields")
), ("namedtuple?")
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"
), f"expected {example_value.__class__.__module__} == torch.return_types or named tuple but got {type(example_value)}"
return NamedTupleVariable(unpacked, example_value.__class__, **options)
elif example_value is None or proxy.node.target is torch.manual_seed:
return ConstantVariable(None, **options)
Expand Down Expand Up @@ -1338,51 +1336,49 @@ def update_dim2constraint(dim, constraint_range):
constraint.shared.dim, constraint.constraint_range
)

dynamic_dims = None
constraint_dims = None
if tx.fake_mode.shape_env is not None:
dynamic_dims = []
constraint_dims = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

# NB: both static and dynamic have precedence over
automatic_dynamic = config.automatic_dynamic_shapes and (
frame_state_entry.size is None or frame_state_entry.size[i] is None
)
dynamic_dims = []
constraint_dims = []
for i in range(e.dim()):
# NB: mark dynamic has precedence over static
marked_dynamic = i in getattr(e, "_dynamo_dynamic_indices", set())
marked_weak_dynamic = i in getattr(e, "_dynamo_weak_dynamic_indices", set())
marked_static = i in getattr(e, "_dynamo_static_indices", set())

# NB: both static and dynamic have precedence over
automatic_dynamic = config.automatic_dynamic_shapes and (
frame_state_entry.size is None or frame_state_entry.size[i] is None
)

# Reflect the user directive in the frame_state
# For dynamic, apply None always
if frame_state_entry.size and marked_dynamic:
frame_state_entry.size[i] = None

# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=False)
elif not marked_static and automatic_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=True)
constraint_dims.append(constraint)

# Now, figure out if the dim is dynamic/duck/static
if constraint is not None or marked_dynamic or marked_weak_dynamic:
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override policy in this
# case
dynamic = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic = DimDynamic.STATIC
else:
dynamic = DimDynamic.DUCK
dynamic_dims.append(dynamic)
# Reflect the user directive in the frame_state
# For dynamic, apply None always
if frame_state_entry.size and marked_dynamic:
frame_state_entry.size[i] = None

# We will process constraints first, as they will imply that we
# have a dynamic dimension
# Precedence: export constraints > eager constraints
constraint = dim2constraint.get(i)
if constraint is None:
if marked_dynamic and not config.allow_ignore_mark_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=False)
elif not marked_static and automatic_dynamic:
constraint = RelaxedUnspecConstraint(warn_only=True)
constraint_dims.append(constraint)

# Now, figure out if the dim is dynamic/duck/static
if constraint is not None or marked_dynamic or marked_weak_dynamic:
# NB: We could assert static_shapes is False here, but it
# seems better to allow the user to override policy in this
# case
dynamic = DimDynamic.DYNAMIC
elif static_shapes or config.assume_static_by_default or marked_static:
dynamic = DimDynamic.STATIC
else:
dynamic = DimDynamic.DUCK

dynamic_dims.append(dynamic)

tx.output.frame_state[name] = frame_state_entry
tx.output.frame_state[name] = frame_state_entry

return dynamic_dims, constraint_dims

Expand Down
8 changes: 7 additions & 1 deletion torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,13 @@ def current_reduction_nodes(nodes):
kernel.set_last_usage(current_reduction_nodes(node_schedule[i:]))
else:
# TODO - mostly works but needs a couple fixes
if not dynamo_config.dynamic_shapes:
# Problem looks like free variables NYI: s0
# We need to detect if the proposed ranges would have
# symbols and bail out on this optimization if so
if (
not dynamo_config.dynamic_shapes
and dynamo_config.assume_static_by_default
):
# TODO - use split ranges ?
indexing_dtype_strength_reduction(node._body)
index_vars = kernel.split_and_set_ranges(node.get_ranges())
Expand Down
14 changes: 10 additions & 4 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2043,6 +2043,14 @@ def create_symbolic_sizes_strides_storage_offset(
dynamic_dims.append(r)
dynamic_dims = [DimDynamic.DUCK] * dim

# TODO: make this configurable from outside policy; we made a policy
# decision here where if all sizes are static, we are going to
# specialize all of the inner strides/offset too. We don't have to
# do this, and arguably we should ALWAYS allow for dynamic offset,
# this is cheap.
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK

assert len(dynamic_dims) == dim
assert len(constraint_dims) == dim

Expand Down Expand Up @@ -2078,8 +2086,7 @@ def create_symbolic_sizes_strides_storage_offset(
stride[i] = self.create_symbol(
val,
TensorPropertySource(source, TensorProperty.STRIDE, i),
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_dim=DimDynamic.DUCK,
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
)
assert all(x is not None for x in stride)
Expand All @@ -2094,8 +2101,7 @@ def create_symbolic_sizes_strides_storage_offset(
sym_storage_offset = self.create_symintnode(self.create_symbol(
ex.storage_offset(),
TensorPropertySource(source, TensorProperty.STORAGE_OFFSET),
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_dim=DimDynamic.DUCK,
dynamic_dim=dynamic_strides_offset,
constraint_dim=None,
), hint=ex.storage_offset())
return sym_sizes, sym_stride, sym_storage_offset
Expand Down
4 changes: 3 additions & 1 deletion torch/fx/passes/fake_tensor_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def extract_val(obj):
if isinstance(obj, FakeTensor):
return snapshot_fake(obj)
elif isinstance(obj, torch.Tensor):
return snapshot_fake(self._mode.from_tensor(obj))
# TODO: How is it possible that we get a non fake tensor? We
# should be running under the mode...
return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True))
elif isinstance(obj, py_sym_types):
return obj
else:
Expand Down
2 changes: 1 addition & 1 deletion torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def propagate(self, *args):
Any: The value returned from executing the Module
"""
if self.fake_mode is not None:
fake_args = [self.fake_mode.from_tensor(t) for t in args]
fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args]
else:
fake_args = args
return super().run(*fake_args)

0 comments on commit c3fdfca

Please sign in to comment.