Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(product-assistant): streaming on ASGI #26359

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
324 changes: 104 additions & 220 deletions ee/hogai/assistant.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
from collections.abc import Generator, Hashable, Iterator
from typing import Any, Literal, Optional, TypedDict, TypeGuard, Union, cast
import asyncio
from collections.abc import AsyncGenerator, Generator
from typing import Any, Literal, Optional, TypedDict, TypeGuard, Union

from asgiref.sync import sync_to_async
from langchain_core.messages import AIMessageChunk
from langfuse.callback import CallbackHandler
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph.state import CompiledStateGraph
from pydantic import BaseModel
from sentry_sdk import capture_exception

from ee import settings
from ee.hogai.funnels.nodes import (
FunnelGeneratorNode,
FunnelGeneratorToolsNode,
FunnelPlannerNode,
FunnelPlannerToolsNode,
)
from ee.hogai.router.nodes import RouterNode
from ee.hogai.graph import AssistantGraph
from ee.hogai.schema_generator.nodes import SchemaGeneratorNode
from ee.hogai.summarizer.nodes import SummarizerNode
from ee.hogai.trends.nodes import (
TrendsGeneratorNode,
TrendsGeneratorToolsNode,
TrendsPlannerNode,
TrendsPlannerToolsNode,
)
from ee.hogai.utils import AssistantNodeName, AssistantState, Conversation
from posthog.models.team.team import Team
from posthog.event_usage import report_user_action
from posthog.models import Team, User
from posthog.schema import (
AssistantEventType,
AssistantGenerationStatusEvent,
AssistantGenerationStatusType,
AssistantMessage,
FailureMessage,
HumanMessage,
VisualizationMessage,
)
from posthog.settings import SERVER_GATEWAY_INTERFACE

if settings.LANGFUSE_PUBLIC_KEY:
langfuse_handler = CallbackHandler(
Expand Down Expand Up @@ -74,228 +73,113 @@
}


class AssistantGraph:
_team: Team
_graph: StateGraph

def __init__(self, team: Team):
self._team = team
self._graph = StateGraph(AssistantState)
self._has_start_node = False

def add_edge(self, from_node: AssistantNodeName, to_node: AssistantNodeName):
if from_node == AssistantNodeName.START:
self._has_start_node = True
self._graph.add_edge(from_node, to_node)
return self

def compile(self):
if not self._has_start_node:
raise ValueError("Start node not added to the graph")
return self._graph.compile()

def add_start(self):
return self.add_edge(AssistantNodeName.START, AssistantNodeName.ROUTER)

def add_router(
self,
path_map: Optional[dict[Hashable, AssistantNodeName]] = None,
):
builder = self._graph
path_map = path_map or {
"trends": AssistantNodeName.TRENDS_PLANNER,
"funnel": AssistantNodeName.FUNNEL_PLANNER,
}
router_node = RouterNode(self._team)
builder.add_node(AssistantNodeName.ROUTER, router_node.run)
builder.add_conditional_edges(
AssistantNodeName.ROUTER,
router_node.router,
path_map=cast(dict[Hashable, str], path_map),
)
return self

def add_trends_planner(self, next_node: AssistantNodeName = AssistantNodeName.TRENDS_GENERATOR):
builder = self._graph

create_trends_plan_node = TrendsPlannerNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_PLANNER, create_trends_plan_node.run)
builder.add_conditional_edges(
AssistantNodeName.TRENDS_PLANNER,
create_trends_plan_node.router,
path_map={
"tools": AssistantNodeName.TRENDS_PLANNER_TOOLS,
},
)

create_trends_plan_tools_node = TrendsPlannerToolsNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_PLANNER_TOOLS, create_trends_plan_tools_node.run)
builder.add_conditional_edges(
AssistantNodeName.TRENDS_PLANNER_TOOLS,
create_trends_plan_tools_node.router,
path_map={
"continue": AssistantNodeName.TRENDS_PLANNER,
"plan_found": next_node,
},
)

return self

def add_trends_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
builder = self._graph

trends_generator = TrendsGeneratorNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR, trends_generator.run)

trends_generator_tools = TrendsGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.TRENDS_GENERATOR_TOOLS, trends_generator_tools.run)

builder.add_edge(AssistantNodeName.TRENDS_GENERATOR_TOOLS, AssistantNodeName.TRENDS_GENERATOR)
builder.add_conditional_edges(
AssistantNodeName.TRENDS_GENERATOR,
trends_generator.router,
path_map={
"tools": AssistantNodeName.TRENDS_GENERATOR_TOOLS,
"next": next_node,
},
)

