Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: get raw table information from the kernel #120

Merged
merged 1 commit into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions src/tablegpt/agent/file_reading/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import ast
import logging
from ast import literal_eval
from enum import Enum
from typing import TYPE_CHECKING, Literal
from uuid import uuid4

import pandas as pd
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
Expand All @@ -19,7 +21,6 @@
from tablegpt.errors import NoAttachmentsError
from tablegpt.tools import IPythonTool, markdown_console_template
from tablegpt.translation import create_translator
from tablegpt.utils import get_raw_table_info

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -84,9 +85,8 @@ def create_file_reading_workflow(
translation_chain = None
if locale is not None:
translation_chain = create_translator(llm=llm)

tools = [IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)]
tool_executor = ToolNode(tools)
ipython_tool = IPythonTool(pybox_manager=pybox_manager, cwd=workdir, session_id=session_id)
tool_executor = ToolNode([ipython_tool])

async def agent(state: AgentState) -> dict:
if state.get("processing_stage", Stage.UPLOADED) == Stage.UPLOADED:
Expand All @@ -103,15 +103,18 @@ async def generate_normalization_code(state: AgentState) -> str:
else:
raise NoAttachmentsError

filepath = workdir.joinpath(filename)
var_name = state["entry_message"].additional_kwargs.get("var_name", "df")

# TODO: refactor the data normalization to langgraph
raw_table_info = get_raw_table_info(filepath=filepath)
content = await ipython_tool.ainvoke(
input=RAW_TABLE_INFO_CODE.format_map({"filepath": filename, "var_name": f"{var_name}_5rows"})
)
raw_table_info = ast.literal_eval(next(x["text"] for x in content if x["type"] == "text"))
table_reformat_chain = get_table_reformat_chain(llm=normalize_llm)
reformatted_table = await table_reformat_chain.ainvoke(input={"table": raw_table_info})

if reformatted_table == raw_table_info:
# TODO: Replace pandas dependency with a lightweight alternative or custom implementation.
if pd.DataFrame(reformatted_table).astype(str).equals(pd.DataFrame(raw_table_info)):
vegetablest marked this conversation as resolved.
Show resolved Hide resolved
return ""

normalize_chain = get_data_normalize_chain(llm=normalize_llm)
Expand Down Expand Up @@ -316,3 +319,18 @@ def should_continue(state: AgentState) -> Literal["tools", "end"]:
)

return workflow.compile(debug=verbose)


RAW_TABLE_INFO_CODE = """import numpy as np
from datetime import datetime

{var_name} = read_df('{filepath}', nrows=5, header=None)

# Replace NaN with None and format datetime cells
{var_name} = {var_name}.where({var_name}.notnull(), None).map(
lambda cell: (
cell.strftime('%Y-%m-%d') if isinstance(cell, (pd.Timestamp, datetime)) and pd.notnull(cell) else cell
)
)
# Convert DataFrame to a list of lists
print({var_name}.replace(np.nan, None).values.tolist())"""
40 changes: 1 addition & 39 deletions src/tablegpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import concurrent.futures
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, NamedTuple, cast
from typing import TYPE_CHECKING, NamedTuple, cast

import numpy as np
import pandas as pd

