Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: example fixes #3176

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ Tutorials
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2

Python API Documentation
------------------------
Expand Down
31 changes: 24 additions & 7 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
.. _torch_compile:

Dynamo / ``torch.compile``
----------------------------
Torch-TensorRT Examples
====================================

Torch-TensorRT provides a backend for the new ``torch.compile`` API released in PyTorch 2.0. In the following examples we describe
a number of ways you can leverage this backend to accelerate inference.
Please refer to the following examples which demonstrate the usage of different features of Torch-TensorRT. We also provide
examples of Torch-TensorRT compilation of select computer vision and language models.

* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
Dependencies
------------------------------------

Please install the following external depencies (assuming you already have `torch_tensorrt` installed)

.. code-block:: python

pip install -r requirements.txt


Compiler Features
------------------------------------
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`torch_export_cudagraphs`: Using the Cudagraphs integration with `ir="dynamo"`
* :ref:`converter_overloading`: How to write custom converters and overload existing ones
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
Expand All @@ -18,3 +27,11 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT

Model Zoo
------------------------------------
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
4 changes: 2 additions & 2 deletions examples/dynamo/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cupy==13.1.0
torch>=2.4.0.dev20240503+cu121
torch-tensorrt>=2.4.0.dev20240503+cu121
triton==2.3.0
diffusers==0.30.3
transformers==4.44.2
100 changes: 100 additions & 0 deletions examples/dynamo/torch_compile_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
.. _torch_compile_gpt2:

Compiling GPT2 using the Torch-TensorRT `torch.compile` Backend
==========================================================

This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a GPT2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer

# %%

# Define the parameters
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the GPT2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.cuda()
)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()

# Auto-regressive generation loop for greedy search using PyTorch model.
pyt_gen_tokens = model.generate(
input_ids,
max_length=MAX_TOKENS,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Compilation with `torch.compile` using tensorrt backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile the model and mark the input sequence length to be dynamic
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float32},
"disable_tf32": True,
"min_block_size": 1,
"debug": True,
},
)

# Auto-regressive generation loop for greedy decoding using TensorRT model
# The first token generation compiles the model using TensorRT and the second token
# encounters recompilation
trt_gen_tokens = model.generate(
inputs=input_ids,
max_length=MAX_TOKENS,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like

# Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
# =============================
# TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
89 changes: 89 additions & 0 deletions examples/dynamo/torch_compile_llama2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
.. _torch_compile_gpt2:

Compiling GPT2 using the Torch-TensorRT `torch.compile` Backend
==========================================================

This interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a GPT2 model."""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import generate

# %%

# Define the parameters
MAX_TOKENS = 32
DEVICE = torch.device("cuda:0")

# Define the GPT2 model from hugging face
# kv_cache is not supported in Torch-TRT currently.
# CPU is used here so that GPU memory is reserved for TRT compilation.
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()

tokenizer = AutoTokenizer.from_pretrained(llama_path)

# %%
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()

# Auto-regressive generation loop for greedy search using PyTorch model.
# We use a custom generate function which is very similar to the huggingface one.
# pyt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Compilation with `torch.compile` using tensorrt backend and generate TensorRT outputs
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Compile the model and mark the input sequence length to be dynamic
with torch_tensorrt.logging.debug():
torch._dynamo.mark_dynamic(input_ids, 1, min=7, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float32},
"disable_tf32": True,
"debug": True,
# "use_python_runtime": True
},
)
model(input_ids)
breakpoint()
model(input_ids)
# Auto-regressive generation loop for greedy decoding using TensorRT model
# We use a custom generate function which is very similar to the huggingface one.
# Move inputs to GPU
input_ids = input_ids.to(DEVICE)
trt_gen_tokens = generate(model, input_ids, MAX_TOKENS, tokenizer.eos_token_id)

# %%
# Decode the output sentences of PyTorch and TensorRT
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

print("=============================")
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like
#
#
9 changes: 8 additions & 1 deletion examples/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ def generate(model, input_seq, max_tokens, eos_token_id):
)

while True:
outputs = model(input_seq)
outputs = model(
input_seq,
past_key_values=None,
position_ids=None,
attention_mask=None,
use_cache=False,
token_type_ids=None,
)
logits = outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def compile(
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)

return trt_gm


Expand Down
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _pretraced_backend(
repair_input_aliasing(gm)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm)
remove_sym_nodes(gm, sample_inputs)

torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]
Expand All @@ -91,7 +92,7 @@ def _pretraced_backend(
# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
torch_inputs,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand Down
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import logging
from typing import Any, Sequence

import torch

logger = logging.getLogger(__name__)


def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def remove_sym_nodes(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any]
) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholder_sym_ints = [
node
for node in gm.graph.nodes
placeholder_idx_sym_ints = [
(idx, node)
for idx, node in enumerate(gm.graph.nodes)
if (
node.op == "placeholder"
and isinstance(node.type, type)
Expand All @@ -21,8 +24,9 @@ def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
)
]

for node in placeholder_sym_ints:
for idx, node in placeholder_idx_sym_ints:
gm.graph.erase_node(node)
sample_inputs.pop(idx)

gm.graph.lint()
gm.recompile()
Expand Down
Loading