Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: truncate messages based on token count #135

Merged
merged 1 commit into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/tablegpt/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import date # noqa: TCH003
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

from langchain_core.messages import BaseMessage # noqa: TCH002
from langgraph.graph import END, START, MessagesState, StateGraph
Expand Down Expand Up @@ -46,6 +46,7 @@ def create_tablegpt_graph(
normalize_llm: BaseLanguageModel | None = None,
locale: str | None = None,
checkpointer: BaseCheckpointSaver | None = None,
trim_message_method: Literal["default", "token_count"] = "default",
verbose: bool = False,
) -> CompiledStateGraph:
"""Creates a state graph for processing datasets.
Expand All @@ -67,6 +68,7 @@ def create_tablegpt_graph(
normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None.
locate (str | None, optional): The locale of the user. Defaults to None.
checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None.
trim_message_method (Literal["default", "token_count"], optional): Determines the method used to trim the message. Defaults to "default".
verbose (bool, optional): Flag to enable verbose logging. Defaults to False.

Returns:
Expand All @@ -92,6 +94,7 @@ def create_tablegpt_graph(
vlm=vlm,
safety_llm=safety_llm,
dataset_retriever=dataset_retriever,
trim_message_method=trim_message_method,
verbose=verbose,
)

Expand Down
63 changes: 49 additions & 14 deletions src/tablegpt/agent/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from copy import deepcopy
from datetime import date # noqa: TCH003
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable, Literal
from uuid import uuid4

from langchain_core.messages import (
Expand All @@ -18,13 +18,18 @@
from tablegpt.agent.output_parser import MarkdownOutputParser
from tablegpt.retriever import format_columns
from tablegpt.safety import create_hazard_classifier
from tablegpt.tools import IPythonTool, markdown_console_template, process_content
from tablegpt.tools import (
IPythonTool,
markdown_console_template,
process_content,
)
from tablegpt.utils import filter_contents

if TYPE_CHECKING:
from pathlib import Path

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from pybox.base import BasePyBoxManager
Expand Down Expand Up @@ -52,7 +57,7 @@
)


def get_data_analyzer_agent(llm: Runnable) -> Runnable:
def get_data_analyzer_agent(llm: BaseLanguageModel) -> Runnable:
return PROMPT | llm | MarkdownOutputParser(language_actions={"python": "python", "py": "python"})


Expand All @@ -65,17 +70,36 @@ class AgentState(MessagesState):
parent_id: str | None


def get_messages_truncation_config(
blm: BaseLanguageModel | None,
trim_message_method: Literal["default", "token_count"] = "default",
) -> tuple[
Callable[[list[BaseMessage]], int] | Callable[[BaseMessage], int] | BaseLanguageModel,
int,
]:
match trim_message_method:
case "default":
return len, 20

case "token_count":
return blm, int((blm.metadata or {}).get("max_history_tokens", 4096))
case _:
e_msg = f"The expected value should be one of ['default', 'token_count'], but you provided {trim_message_method}."
raise ValueError(e_msg)


def create_data_analyze_workflow(
llm: Runnable,
llm: BaseLanguageModel,
pybox_manager: BasePyBoxManager,
*,
workdir: Path | None = None,
session_id: str | None = None,
error_trace_cleanup: bool = False,
vlm: Runnable | None = None,
vlm: BaseLanguageModel | None = None,
safety_llm: Runnable | None = None,
dataset_retriever: BaseRetriever | None = None,
verbose: bool = False,
trim_message_method: Literal["default", "token_count"] = "default",
) -> Runnable:
"""Creates a data analysis workflow for processing user input and datasets.

Expand All @@ -85,15 +109,21 @@ def create_data_analyze_workflow(
The workflow is designed to handle multiple types of messages and responses.

Args:
llm (Runnable): The primary language model for processing user input.
llm (BaseLanguageModel): The primary language model for processing user input.
pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm.
workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None.
session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None.
error_trace_cleanup (bool, optional): Flag to indicate if error traces should be cleaned up. Defaults to False.
vlm (Runnable | None, optional): Optional vision language model for processing images. Defaults to None.
vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None.
safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None.
dataset_retriever (BaseRetriever | None, optional): Component to retrieve dataset columns based on user input. Defaults to None.
verbose (bool, optional): Flag to enable detailed logging. Defaults to False.
trim_message_method (Literal["default", "token_count"], optional): Determines the method used to trim the message. Defaults to "default".
- "default": Applies the default trimming method (Truncate using the length of messages, default max length is 20).
- "token_count": Use token count to truncate messages.
Ensure the `BaseLanguageModel` has the `get_num_tokens_from_messages` method.
And set `max_history_tokens` in `BaseLanguageModel.metadata`, e.g., {"max_history_tokens": 4096} (default 4096).
You can specify the value using: `max_model_len (max tokens the service supports) - max_new_tokens (tokens needed for the request)`.

Returns:
Runnable: A runnable object representing the data analysis workflow.
Expand All @@ -108,7 +138,9 @@ def create_data_analyze_workflow(
if safety_llm is not None:
hazard_classifier = create_hazard_classifier(safety_llm)

async def run_input_guard(state: AgentState) -> dict[str, list[BaseMessage]]:
async def run_input_guard(
state: AgentState,
) -> dict[str, list[BaseMessage]]:
if hazard_classifier is not None:
last_message = state["messages"][-1]
flag, category = await hazard_classifier.ainvoke(input={"input": last_message.content})
Expand Down Expand Up @@ -157,11 +189,13 @@ async def retrieve_columns(state: AgentState) -> dict:

async def arun_tablegpt_agent(state: AgentState) -> dict:
# Truncate messages based on message count.
# TODO: truncate based on token count.

token_counter, max_tokens = get_messages_truncation_config(llm, trim_message_method)

windowed_messages = trim_messages(
state["messages"],
token_counter=len,
max_tokens=20,
token_counter=token_counter,
max_tokens=max_tokens,
start_on="human", # This means that the first message should be from the user after trimming.
# The system message is not in `messages`, so we don't need to specify `include_system`
)
Expand Down Expand Up @@ -190,11 +224,12 @@ async def arun_tablegpt_agent(state: AgentState) -> dict:

async def arun_vlm_agent(state: AgentState) -> dict:
# Truncate messages based on message count.
# TODO: truncate based on token count.
token_counter, max_tokens = get_messages_truncation_config(vlm, trim_message_method)

windowed_messages: list[BaseMessage] = trim_messages(
state["messages"],
token_counter=len,
max_tokens=20,
token_counter=token_counter,
max_tokens=max_tokens,
start_on="human", # This means that the first message should be from the user after trimming.
# The system message is not in `messages`, so we don't need to specify `include_system`
)
Expand Down
Loading