Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow research mode and other conversation commands in automations #1011

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/interface/web/app/automations/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ const suggestedAutomationsMetadata: AutomationsData[] = [
{
subject: "Weekly Newsletter",
query_to_run:
"Compile a message including: 1. A recap of news from last week 2. An at-home workout I can do before work 3. A quote to inspire me for the week ahead",
"/research Compile a message including: 1. A recap of news from last week 2. An at-home workout I can do before work 3. A quote to inspire me for the week ahead",
schedule: "9AM every Monday",
next: "Next run at 9AM on Monday",
crontime: "0 9 * * 1",
Expand All @@ -185,7 +185,7 @@ const suggestedAutomationsMetadata: AutomationsData[] = [
{
subject: "Front Page of Hacker News",
query_to_run:
"Summarize the top 5 posts from https://news.ycombinator.com/best and share them with me, including links",
"/research Summarize the top 5 posts from https://news.ycombinator.com/best and share them with me, including links",
schedule: "9PM on every Wednesday",
next: "Next run at 9PM on Wednesday",
crontime: "0 21 * * 3",
Expand All @@ -195,7 +195,7 @@ const suggestedAutomationsMetadata: AutomationsData[] = [
{
subject: "Market Summary",
query_to_run:
"Get the market summary for today and share it with me. Focus on tech stocks and the S&P 500.",
"/research Get the market summary for today and share it with me. Focus on tech stocks and the S&P 500.",
schedule: "9AM on every weekday",
next: "Next run at 9AM on Monday",
crontime: "0 9 * * *",
Expand All @@ -214,7 +214,7 @@ const suggestedAutomationsMetadata: AutomationsData[] = [
{
subject: "Round-up of research papers about AI in healthcare",
query_to_run:
"Summarize the top 3 research papers about AI in healthcare that were published in the last week. Include links to the full papers.",
"/research Summarize the top 3 research papers about AI in healthcare that were published in the last week. Include links to the full papers.",
schedule: "9AM every Friday",
next: "Next run at 9AM on Friday",
crontime: "0 9 * * 5",
Expand Down
19 changes: 13 additions & 6 deletions src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,7 +724,16 @@ def collect_telemetry():
yield result
return

conversation_commands = [get_conversation_command(query=q, any_references=True)]
# Automated tasks are handled before to allow mixing them with other conversation commands
cmds_to_rate_limit = []
is_automated_task = False
if q.startswith("/automated_task"):
is_automated_task = True
q = q.replace("/automated_task", "").lstrip()
cmds_to_rate_limit += [ConversationCommand.AutomatedTask]

# Extract conversation command from query
conversation_commands = [get_conversation_command(query=q)]

conversation = await ConversationAdapters.aget_conversation_by_user(
user,
Expand Down Expand Up @@ -757,11 +766,8 @@ def collect_telemetry():
location = None
if city or region or country or country_code:
location = LocationData(city=city, region=region, country=country, country_code=country_code)

user_message_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

meta_log = conversation.conversation_log
is_automated_task = conversation_commands == [ConversationCommand.AutomatedTask]

researched_results = ""
online_results: Dict = dict()
Expand All @@ -778,7 +784,7 @@ def collect_telemetry():
generated_excalidraw_diagram: str = None
program_execution_context: List[str] = []

if conversation_commands == [ConversationCommand.Default] or is_automated_task:
if conversation_commands == [ConversationCommand.Default]:
chosen_io = await aget_data_sources_and_output_format(
q,
meta_log,
Expand All @@ -799,7 +805,8 @@ def collect_telemetry():
async for result in send_event(ChatEvent.STATUS, f"**Selected Tools:** {conversation_commands_str}"):
yield result

for cmd in conversation_commands:
cmds_to_rate_limit += conversation_commands
for cmd in cmds_to_rate_limit:
try:
await conversation_command_rate_limiter.update_and_check_if_valid(request, cmd)
q = q.replace(f"/{cmd.value}", "").strip()
Expand Down
58 changes: 26 additions & 32 deletions src/khoj/routers/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_next_url(request: Request) -> str:
return urljoin(str(request.base_url).rstrip("/"), next_path)


def get_conversation_command(query: str, any_references: bool = False) -> ConversationCommand:
def get_conversation_command(query: str) -> ConversationCommand:
if query.startswith("/notes"):
return ConversationCommand.Notes
elif query.startswith("/help"):
Expand All @@ -254,9 +254,6 @@ def get_conversation_command(query: str, any_references: bool = False) -> Conver
return ConversationCommand.Code
elif query.startswith("/research"):
return ConversationCommand.Research
# If no relevant notes found for the given query
elif not any_references:
return ConversationCommand.General
else:
return ConversationCommand.Default

Expand Down Expand Up @@ -408,42 +405,39 @@ async def aget_data_sources_and_output_format(
response = clean_json(response)
response = json.loads(response)

selected_sources = [q.strip() for q in response.get("source", []) if q.strip()]
selected_output = response.get("output", "text").strip() # Default to text output
chosen_sources = [s.strip() for s in response.get("source", []) if s.strip()]
chosen_output = response.get("output", "text").strip() # Default to text output

if not isinstance(selected_sources, list) or not selected_sources or len(selected_sources) == 0:
if is_none_or_empty(chosen_sources) or not isinstance(chosen_sources, list):
raise ValueError(
f"Invalid response for determining relevant tools: {selected_sources}. Raw Response: {response}"
f"Invalid response for determining relevant tools: {chosen_sources}. Raw Response: {response}"
)

result: Dict = {"sources": [], "output": None if not is_task else ConversationCommand.AutomatedTask}
for selected_source in selected_sources:
# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if (
selected_source in source_options.keys()
and isinstance(result["sources"], list)
and (len(agent_sources) == 0 or selected_source in agent_sources)
):
# Check whether the tool exists as a valid ConversationCommand
result["sources"].append(ConversationCommand(selected_source))

# Add a double check to verify it's in the agent list, because the LLM sometimes gets confused by the tool options.
if selected_output in output_options.keys() and (len(agent_outputs) == 0 or selected_output in agent_outputs):
# Check whether the tool exists as a valid ConversationCommand
result["output"] = ConversationCommand(selected_output)

if is_none_or_empty(result):
output_mode = ConversationCommand.Text
# Verify selected output mode is enabled for the agent, as the LLM can sometimes get confused by the tool options.
if chosen_output in output_options.keys() and (len(agent_outputs) == 0 or chosen_output in agent_outputs):
# Ensure that the chosen output mode exists as a valid ConversationCommand
output_mode = ConversationCommand(chosen_output)

data_sources = []
# Verify selected data sources are enabled for the agent, as the LLM can sometimes get confused by the tool options.
for chosen_source in chosen_sources:
# Ensure that the chosen data source exists as a valid ConversationCommand
if chosen_source in source_options.keys() and (len(agent_sources) == 0 or chosen_source in agent_sources):
data_sources.append(ConversationCommand(chosen_source))

# Fallback to default sources if the inferred data sources are unset or invalid
if is_none_or_empty(data_sources):
if len(agent_sources) == 0:
result = {"sources": [ConversationCommand.Default], "output": ConversationCommand.Text}
data_sources = [ConversationCommand.Default]
else:
result = {"sources": [ConversationCommand.General], "output": ConversationCommand.Text}
data_sources = [ConversationCommand.General]
except Exception as e:
logger.error(f"Invalid response for determining relevant tools: {response}. Error: {e}", exc_info=True)
sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default]
output = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text
result = {"sources": sources, "output": output}
data_sources = agent_sources if len(agent_sources) > 0 else [ConversationCommand.Default]
output_mode = agent_outputs[0] if len(agent_outputs) > 0 else ConversationCommand.Text

return result
return {"sources": data_sources, "output": output_mode}


async def infer_webpage_urls(
Expand Down Expand Up @@ -1686,7 +1680,7 @@ def scheduled_chat(
last_run_time = datetime.strptime(last_run_time, "%Y-%m-%d %I:%M %p %Z").replace(tzinfo=timezone.utc)

# If the last run time was within the last 6 hours, don't run it again. This helps avoid multithreading issues and rate limits.
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 21600:
if (datetime.now(timezone.utc) - last_run_time).total_seconds() < 6 * 60 * 60:
logger.info(f"Skipping scheduled chat {job_id} as the next run time is in the future.")
return

Expand Down
Loading