Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add agent-eval #51

Merged
merged 1 commit into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions eval/agent_eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# TableGPT Evaluation

This document will guide you through the process of setting up the evaluation environment and running evaluations.

## Evaluation Datasets

Before running the evaluation, you need to create the evaluation datasets on Local.

In the evaluation context, the term "dataset" can be confusing because it has two different meanings. The first refers to evaluation datasets, which contain the samples you wish to evaluate. Each sample must have an 'input' field representing the user input and may optionally include an 'expected output' field if there is a ground truth answer to that input. The second definition refers to the dataset on which the user wants to perform analysis, which we refer to as 'reference data'.

### Input

We use LLM to assist in generating questions based on the input dataset. You can find the script [here](./questioner.py).

Please note that while our goal was to create a one-click solution for question generation, the current implementation may require some manual adjustments. Depending on your dataset, you might need to tweak the prompt accordingly. For instance, the default prompt aims to "uncover business value," which is not suitable for datasets related to diseases.

### Expected Output

While not all samples require an 'expected output' field, certain inputs—particularly those related to data analysis—do need a ground truth answer for comparison during evaluation. We use Agent Apps (such as ChatGPT, ChatGLM, etc.) to assist in generating the 'expected output.'

It's crucial to be meticulous when crafting the 'expected output' because it serves as the ground truth for evaluation. If the 'expected output' is incorrect, the evaluation results will be inaccurate.

## Installation

Create a virtual environment

```sh
python -m venv venv
source ./venv/bin/activate # On Windows, use `.\venv\Scripts\activate`
```

Install dependencies for eval

```sh
pip install -r requirements.txt
```

## Configuration

The configuration file for evaluation is a YAML file (config.yaml by default). Refer to [example-config.yaml](./example-config.yaml) for detailed information.

## Run the evaluation script

Besides the config file, you need to set up some environment variables, either by exporting them or by creating a `.env` file in the root directory.

To run the evaluation script, use the following command:

```sh
python -m agent_eval --config path/to/your/config.yaml
```
Empty file added eval/agent_eval/__init__.py
Empty file.
52 changes: 52 additions & 0 deletions eval/agent_eval/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import asyncio
import logging
import os
import signal
import sys

from agent_eval.config import load_config
from agent_eval.evaluator import Evaluator
from dotenv import find_dotenv, load_dotenv
from langchain.globals import set_debug
from traitlets.log import get_logger

LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO")


logger = logging.getLogger(__name__)
logger.setLevel(level=LOG_LEVEL)

set_debug(LOG_LEVEL.upper() == "TRACE")

# silent traitlets logs
traitlets_logger = get_logger()
traitlets_logger.setLevel("ERROR")


async def main() -> None:
# Set up signal handling for graceful shutdown
stop_event = asyncio.Event()
# Windows does not support signal handling, we handle KeyboardInterrupt instead
if sys.platform != "win32":
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, stop_event.set)
loop.add_signal_handler(signal.SIGTERM, stop_event.set)

config = load_config()
evaluator = Evaluator(config)
try:
await evaluator.run(stop_event)
except asyncio.exceptions.CancelledError:
stop_event.set()
except KeyboardInterrupt:
# TODO: On Windows we should enter here. However we went to the except block above.
logger.warning("Received CTRL+C, stopping...")
stop_event.set()


if __name__ == "__main__":
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())

load_dotenv(find_dotenv())
asyncio.run(main())
49 changes: 49 additions & 0 deletions eval/agent_eval/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import argparse
from pathlib import Path
from typing import Any
from uuid import uuid4

import yaml
from pydantic import BaseModel, Field, PositiveInt
from pydantic_settings import BaseSettings, SettingsConfigDict


class DatasetSettings(BaseModel):
name: str


class EvalSettings(BaseSettings):
model_config = SettingsConfigDict(extra="ignore")

run_name: str = Field(default_factory=lambda: f"eval-run-{uuid4()}")
metadata: dict[str, Any]
user: str = "eval-user"
datasets: list[DatasetSettings]

max_concurrency: PositiveInt = 1
num_repetitions: PositiveInt = 1

grader: dict[str, Any]


def load_config() -> dict[str, Any]:
parser = argparse.ArgumentParser(description="Run the evaluation script.")
parser.add_argument(
"--config",
type=str,
default="config.yaml",
help="Config file location.",
)
args = parser.parse_args()
config_path = Path(args.config).absolute()
if not config_path.exists():
raise RuntimeError(f"Config file '{args.config}' not found")

print(f"Using config file: {config_path}")
with open(str(config_path), "r") as file:
try:
config = yaml.safe_load(file)
except Exception as e:
raise ValueError(f"Error loading config file: {e}")

return EvalSettings(**config)
209 changes: 209 additions & 0 deletions eval/agent_eval/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import asyncio
import json
import logging
import traceback
from datetime import datetime
from typing import Any

import aiofiles
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.base import Checkpoint
from langgraph.checkpoint.memory import MemorySaver
from tqdm.asyncio import tqdm

from agent_eval.config import EvalSettings
from agent_eval.grader import grader_chain
from agent_eval.grader.prompt import (
DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER,
DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER,
)
from agent_eval.student import create_student_graph, student_context
from agent_eval.workflow import create_eval_workflow

logger = logging.getLogger(__name__)

# TODO: make this configurable, and we can continue running after an error
eval_run_output_file = f"eval_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl"


