Skip to content

Commit

Permalink
refactor: get raw table information from the kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
vegetablest committed Nov 27, 2024
1 parent d29e951 commit d1c1c36
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 255 deletions.
32 changes: 25 additions & 7 deletions src/tablegpt/agent/file_reading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import ast
import logging
from ast import literal_eval
from enum import Enum
from typing import TYPE_CHECKING, Literal
from uuid import uuid4

import pandas as pd
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
Expand All @@ -19,7 +21,6 @@
from tablegpt.errors import NoAttachmentsError
from tablegpt.tools import IPythonTool, markdown_console_template
from tablegpt.translation import create_translator
from tablegpt.utils import get_raw_table_info

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -84,9 +85,8 @@ def create_file_reading_workflow(
translation_chain = None
if locale is not None:
translation_chain = create_translator(llm=llm)

tools = [IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)]
tool_executor = ToolNode(tools)
ipython_tool = IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)
tool_executor = ToolNode([ipython_tool])

async def agent(state: AgentState) -> dict:
if state.get("processing_stage", Stage.UPLOADED) == Stage.UPLOADED:
Expand All @@ -103,15 +103,18 @@ async def generate_normalization_code(state: AgentState) -> str:
else:
raise NoAttachmentsError

filepath = workdir.joinpath(filename)
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

# TODO: refactor the data normalization to langgraph
raw_table_info = get_raw_table_info(filepath=filepath)
content = await ipython_tool.ainvoke(
input=RAW_TABLE_INFO_CODE.format_map({"filepath": filename, "var_name": f"{var_name}_5rows"})
)
raw_table_info = ast.literal_eval(next(x["text"] for x in content if x["type"] == "text"))
table_reformat_chain = get_table_reformat_chain(llm=normalize_llm)
reformatted_table = await table_reformat_chain.ainvoke(input={"table": raw_table_info})

if reformatted_table == raw_table_info:
# TODO: Replace pandas dependency with a lightweight alternative or custom implementation.
if pd.DataFrame(reformatted_table).astype(str).equals(pd.DataFrame(raw_table_info)):
return ""

normalize_chain = get_data_normalize_chain(llm=normalize_llm)
Expand Down Expand Up @@ -316,3 +319,18 @@ def should_continue(state: AgentState) -> Literal["tools", "end"]:
)

return workflow.compile(debug=verbose)


RAW_TABLE_INFO_CODE = """import numpy as np
from datetime import datetime
{var_name} = read_df('{filepath}', nrows=5, header=None)
# Replace NaN with None and format datetime cells
{var_name} = {var_name}.where({var_name}.notnull(), None).map(
lambda cell: (
cell.strftime('%Y-%m-%d') if isinstance(cell, (pd.Timestamp, datetime)) and pd.notnull(cell) else cell
)
)
# Convert DataFrame to a list of lists
print({var_name}.replace(np.nan, None).values.tolist())"""
40 changes: 1 addition & 39 deletions src/tablegpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import concurrent.futures
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import TYPE_CHECKING, NamedTuple, cast

import numpy as np
import pandas as pd

from tablegpt.errors import (
Expand Down Expand Up @@ -224,39 +222,3 @@ def filter_content(message: BaseMessage, keep: Sequence[str] | None = None) -> B
if not isinstance(part, dict) or part.get("type") in keep:
cloned.content.append(part)
return cloned


def get_raw_table_info(
filepath: Path,
nrows: int = 5,
date_format: str = "%Y-%m-%d",
) -> list[list[Any]]:
"""Reads the first nrows from the specified sheet of an Excel or CSV file.
This function reads the data from a table stored in a file, allowing
for the selection of a specific sheet and limiting the number of rows
to be read.
Args:
filepath: The path to the file containing the table data.
sheet_name: The name or index of the sheet within the file
to read the data from. If not specified, the first sheet will
be used.
nrows: The number of rows to read from the table.
date_format: The format of the datetime cells.
Returns:
list[list[Any]]: A 2D array where each sublist represents a row from the table,
with each element in the sublist corresponding to a column value in that row.
"""
read_kwargs = {"nrows": nrows, "header": None}

df = read_df(filepath, **read_kwargs)

# Replace NaN with None and format datetime cells
df = df.where(df.notnull(), None).map(
lambda cell: (
cell.strftime(date_format) if isinstance(cell, (pd.Timestamp, datetime)) and pd.notnull(cell) else cell
)
)
# Convert DataFrame to a list of lists
return df.replace({np.nan: None}).values.tolist()
209 changes: 0 additions & 209 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
from langchain_core.messages import BaseMessage
from tablegpt.utils import (
filter_content,
get_raw_table_info,
path_from_uri,
)

Expand Down Expand Up @@ -63,211 +59,6 @@ def test_valid_file_uri_with_encoded_characters(self):
assert path_from_uri(uri) == expected_path


class TestGetRawTableInfo(unittest.TestCase):
def test_unsupported_formats(self):
"""Test list sheets when Excel file contains sheets."""
with self.assertRaises(ValueError) as e: # noqa: PT027
get_raw_table_info(Path("/home/user/file.text"))
assert str(e) == "Unsupported file format: .text"

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xlsx_with_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xlsx"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]

# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xlsx_without_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xls_with_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]

# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xls_without_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.read_df")
def test_csv(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
[
["Header1", "Header2", "Header3"], # Header row
["Data1", "123", "MoreData1"], # Data row 1
["Data2", "456", "MoreData2"], # Data row 2
]
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
["Header1", "Header2", "Header3"], # Header row
["Data1", "123", "MoreData1"], # Row 1 data
["Data2", "456", "MoreData2"], # Row 2 data
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info

@patch("tablegpt.utils.read_df")
def test_with_nan(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
{
"str_col": [None, "Data1", "Data2"],
"int_col": [123, 456, 789],
"float_col": [1.1, np.nan, 3.3],
}
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
[None, 123, 1.1],
["Data1", 456, None],
["Data2", 789, 3.3],
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info

@patch("tablegpt.utils.read_df")
def test_with_timestamp(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
{
"numbers": [1, 2, np.nan, 4],
"dates": [
pd.Timestamp("2023-01-01 12:30"),
pd.NaT,
pd.Timestamp("2023-01-02"),
pd.NaT,
],
}
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
[1, "2023-01-01"],
[2, None],
[None, "2023-01-02"],
[4, None],
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info


class TestFilterContent(unittest.TestCase):
def test_filter_content_with_string_content(self):
message = BaseMessage(content="Hello, World!", type="ai")
Expand Down

0 comments on commit d1c1c36

Please sign in to comment.