-
Notifications
You must be signed in to change notification settings - Fork 106
/
Copy pathhf-dataset-llm-eval.py
68 lines (58 loc) · 2.07 KB
/
hf-dataset-llm-eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from huggingface_hub import InferenceClient
from requests import HTTPError
from datachain import C, DataChain, DataModel
PROMPT = """
Was this dialog successful? Put result as a single word: Success or Failure.
Explain the reason in a few words.
"""
class DialogEval(DataModel):
result: str
reason: str
# DataChain function to evaluate dialog.
# DataChain is using types for inputs, results to automatically infer schema.
def eval_dialog(
client: InferenceClient,
user_input: str,
bot_response: str,
) -> DialogEval:
try:
completion = client.chat_completion(
messages=[
{
"role": "user",
"content": f"{PROMPT}\n\nUser: {user_input}\nBot: {bot_response}",
},
],
response_format={"type": "json", "value": DialogEval.model_json_schema()},
)
except HTTPError:
return DialogEval(
result="Error", reason="Error while interacting with the Hugging Face API."
)
message = completion.choices[0].message
try:
return DialogEval.model_validate_json(message.content)
except ValueError:
return DialogEval(result="Error", reason="Failed to parse response.")
# Run HF inference in parallel for each example.
# Get result as Pydantic model that DataChain can understand and serialize it.
# Save to HF as Parquet. Dataset can be previewed here:
# https://huggingface.co/datasets/dvcorg/test-datachain-llm-eval/viewer
(
DataChain.from_csv(
"hf://datasets/infinite-dataset-hub/MobilePlanAssistant/data.csv"
)
.settings(parallel=10)
.setup(client=lambda: InferenceClient("meta-llama/Llama-3.1-70B-Instruct"))
.map(response=eval_dialog)
.to_parquet("hf://datasets/dvcorg/test-datachain-llm-eval/data.parquet")
)
# Read it back to filter and show.
# It restores the Pydantic model from Parquet under the hood.
(
DataChain.from_parquet(
"hf://datasets/dvcorg/test-datachain-llm-eval/data.parquet", source=False
)
.filter(C("response.result") == "Failure")
.show(3)
)