Skip to content

Fix: Refactor AsyncExitStack usage to resolve lock handling errors #1862

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 136 additions & 142 deletions llama_cpp/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,86 +267,83 @@ async def create_completion(
request: Request,
body: CreateCompletionRequest,
) -> llama_cpp.Completion:
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
async with contextlib.AsyncExitStack() as exit_stack:
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""

llama = llama_proxy(
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)
if isinstance(body.prompt, list):
assert len(body.prompt) <= 1
body.prompt = body.prompt[0] if len(body.prompt) > 0 else ""

llama = llama_proxy(
body.model
if request.url.path != "/v1/engines/copilot-codex/completions"
else "copilot-codex"
)

exclude = {
"n",
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
exclude = {
"n",
"best_of",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)

if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
try:
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)
except Exception as err:
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.CreateCompletionResponse,
Iterator[llama_cpp.CreateCompletionStreamResponse],
] = await run_in_threadpool(llama, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.CreateCompletionStreamResponse]:
yield first_response
yield from iterator_or_completion

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion
return iterator_or_completion


@router.post(
Expand Down Expand Up @@ -474,74 +471,71 @@ async def create_chat_completion(
# where the dependency is cleaned up before a StreamingResponse
# is complete.
# https://github.com/tiangolo/fastapi/issues/11143
exit_stack = contextlib.AsyncExitStack()
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
exclude = {
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)
async with contextlib.AsyncExitStack() as exit_stack:
llama_proxy = await exit_stack.enter_async_context(contextlib.asynccontextmanager(get_llama_proxy)())
if llama_proxy is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Service is not available",
)
exclude = {
"n",
"logit_bias_type",
"user",
"min_tokens",
}
kwargs = body.model_dump(exclude=exclude)
llama = llama_proxy(body.model)
if body.logit_bias is not None:
kwargs["logit_bias"] = (
_logit_bias_tokens_to_input_ids(llama, body.logit_bias)
if body.logit_bias_type == "tokens"
else body.logit_bias
)

if body.grammar is not None:
kwargs["grammar"] = llama_cpp.LlamaGrammar.from_string(body.grammar)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

if body.min_tokens > 0:
_min_tokens_logits_processor = llama_cpp.LogitsProcessorList(
[llama_cpp.MinTokensLogitsProcessor(body.min_tokens, llama.token_eos())]
)
if "logits_processor" not in kwargs:
kwargs["logits_processor"] = _min_tokens_logits_processor
try:
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
except Exception as err:
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
kwargs["logits_processor"].extend(_min_tokens_logits_processor)

try:
iterator_or_completion: Union[
llama_cpp.ChatCompletion, Iterator[llama_cpp.ChatCompletionChunk]
] = await run_in_threadpool(llama.create_chat_completion, **kwargs)
except Exception as err:
await exit_stack.aclose()
raise err

if isinstance(iterator_or_completion, Iterator):
# EAFP: It's easier to ask for forgiveness than permission
first_response = await run_in_threadpool(next, iterator_or_completion)

# If no exception was raised from first_response, we can assume that
# the iterator is valid and we can use it to stream the response.
def iterator() -> Iterator[llama_cpp.ChatCompletionChunk]:
yield first_response
yield from iterator_or_completion

send_chan, recv_chan = anyio.create_memory_object_stream(10)
return EventSourceResponse(
recv_chan,
data_sender_callable=partial( # type: ignore
get_event_publisher,
request=request,
inner_send_chan=send_chan,
iterator=iterator(),
on_complete=exit_stack.aclose,
),
sep="\n",
ping_message_factory=_ping_message_factory,
)
else:
await exit_stack.aclose()
return iterator_or_completion
return iterator_or_completion


@router.get(
Expand Down