Skip to content

feat: Improve Dynamo partitioning System Performance on Large Models #2175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
OPTIMIZATION_LEVEL = None
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
Expand All @@ -29,3 +30,4 @@ class CompilationSettings:
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
65 changes: 56 additions & 9 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import torch
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs

Expand Down Expand Up @@ -109,24 +108,68 @@ def _compile_module(
Returns:
Compiled FX GraphModule
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
gm, settings.debug, settings.torch_executed_ops
)

# If the number of supported operations is 0 or less than the block size, skip the subgraph
# TODO: Add condition to second expression below when require_full_compilation is added
if num_supported_ops == 0 or (num_supported_ops < settings.min_block_size):
logger.warning(
f"{num_supported_ops} supported operations detected in subgraph containing {total_ops} computational nodes. "
f"Skipping this subgraph, since min_block_size was detected to be {settings.min_block_size}"
)
return gm
else:
logger.debug(
f"Detected support for {num_supported_ops} operators out of {total_ops} in subgraph."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False

# If specified, try using the fast partitioner and fall back to the global one on failure
if settings.use_fast_partitioner:
try:
partitioned_module = partitioning.fast_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
+ "Retrying with global partition.",
exc_info=True,
)

fast_partitioner_failed = True
settings.use_fast_partitioner = False

if not settings.use_fast_partitioner:
partitioned_module = partitioning.global_partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
# Criteria for a module to be convertible to TRT
if settings.use_fast_partitioner and "_run_on_acc" not in name:
continue

submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = get_submod_inputs(
submodule_inputs = partitioning.get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

Expand All @@ -151,4 +194,8 @@ def _compile_module(
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

# Reset settings object to user specification after fallback to global partitioning mode
if fast_partitioner_failed:
settings.use_fast_partitioner = True

return partitioned_module
54 changes: 5 additions & 49 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
import torch_tensorrt
from torch.fx.passes.pass_manager import PassManager
from torch.fx.passes.splitter_base import SplitResult
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
EngineCapability,
Expand All @@ -19,18 +18,17 @@
PASS_THROUGH_BUILD_FAILURES,
PRECISION,
TRUNCATE_LONG_AND_DOUBLE,
USE_FAST_PARTITIONER,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo.backend.backends import _compile_module
from torch_tensorrt.dynamo.conversion import convert_module
from torch_tensorrt.dynamo.lowering._fusers import (
fuse_permute_linear,
fuse_permute_matmul,
)
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,6 +60,7 @@ def compile(
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
Expand All @@ -73,7 +72,7 @@ def compile(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -113,55 +112,12 @@ def compile(
"optimization_level": optimization_level,
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
}

settings = CompilationSettings(**compilation_options)
if kwargs.get("use_capability_partitioner", None):
model = lower_model(gm, torch_inputs)
return _compile_module(model, torch_inputs, settings)
else:
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
trt_module = _compile_graph(split_result, torch_inputs, settings)

return trt_module


def _compile_graph(
split_result: SplitResult,
inputs: Any,
settings: CompilationSettings = CompilationSettings(),
**kwargs: Any,
) -> torch.fx.GraphModule:
for submod_name, submod_inputs in split_result.submodule_inputs.items():
submod = getattr(split_result.split_module, submod_name)
# Only acc submodules will be lowered.
if not submod_name.startswith(split_result.non_acc_submodule_prefix):
# Create TRT Module from submodule
trt_mod = convert_module(
submod,
submod_inputs,
settings=settings,
name=submod_name,
)
setattr(split_result.split_module, submod_name, trt_mod)

return split_result.split_module


def lower_model_using_trt_splitter(
model: torch.nn.Module, inputs: Any, **kwargs: Any
) -> SplitResult:
# Perform basic lowering
model = lower_model(model, inputs)
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter_setting.min_acc_module_size = 1
splitter_setting.use_experimental_rt = False
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
split_result = splitter.generate_split_results()

return split_result
return _compile_module(gm, torch_inputs, settings)


def lower_model(
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/converter_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]:
"""Returns the set of unique converter targets stored across all registries"""
return set.union(*[set(registry.keys()) for registry in self.registries])

# TODO: Make this a static method since it does not need state
def qualified_name_or_str(self, target: Target) -> str:
@staticmethod
def qualified_name_or_str(target: Target) -> str:
"""Returns string representation of an FX Node target"""
if isinstance(target, str):
return target
Expand Down
9 changes: 2 additions & 7 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F403
from ._partition import ( # noqa: F401
DEFAULT_SINGLE_NODE_PARTITIONS,
get_submod_inputs,
partition,
)
from ._fusers import * # noqa: F401
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
from ._pre_aot_lowering import register_substitution # noqa: F401
from .substitutions import * # noqa: F403
from .substitutions import * # noqa: F401
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/partitioning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._adjacency_partitioner import partition as fast_partition
from ._global_partitioner import partition as global_partition
from .common import get_graph_converter_support, get_submod_inputs
Loading