Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support weight-stripped engine and REFIT_IDENTICAL flag #3167

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
40349a8
support weight-stripped engine and REFIT_IDENTICAL flag
zewenli98 Sep 19, 2024
5d7c677
refactor with new design
zewenli98 Sep 20, 2024
82b7ddc
lint
zewenli98 Oct 1, 2024
9f6a771
samll fix
zewenli98 Oct 1, 2024
7ea3c0f
remove make_refittable
zewenli98 Oct 1, 2024
bf7553b
fast refit -> slow refit
zewenli98 Oct 2, 2024
46e9bc8
fix np.bool_, group_norm
zewenli98 Oct 2, 2024
d783fdd
add immutable_weights
zewenli98 Oct 2, 2024
160588e
skip engine caching for non-refittable engines, slow refit -> fast refit
zewenli98 Oct 2, 2024
493f981
refactored, there are 3 types of engines
zewenli98 Oct 5, 2024
f204104
fix and add tests
zewenli98 Oct 5, 2024
4663c83
fix issues #3206 #3217
zewenli98 Oct 8, 2024
c57ab06
small fix
zewenli98 Oct 15, 2024
402c9b0
resolve comments
zewenli98 Oct 15, 2024
d8e59da
WIP: cache weight-stripped engine
zewenli98 Oct 22, 2024
e8811fd
Merge branch 'main' into weight_stripped_engine
zewenli98 Oct 31, 2024
f2e3f00
redesigned hash func and add constant mapping to fast refit
zewenli98 Nov 4, 2024
31af308
refactor and add tests
zewenli98 Nov 6, 2024
1ae33f4
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 6, 2024
90bf679
update
zewenli98 Nov 6, 2024
a8a34f6
increase ENGINE_CACHE_SIZE
zewenli98 Nov 6, 2024
285bc90
skip some tests
zewenli98 Nov 7, 2024
2d152cf
fix tests
zewenli98 Nov 7, 2024
d461608
try fixing cumsum
zewenli98 Nov 8, 2024
d57b885
Merge branch 'main' into weight_stripped_engine
zewenli98 Nov 8, 2024
23d68d5
fix windows cross compile, TODO: whether windows support stripping en…
zewenli98 Nov 8, 2024
a928f67
CI debug test 1
zewenli98 Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
5 changes: 0 additions & 5 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -97,7 +95,6 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -157,7 +154,6 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refittable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -268,7 +264,6 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
2 changes: 0 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refittable": True,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -80,7 +79,6 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refittable": True,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
7 changes: 1 addition & 6 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,7 @@
# Make a refittable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
# The inital step is to compile a module and save it as with a normal.
#
# In this case we are going to compile a ResNet18 model with randomly initialized weights and save it.

Expand All @@ -69,8 +66,6 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refittable=True,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _from(
return dtype.f32
elif t == np.float64:
return dtype.f64
elif t == np.bool:
elif t == np.bool_:
return dtype.b
# TODO: Consider using ml_dtypes when issues like this are resolved:
# https://github.com/pytorch/pytorch/issues/109873
Expand Down Expand Up @@ -1384,7 +1384,7 @@ def current_platform(cls) -> Platform:
def __str__(self) -> str:
return str(self.name)

@needs_torch_tensorrt_runtime
@needs_torch_tensorrt_runtime # type: ignore
def _to_serialized_rt_platform(self) -> str:
val: str = torch.ops.tensorrt._platform_unknown()

Expand Down
68 changes: 45 additions & 23 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def cross_compile_for_windows(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -93,6 +92,7 @@ def cross_compile_for_windows(
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -132,7 +132,6 @@ def cross_compile_for_windows(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -164,6 +163,7 @@ def cross_compile_for_windows(
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Returns:
Expand Down Expand Up @@ -193,14 +193,17 @@ def cross_compile_for_windows(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refittable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -275,7 +278,6 @@ def cross_compile_for_windows(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand All @@ -286,6 +288,7 @@ def cross_compile_for_windows(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
}
Expand Down Expand Up @@ -342,7 +345,6 @@ def compile(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -372,6 +374,9 @@ def compile(
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> torch.fx.GraphModule:
Expand Down Expand Up @@ -413,7 +418,6 @@ def compile(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -445,6 +449,9 @@ def compile(
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Returns:
Expand All @@ -468,14 +475,17 @@ def compile(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refittable = kwargs["refit"]

if (
"enable_cross_compile_for_windows" in kwargs.keys()
Expand Down Expand Up @@ -541,9 +551,6 @@ def compile(

engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refittable
), "Engine caching requires make_refittable to be set to True"
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
Expand Down Expand Up @@ -574,7 +581,6 @@ def compile(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand All @@ -587,6 +593,9 @@ def compile(
"reuse_cached_engines": reuse_cached_engines,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
}
Expand Down Expand Up @@ -861,7 +870,6 @@ def convert_exported_program_to_serialized_trt_engine(
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
disable_tf32: bool = _defaults.DISABLE_TF32,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
Expand All @@ -872,6 +880,9 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> bytes:
Expand Down Expand Up @@ -922,7 +933,6 @@ def convert_exported_program_to_serialized_trt_engine(
Only applicable for `ir="dynamo"`; has no effect for `torch.compile` path
disable_tf32 (bool): Whether to disable TF32 computation for TRT layers
sparse_weights (bool): Whether to allow the builder to use sparse weights
refit (bool): Whether to build a refittable engine
engine_capability (trt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
Expand All @@ -933,6 +943,9 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
Expand All @@ -954,10 +967,17 @@ def convert_exported_program_to_serialized_trt_engine(
)
if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)

if arg_inputs is None and inputs is None:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")

Expand Down Expand Up @@ -1000,7 +1020,6 @@ def convert_exported_program_to_serialized_trt_engine(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
Expand All @@ -1009,6 +1028,9 @@ def convert_exported_program_to_serialized_trt_engine(
"timing_cache_path": timing_cache_path,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_weight_streaming": enable_weight_streaming,
}

Expand Down
6 changes: 4 additions & 2 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
MAKE_REFITTABLE = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
Expand All @@ -38,10 +37,13 @@
CACHE_BUILT_ENGINES = False
REUSE_CACHED_ENGINES = False
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
ENGINE_CACHE_SIZE = 1073741824
ENGINE_CACHE_SIZE = 5368709120 # 5GB
CUSTOM_ENGINE_CACHE = None
USE_EXPLICIT_TYPING = False
USE_FP32_ACC = False
REFIT_IDENTICAL_ENGINE_WEIGHTS = False
STRIP_ENGINE_WEIGHTS = False
IMMUTABLE_WEIGHTS = False
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False

Expand Down
Loading
Loading