Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove trailing Markdown code tags in completion suggestions #726

Merged
merged 3 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ async def handle_stream_request(self, request: InlineCompletionRequest):
continue
else:
suggestion = self._post_process_suggestion(suggestion, request)
elif suggestion.rstrip().endswith("```"):
suggestion = self._post_process_suggestion(suggestion, request)
self.write_message(
InlineCompletionStreamChunk(
type="stream",
Expand Down Expand Up @@ -151,4 +153,9 @@ def _post_process_suggestion(
if suggestion.startswith(request.prefix):
suggestion = suggestion[len(request.prefix) :]
break

# check if the suggestion ends with a closing markdown identifier and remove it
if suggestion.rstrip().endswith("```"):
suggestion = suggestion.rstrip()[:-3].rstrip()

return suggestion
49 changes: 44 additions & 5 deletions packages/jupyter-ai/jupyter_ai/tests/completions/test_handlers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from types import SimpleNamespace

import pytest
from jupyter_ai.completions.handlers.default import DefaultInlineCompletionHandler
from jupyter_ai.completions.models import InlineCompletionRequest
from jupyter_ai_magics import BaseProvider
Expand All @@ -17,7 +18,8 @@ class MockProvider(BaseProvider, FakeListLLM):
models = ["model"]

def __init__(self, **kwargs):
kwargs["responses"] = ["Test response"]
if not "responses" in kwargs:
kwargs["responses"] = ["Test response"]
super().__init__(**kwargs)


Expand All @@ -34,7 +36,7 @@ def __init__(self):
create_task=lambda x: self.tasks.append(x)
)
self.settings["model_parameters"] = {}
self.llm_params = {}
self.llm_params = {"model_id": "model"}
self.create_llm_chain(MockProvider, {"model_id": "model"})

def write_message(self, message: str) -> None: # type: ignore
Expand Down Expand Up @@ -88,8 +90,45 @@ async def test_handle_request(inline_handler):
assert suggestions[0].insertText == "Test response"


@pytest.mark.parametrize(
"response,expected_suggestion",
[
("```python\nTest python code\n```", "Test python code"),
("```\ntest\n```\n \n", "test"),
("```hello```world```", "hello```world"),
],
)
async def test_handle_request_with_spurious_fragments(response, expected_suggestion):
inline_handler = MockCompletionHandler()
inline_handler.create_llm_chain(
MockProvider,
{
"model_id": "model",
"responses": [response],
},
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=False
)

await inline_handler.handle_request(dummy_request)
# should write a single reply
assert len(inline_handler.messages) == 1
# reply should contain a single suggestion
suggestions = inline_handler.messages[0].list.items
assert len(suggestions) == 1
# the suggestion should include insert text from LLM without spurious fragments
assert suggestions[0].insertText == expected_suggestion


async def test_handle_stream_request(inline_handler):
inline_handler.llm_chain = FakeListLLM(responses=["test"])
inline_handler.create_llm_chain(
MockProvider,
{
"model_id": "model",
"responses": ["test"],
},
)
dummy_request = InlineCompletionRequest(
number=1, prefix="", suffix="", mime="", stream=True
)
Expand All @@ -106,11 +145,11 @@ async def test_handle_stream_request(inline_handler):
# second reply should be a chunk containing the token
second = inline_handler.messages[1]
assert second.type == "stream"
assert second.response.insertText == "Test response"
assert second.response.insertText == "test"
assert second.done == False

# third reply should be a closing chunk
third = inline_handler.messages[2]
assert third.type == "stream"
assert third.response.insertText == "Test response"
assert third.response.insertText == "test"
assert third.done == True
Loading