diff --git a/serve/mlc_serve/engine/staging_engine.py b/serve/mlc_serve/engine/staging_engine.py index 52eafd9831..a457ff8385 100644 --- a/serve/mlc_serve/engine/staging_engine.py +++ b/serve/mlc_serve/engine/staging_engine.py @@ -229,7 +229,7 @@ def step(self) -> InferenceStepResult: gen_seq = state.generation_sequences[seq_output.id.sequence_index] new_token_ids = seq_output.new_tokens - LOG.debug(f"New token ids: {new_token_ids}") + # LOG.debug(f"New token ids: {new_token_ids}") if new_token_ids: delta, logprob_info = prepare_output( gen_seq, diff --git a/serve/mlc_serve/model/tvm_model.py b/serve/mlc_serve/model/tvm_model.py index 4b0d9ff228..df24309f06 100644 --- a/serve/mlc_serve/model/tvm_model.py +++ b/serve/mlc_serve/model/tvm_model.py @@ -285,18 +285,6 @@ def generate_multi_query( # Use `vocab_size` as a padding past_decode_tokens.append([self.vocab_size, *request.queries.token_ids]) - # Prepare sampling tensors in another stream to overlap - # CPU<->GPU data transfer with GPU computation in forward pass. - with torch.cuda.stream(self._copy_stream): - sampling_state = SamplingState.from_sampling_params( - sampling_params, - past_decode_tokens, - prompt_masks, - self.torch_dtype, - self.torch_dev, - self.vocab_size, - ) - ( input_ids, positions, @@ -335,6 +323,18 @@ def generate_multi_query( self.params, ) + # Prepare sampling tensors in another stream to overlap + # CPU<->GPU data transfer with GPU computation in forward pass. + with torch.cuda.stream(self._copy_stream): + sampling_state = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + prompt_masks, + self.torch_dtype, + self.torch_dev, + self.vocab_size, + ) + if self.disco_session: logits, _ = out.debug_get_from_remote(0) else: @@ -415,18 +415,6 @@ def generate( all_token_ids.append(request.token_ids) sampling_params.append(request.sampling_params) - # Prepare sampling tensors in another stream to overlap - # CPU<->GPU data transfer with GPU computation in forward pass. - with torch.cuda.stream(self._copy_stream): - sampling_state = SamplingState.from_sampling_params( - sampling_params, - past_decode_tokens, - prompt_masks, - self.torch_dtype, - self.torch_dev, - self.vocab_size, - ) - ( input_ids, positions, @@ -508,6 +496,18 @@ def generate( self.params, ) + # Prepare sampling tensors in another stream to overlap + # CPU<->GPU data transfer with GPU computation in forward pass. + with torch.cuda.stream(self._copy_stream): + sampling_state = SamplingState.from_sampling_params( + sampling_params, + past_decode_tokens, + prompt_masks, + self.torch_dtype, + self.torch_dev, + self.vocab_size, + ) + if self.disco_session: logits, _ = out.debug_get_from_remote(0) else: