Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use permchain agent executor, Streaming, API docs #45

Merged
merged 34 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
1f538d8
WIP Use permchain agent executor
nfcampos Nov 14, 2023
c363583
Lint
nfcampos Nov 14, 2023
1089ae0
Lint
nfcampos Nov 14, 2023
b557a37
Lint
nfcampos Nov 14, 2023
d8ba886
Use permchain checkpoints
nfcampos Nov 14, 2023
73ce7b0
Fix streaming of messages
nfcampos Nov 15, 2023
1652c94
Lint
nfcampos Nov 15, 2023
a3c5a3b
Lint
nfcampos Nov 15, 2023
30eaddd
Implement stop message, fix tool execution
nfcampos Nov 15, 2023
f1ee430
Comment
nfcampos Nov 15, 2023
b227539
Delete previous impl
nfcampos Nov 15, 2023
5665978
WIP Stream message tokens
nfcampos Nov 16, 2023
226b560
Hack for function messages
nfcampos Nov 16, 2023
65c1b20
Small updates to schemas
nfcampos Nov 17, 2023
9dab4c8
Add partial validation
eyurtsev Nov 17, 2023
121ab8c
Merge branch 'nc/stream-tokens' into eugene/add_docs
eyurtsev Nov 17, 2023
accb44a
Use published permchain
nfcampos Nov 20, 2023
288d275
Merge pull request #56 from langchain-ai/eugene/add_docs
nfcampos Nov 20, 2023
1a6b46c
Merge pull request #52 from langchain-ai/nc/stream-tokens
nfcampos Nov 20, 2023
9391065
Reorganize api
nfcampos Nov 20, 2023
815b312
Add missing endpoints
nfcampos Nov 20, 2023
2d8c4ef
Name
nfcampos Nov 20, 2023
6269b0d
Fix message types
nfcampos Nov 20, 2023
a031c96
Fixes for feedback
nfcampos Nov 20, 2023
40b625e
Update ui files
nfcampos Nov 20, 2023
7757f92
Fix ui files
nfcampos Nov 20, 2023
d38fafe
Fix issues related to 0 messages in state
nfcampos Nov 20, 2023
d0c86b6
Make messages optional
nfcampos Nov 20, 2023
113e815
Build ui
nfcampos Nov 20, 2023
d8a37c3
Add endpoint to post thread messages
nfcampos Nov 20, 2023
bb35d97
Build ui
nfcampos Nov 20, 2023
0480539
Separate stream endpoint
nfcampos Nov 20, 2023
eae5304
Lint
nfcampos Nov 20, 2023
669d957
Pin pydantic
nfcampos Nov 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions backend/app/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from datetime import datetime

import orjson
from langchain.schema.messages import messages_from_dict
from langchain.utilities.redis import get_client
from agent_executor.checkpoint import RedisCheckpoint
from redis.client import Redis as RedisType


Expand Down Expand Up @@ -115,13 +115,13 @@ def list_threads(user_id: str):


def get_thread_messages(user_id: str, thread_id: str):
client = _get_redis_client()
messages = client.lrange(thread_messages_key(user_id, thread_id), 0, -1)
client = RedisCheckpoint()
checkpoint = client.get(
{"configurable": {"user_id": user_id, "thread_id": thread_id}}
)
_, messages = checkpoint.get("messages", [[], []])
return {
"messages": [
m.dict()
for m in messages_from_dict([orjson.loads(m) for m in messages[::-1]])
],
"messages": [m.dict() for m in messages],
}


Expand Down
65 changes: 65 additions & 0 deletions backend/packages/agent-executor/agent_executor/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import pickle
from functools import partial
from typing import Any, Mapping, Sequence

from langchain.pydantic_v1 import Field
from langchain.schema.runnable import RunnableConfig
from langchain.schema.runnable.utils import ConfigurableFieldSpec
from langchain.utilities.redis import get_client
from permchain.checkpoint.base import BaseCheckpointAdapter
from redis.client import Redis as RedisType


def checkpoint_key(user_id: str, thread_id: str):
return f"opengpts:{user_id}:thread:{thread_id}:checkpoint"


def _dump(mapping: dict[str, Any]) -> dict:
return {k: pickle.dumps(v) if v is not None else None for k, v in mapping.items()}


def _load(mapping: dict[bytes, bytes]) -> dict:
return {
k.decode(): pickle.loads(v) if v is not None else None
for k, v in mapping.items()
}


class RedisCheckpoint(BaseCheckpointAdapter):
client: RedisType = Field(
default_factory=partial(get_client, os.environ.get("REDIS_URL"))
)

class Config:
arbitrary_types_allowed = True

