Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Intermediate Steps in Formulating Chat Response #799

Merged
merged 4 commits into from
Jun 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/khoj/processor/content/org_mode/org_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def convert_org_nodes_to_entries(
# Children nodes do not need ancestors trail as root parent node will have it
if not entry_heading:
ancestors_trail = " / ".join(parsed_entry.ancestors) or Path(entry_to_file_map[parsed_entry])
heading = f"* Path: {ancestors_trail}\n{heading}" if heading else f"* Path: {ancestors_trail}."
heading = f"* {ancestors_trail}\n{heading}" if heading else f"* {ancestors_trail}."

compiled = heading

Expand Down
4 changes: 1 addition & 3 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,9 +590,7 @@ async def send_rate_limit_message(message: str):
)

if compiled_references:
headings = "\n- " + "\n- ".join(
set([" ".join(c.get("compiled", c).split("Path: ")[1:]).split("\n ")[0] for c in compiled_references])
)
headings = "\n- " + "\n- ".join(set([c.get("compiled", c).split("\n")[0] for c in compiled_references]))
debanjum marked this conversation as resolved.
Show resolved Hide resolved
await send_status_update(f"**📜 Found Relevant Notes**: {headings}")

online_results: Dict = dict()
Expand Down
8 changes: 4 additions & 4 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ async def aget_relevant_output_modes(query: str, conversation_history: dict, is_
response = response.strip()

if is_none_or_empty(response):
return ConversationCommand.Default
return ConversationCommand.Text

if response in mode_options.keys():
# Check whether the tool exists as a valid ConversationCommand
return ConversationCommand(response)

return ConversationCommand.Default
except Exception as e:
return ConversationCommand.Text
except Exception:
logger.error(f"Invalid response for determining relevant mode: {response}")
return ConversationCommand.Default
return ConversationCommand.Text


async def infer_webpage_urls(q: str, conversation_history: dict, location_data: LocationData) -> List[str]:
Expand Down
3 changes: 2 additions & 1 deletion src/khoj/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ class ConversationCommand(str, Enum):
Online = "online"
Webpage = "webpage"
Image = "image"
Text = "text"
Automation = "automation"
AutomatedTask = "automated_task"

Expand All @@ -330,7 +331,7 @@ class ConversationCommand(str, Enum):
mode_descriptions_for_llm = {
ConversationCommand.Image: "Use this if the user is requesting an image or visual response to their query.",
ConversationCommand.Automation: "Use this if the user is requesting a response at a scheduled date or time.",
ConversationCommand.Default: "Use this if the other response modes don't seem to fit the query.",
ConversationCommand.Text: "Use this if the other response modes don't seem to fit the query.",
}


Expand Down
16 changes: 8 additions & 8 deletions tests/test_org_to_entries.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def test_entry_split_when_exceeds_max_tokens():
data = {
f"{tmp_path}": entry,
}
expected_heading = f"* Path: {tmp_path}\n** Heading"
expected_heading = f"* {tmp_path}\n** Heading"

# Act
# Extract Entries from specified Org files
entries = OrgToEntries.extract_org_entries(org_files=data)

# Split each entry from specified Org files by max tokens
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=6)
entries = TextToEntries.split_entries_by_max_tokens(entries, max_tokens=5)

# Assert
assert len(entries) == 2
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):
f"{tmp_path}": entry,
}
first_expected_entry = f"""
* Path: {tmp_path}
* {tmp_path}
** Heading 1.
body line 1

Expand All @@ -148,13 +148,13 @@ def test_parse_org_entry_with_children_as_single_entry_if_small(tmp_path):

""".lstrip()
second_expected_entry = f"""
* Path: {tmp_path}
* {tmp_path}
** Heading 2.
body line 2

""".lstrip()
third_expected_entry = f"""
* Path: {tmp_path} / Heading 2
* {tmp_path} / Heading 2
** Subheading 2.1.
longer body line 2.1

Expand Down Expand Up @@ -192,7 +192,7 @@ def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path)
f"{tmp_path}": entry,
}
first_expected_entry = f"""
* Path: {tmp_path}
* {tmp_path}
** Heading 1.
body line 1

Expand All @@ -201,7 +201,7 @@ def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path)

""".lstrip()
second_expected_entry = f"""
* Path: {tmp_path}
* {tmp_path}
** Heading 2.
body line 2

Expand All @@ -210,7 +210,7 @@ def test_separate_sibling_org_entries_if_all_cannot_fit_in_token_limit(tmp_path)

""".lstrip()
third_expected_entry = f"""
* Path: {tmp_path}
* {tmp_path}
** Heading 3.
body line 3

Expand Down
Loading