Skip to content

Commit

Permalink
sampler memory
Browse files Browse the repository at this point in the history
Signed-off-by: Jennifer Zhao <7443418+JenZhao@users.noreply.github.com>
  • Loading branch information
JenZhao committed Feb 20, 2025
1 parent a4c402a commit 185d41f
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

from vllm.v1.sample.metadata import SamplingMetadata
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput

Expand Down Expand Up @@ -1303,11 +1303,34 @@ def profile_run(self) -> None:
if get_pp_group().is_last_rank:
hidden_states = hidden_states[logit_indices]
logits = self.model.compute_logits(hidden_states, None)
# TODO(woosuk): Consider the memory usage of the sampler.
penalties = torch.full((logits.size(0),), 0.0, device=self.device)
dummy_metadata = SamplingMetadata(
temperature=torch.full((logits.size(0),), 0.5, device=self.device),
all_greedy=False,
all_random=False,
spec_token_ids=None,
top_p=torch.full((logits.size(0),), 0.99, device=self.device),
top_k=torch.full((logits.size(0),), logits.size(1) - 1, device=self.device),
min_p=None,
generators={},
max_num_logprobs=None,
no_penalties=True,
prompt_token_ids=None,
frequency_penalties=penalties,
presence_penalties=penalties,
repetition_penalties=penalties,
output_token_ids=[[] for _ in range(logits.size(0))],
min_tokens={},
logit_bias=[None for _ in range(logits.size(0))]
)
sampler_output = self.model.sample(logits=logits, sampling_metadata=dummy_metadata)
else:
logits = None
sampler_output = None
penalties = None
dummy_metadata = None
torch.cuda.synchronize()
del hidden_states, logits
del hidden_states, logits, sampler_output, penalties, dummy_metadata
self.encoder_cache.clear()
gc.collect()

Expand Down

0 comments on commit 185d41f

Please sign in to comment.