Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper module around TRT + pytorch subgraphs #3270

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

keehyuna
Copy link
Collaborator

@keehyuna keehyuna commented Oct 31, 2024

Description

Wrapper module around TRT + pytorch subgraphs and record/replay cuda graphs

Fixes #3277

Type of change

Please delete options that are not relevant and/or add your own.

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Oct 31, 2024
@keehyuna keehyuna self-assigned this Oct 31, 2024
@github-actions github-actions bot added the component: core Issues re: The core compiler label Nov 6, 2024
@github-actions github-actions bot removed the component: tests Issues re: Tests label Nov 15, 2024
@keehyuna keehyuna marked this pull request as ready for review November 15, 2024 10:36
@@ -102,7 +103,8 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;

bool cudagraphs_enabled = false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we made this like an enum that describes the runtime mode STANDARD, SUBGRAPH_CUDAGRAPHS, WHOLE_GRAPH_CUDAGRAPHS so we avoid the checks in execute engine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually why does the engine need to know if its in whole graph cudagraph mode? seems like it doesnt actually change the behavior from "STANDARD" mode

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This instance variable is required to check previous status of cuda graph.

If cuda graph is enabled in second forward() and input shape doesn't change,
need_cudagraphs_record=False. it's not related wrapper module but needs be addressed together.

cuda graph input shape
disabled (2, 3)
enabled (2, 3) -> need_cudagraphs_record=False as only shape change was checked.

https://github.com/pytorch/TensorRT/pull/3270/files#diff-12ec24175cd347c79cd109427bde289f503e2dec9fc9d1a27871e5523a218638R119-R123

output_dtypes: List[torch.dtype],
):
super(WrapperTorchTensorRTModule, self).__init__()
self.original_module = original_module
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this module be the original PyTorch one or the compiled one?

Copy link
Collaborator Author

@keehyuna keehyuna Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is compiled module. I renamed it to be descriptive.
original_module - unmodified FX GraphModule. It will be used to infer output shape when dynamic input is used.
compiled_module- complied fx graphModule that will be wrapped

Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keehyuna can you add an example on how this would be used by a user?

@github-actions github-actions bot added the component: tests Issues re: Tests label Nov 18, 2024
if self.input_is_dynamic:
with FakeTensorMode() as mode:
fake_inputs = [mode.from_tensor(input) for input in inputs]
tmp_outputs = self.original_module(*fake_inputs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can use complied_module() instead. I'm now sure which one is better.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compiled module is probably better

@@ -832,6 +835,15 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

dryrun_stats_display(dryrun_tracker, settings.dryrun)

if len(dryrun_tracker.to_run_in_torch) > 0:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enabled wrapper module when there is graph break.
@narendasan Do you think we need another runtime api to enable wrapped module?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could have a compile setting that is something like run_with_cudagraphs or something that returns the wrapped module. But I think its confusing that we have both, I think at some point we might want to just pick one

tmp_outputs = self.original_module(*fake_inputs)
if not isinstance(tmp_outputs, (list, tuple)):
tmp_outputs = [tmp_outputs]
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of performing dummy inference to get output shapes here ? Can we not infer the shapes from the compiled_module eg: using output_node.meta["val"] ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I check meta["val"] from compiled_module when dynamic input shape was usd, I got symInt node. Please let me know if I miss something.
FakeTensor(..., device='cuda:0', size=(s0, 3, s1, 224))

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py Outdated Show resolved Hide resolved
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
self.compiled_module(*inputs_tensor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we wrap only the warm up runs in a error log level ? The scenario I'm thinking is let's say user ran his model in debug mode using

with torch_tensorrt.logging.debug():
      trt_model = torch_tensorrt.compile(model, inputs, debug=True)

This would print the warm up logs as well which can be confusing.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great suggestion! I'll update it accordingly


def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
contiguous_inputs: List[torch.Tensor] = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like some of the code here is similar to PythonTorchTensorRTModule. Can we functionalize them and reuse across these modules ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I will check once api is finalized.

# Disable cudagrphs in submodules as it will be enabled in wrapper
for name, rt_mod in self.compiled_module.named_children():
if "_run_on_acc" in name:
rt_mod.set_whole_cudagraphs(True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how are we disabling them in submodules ? Also, do we now have two modes like intra cuda graphs (only within TRT modules) and inter cuda graphs ? I believe we should just have one concept of use_cuda_graphs exposed to users. What do you think ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if set_whole_cudagraphs(True) in trt sub module, cuda graph is ignored and only active in wrapper module.
Currently, optimized module return wrapper module only there is graph break. No change in cuda runtime api.
Or we can add compiler option or additional api to create wrapper module.

Copy link
Collaborator Author

@keehyuna keehyuna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan / @peri044
I would appreciate your feedback on three candidate API drafts.
Please let me know if you have any alternative ideas or suggestions for the API

  1. Compiler option: enable_wrapper_module=True
wrapped_module = torchtrt.compile(
    ...
    enable_wrapper_module=True,
)
wrapped_module(*input)

simple usage. but it is difficult to handle torch.export(wrapped_module) when torch module is in input side.

  1. Optional parameter in existing cuda runtime api
with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module:
    wrapped_module(*input)

Reuse existing api. better usability to turn on both cuda graph and wrapped module

  1. Import wrapper module
wrapped_module = WrapperTorchTensorRTModule(optimized_model)
with torchtrt.runtime.enable_cudagraphs():
    wrapped_module(*input)

Feels more intuitive.

https://github.com/pytorch/TensorRT/pull/3270/files#diff-0355c793a29056c378c96cae4e041427e64ca145a9199310c25b43ac1091bbecR235-R282

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

✨[Feature] Performance optimization of PyTorch + TRT subgraphs
4 participants