Skip to content

feat: Module-Acceleration in Dynamo [5 / x] #1979

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _find_lib(name, paths):

from torch_tensorrt import fx

if version.parse(torch.__version__) >= version.parse("2.dev"):
if version.parse(torch.__version__) >= version.parse("2.1.dev"):
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import backend

Expand Down
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._pre_aot_lowering import (
pre_aot_substitutions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
Expand Down Expand Up @@ -41,6 +44,9 @@ def aot_torch_tensorrt_aten_backend(
settings=settings,
)

# Perform Pre-AOT Lowering for Module-Level Replacement
gm = pre_aot_substitutions(gm)

# Invoke AOTAutograd to translate operators to aten
return aot_module_simplified(
gm,
Expand All @@ -65,6 +71,8 @@ def _pretraced_backend(
Compiled FX GraphModule
"""
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
gm,
sample_inputs,
Expand Down
10 changes: 6 additions & 4 deletions py/torch_tensorrt/dynamo/backend/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
from ._decompositions import (
get_decompositions,
)
from torch_tensorrt.dynamo.backend.lowering._partition import (
partition,
get_submod_inputs,
from ._pre_aot_lowering import (
SUBSTITUTION_REGISTRY,
register_substitution,
)
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
from .substitutions import *
12 changes: 10 additions & 2 deletions py/torch_tensorrt/dynamo/backend/lowering/_partition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
from typing import Dict, List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Set

import torch

from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE
from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition
from torch.fx.graph_module import GraphModule
from torch.fx.node import _get_qualified_name
Expand All @@ -14,6 +15,11 @@

logger = logging.getLogger(__name__)

DEFAULT_SINGLE_NODE_PARTITIONS: Set[str] = set(
_get_qualified_name(to_replace.new_operator)
for to_replace in SUBSTITUTION_REGISTRY.values()
)


class TRTPartitioner(CapabilityBasedPartitioner):
"""Partitioner to split an FX graph into subgraphs based on operator support
Expand All @@ -35,7 +41,9 @@ def __init__(
operator_support: OperatorSupport,
*,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[
Sequence[str]
] = DEFAULT_SINGLE_NODE_PARTITIONS,
min_block_size=MIN_BLOCK_SIZE,
) -> None:
super().__init__(
Expand Down
139 changes: 139 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/_pre_aot_lowering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Type, Union
import torch
import logging


logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class Substitution:
"""Class to store key functionality for module replacement"""

# torch.ops.___ name for replacement function for module
new_operator: torch._ops.OpOverload

# Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
# and returning a replacement node, with type 'call_function', or raising an Error if
# incompatibility is detected
# Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
subgraph_insertion_fn: Callable[
[torch.fx.GraphModule, torch.fx.Node, Optional[torch.nn.Module]], torch.fx.Node
]


# Dictionary mapping module to Substitution instance
SUBSTITUTION_REGISTRY: Dict[
Union[Type[torch.nn.Module], Callable], Substitution
] = dict()


def register_substitution(
module_or_function_to_replace: Union[Type[torch.nn.Module], Callable],
new_operator: torch._ops.OpOverload,
enabled: bool = True,
) -> Callable[[Any], Any]:
"""Decorator to register subgraph insertion functions

Args:
module_or_function_to_replace: nn.Module or node target Callable to replace
new_operator: Custom torch operator to replace with
enabled: Whether the substitution is enabled or disabled
Returns:
torch.fx.GraphModule
"""

def enable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is enabled"""
replacement = Substitution(
new_operator=new_operator, subgraph_insertion_fn=subgraph_insertion_fn
)
SUBSTITUTION_REGISTRY[module_or_function_to_replace] = replacement
return subgraph_insertion_fn

def disable_substitution(subgraph_insertion_fn):
"""Function for use if substitution is disabled"""
return subgraph_insertion_fn

return enable_substitution if enabled else disable_substitution


def pre_aot_substitutions(gm: torch.fx.GraphModule):
"""Perform graph substitutions prior to AOT tracing

Args:
gm: FX GraphModule to perform substitution on
Returns:
torch.fx.GraphModule

"""
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))

# Ensure all parameters are in inference mode
for param in gm.parameters():
param.requires_grad = False

# Iterate over graph nodes, extracting module calls, to check for interceptions
for n in gm.graph.nodes:
exists_in_registry = False
to_replace = None

if n.op == "call_module":
# Extract submodule from graph, validate in registry
submodule = gm.get_submodule(n.target)
to_replace = type(submodule)
exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
elif n.op == "call_function":
# Extract function from graph, validate in registry
to_replace = n.target
exists_in_registry = n.target in SUBSTITUTION_REGISTRY

# If submodule/function is a member of the substitution registry, replace it
if exists_in_registry:
try:
replacement = SUBSTITUTION_REGISTRY[to_replace]
op, insertion_fn = (
replacement.new_operator,
replacement.subgraph_insertion_fn,
)
logger.debug(f"Replacing node of type {to_replace} with {op}")

# Insert new node prior to older node
with gm.graph.inserting_before(n):
new_node = insertion_fn(
gm, n, submodule if n.op == "call_module" else None
)

# If submodule is not a native torch.nn module, it must be manually excluded
# from Dynamo tracing
if n.op == "call_module" and not type(submodule).__module__.startswith(
"torch.nn"
):
torch._dynamo.allowed_functions._allowed_function_ids.add(
id(to_replace)
)

# Replace all original node uses and clean up graph
n.replace_all_uses_with(new_node)
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

# A replacement can fail in the event that the specific instance of the submodule/function
# cannot be replaced
except Exception:
logger.debug(
f"Encountered error while replacing {to_replace}",
exc_info=True,
)
continue

# Perform cleanup and recompilation before returning module
gm.graph.eliminate_dead_code()
gm.graph.lint()
gm.recompile()

logger.debug("Post-module replacement graph:\n" + str(gm.graph))

return gm
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .maxpool1d import *
from .einsum import *
80 changes: 80 additions & 0 deletions py/torch_tensorrt/dynamo/backend/lowering/substitutions/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Dict, Tuple
import torch
from torch._custom_op.impl import custom_op
from torch.fx.node import Argument, Target

from torch_tensorrt.fx.converter_registry import tensorrt_converter
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor

from torch_tensorrt.dynamo.backend.lowering import register_substitution


@custom_op(
qualname="tensorrt::einsum",
manual_schema="(str equation, Tensor[] tensors) -> Tensor",
)
def einsum(equation, tensors):
# Defines operator schema, name, namespace, and function header
...


@einsum.impl("cpu")
@einsum.impl("cuda")
@einsum.impl_abstract()
def einsum_generic(
*args,
**kwargs,
):
# Defines a converter implementation for AOT Autograd to use for shape analysis/propagation
return torch.einsum(
*args,
**kwargs,
)


@tensorrt_converter(torch.ops.tensorrt.einsum.default)
def aten_ops_einsum(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> TRTTensor:
# Defines converter replacing the default operator for this function
for input_trt in args[1]:
if not isinstance(input_trt, TRTTensor):
raise RuntimeError(f"Einsum received non-TRTTensor input: {input_trt}")

einsum_layer = network.add_einsum(inputs=args[1], equation=args[0])

set_layer_name(einsum_layer, target, name)
return einsum_layer.get_output(0)


@register_substitution(torch.einsum, torch.ops.tensorrt.einsum)
def einsum_insertion_fn(
gm: torch.fx.GraphModule,
node: torch.fx.Node,
_unused: None = None,
) -> torch.fx.Node:
equation = node.args[0]

# Ensure inputs is a list of (Tensor) arguments
if isinstance(node.args[1], (tuple, list)):
inputs = node.args[1]
else:
inputs = node.args[1:]

assert (
1 <= len(inputs) <= 2
), f"TRT Einsum currently only supports 1 or 2 Tensors, got {len(inputs)} Tensors"

# Ensure the input is formatted as an equation and
new_node = gm.graph.call_function(
torch.ops.tensorrt.einsum,
args=(equation, inputs),
kwargs=node.kwargs,
)

return new_node
Loading