Skip to content

Commit

Permalink
Enable automatic_dynamic_shapes by default (pytorch#103623)
Browse files Browse the repository at this point in the history
Some notes:

* I now manually turn off `_generate` jobs from running with cudagraphs, as it is unrealistic to expect to cudagraph autoregressive generation up to max sequence length, this would imply compiling the entire unrolled sequence generation. Concretely, cm3leon_generate was timing out post this change, likely due to the compile time slowdown of dynamic shapes ON TOP OF accidentally unrolling all the loops
* A few torch._dynamo.reset tactically inserted to force recompiles on tests that expected it
* expectedFailureAutomaticDynamic flip into patching automatic_dynamic_shapes=False

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: pytorch#103623
Approved by: https://github.com/voznesenskym
  • Loading branch information
ezyang authored and pytorchmergebot committed Jul 5, 2023
1 parent 2abbed4 commit 2385dad
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ basic_gnn_edgecnn,pass,0
basic_gnn_gcn,pass,6
basic_gnn_gin,pass,0
basic_gnn_sage,pass,0
cm3leon_generate,pass,67
cm3leon_generate,pass,6
dcgan,pass,0
dlrm,pass,0
doctr_det_predictor,pass,4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ hf_Bert,pass,7
hf_Bert_large,pass,7
hf_DistilBert,pass,7
hf_GPT2,pass,7
hf_Reformer,pass,64
hf_Reformer,pass,44
hf_T5_large,pass_due_to_skip,0
lennard_jones,pass,8
maml_omniglot,pass,8
Expand Down Expand Up @@ -48,5 +48,5 @@ timm_vision_transformer_large,pass_due_to_skip,0
timm_vovnet,pass,8
tts_angular,pass,10
vgg16,pass,8
vision_maskrcnn,fail_accuracy,56
yolov3,pass,11
vision_maskrcnn,fail_accuracy,42
yolov3,pass,10
2 changes: 1 addition & 1 deletion benchmarks/dynamo/ci_expected_accuracy/update_expected.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def get_artifacts_urls(results, suites):
for r in results:
if "inductor" == r["workflowName"] and "test" in r["jobName"]:
config_str, test_str = parse_job_name(r["jobName"])
suite, shard_id, num_shards, machine = parse_test_str(test_str)
suite, shard_id, num_shards, machine, *_ = parse_test_str(test_str)
workflowId = r["workflowId"]
id = r["id"]
runAttempt = r["runAttempt"]
Expand Down
6 changes: 6 additions & 0 deletions benchmarks/dynamo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2848,6 +2848,12 @@ def run(runner, args, original_dir=None):
torch.use_deterministic_algorithms(True)
if args.only in {"hf_T5_generate"}:
torch._dynamo.config.automatic_dynamic_shapes = True
if args.only is not None and args.only.endswith("_generate"):
log.warning(
"Disabling cudagraphs for autoregressive generation (reenable if selective cudagraphs implemented)"
)
args.disable_cudagraphs = True
torch._inductor.config.triton.cudagraphs = False
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
Expand Down
9 changes: 2 additions & 7 deletions test/dynamo/test_aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@

import torch._dynamo
import torch._dynamo.test_case
from torch._dynamo.testing import (
CompileCounter,
expectedFailureAutomaticDynamic,
expectedFailureDynamic,
rand_strided,
)
from torch._dynamo.testing import CompileCounter, expectedFailureDynamic, rand_strided
from torch.testing._internal.common_utils import compare_equal_outs_and_grads


Expand Down Expand Up @@ -656,7 +651,7 @@ def guard_fail_fn(failure):
self.assertExpectedInline(failure_reason, """L['c'] is L['d']""")

@expectedFailureDynamic # https://github.com/pytorch/pytorch/issues/103539
@expectedFailureAutomaticDynamic # as above
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
@patch("torch._functorch.config.debug_assert", True)
def test_multiple_aot_autograd_calls_dupe_args(self):
# this is just dealing with the fact that
Expand Down
14 changes: 3 additions & 11 deletions test/dynamo/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,25 +29,18 @@
test_classes = {}


def make_dynamic_cls(cls, automatic_dynamic_shapes=False):
def make_dynamic_cls(cls):
suffix = "_dynamic_shapes"
if automatic_dynamic_shapes:
suffix = "_automatic_dynamic_shapes"

cls_prefix = "DynamicShapes"
if automatic_dynamic_shapes:
cls_prefix = "AutomaticDynamicShapes"

test_class = make_test_cls_with_patches(
cls,
cls_prefix,
suffix,
(config, "assume_static_by_default", automatic_dynamic_shapes),
(config, "automatic_dynamic_shapes", automatic_dynamic_shapes),
(config, "assume_static_by_default", False),
(config, "specialize_int", False),
xfail_prop="_expected_failure_automatic_dynamic"
if automatic_dynamic_shapes
else "_expected_failure_dynamic",
xfail_prop="_expected_failure_dynamic",
)

test_classes[test_class.__name__] = test_class
Expand All @@ -69,7 +62,6 @@ def make_dynamic_cls(cls, automatic_dynamic_shapes=False):
]
for test in tests:
make_dynamic_cls(test)
make_dynamic_cls(test, automatic_dynamic_shapes=True)

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
14 changes: 7 additions & 7 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def normalize_gm(gm_str):

