-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrapper module around TRT + pytorch subgraphs #3270
base: main
Are you sure you want to change the base?
Conversation
0c6b6cd
to
7ac1176
Compare
7ac1176
to
ce063a2
Compare
core/runtime/TRTEngine.h
Outdated
@@ -102,7 +103,8 @@ struct TRTEngine : torch::CustomClassHolder { | |||
std::vector<at::Tensor> input_buffers = {}; | |||
std::vector<at::Tensor> output_buffers = {}; | |||
std::string shape_key; | |||
|
|||
bool cudagraphs_enabled = false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we made this like an enum that describes the runtime mode STANDARD, SUBGRAPH_CUDAGRAPHS, WHOLE_GRAPH_CUDAGRAPHS
so we avoid the checks in execute engine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually why does the engine need to know if its in whole graph cudagraph mode? seems like it doesnt actually change the behavior from "STANDARD" mode
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This instance variable is required to check previous status of cuda graph.
If cuda graph is enabled in second forward() and input shape doesn't change,
need_cudagraphs_record=False. it's not related wrapper module but needs be addressed together.
cuda graph input shape
disabled (2, 3)
enabled (2, 3) -> need_cudagraphs_record=False as only shape change was checked.
output_dtypes: List[torch.dtype], | ||
): | ||
super(WrapperTorchTensorRTModule, self).__init__() | ||
self.original_module = original_module |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this module be the original PyTorch one or the compiled one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is compiled module. I renamed it to be descriptive.
original_module - unmodified FX GraphModule. It will be used to infer output shape when dynamic input is used.
compiled_module- complied fx graphModule that will be wrapped
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keehyuna can you add an example on how this would be used by a user?
if self.input_is_dynamic: | ||
with FakeTensorMode() as mode: | ||
fake_inputs = [mode.from_tensor(input) for input in inputs] | ||
tmp_outputs = self.original_module(*fake_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I can use complied_module() instead. I'm now sure which one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compiled module is probably better
py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
@@ -832,6 +835,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |||
|
|||
dryrun_stats_display(dryrun_tracker, settings.dryrun) | |||
|
|||
if len(dryrun_tracker.to_run_in_torch) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I enabled wrapper module when there is graph break.
@narendasan Do you think we need another runtime api to enable wrapped module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could have a compile setting that is something like run_with_cudagraphs
or something that returns the wrapped module. But I think its confusing that we have both, I think at some point we might want to just pick one
py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py
Outdated
Show resolved
Hide resolved
tmp_outputs = self.original_module(*fake_inputs) | ||
if not isinstance(tmp_outputs, (list, tuple)): | ||
tmp_outputs = [tmp_outputs] | ||
self.output_shapes = [tuple(output.shape) for output in tmp_outputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of performing dummy inference to get output shapes here ? Can we not infer the shapes from the compiled_module eg: using output_node.meta["val"] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I check meta["val"] from compiled_module when dynamic input shape was usd, I got symInt node. Please let me know if I miss something.
FakeTensor(..., device='cuda:0', size=(s0, 3, s1, 224))
s.wait_stream(torch.cuda.current_stream()) | ||
with torch.cuda.stream(s): | ||
for _ in range(3): | ||
self.compiled_module(*inputs_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we wrap only the warm up runs in a error log level ? The scenario I'm thinking is let's say user ran his model in debug mode using
with torch_tensorrt.logging.debug():
trt_model = torch_tensorrt.compile(model, inputs, debug=True)
This would print the warm up logs as well which can be confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great suggestion! I'll update it accordingly
|
||
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: | ||
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors | ||
contiguous_inputs: List[torch.Tensor] = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like some of the code here is similar to PythonTorchTensorRTModule
. Can we functionalize them and reuse across these modules ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, I will check once api is finalized.
# Disable cudagrphs in submodules as it will be enabled in wrapper | ||
for name, rt_mod in self.compiled_module.named_children(): | ||
if "_run_on_acc" in name: | ||
rt_mod.set_whole_cudagraphs(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how are we disabling them in submodules ? Also, do we now have two modes like intra cuda graphs (only within TRT modules) and inter cuda graphs ? I believe we should just have one concept of use_cuda_graphs exposed to users. What do you think ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if set_whole_cudagraphs(True) in trt sub module, cuda graph is ignored and only active in wrapper module.
Currently, optimized module return wrapper module only there is graph break. No change in cuda runtime api.
Or we can add compiler option or additional api to create wrapper module.
bd78c31
to
711930f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan / @peri044
I would appreciate your feedback on three candidate API drafts.
Please let me know if you have any alternative ideas or suggestions for the API
- Compiler option: enable_wrapper_module=True
wrapped_module = torchtrt.compile(
...
enable_wrapper_module=True,
)
wrapped_module(*input)
simple usage. but it is difficult to handle torch.export(wrapped_module) when torch module is in input side.
- Optional parameter in existing cuda runtime api
with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module:
wrapped_module(*input)
Reuse existing api. better usability to turn on both cuda graph and wrapped module
- Import wrapper module
wrapped_module = WrapperTorchTensorRTModule(optimized_model)
with torchtrt.runtime.enable_cudagraphs():
wrapped_module(*input)
Feels more intuitive.
Description
Wrapper module around TRT + pytorch subgraphs and record/replay cuda graphs
Fixes #3277
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: