Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[5/N][torch.compile] torch.jit.script --> torch.compile #10406

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _smallest_positive_value(self) -> float:
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
@torch.compile(dynamic=True)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def __post_init__(self):
assert self.num_added_elements <= self.num_added_elements_padded


@torch.jit.script
@torch.compile(dynamic=True)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def weight_loader(self, param: torch.nn.Parameter,
return load_column_parallel_weight(param, loaded_weight)


@torch.jit.script
@torch.compile(dynamic=True)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)


@torch.jit.script
@torch.compile(dynamic=True)
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,7 +1770,7 @@ def capture(
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# Note one iteration is not enough for torch.jit.script
# Note one iteration is not enough for torch.compile
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think one warmup iteration should be enough for torch.compile , but we can investigate and confirm later.

for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
Expand Down