Skip to content

Commit

Permalink
[inductor] Minor fixes to various tests before enabling fx graph cach…
Browse files Browse the repository at this point in the history
…ing in OSS by default (pytorch#125258)

Summary: Discovered breakages by enabling codecache by default and doing a CI run. I'll commit these fixes first and eventually enabling caching by default will (hopefully) be a one-liner.

Pull Request resolved: pytorch#125258
Approved by: https://github.com/eellison
  • Loading branch information
masnesral authored and petrex committed May 3, 2024
1 parent 713c44d commit 4cdd82b
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 6 deletions.
6 changes: 3 additions & 3 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@

import numpy as np
import torch

import torch._dynamo.test_case
import torch._dynamo.testing

import torch._inductor.test_case
import torch.onnx.operators

import torch.utils._pytree as pytree
Expand Down Expand Up @@ -151,7 +151,7 @@ def __getattr__(self, key):
return self.__dict__[f"pfx_{key}"]


class MiscTests(torch._dynamo.test_case.TestCase):
class MiscTests(torch._inductor.test_case.TestCase):
def test_get_cache_entry(self):
def f(x):
return x + 1
Expand Down
5 changes: 3 additions & 2 deletions test/dynamo/test_structured_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import torch._logging.structured
import torch.distributed as dist

from torch._inductor.test_case import TestCase

from torch._logging._internal import TorchLogsFormatter
from torch.nn.parallel import DistributedDataParallel as DDP

from torch.testing._internal.common_utils import find_free_port, TestCase
from torch.testing._internal.common_utils import find_free_port
from torch.testing._internal.inductor_utils import HAS_CUDA

requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
Expand Down
12 changes: 11 additions & 1 deletion torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,14 @@ def _reduce_symint(s):
return (_ident, (str(s),))


def _reduce_unsupported(s):
"""
See FxGraphCachePickler. Custom reducer to handle any objects that we don't
support and therefore raise to bypass caching.
"""
raise BypassFxGraphCache


class FxGraphCachePickler(pickle.Pickler):
"""
Custom pickler to customize the pickling of some objects (Tensors), only for the
Expand All @@ -494,6 +502,9 @@ class FxGraphCachePickler(pickle.Pickler):
dispatch_table[FakeTensor] = _reduce_fake_tensor
dispatch_table[torch.Tensor] = _reduce_tensor
dispatch_table[torch.SymInt] = _reduce_symint
dispatch_table[
torch.fx.experimental._backward_state.BackwardState
] = _reduce_unsupported

@classmethod
def dumps(cls, obj) -> bytes:
Expand Down Expand Up @@ -893,7 +904,6 @@ def load(
Load a compiled graph from the cache. If a cached entry does not exist,
compile the graph and save it to the cache.
"""

compiled_graph = None
try:
FxGraphCache._check_can_cache(gm)
Expand Down
1 change: 1 addition & 0 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,7 @@ def check_equal(self, other: "ShapeEnv") -> None:
"source_name_to_debug_name",
"_prev_cache_key",
"_version_counter",
"dim_constraints",
)

# Mapping of the value of each to-be-compared field into the values that
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch._logging
import torch._logging._internal
from torch._dynamo.utils import LazyString
from torch._inductor import config as inductor_config
import logging
import io

Expand Down Expand Up @@ -74,6 +75,7 @@ def append_setting(name, level):
# that the logs are setup correctly and capturing the correct records.
def make_logging_test(**kwargs):
def wrapper(fn):
@inductor_config.patch({"fx_graph_cache": False})
def test_fn(self):

torch._dynamo.reset()
Expand Down

0 comments on commit 4cdd82b

Please sign in to comment.