Skip to content

Commit

Permalink
fix: allow evaluation of multi-turn conversations containing tool cal…
Browse files Browse the repository at this point in the history
…ls (GoogleCloudPlatform#1313)


Main author: @mariagpuyol 
co-author: @eliasecchig 

We simplify the evaluation process on the starter pack.
1. Updated the poetry environment to use the new extra:
google-cloud-aiplatform[evaluation]
2. Removed batch scoring utils functions to leverage the `batch` method
with LangChain
3. Introduce support for tool calling processing within the ground truth

---------

Co-authored-by: Elia Secchi <eliasecchi@google.com>
Co-authored-by: Holt Skinner <holtskinner@google.com>
  • Loading branch information
3 people authored Oct 25, 2024
1 parent dac293d commit 97e6a0f
Show file tree
Hide file tree
Showing 10 changed files with 1,188 additions and 646 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@
/generative-ai/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb @alvarobartt @philschmid @pagezyhf @jeffboudier
/generative-ai/gemini/use-cases/applying-llms-to-data/semantic-search-in-bigquery/stackoverflow_questions_semantic_search.ipynb @sethijaideep @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/gemini/use-cases/retrieval-augmented-generation/raw_with_bigquery.ipynb @jeffonelson @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/gemini/sample-apps/e2e-gen-ai-app-starter-pack @eliasecchig @lspatarog @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/gemini/sample-apps/e2e-gen-ai-app-starter-pack @eliasecchig @lspatarog @mariagpuyol @GoogleCloudPlatform/generative-ai-devrel
/generative-ai/vision/use-cases/hey_llm @tushuhei @GoogleCloudPlatform/generative-ai-devrel
21 changes: 16 additions & 5 deletions .github/actions/spelling/allow.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ AIP
AMNOSH
ANZ
APIENTRY
appspot
APSTUDIO
AUVs
Adidas
Expand Down Expand Up @@ -66,7 +65,6 @@ Dexin
Disturbia
Doaa
Doogler
dotprompt
Dreesen
Durafast
Durmus
Expand Down Expand Up @@ -109,7 +107,6 @@ Gameplay
Gandalf
Gatace
GenTwo
genkit
Gfm
Gisting
Glickman
Expand Down Expand Up @@ -206,6 +203,7 @@ Mosher
Mvar
NARI
NCCREATE
NDCG
NDEBUG
NGRAM
NGRAMS
Expand All @@ -218,7 +216,6 @@ NVIDIA
Nagasu
Niitsuma
Nintendo
noabe
Nominatim
Noogler
ODb
Expand Down Expand Up @@ -251,6 +248,7 @@ Qwiklab
Qwiklabs
RAGAS
RLHF
RMSE
RNNs
ROOTSPAN
RRF
Expand Down Expand Up @@ -350,6 +348,7 @@ Wehn
Welwyn
Wnd
Womens
XSum
XXE
Xiang
Youxi
Expand Down Expand Up @@ -378,6 +377,7 @@ alloydb
antiword
apikey
apikeys
appspot
apredict
aquery
arXiv
Expand Down Expand Up @@ -457,6 +457,7 @@ dino
diy
docai
docstore
dotprompt
dpi
draig
drinkware
Expand All @@ -474,6 +475,7 @@ embs
embvs
emojis
ename
engi
epath
epoc
erty
Expand All @@ -484,6 +486,7 @@ evals
faiss
fastapi
fda
fea
fect
fewshot
ffi
Expand Down Expand Up @@ -516,6 +519,7 @@ gcsfs
gdk
gdkx
genai
genkit
geocoded
getdata
getexif
Expand Down Expand Up @@ -556,6 +560,7 @@ icudtl
idk
idks
idxs
ience
iloc
imagefont
imageno
Expand All @@ -572,6 +577,7 @@ itable
itables
iterrows
ivf
ized
jegadesh
jetbrains
jiwer
Expand Down Expand Up @@ -626,12 +632,14 @@ nbfmt
nbformat
ncols
ndarray
neering
newaxis
newaxisngram
ngrams
nlp
nmade
nmilitary
noabe
nobserved
norigin
notetaker
Expand Down Expand Up @@ -678,6 +686,7 @@ podfile
podhelper
powerups
preds
produc
projectid
protobuf
pstotext
Expand Down Expand Up @@ -712,6 +721,7 @@ reranked
reranker
reranking
reranks
resil
ribeye
ringspun
roboto
Expand Down Expand Up @@ -763,6 +773,7 @@ thelook
throug
tiktoken
timechart
tion
titlebar
tobytes
toself
Expand All @@ -772,6 +783,7 @@ traceloop
tritan
tseslint
tsv
tures
ubuntu
undst
unigram
Expand Down Expand Up @@ -821,4 +833,3 @@ youtube
ytd
yticks
zaxis
XSum
181 changes: 40 additions & 141 deletions gemini/sample-apps/e2e-gen-ai-app-starter-pack/app/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from concurrent.futures import ThreadPoolExecutor
from functools import partial
import glob
import logging
from typing import Any, Callable, Dict, Iterator, List
from typing import Any, Dict, List

import nest_asyncio
import pandas as pd
from tqdm import tqdm
import yaml

nest_asyncio.apply()


def load_chats(path: str) -> List[Dict[str, Any]]:
"""
Expand All @@ -45,27 +38,47 @@ def load_chats(path: str) -> List[Dict[str, Any]]:
return chats


def pairwise(iterable: List[Any]) -> Iterator[tuple[Any, Any]]:
"""Creates an iterable with tuples paired together
e.g s -> (s0, s1), (s2, s3), (s4, s5), ...
def _process_conversation(row: Dict[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
"""Processes a single conversation row to extract messages and build conversation history.
Most human-ai interactions are composed of a human message followed by an ai message.
But when there's a tool call, the interactions are as follows:
- human message
- ai message with empty content and tool_calls set
- tool message with tool call arguments
- ai message with non-empty content and tool_calls empty.
In any case the human message is the first in the set and the final answer is the last in the set.
"""
a = iter(iterable)
return zip(a, a)
conversation_history: List[Dict] = []
messages: List[Dict[str, Any]] = []
messages_since_last_human_message: List[Dict[str, Any]] = []

for message in row["messages"]:
if message["type"] == "human":
# Reset for new human message
messages_since_last_human_message = []

def _process_conversation(row: Dict[str, List[str]]) -> List[Dict[str, Any]]:
"""Processes a single conversation row to extract messages and build conversation history."""
conversation_history: List[Dict] = []
messages = []
for human_message, ai_message in pairwise(row["messages"]):
messages.append(
{
"human_message": human_message,
"ai_message": ai_message,
"conversation_history": conversation_history.copy(),
}
)
conversation_history.extend([human_message, ai_message])
# Add current message to temporary storage
messages_since_last_human_message.append(message)

# Check if this is a final AI response (not a tool call)
if message["type"] == "ai" and (
"tool_calls" not in message or len(message["tool_calls"]) == 0
):
# Process the completed exchange
messages.append(
{
"human_message": messages_since_last_human_message[
0
], # First message is human
"ai_message": messages_since_last_human_message[
-1
], # Last message is AI's final response
"conversation_history": conversation_history.copy(), # Include previous conversation
}
)

# Update overall conversation history
conversation_history.extend(messages_since_last_human_message)
return messages


Expand All @@ -89,121 +102,7 @@ def generate_multiturn_history(df: pd.DataFrame) -> pd.DataFrame:
- human_message: The human message in that turn.
- ai_message: The AI message in that turn.
- conversation_history: A list of all messages in the conversation
up to and including the current turn.
up to the current turn (excluded).
"""
processed_messages = df.apply(_process_conversation, axis=1).explode().tolist()
return pd.DataFrame(processed_messages)


def generate_message(row: tuple[int, Dict[str, Any]], runnable: Any) -> Dict[str, Any]:
"""Generates a response message using a given runnable and updates the row dictionary.
This function takes a row dictionary containing message data and a runnable object.
It extracts conversation history and the current human message from the row,
then uses the runnable to generate a response based on the conversation history.
The generated response content and usage metadata are then added to the original
message dictionary within the row.
Args:
row (tuple[int, Dict[str, Any]]): A tuple containing the index and a dictionary
with message data, including:
- "conversation_history" (List[str]): Optional. List of previous
messages
in the conversation.
- "human_message" (str): The current human message.
runnable (Any): A runnable object that takes a dictionary with a "messages" key
and returns a response object with "content" and
"usage_metadata" attributes.
Returns:
Dict[str, Any]: The updated row dictionary with the generated response added to the message.
The message will now contain:
- "response" (str): The generated response content.
- "response_obj" (Any): The usage metadata of the response from the runnable.
"""
_, message = row
messages = (
message["conversation_history"] if "conversation_history" in message else []
)
messages.append(message["human_message"])
input_runnable = {"messages": messages}
response = runnable.invoke(input_runnable)
message["response"] = response.content
message["response_obj"] = response.usage_metadata
return message


def batch_generate_messages(
messages: pd.DataFrame,
runnable: Callable[[List[Dict[str, Any]]], Dict[str, Any]],
max_workers: int = 4,
) -> pd.DataFrame:
"""Generates AI responses to user messages using a provided runnable.
Processes a Pandas DataFrame containing conversation histories and user messages, utilizing
the specified runnable to predict AI responses in parallel.
Args:
messages (pd.DataFrame): DataFrame with a 'messages' column. Each row
represents a conversation and contains a list of dictionaries, where
each dictionary
represents a message turn in the format:
```json
[
{"type": "human", "content": "user's message"},
{"type": "ai", "content": "AI's response"},
{"type": "human", "content": "current user's message"},
...
]
```
runnable (Callable[[List[Dict[str, Any]]], Dict[str, Any]]): Runnable object
(e.g., LangChain Chain) used
for response generation. It should accept a list of message dictionaries
(as described above) and return a dictionary with the following structure:
```json
{
"response": "AI's response",
"response_obj": { ... } # optional response metadata
}
```
max_workers (int, optional): Number of worker processes for parallel
prediction. Defaults to 4.
Returns:
pd.DataFrame: DataFrame with the original 'messages' column and two new
columns: 'response' containing the predicted AI responses, and
'response_obj' containing optional response metadata.
Example:
```python
import pandas as pd
messages_df = pd.DataFrame({
"messages": [
[
{"type": "human", "content": "What's the weather today?"}
],
[
{"type": "human", "content": "Tell me a joke."},
{"type": "ai", "content": "Why did the scarecrow win an award?"},
{"type": "human", "content": "I don't know, why?"}
]
]
})
responses_df = batch_generate_messages(my_runnable, messages_df)
```
"""
logging.info("Executing batch scoring")
predicted_messages = []
with ThreadPoolExecutor(max_workers) as pool:
partial_func = partial(generate_message, runnable=runnable)
for message in tqdm(
pool.map(partial_func, messages.iterrows()), total=len(messages)
):
predicted_messages.append(message)
return pd.DataFrame(predicted_messages)
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ def should_continue(state: MessagesState) -> str:
return "tools" if last_message.tool_calls else END


async def call_model(
state: MessagesState, config: RunnableConfig
) -> Dict[str, BaseMessage]:
def call_model(state: MessagesState, config: RunnableConfig) -> Dict[str, BaseMessage]:
"""Calls the language model and returns the response."""
system_message = "You are a helpful AI assistant."
messages_with_system = [{"type": "system", "content": system_message}] + state[
Expand Down
Loading

0 comments on commit 97e6a0f

Please sign in to comment.