diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index b6db453b8..e5c741124 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -267,86 +267,83 @@ async def create_completion( request: Request, body: CreateCompletionRequest, ) -> llama_cpp.Completion: - exit_stack = contextlib.AsyncExitStack() - llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - if llama_proxy is None: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Service is not available", + async with contextlib.AsyncExitStack() as exit_stack: + llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + if llama_proxy is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Service is not available", + ) + if isinstance(body.prompt, list): + assert len(body.prompt) <= 1 + body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" + + llama = llama_proxy( + body.model + if request.url.path != "/v1/engines/copilot-codex/completions" + else "copilot-codex" ) - if isinstance(body.prompt, list): - assert len(body.prompt) <= 1 - body.prompt = body.prompt[0] if len(body.prompt) > 0 else "" - - llama = llama_proxy( - body.model - if request.url.path != "/v1/engines/copilot-codex/completions" - else "copilot-codex" - ) - - exclude = { - "n", - "best_of", - "logit_bias_type", - "user", - "min_tokens", - } - kwargs = body.model_dump(exclude=exclude) - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + exclude = { + "n", + "best_of", + "logit_bias_type", + "user", + "min_tokens", + } + kwargs = body.model_dump(exclude=exclude) + + if body.logit_bias is not None: + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, body.logit_bias) + if body.logit_bias_type == "tokens" + else body.logit_bias + ) + + if body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + + if body.min_tokens > 0: + _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( + [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] + ) + if "logits_processor" not in kwargs: + kwargs["logits_processor"] = _min_tokens_logits_processor + else: + kwargs["logits_processor"].extend(_min_tokens_logits_processor) - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor + try: + iterator_or_completion: Union[ + llama_cpp.CreateCompletionResponse, + Iterator[llama_cpp.CreateCompletionStreamResponse], + ] = await run_in_threadpool(llama, **kwargs) + except Exception as err: + raise err + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + sep="\n", + ping_message_factory=_ping_message_factory, + ) else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) - - try: - iterator_or_completion: Union[ - llama_cpp.CreateCompletionResponse, - Iterator[llama_cpp.CreateCompletionStreamResponse], - ] = await run_in_threadpool(llama, **kwargs) - except Exception as err: - await exit_stack.aclose() - raise err - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]: - yield first_response - yield from iterator_or_completion - - send_chan, recv_chan = anyio.create_memory_object_stream(10) - return EventSourceResponse( - recv_chan, - data_sender_callable=partial( # type: ignore - get_event_publisher, - request=request, - inner_send_chan=send_chan, - iterator=iterator(), - on_complete=exit_stack.aclose, - ), - sep="\n", - ping_message_factory=_ping_message_factory, - ) - else: - await exit_stack.aclose() - return iterator_or_completion + return iterator_or_completion @router.post( @@ -474,74 +471,71 @@ async def create_chat_completion( # where the dependency is cleaned up before a StreamingResponse # is complete. # https://github.com/tiangolo/fastapi/issues/11143 - exit_stack = contextlib.AsyncExitStack() - llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) - if llama_proxy is None: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Service is not available", - ) - exclude = { - "n", - "logit_bias_type", - "user", - "min_tokens", - } - kwargs = body.model_dump(exclude=exclude) - llama = llama_proxy(body.model) - if body.logit_bias is not None: - kwargs["logit_bias"] = ( - _logit_bias_tokens_to_input_ids(llama, body.logit_bias) - if body.logit_bias_type == "tokens" - else body.logit_bias - ) - - if body.grammar is not None: - kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + async with contextlib.AsyncExitStack() as exit_stack: + llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)()) + if llama_proxy is None: + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Service is not available", + ) + exclude = { + "n", + "logit_bias_type", + "user", + "min_tokens", + } + kwargs = body.model_dump(exclude=exclude) + llama = llama_proxy(body.model) + if body.logit_bias is not None: + kwargs["logit_bias"] = ( + _logit_bias_tokens_to_input_ids(llama, body.logit_bias) + if body.logit_bias_type == "tokens" + else body.logit_bias + ) + + if body.grammar is not None: + kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar) + + if body.min_tokens > 0: + _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( + [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] + ) + if "logits_processor" not in kwargs: + kwargs["logits_processor"] = _min_tokens_logits_processor + else: + kwargs["logits_processor"].extend(_min_tokens_logits_processor) - if body.min_tokens > 0: - _min_tokens_logits_processor = llama_cpp.LogitsProcessorList( - [llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())] - ) - if "logits_processor" not in kwargs: - kwargs["logits_processor"] = _min_tokens_logits_processor + try: + iterator_or_completion: Union[ + llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] + ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) + except Exception as err: + raise err + + if isinstance(iterator_or_completion, Iterator): + # EAFP: It's easier to ask for forgiveness than permission + first_response = await run_in_threadpool(next, iterator_or_completion) + + # If no exception was raised from first_response, we can assume that + # the iterator is valid and we can use it to stream the response. + def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: + yield first_response + yield from iterator_or_completion + + send_chan, recv_chan = anyio.create_memory_object_stream(10) + return EventSourceResponse( + recv_chan, + data_sender_callable=partial( # type: ignore + get_event_publisher, + request=request, + inner_send_chan=send_chan, + iterator=iterator(), + ), + sep="\n", + ping_message_factory=_ping_message_factory, + ) else: - kwargs["logits_processor"].extend(_min_tokens_logits_processor) - - try: - iterator_or_completion: Union[ - llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk] - ] = await run_in_threadpool(llama.create_chat_completion, **kwargs) - except Exception as err: - await exit_stack.aclose() - raise err - - if isinstance(iterator_or_completion, Iterator): - # EAFP: It's easier to ask for forgiveness than permission - first_response = await run_in_threadpool(next, iterator_or_completion) - - # If no exception was raised from first_response, we can assume that - # the iterator is valid and we can use it to stream the response. - def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]: - yield first_response - yield from iterator_or_completion - - send_chan, recv_chan = anyio.create_memory_object_stream(10) - return EventSourceResponse( - recv_chan, - data_sender_callable=partial( # type: ignore - get_event_publisher, - request=request, - inner_send_chan=send_chan, - iterator=iterator(), - on_complete=exit_stack.aclose, - ), - sep="\n", - ping_message_factory=_ping_message_factory, - ) - else: - await exit_stack.aclose() - return iterator_or_completion + return iterator_or_completion @router.get(