@property
def config_specs(self) -> Sequence[ConfigurableFieldSpec]:
return [
ConfigurableFieldSpec(
id="user_id",
annotation=str,
name="User ID",
description=None,
default=None,
),
ConfigurableFieldSpec(
id="thread_id",
annotation=str,
name="Thread ID",
description=None,
default="",
),
]

def _hash_key(self, config: RunnableConfig) -> str:
return checkpoint_key(
config["configurable"]["user_id"], config["configurable"]["thread_id"]
)

def get(self, config: RunnableConfig) -> Mapping[str, Any] | None:
return _load(self.client.hgetall(self._hash_key(config)))

def put(self, config: RunnableConfig, checkpoint: Mapping[str, Any]) -> None:
return self.client.hmset(self._hash_key(config), _dump(checkpoint))
105 changes: 105 additions & 0 deletions backend/packages/agent-executor/agent_executor/permchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import json

from permchain import Channel, Pregel
from permchain.channels import Topic
from permchain.checkpoint.base import BaseCheckpointAdapter
from langchain.schema.runnable import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnablePassthrough,
)
from langchain.schema.agent import AgentAction, AgentFinish, AgentActionMessageLog
from langchain.schema.messages import AIMessage, FunctionMessage, AnyMessage
from langchain.tools import BaseTool


def _create_agent_message(
output: AgentAction | AgentFinish
) -> list[AnyMessage] | AnyMessage:
if isinstance(output, AgentAction):
if isinstance(output, AgentActionMessageLog):
output.message_log[-1].additional_kwargs["agent"] = output
return output.message_log
else:
return AIMessage(
content=output.log,
additional_kwargs={"agent": output},
)
else:
return AIMessage(
content=output.return_values["output"],
additional_kwargs={"agent": output},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens to these kwargs when we send them to the llm? do they get stripped out?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our ChatModels dont use additional kwargs they dont know about

)


def _create_function_message(
agent_action: AgentAction, observation: str
) -> FunctionMessage:
if not isinstance(observation, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't see from current code but when is observation not str? maybe update type signature?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its often not (sadly)

type hints need updating

try:
content = json.dumps(observation, ensure_ascii=False)
except Exception:
content = str(observation)
else:
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)


def _run_tool(
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool]
) -> FunctionMessage:
action: AgentAction = messages[-1].additional_kwargs["agent"]
tool = tools[action.tool]
result = tool.invoke(action.tool_input, config)
return _create_function_message(action, result)


async def _arun_tool(
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool]
) -> FunctionMessage:
action: AgentAction = messages[-1].additional_kwargs["agent"]
tool = tools[action.tool]
result = await tool.ainvoke(action.tool_input, config)
return _create_function_message(action, result)


def get_agent_executor(
tools: list[BaseTool],
agent: Runnable[dict[str, list[AnyMessage]], AgentAction | AgentFinish],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm at some point we could add another executor to support also agents that return a list of AnyMessage (will help remove the boiler plate for working with something like openai directly)

checkpoint: BaseCheckpointAdapter,
) -> Pregel:
tool_map = {tool.name: tool for tool in tools}
tool_lambda = RunnableLambda(_run_tool, _arun_tool).bind(tools=tool_map)

tool_chain = tool_lambda | Channel.write_to("messages")
agent_chain = (
{"messages": RunnablePassthrough()}
| agent
| _create_agent_message
| Channel.write_to("messages")
)

def route_last_message(messages: list[AnyMessage]) -> Runnable:
message = messages[-1]
if isinstance(message, AIMessage):
if isinstance(message.additional_kwargs.get("agent"), AgentAction):
# TODO if this is last step, return stop message instead
return tool_chain
elif isinstance(message.additional_kwargs.get("agent"), AgentFinish):
return RunnablePassthrough()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirming -- this stops execution because there was no value update? How does it check for updates under the hood? Does it compare serialized representations? Is RunnablePassthrough() special?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It stops execution because it doesn't write to any channel (execution continues only when channels are written to)

else:
return agent_chain

executor = Channel.subscribe_to("messages") | route_last_message