def check_dynamic_shape_capture():
# This also mirrors config from `test/dynamo/test_dynamic_shapes.py:make_dynamic_cls`
if config.assume_static_by_default and config.automatic_dynamic_shapes:
return True
if not config.assume_static_by_default and not config.automatic_dynamic_shapes:
if not config.assume_static_by_default:
return True
return False

Expand Down Expand Up @@ -1464,7 +1462,10 @@ def fn(x):
if check_dynamic_shape_capture():
return

expected = """\
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor):
l_x_ = L_x_
Expand All @@ -1486,9 +1487,8 @@ def forward(self, l_x_, l_y_):
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
return sum_1
"""
actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
self.assertExpectedInline(actual, expected)
""",
)

def test_grad_closure_scalar(self):
counters.clear()
Expand Down
24 changes: 11 additions & 13 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from torch._dynamo.source import GetItemSource, LocalSource
from torch._dynamo.testing import (
CompileCounter,
expectedFailureAutomaticDynamic,
expectedFailureDynamic,
requires_numpy_pytorch_interop,
same,
Expand Down Expand Up @@ -5334,7 +5333,8 @@ def fn(a, b):
fn(torch.rand(2, 3), torch.rand(2, 3))
fn(torch.rand(2, 3), (1, 2, 3))

@expectedFailureAutomaticDynamic
@expectedFailureDynamic
@torch._dynamo.config.patch(automatic_dynamic_shapes=False)
def test_compile_profiler(self):
class Model(torch.nn.Module):
def forward(self, input):
Expand Down Expand Up @@ -5365,18 +5365,16 @@ def forward(self, input):
else:
base_checker().check("No recompilation detected.").run(prof.report())

# Ensure correct guard fail message is selected to show to user
if torch._dynamo.config.assume_static_by_default:
new_shape_input = torch.rand((4, 3, 4))
_ = compiled(new_shape_input)
new_shape_input = torch.rand((4, 3, 4))
_ = compiled(new_shape_input)

base_checker().check("Recompile Reasons").check("'forward'").check(
"tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
).check(
"tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
).run(
prof.report()
)
base_checker().check("Recompile Reasons").check("'forward'").check(
"tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
).check(
"tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
).run(
prof.report()
)

def test_guards_strip_function_call(self):
from torch._dynamo.guards import strip_function_call
Expand Down
3 changes: 1 addition & 2 deletions test/inductor/test_mkldnn_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
from torch._dynamo.test_case import run_tests, TestCase
from torch._dynamo.testing import expectedFailureDynamicWrapper
from torch._dynamo.utils import counters
from torch._inductor.utils import run_and_get_code
from torch.nn import functional as F
Expand Down Expand Up @@ -259,7 +258,6 @@ def forward(self, x):
)
self._test_common(mod, (v,), 1, match_nodes)

@expectedFailureDynamicWrapper
def test_linear_binary(self):
class M(torch.nn.Module):
def __init__(self, binary_fn, in_channels, out_channels, bias, **kwargs):
Expand All @@ -279,6 +277,7 @@ def forward(self, x, y):
out_feature = 30
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
for binary_fn, input_shape, bias in options:
torch._dynamo.reset()
mod = M(binary_fn, input_shape[-1], out_feature, bias).to(dtype).eval()
v = torch.randn(input_shape).to(dtype)
other = torch.randn(input_shape[:-1] + [out_feature]).to(dtype)
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_pattern_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def fn(a, b, c):
(4, torch.randn(16, 16, device="cuda"), torch.randn(16, 16, device="cuda")),
]
for args in args_list:
torch._dynamo.reset()
counters.clear()
e1, e2 = fn(*args)
a1, a2 = torch.compile(fn)(*args)
Expand Down
6 changes: 5 additions & 1 deletion torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
from . import external_utils


def is_fbcode():
return not hasattr(torch.version, "git_version")


# to configure logging for dynamo, aot, and inductor
# use the following API in the torch._logging module
# torch._logging.set_logs(dynamo=<level>, aot=<level>, inductor<level>)
Expand Down Expand Up @@ -64,7 +68,7 @@
# with assume_static_by_default=True.
# With this flag enabled, we always compile a frame as fully static for the first time, and, if we fail
# any guards due to wobbles in shape, we recompile with *all* the wobbled shapes as being marked dynamic.
automatic_dynamic_shapes = False
automatic_dynamic_shapes = not is_fbcode()

# Typically, if you mark_dynamic a dimension, we will error if the dimension
# actually ended up getting specialized. This knob changes the behavior so
Expand Down
6 changes: 0 additions & 6 deletions torch/_dynamo/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,12 +350,6 @@ def expectedFailureDynamic(fn):
return fn


# Controls tests generated in test/dynamo/test_dynamic_shapes.py
def expectedFailureAutomaticDynamic(fn):
fn._expected_failure_automatic_dynamic = True
return fn


# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
def expectedFailureCodegenDynamic(fn):
fn._expected_failure_codegen_dynamic = True
Expand Down
5 changes: 4 additions & 1 deletion torch/_inductor/compile_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,10 @@ def remove_unaligned_input_idxs(inputs, static_input_idxs):
that aren't.
"""
aligned_static_input_idxs = {
idx for idx in static_input_idxs if (inputs[idx].data_ptr() % ALIGNMENT) == 0
idx
for idx in static_input_idxs
if isinstance(inputs[idx], torch.Tensor)
and (inputs[idx].data_ptr() % ALIGNMENT) == 0
}
if len(aligned_static_input_idxs) != len(static_input_idxs):
return aligned_static_input_idxs
Expand Down
5 changes: 2 additions & 3 deletions torch/_inductor/cudagraph_trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ def __init__(
self._tensor_metadata(out, ignore_storage_offset=False)
)
else:
assert out is None
self.outputs_metadata.append(None)

self.graph.replay()
Expand Down Expand Up @@ -1075,7 +1074,7 @@ def _add_first_outputs(
self.static_output_tensors = [None for _ in range(len(outputs))]

for i, o in enumerate(outputs):
if o is None:
if o is None or not isinstance(o, torch.Tensor):
self.output_storage_alias.append(UnaliasedStorage)
continue

Expand Down Expand Up @@ -1120,7 +1119,7 @@ def _add_first_outputs(

assert not self.outputs_weakrefs
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
if out is None or static_output_tensor is not None:
if not isinstance(out, torch.Tensor) or static_output_tensor is not None:
self.outputs_weakrefs.append(None)
self.tensor_weakrefs.append(None)
else:
Expand Down

0 comments on commit 2385dad

Please sign in to comment.