return self

def add_funnel_planner(self, next_node: AssistantNodeName = AssistantNodeName.FUNNEL_GENERATOR):
builder = self._graph

funnel_planner = FunnelPlannerNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_PLANNER, funnel_planner.run)
builder.add_conditional_edges(
AssistantNodeName.FUNNEL_PLANNER,
funnel_planner.router,
path_map={
"tools": AssistantNodeName.FUNNEL_PLANNER_TOOLS,
},
)

funnel_planner_tools = FunnelPlannerToolsNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_PLANNER_TOOLS, funnel_planner_tools.run)
builder.add_conditional_edges(
AssistantNodeName.FUNNEL_PLANNER_TOOLS,
funnel_planner_tools.router,
path_map={
"continue": AssistantNodeName.FUNNEL_PLANNER,
"plan_found": next_node,
},
)

return self

def add_funnel_generator(self, next_node: AssistantNodeName = AssistantNodeName.SUMMARIZER):
builder = self._graph

funnel_generator = FunnelGeneratorNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR, funnel_generator.run)

funnel_generator_tools = FunnelGeneratorToolsNode(self._team)
builder.add_node(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, funnel_generator_tools.run)

builder.add_edge(AssistantNodeName.FUNNEL_GENERATOR_TOOLS, AssistantNodeName.FUNNEL_GENERATOR)
builder.add_conditional_edges(
AssistantNodeName.FUNNEL_GENERATOR,
funnel_generator.router,
path_map={
"tools": AssistantNodeName.FUNNEL_GENERATOR_TOOLS,
"next": next_node,
},
)

return self

def add_summarizer(self, next_node: AssistantNodeName = AssistantNodeName.END):
builder = self._graph
summarizer_node = SummarizerNode(self._team)
builder.add_node(AssistantNodeName.SUMMARIZER, summarizer_node.run)
builder.add_edge(AssistantNodeName.SUMMARIZER, next_node)
return self

def compile_full_graph(self):
return (
self.add_start()
.add_router()
.add_trends_planner()
.add_trends_generator()
.add_funnel_planner()
.add_funnel_generator()
.add_summarizer()
.compile()
)


class Assistant:
_team: Team
_graph: CompiledStateGraph
_chunks: AIMessageChunk

def __init__(self, team: Team):
def __init__(self, team: Team, conversation: Conversation, user: Optional[User] = None):
self._team = team
self._conversation = conversation
self._user = user
self._graph = AssistantGraph(team).compile_full_graph()
self._chunks = AIMessageChunk(content="")

def stream(self, conversation: Conversation) -> Generator[BaseModel, None, None]:
callbacks = [langfuse_handler] if langfuse_handler else []
messages = [message.root for message in conversation.messages]

chunks = AIMessageChunk(content="")
state: AssistantState = {"messages": messages, "intermediate_steps": None, "plan": None}
def stream(self) -> Generator[str, None, None] | AsyncGenerator[str, None]:
if SERVER_GATEWAY_INTERFACE == "ASGI":
return self._astream()

Check failure on line 90 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Incompatible return value type (got "AsyncGenerator[BaseModel, None]", expected "Generator[str, None, None] | AsyncGenerator[str, None]")
return self._stream()

generator: Iterator[Any] = self._graph.stream(
state,
config={"recursion_limit": 24, "callbacks": callbacks},
async def _astream(self) -> AsyncGenerator[BaseModel, None]:
generator = self._graph.astream(
self._initial_state,
config=self._config,

Check failure on line 96 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Argument "config" to "astream" of "Pregel" has incompatible type "dict[str, Any]"; expected "RunnableConfig | None"
stream_mode=["messages", "values", "updates"],
)

chunks = AIMessageChunk(content="")

# Send a chunk to establish the connection avoiding the worker's timeout.
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK)
yield self._serialize_message(AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.ACK))

try:
for update in generator:
if is_state_update(update):
_, new_state = update
state = new_state

elif is_value_update(update):
_, state_update = update

if (
AssistantNodeName.ROUTER in state_update
and "messages" in state_update[AssistantNodeName.ROUTER]
):
yield state_update[AssistantNodeName.ROUTER]["messages"][0]
elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys():
# Reset chunks when schema validation fails.
chunks = AIMessageChunk(content="")

