Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stablehlo compiler #338

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

Stablehlo compiler #338

wants to merge 7 commits into from

Conversation

ddilbazTT
Copy link
Contributor

@ddilbazTT ddilbazTT commented Feb 21, 2025

Ticket

#203

Problem description

Currently, we only support compilation or execution of torch graphs. We do not support compilation of stablehlo graphs because this would fail at the first unsupported stablehlo op. The solution is to create a module for each op. This can be done through stablehlo support.

What's changed

Input\Compile Depth Compile Op by Op Execute Op by Op Execute Compile Stablehlo Op by Op Other
Stablehlo String _shlo_backend _shlo_backend _base_backend Overwrite to CompileDepth.COMPILE_OP_BY_OP + _shlo_backend No support
Torch Graph _torch_backend _torch_backend _base_backend torch_to_shlo + _shlo_backend _torch_backend
  • _base_backend --> _torch_backend
  • new _base_backend is for execution, for other compile depths using _torch_backend or _shlo_backend
  • divided tt_torch/dynamo/backend.py into tt_torch/dynamo/backend.py, tt_torch/dynamo/shlo_backend.py, tt_torch/dynamo/torch_backend.py
  • old tt_torch/dynamo/backend.py contents are mostly moved to tt_torch/dynamo/torch_backend.py, tt_torch/dynamo/executor.py
  • tt_torch/dynamo/executor.py provides base executor class to be used by both stablehlo and torch backends
  • env/activate edited to add support for dependencies

Testing strategy:

  • Please see commit history for how I'm testing stablehlo input. I will attach output log when finished.

mmanzoorTT and others added 2 commits February 21, 2025 01:22
Add COMPILE_STABLEHLO_OP_BY_OP CompileDepth. Allow compilation/
execution starting from stablehlo.
Copy link

TestsPassed ✅Skipped ⚠️Failed
TT-Torch Tests435 ran428 passed7 skipped0 failed
TestResult
No test annotations available

@codecov-commenter
Copy link

codecov-commenter commented Feb 21, 2025

❌ 4 Tests Failed:

Tests completed Failed Passed Skipped
428 4 424 7
View the top 3 failed test(s) by shortest run time
tests.torch.test_basic::test_multiple_users
Stack Traces | 0.017s run time
self = <torch._dynamo.output_graph.OutputGraph object at 0x7f29a4554a90>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
>           compiled_fn = compiler_fn(gm, self.example_inputs())

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1446: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../venv/lib/python3.11.../_dynamo/repro/after_dynamo.py:129: in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
.../venv/lib/python3.11.../site-packages/torch/__init__.py:2280: in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

gm_or_shlo = GraphModule()
example_inputs = [tensor([[ 0.0037, -0.2682,  0.4115,  ...,  0.1462,  0.2968,  0.1849],
        [ 0.4956, -0.2257,  0.2401,  ..., -0.16...116, -0.1209,  ...,  0.4374, -0.2197,  0.3502],
        [-0.3913,  0.1134,  0.4019,  ...,  0.4797,  0.0639, -0.2152]])]
options = <tt_torch.tools.utils.CompilerConfig object at 0x7f29a4554110>

    def backend(gm_or_shlo, example_inputs, options=None):
        if options is None:
            options = CompilerConfig()
        if (
            options.compile_depth == CompileDepth.COMPILE_OP_BY_OP
            or options.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
        ):
>           if options.op_by_op_backend == OpByOpBackend.TORCH:
E           NameError: name 'OpByOpBackend' is not defined

tt_torch/dynamo/backend.py:160: NameError

The above exception was the direct cause of the following exception:

    def test_multiple_users():
        class Basic(nn.Module):
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                x2 = x + x  # add op
                y1 = x2 + x  # user 1 of add op
                y2 = x2 + x  # user 2 of add op
                z = y1 + y2
                return z
    
        cc = CompilerConfig()
        cc.compile_depth = tt_torch.tools.utils.CompileDepth.EXECUTE_OP_BY_OP
>       verify_module(
            Basic(), input_shapes=[(256, 256)], compiler_config=cc, do_assert=False
        )

