-
Notifications
You must be signed in to change notification settings - Fork 879
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use permchain agent executor, Streaming, API docs #45
Changes from 5 commits
1f538d8
c363583
1089ae0
b557a37
d8ba886
73ce7b0
1652c94
a3c5a3b
30eaddd
f1ee430
b227539
5665978
226b560
65c1b20
9dab4c8
121ab8c
accb44a
288d275
1a6b46c
9391065
815b312
2d8c4ef
6269b0d
a031c96
40b625e
7757f92
d38fafe
d0c86b6
113e815
d8a37c3
bb35d97
0480539
eae5304
669d957
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import pickle | ||
from functools import partial | ||
from typing import Any, Mapping, Sequence | ||
|
||
from langchain.pydantic_v1 import Field | ||
from langchain.schema.runnable import RunnableConfig | ||
from langchain.schema.runnable.utils import ConfigurableFieldSpec | ||
from langchain.utilities.redis import get_client | ||
from permchain.checkpoint.base import BaseCheckpointAdapter | ||
from redis.client import Redis as RedisType | ||
|
||
|
||
def checkpoint_key(user_id: str, thread_id: str): | ||
return f"opengpts:{user_id}:thread:{thread_id}:checkpoint" | ||
|
||
|
||
def _dump(mapping: dict[str, Any]) -> dict: | ||
return {k: pickle.dumps(v) if v is not None else None for k, v in mapping.items()} | ||
|
||
|
||
def _load(mapping: dict[bytes, bytes]) -> dict: | ||
return { | ||
k.decode(): pickle.loads(v) if v is not None else None | ||
for k, v in mapping.items() | ||
} | ||
|
||
|
||
class RedisCheckpoint(BaseCheckpointAdapter): | ||
client: RedisType = Field( | ||
default_factory=partial(get_client, os.environ.get("REDIS_URL")) | ||
) | ||
|
||
class Config: | ||
arbitrary_types_allowed = True | ||
|
||
@property | ||
def config_specs(self) -> Sequence[ConfigurableFieldSpec]: | ||
return [ | ||
ConfigurableFieldSpec( | ||
id="user_id", | ||
annotation=str, | ||
name="User ID", | ||
description=None, | ||
default=None, | ||
), | ||
ConfigurableFieldSpec( | ||
id="thread_id", | ||
annotation=str, | ||
name="Thread ID", | ||
description=None, | ||
default="", | ||
), | ||
] | ||
|
||
def _hash_key(self, config: RunnableConfig) -> str: | ||
return checkpoint_key( | ||
config["configurable"]["user_id"], config["configurable"]["thread_id"] | ||
) | ||
|
||
def get(self, config: RunnableConfig) -> Mapping[str, Any] | None: | ||
return _load(self.client.hgetall(self._hash_key(config))) | ||
|
||
def put(self, config: RunnableConfig, checkpoint: Mapping[str, Any]) -> None: | ||
return self.client.hmset(self._hash_key(config), _dump(checkpoint)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import json | ||
|
||
from permchain import Channel, Pregel | ||
from permchain.channels import Topic | ||
from permchain.checkpoint.base import BaseCheckpointAdapter | ||
from langchain.schema.runnable import ( | ||
Runnable, | ||
RunnableConfig, | ||
RunnableLambda, | ||
RunnablePassthrough, | ||
) | ||
from langchain.schema.agent import AgentAction, AgentFinish, AgentActionMessageLog | ||
from langchain.schema.messages import AIMessage, FunctionMessage, AnyMessage | ||
from langchain.tools import BaseTool | ||
|
||
|
||
def _create_agent_message( | ||
output: AgentAction | AgentFinish | ||
) -> list[AnyMessage] | AnyMessage: | ||
if isinstance(output, AgentAction): | ||
if isinstance(output, AgentActionMessageLog): | ||
output.message_log[-1].additional_kwargs["agent"] = output | ||
return output.message_log | ||
else: | ||
return AIMessage( | ||
content=output.log, | ||
additional_kwargs={"agent": output}, | ||
) | ||
else: | ||
return AIMessage( | ||
content=output.return_values["output"], | ||
additional_kwargs={"agent": output}, | ||
) | ||
|
||
|
||
def _create_function_message( | ||
agent_action: AgentAction, observation: str | ||
) -> FunctionMessage: | ||
if not isinstance(observation, str): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can't see from current code but when is observation not str? maybe update type signature? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its often not (sadly) type hints need updating |
||
try: | ||
content = json.dumps(observation, ensure_ascii=False) | ||
except Exception: | ||
content = str(observation) | ||
else: | ||
content = observation | ||
return FunctionMessage( | ||
name=agent_action.tool, | ||
content=content, | ||
) | ||
|
||
|
||
def _run_tool( | ||
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool] | ||
) -> FunctionMessage: | ||
action: AgentAction = messages[-1].additional_kwargs["agent"] | ||
tool = tools[action.tool] | ||
result = tool.invoke(action.tool_input, config) | ||
return _create_function_message(action, result) | ||
|
||
|
||
async def _arun_tool( | ||
messages: list[AnyMessage], config: RunnableConfig, *, tools: dict[str, BaseTool] | ||
) -> FunctionMessage: | ||
action: AgentAction = messages[-1].additional_kwargs["agent"] | ||
tool = tools[action.tool] | ||
result = await tool.ainvoke(action.tool_input, config) | ||
return _create_function_message(action, result) | ||
|
||
|
||
def get_agent_executor( | ||
tools: list[BaseTool], | ||
agent: Runnable[dict[str, list[AnyMessage]], AgentAction | AgentFinish], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm at some point we could add another executor to support also agents that return a list of AnyMessage (will help remove the boiler plate for working with something like openai directly) |
||
checkpoint: BaseCheckpointAdapter, | ||
) -> Pregel: | ||
tool_map = {tool.name: tool for tool in tools} | ||
tool_lambda = RunnableLambda(_run_tool, _arun_tool).bind(tools=tool_map) | ||
|
||
tool_chain = tool_lambda | Channel.write_to("messages") | ||
agent_chain = ( | ||
{"messages": RunnablePassthrough()} | ||
| agent | ||
| _create_agent_message | ||
| Channel.write_to("messages") | ||
) | ||
|
||
def route_last_message(messages: list[AnyMessage]) -> Runnable: | ||
message = messages[-1] | ||
if isinstance(message, AIMessage): | ||
if isinstance(message.additional_kwargs.get("agent"), AgentAction): | ||
# TODO if this is last step, return stop message instead | ||
return tool_chain | ||
elif isinstance(message.additional_kwargs.get("agent"), AgentFinish): | ||
return RunnablePassthrough() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confirming -- this stops execution because there was no value update? How does it check for updates under the hood? Does it compare serialized representations? Is RunnablePassthrough() special? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It stops execution because it doesn't write to any channel (execution continues only when channels are written to) |
||
else: | ||
return agent_chain | ||
|
||
executor = Channel.subscribe_to("messages") | route_last_message | ||
|
||
return Pregel( | ||
chains={"executor": executor}, | ||
channels={"messages": Topic(AnyMessage, accumulate=True)}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does accumulation work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
input=["messages"], | ||
output=["messages"], | ||
checkpoint=checkpoint, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When are values written to the checkpoint? And what is the content of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check the implementation in permchain, we have a choice of values being written to checkpoint storage either at the end of each step, or once at the end of the run. By default is end of each step |
||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,6 @@ | |
[ | ||
("system", template), | ||
MessagesPlaceholder(variable_name="messages"), | ||
("ai", "{agent_scratchpad}"), | ||
] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yay |
||
) | ||
|
||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens to these kwargs when we send them to the llm? do they get stripped out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our ChatModels dont use additional kwargs they dont know about