Skip to content

Commit

Permalink
[PERF] Fixes (#221)
Browse files Browse the repository at this point in the history
fixes
  • Loading branch information
vvchernov authored Feb 26, 2024
1 parent 8ee6aaa commit 941320c
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 25 deletions.
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/staging_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def step(self) -> InferenceStepResult:

gen_seq = state.generation_sequences[seq_output.id.sequence_index]
new_token_ids = seq_output.new_tokens
LOG.debug(f"New token ids: {new_token_ids}")
# LOG.debug(f"New token ids: {new_token_ids}")
if new_token_ids:
delta, logprob_info = prepare_output(
gen_seq,
Expand Down
48 changes: 24 additions & 24 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,18 +285,6 @@ def generate_multi_query(
# Use `vocab_size` as a padding
past_decode_tokens.append([self.vocab_size, *request.queries.token_ids])

# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
prompt_masks,
self.torch_dtype,
self.torch_dev,
self.vocab_size,
)

(
input_ids,
positions,
Expand Down Expand Up @@ -335,6 +323,18 @@ def generate_multi_query(
self.params,
)

# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
prompt_masks,
self.torch_dtype,
self.torch_dev,
self.vocab_size,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
Expand Down Expand Up @@ -415,18 +415,6 @@ def generate(
all_token_ids.append(request.token_ids)
sampling_params.append(request.sampling_params)

# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
prompt_masks,
self.torch_dtype,
self.torch_dev,
self.vocab_size,
)

(
input_ids,
positions,
Expand Down Expand Up @@ -508,6 +496,18 @@ def generate(
self.params,
)

# Prepare sampling tensors in another stream to overlap
# CPU<->GPU data transfer with GPU computation in forward pass.
with torch.cuda.stream(self._copy_stream):
sampling_state = SamplingState.from_sampling_params(
sampling_params,
past_decode_tokens,
prompt_masks,
self.torch_dtype,
self.torch_dev,
self.vocab_size,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
else:
Expand Down

0 comments on commit 941320c

Please sign in to comment.