From ea1f208d5fdf2703d59590c2c51be0c53ad4a4a1 Mon Sep 17 00:00:00 2001 From: zt Date: Tue, 3 Dec 2024 16:51:16 +0800 Subject: [PATCH] feat: truncate messages based on token count --- src/tablegpt/agent/__init__.py | 5 ++- src/tablegpt/agent/data_analyzer.py | 63 ++++++++++++++++++++++------- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/src/tablegpt/agent/__init__.py b/src/tablegpt/agent/__init__.py index 2476d79..3ce693a 100644 --- a/src/tablegpt/agent/__init__.py +++ b/src/tablegpt/agent/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date # noqa: TCH003 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from langchain_core.messages import BaseMessage # noqa: TCH002 from langgraph.graph import END, START, MessagesState, StateGraph @@ -46,6 +46,7 @@ def create_tablegpt_graph( normalize_llm: BaseLanguageModel | None = None, locale: str | None = None, checkpointer: BaseCheckpointSaver | None = None, + trim_message_method: Literal["default", "token_count"] = "default", verbose: bool = False, ) -> CompiledStateGraph: """Creates a state graph for processing datasets. @@ -67,6 +68,7 @@ def create_tablegpt_graph( normalize_llm (BaseLanguageModel | None, optional): Model for data normalization tasks. Defaults to None. locate (str | None, optional): The locale of the user. Defaults to None. checkpointer (BaseCheckpointSaver | None, optional): Component for saving checkpoints. Defaults to None. + trim_message_method (Literal["default", "token_count"], optional): Determines the method used to trim the message. Defaults to "default". verbose (bool, optional): Flag to enable verbose logging. Defaults to False. Returns: @@ -92,6 +94,7 @@ def create_tablegpt_graph( vlm=vlm, safety_llm=safety_llm, dataset_retriever=dataset_retriever, + trim_message_method=trim_message_method, verbose=verbose, ) diff --git a/src/tablegpt/agent/data_analyzer.py b/src/tablegpt/agent/data_analyzer.py index 5da0263..8a00bb7 100644 --- a/src/tablegpt/agent/data_analyzer.py +++ b/src/tablegpt/agent/data_analyzer.py @@ -2,7 +2,7 @@ from copy import deepcopy from datetime import date # noqa: TCH003 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Literal from uuid import uuid4 from langchain_core.messages import ( @@ -18,13 +18,18 @@ from tablegpt.agent.output_parser import MarkdownOutputParser from tablegpt.retriever import format_columns from tablegpt.safety import create_hazard_classifier -from tablegpt.tools import IPythonTool, markdown_console_template, process_content +from tablegpt.tools import ( + IPythonTool, + markdown_console_template, + process_content, +) from tablegpt.utils import filter_contents if TYPE_CHECKING: from pathlib import Path from langchain_core.agents import AgentAction, AgentFinish + from langchain_core.language_models import BaseLanguageModel from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import Runnable from pybox.base import BasePyBoxManager @@ -52,7 +57,7 @@ ) -def get_data_analyzer_agent(llm: Runnable) -> Runnable: +def get_data_analyzer_agent(llm: BaseLanguageModel) -> Runnable: return PROMPT | llm | MarkdownOutputParser(language_actions={"python": "python", "py": "python"}) @@ -65,17 +70,36 @@ class AgentState(MessagesState): parent_id: str | None +def get_messages_truncation_config( + blm: BaseLanguageModel | None, + trim_message_method: Literal["default", "token_count"] = "default", +) -> tuple[ + Callable[[list[BaseMessage]], int] | Callable[[BaseMessage], int] | BaseLanguageModel, + int, +]: + match trim_message_method: + case "default": + return len, 20 + + case "token_count": + return blm, int((blm.metadata or {}).get("max_history_tokens", 4096)) + case _: + e_msg = f"The expected value should be one of ['default', 'token_count'], but you provided {trim_message_method}." + raise ValueError(e_msg) + + def create_data_analyze_workflow( - llm: Runnable, + llm: BaseLanguageModel, pybox_manager: BasePyBoxManager, *, workdir: Path | None = None, session_id: str | None = None, error_trace_cleanup: bool = False, - vlm: Runnable | None = None, + vlm: BaseLanguageModel | None = None, safety_llm: Runnable | None = None, dataset_retriever: BaseRetriever | None = None, verbose: bool = False, + trim_message_method: Literal["default", "token_count"] = "default", ) -> Runnable: """Creates a data analysis workflow for processing user input and datasets. @@ -85,15 +109,21 @@ def create_data_analyze_workflow( The workflow is designed to handle multiple types of messages and responses. Args: - llm (Runnable): The primary language model for processing user input. + llm (BaseLanguageModel): The primary language model for processing user input. pybox_manager (BasePyBoxManager): A python code sandbox delegator, used to execute the data analysis code generated by llm. workdir (Path | None, optional): The working directory for `pybox` operations. Defaults to None. session_id (str | None, optional): An optional session identifier used to associate with `pybox`. Defaults to None. error_trace_cleanup (bool, optional): Flag to indicate if error traces should be cleaned up. Defaults to False. - vlm (Runnable | None, optional): Optional vision language model for processing images. Defaults to None. + vlm (BaseLanguageModel | None, optional): Optional vision language model for processing images. Defaults to None. safety_llm (Runnable | None, optional): Model used for safety classification of inputs. Defaults to None. dataset_retriever (BaseRetriever | None, optional): Component to retrieve dataset columns based on user input. Defaults to None. verbose (bool, optional): Flag to enable detailed logging. Defaults to False. + trim_message_method (Literal["default", "token_count"], optional): Determines the method used to trim the message. Defaults to "default". + - "default": Applies the default trimming method (Truncate using the length of messages, default max length is 20). + - "token_count": Use token count to truncate messages. + Ensure the `BaseLanguageModel` has the `get_num_tokens_from_messages` method. + And set `max_history_tokens` in `BaseLanguageModel.metadata`, e.g., {"max_history_tokens": 4096} (default 4096). + You can specify the value using: `max_model_len (max tokens the service supports) - max_new_tokens (tokens needed for the request)`. Returns: Runnable: A runnable object representing the data analysis workflow. @@ -108,7 +138,9 @@ def create_data_analyze_workflow( if safety_llm is not None: hazard_classifier = create_hazard_classifier(safety_llm) - async def run_input_guard(state: AgentState) -> dict[str, list[BaseMessage]]: + async def run_input_guard( + state: AgentState, + ) -> dict[str, list[BaseMessage]]: if hazard_classifier is not None: last_message = state["messages"][-1] flag, category = await hazard_classifier.ainvoke(input={"input": last_message.content}) @@ -157,11 +189,13 @@ async def retrieve_columns(state: AgentState) -> dict: async def arun_tablegpt_agent(state: AgentState) -> dict: # Truncate messages based on message count. - # TODO: truncate based on token count. + + token_counter, max_tokens = get_messages_truncation_config(llm, trim_message_method) + windowed_messages = trim_messages( state["messages"], - token_counter=len, - max_tokens=20, + token_counter=token_counter, + max_tokens=max_tokens, start_on="human", # This means that the first message should be from the user after trimming. # The system message is not in `messages`, so we don't need to specify `include_system` ) @@ -190,11 +224,12 @@ async def arun_tablegpt_agent(state: AgentState) -> dict: async def arun_vlm_agent(state: AgentState) -> dict: # Truncate messages based on message count. - # TODO: truncate based on token count. + token_counter, max_tokens = get_messages_truncation_config(vlm, trim_message_method) + windowed_messages: list[BaseMessage] = trim_messages( state["messages"], - token_counter=len, - max_tokens=20, + token_counter=token_counter, + max_tokens=max_tokens, start_on="human", # This means that the first message should be from the user after trimming. # The system message is not in `messages`, so we don't need to specify `include_system` )