Skip to content

Commit

Permalink
feat: truncate messages based on token count
Browse files Browse the repository at this point in the history
  • Loading branch information
zt committed Dec 3, 2024
1 parent b69e782 commit 1be1d35
Showing 1 changed file with 34 additions and 12 deletions.
46 changes: 34 additions & 12 deletions src/tablegpt/agent/data_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from copy import deepcopy
from datetime import date # noqa: TCH003
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from uuid import uuid4

from langchain_core.messages import (
Expand All @@ -18,13 +18,18 @@
from tablegpt.agent.output_parser import MarkdownOutputParser
from tablegpt.retriever import format_columns
from tablegpt.safety import create_hazard_classifier
from tablegpt.tools import IPythonTool, markdown_console_template, process_content
from tablegpt.tools import (
IPythonTool,
markdown_console_template,
process_content,
)
from tablegpt.utils import filter_contents

if TYPE_CHECKING:
from pathlib import Path

from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from pybox.base import BasePyBoxManager
Expand Down Expand Up @@ -52,7 +57,7 @@
)


def get_data_analyzer_agent(llm: Runnable) -> Runnable:
def get_data_analyzer_agent(llm: BaseLanguageModel) -> Runnable:
return PROMPT | llm | MarkdownOutputParser(language_actions={"python": "python", "py": "python"})


Expand All @@ -65,14 +70,26 @@ class AgentState(MessagesState):
parent_id: str | None


def get_messages_truncation_config(
blm: BaseLanguageModel | None,
) -> tuple[
Callable[[list[BaseMessage]], int] | Callable[[BaseMessage], int] | BaseLanguageModel,
int,
]:
if hasattr(blm, "get_num_tokens_from_messages") and hasattr(blm, "get_max_tokens_after_trimming"):
return blm, blm.get_max_tokens_after_trimming()

return len, 20


def create_data_analyze_workflow(
llm: Runnable,
llm: BaseLanguageModel,
pybox_manager: BasePyBoxManager,
*,
workdir: Path | None = None,
session_id: str | None = None,
error_trace_cleanup: bool = False,
vlm: Runnable | None = None,
vlm: BaseLanguageModel | None = None,
safety_llm: Runnable | None = None,
dataset_retriever: BaseRetriever | None = None,
verbose: bool = False,
Expand Down Expand Up @@ -108,7 +125,9 @@ def create_data_analyze_workflow(
if safety_llm is not None:
hazard_classifier = create_hazard_classifier(safety_llm)

async def run_input_guard(state: AgentState) -> dict[str, list[BaseMessage]]:
async def run_input_guard(
state: AgentState,
) -> dict[str, list[BaseMessage]]:
if hazard_classifier is not None:
last_message = state["messages"][-1]
flag, category = await hazard_classifier.ainvoke(input={"input": last_message.content})
Expand Down Expand Up @@ -157,11 +176,13 @@ async def retrieve_columns(state: AgentState) -> dict:

async def arun_tablegpt_agent(state: AgentState) -> dict:
# Truncate messages based on message count.
# TODO: truncate based on token count.

token_counter, max_tokens = get_messages_truncation_config(llm)

windowed_messages = trim_messages(
state["messages"],
token_counter=len,
max_tokens=20,
token_counter=token_counter,
max_tokens=max_tokens,
start_on="human", # This means that the first message should be from the user after trimming.
# The system message is not in `messages`, so we don't need to specify `include_system`
)
Expand Down Expand Up @@ -190,11 +211,12 @@ async def arun_tablegpt_agent(state: AgentState) -> dict:

async def arun_vlm_agent(state: AgentState) -> dict:
# Truncate messages based on message count.
# TODO: truncate based on token count.
token_counter, max_tokens = get_messages_truncation_config(vlm)

windowed_messages: list[BaseMessage] = trim_messages(
state["messages"],
token_counter=len,
max_tokens=20,
token_counter=token_counter,
max_tokens=max_tokens,
start_on="human", # This means that the first message should be from the user after trimming.
# The system message is not in `messages`, so we don't need to specify `include_system`
)
Expand Down

0 comments on commit 1be1d35

Please sign in to comment.