-
Notifications
You must be signed in to change notification settings - Fork 374
feat: improve engine caching and fix bugs #3932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
6d0a780
d39a29e
9c2faf5
ea81677
2773276
211ca8f
bb93381
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,19 +4,24 @@ | |
| import logging | ||
| from typing import Any, List, NamedTuple, Optional, Sequence | ||
|
|
||
| import tensorrt as trt | ||
| import torch | ||
| from torch_tensorrt._enums import dtype | ||
| from torch_tensorrt._features import ENABLED_FEATURES | ||
| from torch_tensorrt._features import ENABLED_FEATURES, needs_refit | ||
| from torch_tensorrt._Input import Input | ||
| from torch_tensorrt.dynamo._engine_cache import BaseEngineCache | ||
| from torch_tensorrt.dynamo._settings import CompilationSettings | ||
| from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter | ||
| from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible | ||
| from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( | ||
| TRTInterpreter, | ||
| TRTInterpreterResult, | ||
| ) | ||
| from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| get_cpu_memory_usage, | ||
| get_output_dtypes, | ||
| release_host_and_device_memory, | ||
| ) | ||
| from torch_tensorrt.logging import TRT_LOGGER | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -63,6 +68,128 @@ def interpret_module_to_result( | |
| SerializedInterpreterResult | ||
| """ | ||
|
|
||
| def _insert_engine_to_cache( | ||
| hash_val: str, interpreter_result: TRTInterpreterResult | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> None: # type: ignore[unused-ignore] | ||
| # Cache the weight-stripped engine regardless of the `strip_engine_weights` setting | ||
| if engine_cache.check(hash_val) is not None: # type: ignore[union-attr] | ||
| logger.info(f"Engine already exists in cache for hash: {hash_val}") | ||
| return | ||
| if not settings.strip_engine_weights: | ||
| # set EXCLUDE_WEIGHTS flag to strip weights | ||
| serialization_config = ( | ||
| interpreter_result.engine.create_serialization_config() | ||
| ) | ||
| serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) | ||
| weight_stripped_serialized_engine = ( | ||
| interpreter_result.engine.serialize_with_config(serialization_config) | ||
| ) | ||
| else: | ||
| weight_stripped_serialized_engine = interpreter_result.engine.serialize() | ||
|
|
||
| # Insert weight-stripped engine to cache | ||
| engine_cache.insert( # type: ignore[union-attr] | ||
| hash_val, | ||
| ( | ||
| weight_stripped_serialized_engine, | ||
| interpreter_result.input_names, | ||
| interpreter_result.output_names, | ||
| inputs, | ||
| settings, | ||
| interpreter_result.weight_name_map, | ||
| interpreter_result.requires_output_allocator, | ||
| ), | ||
| ) | ||
| logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}") | ||
|
|
||
| @needs_refit # type: ignore[misc] | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]: | ||
zewenli98 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # query the cached TRT engine | ||
| cached_data = engine_cache.check(hash_val) # type: ignore[union-attr] | ||
| if cached_data is not None: # hit the cache | ||
| ( | ||
| serialized_engine, # weight-stripped engine | ||
| input_names, | ||
| output_names, | ||
| cached_engine_inputs, | ||
| cached_engine_compilation_settings, | ||
| weight_name_map, | ||
| requires_output_allocator, | ||
| ) = cached_data | ||
|
|
||
| setting_compatiblity, incompattible_settings = settings_are_compatible( | ||
| settings, cached_engine_compilation_settings | ||
| ) | ||
| assert ( | ||
| setting_compatiblity | ||
| ), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})" | ||
|
|
||
| for i, e in enumerate( | ||
| [ | ||
| Input.equivalent_spec(c, i) | ||
| for c, i in zip(cached_engine_inputs, inputs) | ||
| ] | ||
| ): | ||
| assert ( | ||
| e | ||
| ), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}" | ||
|
|
||
| logger.info( | ||
| "Found the cached engine that corresponds to this graph. It is directly loaded." | ||
| ) | ||
|
|
||
| # refit the cached engine with the new graph module | ||
| if not settings.strip_engine_weights: | ||
| runtime = trt.Runtime(TRT_LOGGER) | ||
| engine = runtime.deserialize_cuda_engine( | ||
| serialized_engine | ||
| ) # weight-stripped engine | ||
|
|
||
| from torch_tensorrt.dynamo._refit import ( | ||
| _refit_single_trt_engine_with_gm, | ||
| ) | ||
|
|
||
| # weight-stripped engine --in place--> weight-included engine | ||
| _refit_single_trt_engine_with_gm( | ||
| new_gm=module, | ||
| old_engine=engine, | ||
| input_list=inputs, | ||
| settings=settings, | ||
| weight_name_map=weight_name_map, | ||
| ) | ||
|
|
||
| # EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set | ||
| serialization_config = engine.create_serialization_config() | ||
| serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) | ||
| serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT) | ||
| serialized_engine = engine.serialize_with_config(serialization_config) | ||
| # Start from here, the engine is weight-included and refittable | ||
|
|
||
| with io.BytesIO() as engine_bytes: | ||
| engine_bytes.write(serialized_engine) | ||
| serialized_engine = engine_bytes.getvalue() | ||
|
|
||
| return SerializedInterpreterResult( | ||
| serialized_engine=serialized_engine, | ||
| input_names=input_names, | ||
| output_names=output_names, | ||
| weight_name_map=weight_name_map, | ||
| requires_output_allocator=requires_output_allocator, | ||
| ) | ||
| return None | ||
|
|
||
| # engine_cache could be None if: | ||
| # 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or | ||
| # 2) both cache_built_engines and reuse_cached_engines are False | ||
| if engine_cache is not None and not settings.immutable_weights: | ||
| if settings.cache_built_engines or settings.reuse_cached_engines: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this code is unclear. Would recommend something like this hash_val = engine_cache.get_hash(module, inputs, settings) if (settings.cache_built_engines or settings.reuse_cached_engines) else None
if settings.reuse_cached_engines:
serialized_interpreter_result = pull_cached_engine(
hash_val, module, engine_cache, settings, inputs
)
if serialized_interpreter_result is not None: # hit the cache
return serialized_interpreter_result
...
if (
ENABLED_FEATURES.refit
and not settings.immutable_weights
and settings.cache_built_engines
and engine_cache is not None
):
_ = insert_engine_to_cache(
hash_val, interpreter_result, engine_cache, settings, inputs
)
serialized_engine = interpreter_result.engine.serialize()
|
||
| hash_val = engine_cache.get_hash(module, inputs, settings) | ||
|
|
||
| if settings.reuse_cached_engines: | ||
| serialized_interpreter_result = _pull_cached_engine(hash_val) | ||
| if serialized_interpreter_result is not None: # hit the cache | ||
| return serialized_interpreter_result # type: ignore[no-any-return] | ||
|
|
||
| output_dtypes = infer_module_output_dtypes( | ||
| module, truncate_double=settings.truncate_double | ||
| ) | ||
|
|
@@ -86,32 +213,20 @@ def interpret_module_to_result( | |
| f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| serialized_engine = interpreter_result.engine.serialize() | ||
| with io.BytesIO() as engine_bytes: | ||
| engine_bytes.write(serialized_engine) | ||
| serialized_engine = engine_bytes.getvalue() | ||
| logger.debug( | ||
| f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| # Engine caching only for refittable engines | ||
| if ( | ||
| not settings.immutable_weights | ||
| and settings.cache_built_engines | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably throw a warning or something if |
||
| and engine_cache is not None | ||
| ): | ||
| hash_val = engine_cache.get_hash(module, inputs, settings) | ||
| engine_cache.insert( | ||
| hash_val, | ||
| ( | ||
| serialized_engine, | ||
| interpreter_result.input_names, | ||
| interpreter_result.output_names, | ||
| inputs, | ||
| settings, | ||
| interpreter_result.weight_name_map, | ||
| interpreter_result.requires_output_allocator, | ||
| ), | ||
| _insert_engine_to_cache(hash_val, interpreter_result) | ||
|
|
||
| serialized_engine = interpreter_result.engine.serialize() | ||
| with io.BytesIO() as engine_bytes: | ||
| engine_bytes.write(serialized_engine) | ||
| serialized_engine = engine_bytes.getvalue() | ||
| logger.debug( | ||
| f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB" | ||
| ) | ||
|
|
||
| serialized_interpreter_result = SerializedInterpreterResult( | ||
|
|
@@ -122,7 +237,7 @@ def interpret_module_to_result( | |
| requires_output_allocator=interpreter_result.requires_output_allocator, | ||
| ) | ||
|
|
||
| return serialized_interpreter_result | ||
| return serialized_interpreter_result # type: ignore[no-any-return] | ||
|
|
||
|
|
||
| def convert_module( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When would a torch.compile use try to use strip weights?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added the warning back. Not sure why strip_engine_weights arg doesn't work for torch.compile()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just doenst make sense. Torch compile is not serializable. So why would you ever want a callable that doesnt have the weights in it