-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #51 from weekenthralling/main
feat: add agent-eval
- Loading branch information
Showing
14 changed files
with
890 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# TableGPT Evaluation | ||
|
||
This document will guide you through the process of setting up the evaluation environment and running evaluations. | ||
|
||
## Evaluation Datasets | ||
|
||
Before running the evaluation, you need to create the evaluation datasets on Local. | ||
|
||
In the evaluation context, the term "dataset" can be confusing because it has two different meanings. The first refers to evaluation datasets, which contain the samples you wish to evaluate. Each sample must have an 'input' field representing the user input and may optionally include an 'expected output' field if there is a ground truth answer to that input. The second definition refers to the dataset on which the user wants to perform analysis, which we refer to as 'reference data'. | ||
|
||
### Input | ||
|
||
We use LLM to assist in generating questions based on the input dataset. You can find the script [here](./questioner.py). | ||
|
||
Please note that while our goal was to create a one-click solution for question generation, the current implementation may require some manual adjustments. Depending on your dataset, you might need to tweak the prompt accordingly. For instance, the default prompt aims to "uncover business value," which is not suitable for datasets related to diseases. | ||
|
||
### Expected Output | ||
|
||
While not all samples require an 'expected output' field, certain inputs—particularly those related to data analysis—do need a ground truth answer for comparison during evaluation. We use Agent Apps (such as ChatGPT, ChatGLM, etc.) to assist in generating the 'expected output.' | ||
|
||
It's crucial to be meticulous when crafting the 'expected output' because it serves as the ground truth for evaluation. If the 'expected output' is incorrect, the evaluation results will be inaccurate. | ||
|
||
## Installation | ||
|
||
Create a virtual environment | ||
|
||
```sh | ||
python -m venv venv | ||
source ./venv/bin/activate # On Windows, use `.\venv\Scripts\activate` | ||
``` | ||
|
||
Install dependencies for eval | ||
|
||
```sh | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Configuration | ||
|
||
The configuration file for evaluation is a YAML file (config.yaml by default). Refer to [example-config.yaml](./example-config.yaml) for detailed information. | ||
|
||
## Run the evaluation script | ||
|
||
Besides the config file, you need to set up some environment variables, either by exporting them or by creating a `.env` file in the root directory. | ||
|
||
To run the evaluation script, use the following command: | ||
|
||
```sh | ||
python -m agent_eval --config path/to/your/config.yaml | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import asyncio | ||
import logging | ||
import os | ||
import signal | ||
import sys | ||
|
||
from agent_eval.config import load_config | ||
from agent_eval.evaluator import Evaluator | ||
from dotenv import find_dotenv, load_dotenv | ||
from langchain.globals import set_debug | ||
from traitlets.log import get_logger | ||
|
||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO") | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(level=LOG_LEVEL) | ||
|
||
set_debug(LOG_LEVEL.upper() == "TRACE") | ||
|
||
# silent traitlets logs | ||
traitlets_logger = get_logger() | ||
traitlets_logger.setLevel("ERROR") | ||
|
||
|
||
async def main() -> None: | ||
# Set up signal handling for graceful shutdown | ||
stop_event = asyncio.Event() | ||
# Windows does not support signal handling, we handle KeyboardInterrupt instead | ||
if sys.platform != "win32": | ||
loop = asyncio.get_running_loop() | ||
loop.add_signal_handler(signal.SIGINT, stop_event.set) | ||
loop.add_signal_handler(signal.SIGTERM, stop_event.set) | ||
|
||
config = load_config() | ||
evaluator = Evaluator(config) | ||
try: | ||
await evaluator.run(stop_event) | ||
except asyncio.exceptions.CancelledError: | ||
stop_event.set() | ||
except KeyboardInterrupt: | ||
# TODO: On Windows we should enter here. However we went to the except block above. | ||
logger.warning("Received CTRL+C, stopping...") | ||
stop_event.set() | ||
|
||
|
||
if __name__ == "__main__": | ||
if sys.platform == "win32": | ||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) | ||
|
||
load_dotenv(find_dotenv()) | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import argparse | ||
from pathlib import Path | ||
from typing import Any | ||
from uuid import uuid4 | ||
|
||
import yaml | ||
from pydantic import BaseModel, Field, PositiveInt | ||
from pydantic_settings import BaseSettings, SettingsConfigDict | ||
|
||
|
||
class DatasetSettings(BaseModel): | ||
name: str | ||
|
||
|
||
class EvalSettings(BaseSettings): | ||
model_config = SettingsConfigDict(extra="ignore") | ||
|
||
run_name: str = Field(default_factory=lambda: f"eval-run-{uuid4()}") | ||
metadata: dict[str, Any] | ||
user: str = "eval-user" | ||
datasets: list[DatasetSettings] | ||
|
||
max_concurrency: PositiveInt = 1 | ||
num_repetitions: PositiveInt = 1 | ||
|
||
grader: dict[str, Any] | ||
|
||
|
||
def load_config() -> dict[str, Any]: | ||
parser = argparse.ArgumentParser(description="Run the evaluation script.") | ||
parser.add_argument( | ||
"--config", | ||
type=str, | ||
default="config.yaml", | ||
help="Config file location.", | ||
) | ||
args = parser.parse_args() | ||
config_path = Path(args.config).absolute() | ||
if not config_path.exists(): | ||
raise RuntimeError(f"Config file '{args.config}' not found") | ||
|
||
print(f"Using config file: {config_path}") | ||
with open(str(config_path), "r") as file: | ||
try: | ||
config = yaml.safe_load(file) | ||
except Exception as e: | ||
raise ValueError(f"Error loading config file: {e}") | ||
|
||
return EvalSettings(**config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import asyncio | ||
import json | ||
import logging | ||
import traceback | ||
from datetime import datetime | ||
from typing import Any | ||
|
||
import aiofiles | ||
from langchain_openai import ChatOpenAI | ||
from langgraph.checkpoint.base import Checkpoint | ||
from langgraph.checkpoint.memory import MemorySaver | ||
from tqdm.asyncio import tqdm | ||
|
||
from agent_eval.config import EvalSettings | ||
from agent_eval.grader import grader_chain | ||
from agent_eval.grader.prompt import ( | ||
DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER, | ||
DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER, | ||
) | ||
from agent_eval.student import create_student_graph, student_context | ||
from agent_eval.workflow import create_eval_workflow | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# TODO: make this configurable, and we can continue running after an error | ||
eval_run_output_file = f"eval_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jsonl" | ||
|
||
|
||
class Evaluator: | ||
"""TableGPT Evaluator. | ||
config(config.EvalSettings): evaluator configuration. | ||
client(langfuse.Langfuse): Langfuse client. | ||
grader(langchain_core.runnables.Runnable): Grader used to grade the student's answer. | ||
""" | ||
|
||
def __init__(self, config: EvalSettings) -> None: | ||
"""Initialize the Evaluator with the given configuration. | ||
Args: | ||
config (dict): Configuration dictionary for the Evaluator. | ||
""" | ||
|
||
logger.info("Initializing evaluator with config: %s}", config) | ||
self.config = config | ||
self.grader = grader_chain(ChatOpenAI(**config.grader)) | ||
logger.info("Evaluator initialized") | ||
|
||
async def run_eval( | ||
self, | ||
payload: dict[str, Any], | ||
student_context: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Run the evaluation workflow. | ||
Usually a student runnable will be executed, followed by a grader runnable. | ||
Args: | ||
payload (dict[str, Any]): Evaluation payload. | ||
student_context (dict[str, Any] | None, optional): Context to be passed to the student. Defaults to None. | ||
""" | ||
|
||
if student_context is None: | ||
student_context = {} | ||
|
||
with MemorySaver() as checkpointer: | ||
student = await create_student_graph( | ||
datasets=payload.get("datasets"), | ||
checkpointer=checkpointer, | ||
**student_context, | ||
) | ||
|
||
eval_wf = create_eval_workflow(student=student, grader=self.grader) | ||
|
||
item: dict[str, Any] = payload["item"] | ||
criteria = payload.get("criteria") | ||
if not criteria: | ||
criteria = ( | ||
DEFAULT_CRITERIA_WITH_REFERENCE_ANSWER | ||
if item["expected_output"] | ||
else DEFAULT_CRITERIA_WITHOUT_REFERENCE_ANSWER | ||
) | ||
try: | ||
res = await eval_wf.ainvoke( | ||
input={ | ||
"input": item["input"], | ||
"reference_answer": item["expected_output"], | ||
"criteria": criteria, | ||
"redlines": payload.get("redlines", []), | ||
}, | ||
) | ||
grader_result = res["grader_result"] | ||
except Exception: | ||
logger.exception( | ||
"Student Workflow failed, item: %s, context: %s", | ||
item["input"], | ||
student_context, | ||
) | ||
# We treat any exception in agent invocation as a bad case | ||
err_info = traceback.format_exc() | ||
grader_result = { | ||
"score": 0, | ||
"explaination": err_info, | ||
} | ||
|
||
checkpoint: Checkpoint = checkpointer.get( | ||
config={ | ||
"configurable": {"thread_id": student_context["session_id"]}, | ||
} | ||
) | ||
messages = checkpoint["channel_values"].get("messages", []) | ||
messages = [message.dict() for message in messages] | ||
|
||
eval_result = { | ||
"input": item["input"], | ||
"score": grader_result, | ||
"reference_answer": item["expected_output"], | ||
"student_answer": res["student_answer"], | ||
"criteria": criteria, | ||
"redlines": payload.get("redlines", []), | ||
"messages": messages, | ||
} | ||
|
||
async with aiofiles.open(eval_run_output_file, mode="a") as f: | ||
await f.write(json.dumps(eval_result, ensure_ascii=False) + "\n") | ||
|
||
async def worker( | ||
self, | ||
queue: asyncio.Queue, | ||
stop_event: asyncio.Event, | ||
pbar: tqdm | None = None, | ||
) -> None: | ||
"""Worker to process tasks from the task queue. | ||
Args: | ||
queue (asyncio.Queue): Task queue. | ||
stop_event (asyncio.Event): Stop events to signal the worker to stop. | ||
pbar (tqdm | None, optional): Progress bar to update task progress. Defaults to None. | ||
""" | ||
logger.info("Worker started") | ||
async with student_context() as context: | ||
while True: | ||
if stop_event.is_set(): | ||
logger.warning("Worker received stop event, cancelling...") | ||
break | ||
try: | ||
payload = queue.get_nowait() | ||
await self.run_eval(payload=payload, student_context=context) | ||
if pbar is not None: | ||
pbar.update(1) | ||
except asyncio.QueueEmpty: | ||
# No more tasks in the queue, quit current worker | ||
logger.info("Worker finished") | ||
break | ||
except Exception as e: | ||
logger.exception("Worker encountered an error") | ||
stop_event.set() # Set the stop event to cancel other workers | ||
break | ||
|
||
async def run(self, stop_event: asyncio.Event) -> None: | ||
"""Gather evaluation samples and run the evaluation process, in parallel.""" | ||
logger.info("Gathering evaluation samples...") | ||
queue = asyncio.Queue() | ||
for dataset_config in self.config.datasets: | ||
logger.debug("Gathering samples from dataset: %s...", dataset_config.name) | ||
|
||
with open(dataset_config.name, "r") as f: | ||
dataset = json.load(f) | ||
_samples = gather_samples(dataset) | ||
logger.debug( | ||
"Gathered %d samples from dataset %s", | ||
len(_samples), | ||
dataset_config.name, | ||
) | ||
for sample in _samples: | ||
for _ in range(self.config.num_repetitions): | ||
await queue.put(sample) | ||
total_samples = queue.qsize() | ||
logger.info("Gathered %s samples for evaluation", total_samples) | ||
|
||
with tqdm(total=total_samples, desc="Evaluation samples") as pbar: | ||
try: | ||
eval_tasks = [ | ||
asyncio.create_task( | ||
self.worker(queue, stop_event, pbar), | ||
name=f"worker-{i}", | ||
) | ||
for i in range(self.config.max_concurrency) | ||
] | ||
await asyncio.gather( | ||
*eval_tasks, return_exceptions=True | ||
) # Ensure all consumers exit | ||
except Exception: | ||
logger.exception("Error in evaluator") | ||
finally: | ||
logger.info("Shutting down evaluator...") | ||
|
||
|
||
def gather_samples(dataset: list[dict[str, Any]]) -> list[dict[str, Any]]: | ||
samples = [] | ||
active_samples = [item for item in dataset if item["status"] != "ARCHIVED"] | ||
|
||
for item in active_samples: | ||
samples.append( | ||
{ | ||
"item": item, | ||
"datasets": item.get("attachments", []), | ||
} | ||
) | ||
return samples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
user: eval-example | ||
|
||
metadata: | ||
name: tablegpt eval | ||
llm: | ||
name: qwen2.5-7b-instruct | ||
temperature: 0.1 | ||
top_p: 0.3 | ||
|
||
datasets: | ||
- name: /datasets/tablegpt-eval-normal.json | ||
|
||
grader: | ||
openai_api_base: http://localhost:8080/v1 | ||
openai_api_key: nothing | ||
model_name: qwen2.5-72b-instruct | ||
temperature: 0.1 | ||
top_p: 0.3 | ||
max_tokens: 1024 | ||
|
||
max_concurrency: 1 | ||
num_repetitions: 1 |
Oops, something went wrong.