Skip to content

Commit

Permalink
[5/N][torch.compile] torch.jit.script --> torch.compile (vllm-project…
Browse files Browse the repository at this point in the history
…#10406)

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: rickyx <rickyx@anyscale.com>
  • Loading branch information
youkaichao authored and rickyyx committed Nov 20, 2024
1 parent 25ff0d9 commit 8ce4c4f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def _smallest_positive_value(self) -> float:
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
@torch.compile(dynamic=True)
def _multinomial(
probs: torch.Tensor,
num_samples: int,
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def __post_init__(self):
assert self.num_added_elements <= self.num_added_elements_padded


@torch.jit.script
@torch.compile(dynamic=True)
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def weight_loader(self, param: torch.nn.Parameter,
return load_column_parallel_weight(param, loaded_weight)


@torch.jit.script
@torch.compile(dynamic=True)
def quick_gelu(x):
return x * torch.sigmoid(1.702 * x)


@torch.jit.script
@torch.compile(dynamic=True)
def gegelu(input, limit: Optional[float] = None):
a_gelu, a_linear = input[..., ::2], input[..., 1::2]
if limit is not None:
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1769,7 +1769,7 @@ def capture(
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# Note one iteration is not enough for torch.jit.script
# Note one iteration is not enough for torch.compile
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids=input_ids,
Expand Down

0 comments on commit 8ce4c4f

Please sign in to comment.