tests/torch/test_basic.py:486: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tt_torch/tools/verify.py:252: in verify_module
    _verify_torch_module(
tt_torch/tools/verify.py:146: in _verify_torch_module
    ret = tt_mod(*inputs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/eval_frame.py:465: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1269: in __call__
    return self._torchdynamo_orig_callable(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1064: in __call__
    result = self._inner_convert(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:526: in __call__
    return _compile(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:924: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:666: in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../site-packages/torch/_utils_internal.py:87: in wrapper_function
    return function(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:699: in _compile_inner
    out_code = transform_code_object(code, transform)
.../venv/lib/python3.11.../torch/_dynamo/bytecode_transformation.py:1322: in transform_code_object
    transformations(instructions, code_options)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:219: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:634: in transform
    tracer.run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2796: in run
    super().run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:983: in run
    while self.step():
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:895: in step
    self.dispatch_table[inst.opcode](self, inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2987: in RETURN_VALUE
    self._return(inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2972: in _return
    self.output.compile_subgraph(
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1117: in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1369: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1416: in call_user_compiler
    return self._call_user_compiler(gm)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f29a4554a90>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
            compiled_fn = compiler_fn(gm, self.example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except exceptions_allowed_to_be_fallback as e:
            if self.has_user_defined_allowed_in_graph:
                raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
                    e.__traceback__
                ) from None
            msg = (
                "Backend compiler failed with a fake tensor exception at \n"
                f"{self.root_tx.format_frame_summary()}"
                "Adding a graph break."
            )
            unimplemented_with_warning(e, self.root_tx.f_code, msg)
        except SkipFrame as e:
            # The backend compiler has requested that we skip the frame, instead of
            # aborting execution.
            raise e
        except Exception as e:
>           raise BackendCompilerFailed(self.compiler_fn, e) from e
E           torch._dynamo.exc.BackendCompilerFailed: backend='backend' raised:
E           NameError: name 'OpByOpBackend' is not defined
E           
E           Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1465: BackendCompilerFailed
tests.torch.test_maxpool2d::test_maxpool2d
Stack Traces | 0.019s run time
self = <torch._dynamo.output_graph.OutputGraph object at 0x7f29a43ef090>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
>           compiled_fn = compiler_fn(gm, self.example_inputs())

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1446: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../venv/lib/python3.11.../_dynamo/repro/after_dynamo.py:129: in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
.../venv/lib/python3.11.../site-packages/torch/__init__.py:2280: in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

gm_or_shlo = GraphModule()
example_inputs = [tensor([[[[-1.1250, -1.1562, -0.2500,  ..., -0.0334, -1.0625, -0.1143],
          [-0.3438,  1.5703,  0.1914,  ..., -...895,  0.2061],
          [ 0.2207, -0.9023,  0.3984,  ...,  1.6953, -0.5156,  0.9375]]]],
       dtype=torch.bfloat16)]
options = <tt_torch.tools.utils.CompilerConfig object at 0x7f298075bd10>

    def backend(gm_or_shlo, example_inputs, options=None):
        if options is None:
            options = CompilerConfig()
        if (
            options.compile_depth == CompileDepth.COMPILE_OP_BY_OP
            or options.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
        ):
>           if options.op_by_op_backend == OpByOpBackend.TORCH:
E           NameError: name 'OpByOpBackend' is not defined

tt_torch/dynamo/backend.py:160: NameError

The above exception was the direct cause of the following exception:

    def test_maxpool2d():
        class Basic(nn.Module):
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2)
    
        cc = CompilerConfig()
        cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP
>       verify_module(
            Basic(),
            inputs=[torch.randn(1, 1, 224, 224).to(torch.bfloat16)],
            compiler_config=cc,
        )

tests/torch/test_maxpool2d.py:23: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tt_torch/tools/verify.py:252: in verify_module
    _verify_torch_module(
tt_torch/tools/verify.py:146: in _verify_torch_module
    ret = tt_mod(*inputs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/eval_frame.py:465: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1269: in __call__
    return self._torchdynamo_orig_callable(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1064: in __call__
    result = self._inner_convert(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:526: in __call__
    return _compile(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:924: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:666: in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../site-packages/torch/_utils_internal.py:87: in wrapper_function
    return function(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:699: in _compile_inner
    out_code = transform_code_object(code, transform)
.../venv/lib/python3.11.../torch/_dynamo/bytecode_transformation.py:1322: in transform_code_object
    transformations(instructions, code_options)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:219: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:634: in transform
    tracer.run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2796: in run
    super().run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:983: in run
    while self.step():
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:895: in step
    self.dispatch_table[inst.opcode](self, inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2987: in RETURN_VALUE
    self._return(inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2972: in _return
    self.output.compile_subgraph(
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1117: in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1369: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1416: in call_user_compiler
    return self._call_user_compiler(gm)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f29a43ef090>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
            compiled_fn = compiler_fn(gm, self.example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except exceptions_allowed_to_be_fallback as e:
            if self.has_user_defined_allowed_in_graph:
                raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
                    e.__traceback__
                ) from None
            msg = (
                "Backend compiler failed with a fake tensor exception at \n"
                f"{self.root_tx.format_frame_summary()}"
                "Adding a graph break."
            )
            unimplemented_with_warning(e, self.root_tx.f_code, msg)
        except SkipFrame as e:
            # The backend compiler has requested that we skip the frame, instead of
            # aborting execution.
            raise e
        except Exception as e:
>           raise BackendCompilerFailed(self.compiler_fn, e) from e
E           torch._dynamo.exc.BackendCompilerFailed: backend='backend' raised:
E           NameError: name 'OpByOpBackend' is not defined
E           
E           Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1465: BackendCompilerFailed
tests.torch.test_basic::test_unused_output
Stack Traces | 0.02s run time
self = <torch._dynamo.output_graph.OutputGraph object at 0x7f2978533290>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
>           compiled_fn = compiler_fn(gm, self.example_inputs())

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1446: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
.../venv/lib/python3.11.../_dynamo/repro/after_dynamo.py:129: in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
.../venv/lib/python3.11.../site-packages/torch/__init__.py:2280: in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

gm_or_shlo = GraphModule()
example_inputs = [tensor([[ 0.0037, -0.2682,  0.4115,  ...,  0.1462,  0.2968,  0.1849],
        [ 0.4956, -0.2257,  0.2401,  ..., -0.16...116, -0.1209,  ...,  0.4374, -0.2197,  0.3502],
        [-0.3913,  0.1134,  0.4019,  ...,  0.4797,  0.0639, -0.2152]])]
options = <tt_torch.tools.utils.CompilerConfig object at 0x7f2978532290>

    def backend(gm_or_shlo, example_inputs, options=None):
        if options is None:
            options = CompilerConfig()
        if (
            options.compile_depth == CompileDepth.COMPILE_OP_BY_OP
            or options.compile_depth == CompileDepth.EXECUTE_OP_BY_OP
        ):
>           if options.op_by_op_backend == OpByOpBackend.TORCH:
E           NameError: name 'OpByOpBackend' is not defined

tt_torch/dynamo/backend.py:160: NameError

The above exception was the direct cause of the following exception:

    def test_unused_output():
        class Basic_var_only(nn.Module):
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                var, mean = torch.var_mean(x)
                return var
    
        class Basic_mean_only(nn.Module):
            def __init__(self):
                super().__init__()
    
            def forward(self, x):
                var, mean = torch.var_mean(x)
                return mean
    
        for module in [Basic_var_only, Basic_mean_only]:
            cc = CompilerConfig()
            cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP
>           verify_module(module(), input_shapes=[(256, 256)], compiler_config=cc)

tests/torch/test_basic.py:469: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
tt_torch/tools/verify.py:252: in verify_module
    _verify_torch_module(
tt_torch/tools/verify.py:146: in _verify_torch_module
    ret = tt_mod(*inputs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/eval_frame.py:465: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1736: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
.../venv/lib/python3.11.../nn/modules/module.py:1747: in _call_impl
    return forward_call(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1269: in __call__
    return self._torchdynamo_orig_callable(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:1064: in __call__
    result = self._inner_convert(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:526: in __call__
    return _compile(
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:924: in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:666: in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
.../venv/lib/python3.11.../site-packages/torch/_utils_internal.py:87: in wrapper_function
    return function(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:699: in _compile_inner
    out_code = transform_code_object(code, transform)
.../venv/lib/python3.11.../torch/_dynamo/bytecode_transformation.py:1322: in transform_code_object
    transformations(instructions, code_options)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:219: in _fn
    return fn(*args, **kwargs)
.../venv/lib/python3.11.../torch/_dynamo/convert_frame.py:634: in transform
    tracer.run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2796: in run
    super().run()
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:983: in run
    while self.step():
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:895: in step
    self.dispatch_table[inst.opcode](self, inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2987: in RETURN_VALUE
    self._return(inst)
.../venv/lib/python3.11.../torch/_dynamo/symbolic_convert.py:2972: in _return
    self.output.compile_subgraph(
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1117: in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1369: in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1416: in call_user_compiler
    return self._call_user_compiler(gm)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <torch._dynamo.output_graph.OutputGraph object at 0x7f2978533290>
gm = GraphModule()

    def _call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn:
        assert self.compiler_fn is not None
        tot = 0
        placeholders = []
        for node in gm.graph.nodes:
            if node.op in ("call_function", "call_method", "call_module"):
                tot += 1
            if node.op == "placeholder":
                placeholders.append(node)
        increment_op_count(tot)
        for pl in placeholders:
            arg = pl.meta["grapharg"]
            # TODO: Why isn't this stored in meta :think:
            pl._dynamo_source = arg.source
    
        gm._param_name_to_source = self.param_name_to_source  # type: ignore[assignment]
        gm._source_to_user_stacks = self.source_to_user_stacks  # type: ignore[assignment]
    
        try:
            name = (
                self.compiler_fn.__name__
                if hasattr(self.compiler_fn, "__name__")
                else ""
            )
            _step_logger()(logging.INFO, f"calling compiler function {name}")
            compiler_fn = self.compiler_fn
            if config.verify_correctness:
                compiler_fn = WrapperBackend(compiler_fn)
            compiled_fn = compiler_fn(gm, self.example_inputs())
            _step_logger()(logging.INFO, f"done compiler function {name}")
            assert callable(compiled_fn), "compiler_fn did not return callable"
        except exceptions_allowed_to_be_fallback as e:
            if self.has_user_defined_allowed_in_graph:
                raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
                    e.__traceback__
                ) from None
            msg = (
                "Backend compiler failed with a fake tensor exception at \n"
                f"{self.root_tx.format_frame_summary()}"
                "Adding a graph break."
            )
            unimplemented_with_warning(e, self.root_tx.f_code, msg)
        except SkipFrame as e:
            # The backend compiler has requested that we skip the frame, instead of
            # aborting execution.
            raise e
        except Exception as e:
>           raise BackendCompilerFailed(self.compiler_fn, e) from e
E           torch._dynamo.exc.BackendCompilerFailed: backend='backend' raised:
E           NameError: name 'OpByOpBackend' is not defined
E           
E           Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
E           
E           
E           You can suppress this exception and fall back to eager by setting:
E               import torch._dynamo
E               torch._dynamo.config.suppress_errors = True

.../venv/lib/python3.11.../torch/_dynamo/output_graph.py:1465: BackendCompilerFailed

To view more test analytics, go to the Test Analytics Dashboard
📋 Got 3 mins? Take this short survey to help us improve Test Analytics.

env/activate Outdated
fi
pip install $TT_TORCH_HOME/dist/torchvision*.whl
pip install --pre torch-mlir torchvision
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's torchvision for here?

@@ -275,7 +276,7 @@ def post_init(self):
else:
torch._dynamo.config.inline_inbuilt_nn_modules = True

def save_unique_ops(self):
def save_unique_ops(self, mode=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps default to mode="torch" as opposed to None.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could do that. Just want to confirm, by default we want the files to be named {self.results_path}{pytest_test}_torch_unique_ops.json or {self.results_path}{pytest_test}_unique_ops.json then? Right now, the way it is:
default --> {self.results_path}{pytest_test}_unique_ops.json
torch --> {self.results_path}{pytest_test}_torch_unique_ops.json
stablehlo --> {self.results_path}{pytest_test}_stablehlo_unique_ops.json

executor.add_gm(gm, graph_constants)
return executor


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is duplicated. Can you move it into a helper function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my latest commit should address this. By complicated, I assumed you meant the isinstance checks. I agree. I now check and assign parsed_module when initializing the Executor object.

self.gm = gm
self.graph_constants = tuple(graph_constants)

def gm_op_by_op(self, *inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe shlo_op_by_op

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is actually the same function from torch_backend, but stripped. This is actually going through a torch graph op by op, and only called during COMPILE_STABLEHLO_OP_BY_OP, since it's implied the input is a torch graph. Doing so, we create 2 json dumps - one for stablehlo and one for torch, during one run.

# No conversion required.
new_inputs = new_inputs + ((input),)
inputs = new_inputs
if self.compiler_config.compile_depth == CompileDepth.EXECUTE:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this handled by base backed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is base_backend, which returns executor. This is in parallel with torch executor.

def _base_backend(gm_or_shlo, example_inputs, compiler_config):
    # Called during EXECUTE
    # input is a torch graph
    if isinstance(gm_or_shlo, torch.fx.GraphModule):
        shlo, executor, gm, graph_constants = torch_to_shlo(
            gm_or_shlo, example_inputs, compiler_config
        )
    # input is a stablehlo string module
    elif isinstance(gm_or_shlo, str):
        shlo = parse_module_from_str(gm_or_shlo)
        executor = StablehloExecutor(
            parsed_module=shlo, compiler_config=compiler_config
        )
    else:
        assert False, "Compiler input not valid"

    binary = shlo_to_flatbuffer(shlo, compiler_config)
    executor.set_binary(binary)
    return executor

Copy link

TestsPassed ☑️Skipped ⚠️Failed ❌️
TT-Torch Tests435 ran424 passed7 skipped4 failed
TestResult
TT-Torch Tests
pytest
test_basic.test_multiple_ops❌ failure
test_basic.test_unused_output❌ failure
test_basic.test_multiple_users❌ failure
test_maxpool2d❌ failure

Copy link

TestsPassedSkippedFailed ❌️
TT-Torch Tests9 ran0 passed0 skipped9 failed
TestResult
TT-Torch Tests
pytest
test_basic.tests.torch.test_basic❌ failure
test_compare.tests.torch.test_compare❌ failure
test_constant_fold.tests.torch.test_constant_fold❌ failure
test_conv2d.tests.torch.test_conv2d❌ failure
test_interpolation.tests.torch.test_interpolation❌ failure
test_logical.tests.torch.test_logical❌ failure
test_maxpool2d.tests.torch.test_maxpool2d❌ failure
test_reduction.tests.torch.test_reduction❌ failure
test_softmax.tests.torch.test_softmax❌ failure

Copy link

TestsPassedSkippedFailed ❌️
TT-Torch Tests9 ran0 passed0 skipped9 failed
TestResult
TT-Torch Tests
pytest
test_basic.tests.torch.test_basic❌ failure
test_compare.tests.torch.test_compare❌ failure
test_constant_fold.tests.torch.test_constant_fold❌ failure
test_conv2d.tests.torch.test_conv2d❌ failure
test_interpolation.tests.torch.test_interpolation❌ failure
test_logical.tests.torch.test_logical❌ failure
test_maxpool2d.tests.torch.test_maxpool2d❌ failure
test_reduction.tests.torch.test_reduction❌ failure
test_softmax.tests.torch.test_softmax❌ failure

Copy link

TestsPassed ☑️Skipped ⚠️Failed ❌️
TT-Torch Tests435 ran424 passed7 skipped4 failed
TestResult
TT-Torch Tests
pytest
test_basic.test_multiple_ops❌ failure
test_basic.test_unused_output❌ failure
test_basic.test_multiple_users❌ failure
test_maxpool2d❌ failure

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants