Skip to content

Commit

Permalink
Fix non-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
UranusSeven committed Oct 13, 2023
1 parent 16f950a commit fd3eed9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 8 deletions.
21 changes: 13 additions & 8 deletions xinference/model/llm/pytorch/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def load(self):
def generate(
self, prompt: str, generate_config: Optional[PytorchGenerateConfig] = None
) -> Union[Completion, Iterator[CompletionChunk]]:
def generator_wrapper(
_prompt: str, _generate_config: PytorchGenerateConfig
) -> Iterator[CompletionChunk]:
for _completion_chunk, _completion_usage in speculative_generate_stream(
draft_model=self._draft_model,
model=self._model,
tokenizer=self._tokenizer,
prompt=_prompt,
generate_config=_generate_config,
):
yield _completion_chunk

from .spec_decoding_utils import speculative_generate_stream

generate_config = self._sanitize_generate_config(generate_config)
Expand Down Expand Up @@ -133,14 +145,7 @@ def generate(
)
return completion
else:
for completion_chunk, completion_usage in speculative_generate_stream(
draft_model=self._draft_model,
model=self._model,
tokenizer=self._tokenizer,
prompt=prompt,
generate_config=generate_config,
):
yield completion_chunk
return generator_wrapper(prompt, generate_config)

def create_embedding(self, input: Union[str, List[str]]) -> Embedding:
raise NotImplementedError
4 changes: 4 additions & 0 deletions xinference/model/llm/pytorch/spec_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def speculative_generate_stream(
prompt: str,
generate_config: Dict[str, Any],
) -> Iterator[Tuple[CompletionChunk, CompletionUsage]]:
logger.debug(
f"Enter speculative_generate_stream, prompt: {prompt}, generate_config: {generate_config}"
)

# TODO: currently, repetition penalty leads to garbled outputs.
if float(generate_config.get("repetition_penalty", 1.0)) != 1.0:
raise ValueError(
Expand Down

0 comments on commit fd3eed9

Please sign in to comment.