from tablegpt.errors import (
Expand Down Expand Up @@ -224,39 +222,3 @@ def filter_content(message: BaseMessage, keep: Sequence[str] | None = None) -> B
if not isinstance(part, dict) or part.get("type") in keep:
cloned.content.append(part)
return cloned


def get_raw_table_info(
filepath: Path,
nrows: int = 5,
date_format: str = "%Y-%m-%d",
) -> list[list[Any]]:
"""Reads the first nrows from the specified sheet of an Excel or CSV file.
This function reads the data from a table stored in a file, allowing
for the selection of a specific sheet and limiting the number of rows
to be read.
Args:
filepath: The path to the file containing the table data.
sheet_name: The name or index of the sheet within the file
to read the data from. If not specified, the first sheet will
be used.
nrows: The number of rows to read from the table.
date_format: The format of the datetime cells.
Returns:
list[list[Any]]: A 2D array where each sublist represents a row from the table,
with each element in the sublist corresponding to a column value in that row.
"""
read_kwargs = {"nrows": nrows, "header": None}

df = read_df(filepath, **read_kwargs)

# Replace NaN with None and format datetime cells
df = df.where(df.notnull(), None).map(
lambda cell: (
cell.strftime(date_format) if isinstance(cell, (pd.Timestamp, datetime)) and pd.notnull(cell) else cell
)
)
# Convert DataFrame to a list of lists
return df.replace({np.nan: None}).values.tolist()
209 changes: 0 additions & 209 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
import unittest
from pathlib import Path
from unittest.mock import MagicMock, patch

import numpy as np
import pandas as pd
from langchain_core.messages import BaseMessage
from tablegpt.utils import (
filter_content,
get_raw_table_info,
path_from_uri,
)

Expand Down Expand Up @@ -63,211 +59,6 @@ def test_valid_file_uri_with_encoded_characters(self):
assert path_from_uri(uri) == expected_path


class TestGetRawTableInfo(unittest.TestCase):
def test_unsupported_formats(self):
"""Test list sheets when Excel file contains sheets."""
with self.assertRaises(ValueError) as e: # noqa: PT027
get_raw_table_info(Path("/home/user/file.text"))
assert str(e) == "Unsupported file format: .text"

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xlsx_with_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xlsx"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]

# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xlsx_without_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xls_with_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]

# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.pd.ExcelFile")
@patch("tablegpt.utils.read_df")
def test_xls_without_sheetname(self, mock_read_df, mock_excel_file):
mock_xls = MagicMock()
mock_excel_file.return_value.__enter__.return_value = mock_xls # Mock ExcelFile as a context manager
mock_xls.sheet_names = ["Sheet1", "Sheet2"] # Mock the sheet names in the file

# Define side_effect for read_df based on sheet_name argument
def mock_read_df_side_effect(filepath, **kwargs): # noqa: ARG001
return pd.DataFrame(
[
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
)

mock_read_df.side_effect = mock_read_df_side_effect

# First test for "Sheet1"
raw_table_info_sheet1 = get_raw_table_info(filepath=Path("/home/user/file.xls"))

# Expected output for Sheet1
expected_df_info_sheet1 = [
["Header1", "Header2", "Header3"],
["Data1", 123, "MoreData1"],
["Data2", 456, "MoreData2"],
]
# Assert Sheet1
assert raw_table_info_sheet1 == expected_df_info_sheet1

@patch("tablegpt.utils.read_df")
def test_csv(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
[
["Header1", "Header2", "Header3"], # Header row
["Data1", "123", "MoreData1"], # Data row 1
["Data2", "456", "MoreData2"], # Data row 2
]
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
["Header1", "Header2", "Header3"], # Header row
["Data1", "123", "MoreData1"], # Row 1 data
["Data2", "456", "MoreData2"], # Row 2 data
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info

@patch("tablegpt.utils.read_df")
def test_with_nan(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
{
"str_col": [None, "Data1", "Data2"],
"int_col": [123, 456, 789],
"float_col": [1.1, np.nan, 3.3],
}
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
[None, 123, 1.1],
["Data1", 456, None],
["Data2", 789, 3.3],
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info

@patch("tablegpt.utils.read_df")
def test_with_timestamp(self, mock_read_df):
# Set up the mock for csv.reader to return an iterator over the expected rows
mock_read_df.return_value = pd.DataFrame(
{
"numbers": [1, 2, np.nan, 4],
"dates": [
pd.Timestamp("2023-01-01 12:30"),
pd.NaT,
pd.Timestamp("2023-01-02"),
pd.NaT,
],
}
)

# Call the function with the mocked file
raw_table_info = get_raw_table_info(filepath=Path("/home/user/file.csv"))

expected_df_info = [
[1, "2023-01-01"],
[2, None],
[None, "2023-01-02"],
[4, None],
]
# Assert that the function output matches the expected result
assert raw_table_info == expected_df_info


class TestFilterContent(unittest.TestCase):
def test_filter_content_with_string_content(self):
message = BaseMessage(content="Hello, World!", type="ai")
Expand Down