node_name = intersected_nodes.pop()
if "messages" in state_update[node_name]:
yield state_update[node_name]["messages"][0]
elif state_update[node_name].get("intermediate_steps", []):
yield AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
elif AssistantNodeName.SUMMARIZER in state_update:
chunks = AIMessageChunk(content="")
yield state_update[AssistantNodeName.SUMMARIZER]["messages"][0]
elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if isinstance(langchain_message, AIMessageChunk):
if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys():
chunks += langchain_message # type: ignore
parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output(
chunks.tool_calls[0]["args"]
)
if parsed_message:
yield VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
)
elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER:
chunks += langchain_message # type: ignore
yield AssistantMessage(content=chunks.content)
last_message = None
async for update in generator:
message = self._process_update(update)
if message is not None:
last_message = message
for serialized_message in self._serialize_message(message):
yield serialized_message
await self._report_user_action(last_message)

Check failure on line 111 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Argument 1 to "_report_user_action" of "Assistant" has incompatible type "BaseModel | None"; expected "BaseModel"
except Exception as e:
capture_exception(e)
yield FailureMessage() # This is an unhandled error, so we just stop further generation at this point

def _stream(self):
iterator = self._astream()
with asyncio.Runner() as runner:
try:
while True:
result = runner.run(anext(iterator))

Check failure on line 121 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Need type annotation for "result"

Check failure on line 121 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Argument 1 to "run" of "Runner" has incompatible type "Awaitable[BaseModel]"; expected "Coroutine[Any, Any, Never]"
yield result
except StopAsyncIteration:
pass

@property
def _initial_state(self) -> AssistantState:
messages = [message.root for message in self._conversation.messages]
return {"messages": messages, "intermediate_steps": None, "plan": None}

@property
def _config(self) -> dict[str, Any]:
callbacks = [langfuse_handler] if langfuse_handler else []
return {"recursion_limit": 24, "callbacks": callbacks}

def _process_update(self, update: Any) -> BaseModel | None:

Check failure on line 136 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Missing return statement
if is_value_update(update):
_, state_update = update

if AssistantNodeName.ROUTER in state_update and "messages" in state_update[AssistantNodeName.ROUTER]:
return state_update[AssistantNodeName.ROUTER]["messages"][0]
elif intersected_nodes := state_update.keys() & VISUALIZATION_NODES.keys():
# Reset chunks when schema validation fails.
self._chunks = AIMessageChunk(content="")

node_name = intersected_nodes.pop()
if "messages" in state_update[node_name]:
return state_update[node_name]["messages"][0]
elif state_update[node_name].get("intermediate_steps", []):
return AssistantGenerationStatusEvent(type=AssistantGenerationStatusType.GENERATION_ERROR)
elif AssistantNodeName.SUMMARIZER in state_update:
self._chunks = AIMessageChunk(content="")
return state_update[AssistantNodeName.SUMMARIZER]["messages"][0]
elif is_message_update(update):
langchain_message, langgraph_state = update[1]
if isinstance(langchain_message, AIMessageChunk):
if langgraph_state["langgraph_node"] in VISUALIZATION_NODES.keys():
self._chunks += langchain_message # type: ignore
parsed_message = VISUALIZATION_NODES[langgraph_state["langgraph_node"]].parse_output(
self._chunks.tool_calls[0]["args"]
)
if parsed_message:
return VisualizationMessage(
reasoning_steps=parsed_message.reasoning_steps, answer=parsed_message.answer
)
elif langgraph_state["langgraph_node"] == AssistantNodeName.SUMMARIZER:
self._chunks += langchain_message # type: ignore
return AssistantMessage(content=self._chunks.content)

def _serialize_message(self, message: BaseModel):
if isinstance(message, AssistantGenerationStatusEvent):
yield f"event: {AssistantEventType.STATUS}\n"
else:
yield f"event: {AssistantEventType.MESSAGE}\n"
yield f"data: {message.model_dump_json(exclude_none=True)}\n\n"

@sync_to_async
def _report_user_action(self, last_message: BaseModel):
human_message = self._conversation.messages[-1].root
if isinstance(human_message, HumanMessage) and self._user:
report_user_action(
self._user, # type: ignore

Check failure on line 182 in ee/hogai/assistant.py

View workflow job for this annotation

GitHub Actions / Python code quality checks

Unused "type: ignore" comment
"chat with ai",
{"prompt": human_message.content, "response": last_message},
)
Loading
Loading