From fd3eed9533a2227f4fb9ce27400be97efc89492f Mon Sep 17 00:00:00 2001 From: UranusSeven <109661872+UranusSeven@users.noreply.github.com> Date: Fri, 13 Oct 2023 16:36:01 +0800 Subject: [PATCH] Fix non-stream --- xinference/model/llm/pytorch/spec.py | 21 ++++++++++++------- .../model/llm/pytorch/spec_decoding_utils.py | 4 ++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/xinference/model/llm/pytorch/spec.py b/xinference/model/llm/pytorch/spec.py index 48fa00af83..f64bbe1b83 100644 --- a/xinference/model/llm/pytorch/spec.py +++ b/xinference/model/llm/pytorch/spec.py @@ -104,6 +104,18 @@ def load(self): def generate( self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None ) -> Union[Completion, Iterator[CompletionChunk]]: + def generator_wrapper( + _prompt: str, _generate_config: PytorchGenerateConfig + ) -> Iterator[CompletionChunk]: + for _completion_chunk, _completion_usage in speculative_generate_stream( + draft_model=self._draft_model, + model=self._model, + tokenizer=self._tokenizer, + prompt=_prompt, + generate_config=_generate_config, + ): + yield _completion_chunk + from .spec_decoding_utils import speculative_generate_stream generate_config = self._sanitize_generate_config(generate_config) @@ -133,14 +145,7 @@ def generate( ) return completion else: - for completion_chunk, completion_usage in speculative_generate_stream( - draft_model=self._draft_model, - model=self._model, - tokenizer=self._tokenizer, - prompt=prompt, - generate_config=generate_config, - ): - yield completion_chunk + return generator_wrapper(prompt, generate_config) def create_embedding(self, input: Union[str, List[str]]) -> Embedding: raise NotImplementedError diff --git a/xinference/model/llm/pytorch/spec_decoding_utils.py b/xinference/model/llm/pytorch/spec_decoding_utils.py index 61d82a42c0..a4fc2cc51f 100644 --- a/xinference/model/llm/pytorch/spec_decoding_utils.py +++ b/xinference/model/llm/pytorch/spec_decoding_utils.py @@ -252,6 +252,10 @@ def speculative_generate_stream( prompt: str, generate_config: Dict[str, Any], ) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]: + logger.debug( + f"Enter speculative_generate_stream, prompt: {prompt}, generate_config: {generate_config}" + ) + # TODO: currently, repetition penalty leads to garbled outputs. if float(generate_config.get("repetition_penalty", 1.0)) != 1.0: raise ValueError(