class Evaluator:
"""TableGPT Evaluator.

config(config.EvalSettings): evaluator configuration.
client(langfuse.Langfuse): Langfuse client.
grader(langchain_core.runnables.Runnable): Grader used to grade the student's answer.
"""

def __init__(self, config: EvalSettings) -> None:
"""Initialize the Evaluator with the given configuration.

Args:
config (dict): Configuration dictionary for the Evaluator.
"""

logger.info("Initializing evaluator with config: %s}", config)
self.config = config
self.grader = grader_chain(ChatOpenAI(**config.grader))
logger.info("Evaluator initialized")

async def run_eval(
self,
payload: dict[str, Any],
student_context: dict[str, Any] | None = None,
) -> None:
"""Run the evaluation workflow.
Usually a student runnable will be executed, followed by a grader runnable.

Args:
payload (dict[str, Any]): Evaluation payload.
student_context (dict[str, Any] | None, optional): Context to be passed to the student. Defaults to None.
"""

if student_context is None:
student_context = {}

with MemorySaver() as checkpointer:
student = await create_student_graph(
datasets=payload.get("datasets"),
checkpointer=checkpointer,
**student_context,
)

eval_wf = create_eval_workflow(student=student, grader=self.grader)

item: dict[str, Any] = payload["item"]
criteria = payload.get("criteria")
if not criteria:
criteria = (
DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER
if item["expected_output"]
else DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER
)
try:
res = await eval_wf.ainvoke(
input={
"input": item["input"],
"reference_answer": item["expected_output"],
"criteria": criteria,
"redlines": payload.get("redlines", []),
},
)
grader_result = res["grader_result"]
except Exception:
logger.exception(
"Student Workflow failed, item: %s, context: %s",
item["input"],
student_context,
)
# We treat any exception in agent invocation as a bad case
err_info = traceback.format_exc()
grader_result = {
"score": 0,
"explaination": err_info,
}

checkpoint: Checkpoint = checkpointer.get(
config={
"configurable": {"thread_id": student_context["session_id"]},
}
)
messages = checkpoint["channel_values"].get("messages", [])
messages = [message.dict() for message in messages]

eval_result = {
"input": item["input"],
"score": grader_result,
"reference_answer": item["expected_output"],
"student_answer": res["student_answer"],
"criteria": criteria,
"redlines": payload.get("redlines", []),
"messages": messages,
}

async with aiofiles.open(eval_run_output_file, mode="a") as f:
await f.write(json.dumps(eval_result, ensure_ascii=False) + "\n")

async def worker(
self,
queue: asyncio.Queue,
stop_event: asyncio.Event,
pbar: tqdm | None = None,
) -> None:
"""Worker to process tasks from the task queue.

Args:
queue (asyncio.Queue): Task queue.
stop_event (asyncio.Event): Stop events to signal the worker to stop.
pbar (tqdm | None, optional): Progress bar to update task progress. Defaults to None.
"""
logger.info("Worker started")
async with student_context() as context:
while True:
if stop_event.is_set():
logger.warning("Worker received stop event, cancelling...")
break
try:
payload = queue.get_nowait()
await self.run_eval(payload=payload, student_context=context)
if pbar is not None:
pbar.update(1)
except asyncio.QueueEmpty:
# No more tasks in the queue, quit current worker
logger.info("Worker finished")
break
except Exception as e:
logger.exception("Worker encountered an error")
stop_event.set() # Set the stop event to cancel other workers
break

async def run(self, stop_event: asyncio.Event) -> None:
"""Gather evaluation samples and run the evaluation process, in parallel."""
logger.info("Gathering evaluation samples...")
queue = asyncio.Queue()
for dataset_config in self.config.datasets:
logger.debug("Gathering samples from dataset: %s...", dataset_config.name)

with open(dataset_config.name, "r") as f:
dataset = json.load(f)
_samples = gather_samples(dataset)
logger.debug(
"Gathered %d samples from dataset %s",
len(_samples),
dataset_config.name,
)
for sample in _samples:
for _ in range(self.config.num_repetitions):
await queue.put(sample)
total_samples = queue.qsize()
logger.info("Gathered %s samples for evaluation", total_samples)

with tqdm(total=total_samples, desc="Evaluation samples") as pbar:
try:
eval_tasks = [
asyncio.create_task(
self.worker(queue, stop_event, pbar),
name=f"worker-{i}",
)
for i in range(self.config.max_concurrency)
]
await asyncio.gather(
*eval_tasks, return_exceptions=True
) # Ensure all consumers exit
except Exception:
logger.exception("Error in evaluator")
finally:
logger.info("Shutting down evaluator...")


def gather_samples(dataset: list[dict[str, Any]]) -> list[dict[str, Any]]:
samples = []
active_samples = [item for item in dataset if item["status"] != "ARCHIVED"]

for item in active_samples:
samples.append(
{
"item": item,
"datasets": item.get("attachments", []),
}
)
return samples
22 changes: 22 additions & 0 deletions eval/agent_eval/example-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
user: eval-example

metadata:
name: tablegpt eval
llm:
name: qwen2.5-7b-instruct
temperature: 0.1
top_p: 0.3

datasets:
- name: /datasets/tablegpt-eval-normal.json

grader:
openai_api_base: http://localhost:8080/v1
openai_api_key: nothing
model_name: qwen2.5-72b-instruct
temperature: 0.1
top_p: 0.3
max_tokens: 1024

max_concurrency: 1
num_repetitions: 1
Loading