return Pregel(
chains={"executor": executor},
channels={"messages": Topic(AnyMessage, accumulate=True)},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does accumulation work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input=["messages"],
output=["messages"],
checkpoint=checkpoint,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are values written to the checkpoint? And what is the content of checkpoint? Is a checkpoint taking place after each update iteration is completed? Is the content of checkpoint all the messages so far or just the new messages? (Seems like the former)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the implementation in permchain, we have a choice of values being written to checkpoint storage either at the end of each step, or once at the end of the run. By default is end of each step

)
14 changes: 1 addition & 13 deletions backend/packages/gizmo-agent/gizmo_agent/agent_types/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
Expand All @@ -27,7 +26,6 @@ def get_openai_function_agent(
[
("system", system_message),
MessagesPlaceholder(variable_name="messages"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
if tools:
Expand All @@ -36,15 +34,5 @@ def get_openai_function_agent(
)
else:
llm_with_tools = llm
agent = (
{
"messages": lambda x: x["messages"],
"agent_scratchpad": lambda x: format_to_openai_functions(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
agent = prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser()
return agent
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import boto3
from langchain.agents.format_scratchpad import format_xml
from langchain.chat_models import BedrockChat, ChatAnthropic
from langchain.schema.messages import AIMessage, HumanMessage
from langchain.tools.render import render_text_description
Expand Down Expand Up @@ -61,10 +60,7 @@ def get_xml_agent(tools, system_message, bedrock=False):
llm_with_stop = model.bind(stop=["</tool_input>"])

agent = (
{
"messages": lambda x: construct_chat_history(x["messages"]),
"agent_scratchpad": lambda x: format_xml(x["intermediate_steps"]),
}
{"messages": lambda x: construct_chat_history(x["messages"])}
| prompt
| llm_with_stop
| parse_output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
[
("system", template),
MessagesPlaceholder(variable_name="messages"),
("ai", "{agent_scratchpad}"),
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yay

)

Expand Down
71 changes: 29 additions & 42 deletions backend/packages/gizmo-agent/gizmo_agent/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
from functools import partial
from typing import Any, Mapping, Optional, Sequence
from agent_executor.checkpoint import RedisCheckpoint

from agent_executor import AgentExecutor
from agent_executor.history import RunnableWithMessageHistory
from langchain.memory import RedisChatMessageHistory
from agent_executor.permchain import get_agent_executor
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.messages import AnyMessage
from langchain.schema.runnable import (
Expand Down Expand Up @@ -64,12 +61,9 @@ def __init__(
_agent = get_xml_agent(_tools, system_message, bedrock=True)
else:
raise ValueError("Unexpected agent type")
agent_executor = AgentExecutor(
agent=_agent,
tools=_tools,
handle_parsing_errors=True,
max_iterations=10,
)
agent_executor = get_agent_executor(
tools=_tools, agent=_agent, checkpoint=RedisCheckpoint()
).with_config({"recursion_limit": 10})
super().__init__(
tools=tools,
agent=agent,
Expand All @@ -81,40 +75,33 @@ def __init__(


class AgentInput(BaseModel):
input: AnyMessage
messages: AnyMessage
nfcampos marked this conversation as resolved.
Show resolved Hide resolved


class AgentOutput(BaseModel):
messages: Sequence[AnyMessage] = Field(..., extra={"widget": {"type": "chat"}})
nfcampos marked this conversation as resolved.
Show resolved Hide resolved
output: str


agent = ConfigurableAgent(
agent=GizmoAgentType.GPT_35_TURBO,
tools=[],
system_message=DEFAULT_SYSTEM_MESSAGE,
assistant_id=None,
).configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(id="assistant_id", name="Assistant ID"),
tools=ConfigurableFieldMultiOption(
id="tools",
name="Tools",
options=TOOL_OPTIONS,
default=[],
),


agent = (
ConfigurableAgent(
agent=GizmoAgentType.GPT_35_TURBO,
tools=[],
system_message=DEFAULT_SYSTEM_MESSAGE,
assistant_id=None,
)
.configurable_fields(
agent=ConfigurableField(id="agent_type", name="Agent Type"),
system_message=ConfigurableField(id="system_message", name="System Message"),
assistant_id=ConfigurableField(id="assistant_id", name="Assistant ID"),
tools=ConfigurableFieldMultiOption(
id="tools",
name="Tools",
options=TOOL_OPTIONS,
default=[],
),
)
.with_types(input_type=AgentInput, output_type=AgentOutput)
)
agent = RunnableWithMessageHistory(
agent,
# first arg should be a function that
# - accepts a single arg "session_id"
# - returns a BaseChatMessageHistory instance
partial(RedisChatMessageHistory, url=os.environ["REDIS_URL"]),
input_key="input",
output_key="messages",
history_key="messages",
).with_types(input_type=AgentInput, output_type=AgentOutput)

if __name__ == "__main__":
import asyncio
Expand All @@ -123,8 +110,8 @@ class AgentOutput(BaseModel):

async def run():
async for m in agent.astream_log(
{"input": HumanMessage(content="whats your name")},
config={"configurable": {"thread_id": "test1"}},
{"messages": HumanMessage(content="whats your name")},
config={"configurable": {"user_id": "1", "thread_id": "test1"}},
):
print(m)

Expand Down
18 changes: 17 additions & 1 deletion backend/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading