Skip to content

Commit

Permalink
feat: Fix file type mismatch in uploading eval results to GCS, suppor…
Browse files Browse the repository at this point in the history
…ted types: CSV, JSON.

PiperOrigin-RevId: 698873912
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Nov 21, 2024
1 parent 97df5fc commit 905c766
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 13 deletions.
24 changes: 24 additions & 0 deletions tests/unit/vertexai/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
_TEST_BUCKET = "gs://test-bucket"
_TEST_FILE_NAME = "test-file-name.csv"
_AUTORATER_INSTRUCTION = """
You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models.
"""
Expand Down Expand Up @@ -181,6 +183,12 @@
text,text,text\n
"""
_TEST_EXPERIMENT = "test-experiment"
_TEST_CSV = pd.DataFrame(
columns={
"response": ["text"],
"reference": ["ref"],
}
)
_EXPECTED_POINTWISE_PROMPT_TEMPLATE = """
# Instruction
hello
Expand Down Expand Up @@ -549,6 +557,16 @@ def mock_experiment_tracker():
yield mock_experiment_tracker


@pytest.fixture
def mock_storage_blob_upload_from_filename():
with mock.patch(
"google.cloud.storage.Blob.upload_from_filename"
) as mock_blob_upload_from_filename, mock.patch(
"google.cloud.storage.Bucket.exists", return_value=True
):
yield mock_blob_upload_from_filename


@pytest.mark.usefixtures("google_auth_mock")
class TestEvaluation:
def setup_method(self):
Expand Down Expand Up @@ -1929,3 +1947,9 @@ def test_pairtwise_metric_prompt_template_with_default_values(self):
str(pairwise_metric_prompt_template)
== _EXPECTED_PAIRWISE_PROMPT_TEMPLATE_WITH_DEFAULT_VALUES.strip()
)

def test_upload_results(self, mock_storage_blob_upload_from_filename):
evaluation.utils.upload_evaluation_results(
_TEST_CSV, _TEST_BUCKET, _TEST_FILE_NAME
)
assert mock_storage_blob_upload_from_filename.called_once_with(_TEST_CSV)
58 changes: 45 additions & 13 deletions vertexai/evaluation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import functools
import io
import os
import tempfile
import threading
import time
from typing import Any, Dict, Optional, TYPE_CHECKING, Union, Callable
from typing import Any, Dict, Optional, TYPE_CHECKING, Union, Callable, Literal

from google.cloud import bigquery
from google.cloud import storage
Expand Down Expand Up @@ -250,25 +251,56 @@ def _read_gcs_file_contents(filepath: str) -> str:
return blob.download_as_string().decode("utf-8")


def _upload_pandas_df_to_gcs(
df: "pd.DataFrame", upload_gcs_path: str, file_type: Literal["csv", "jsonl"]
) -> None:
"""Uploads the provided Pandas DataFrame to a GCS bucket.
Args:
df: The Pandas DataFrame to upload.
upload_gcs_path: The GCS path to upload the data file.
file_type: The file type of the data file.
"""

with tempfile.TemporaryDirectory() as temp_dir:
if file_type == "csv":
local_dataset_path = os.path.join(temp_dir, "metrics_table.csv")
df.to_csv(path_or_buf=local_dataset_path)
elif file_type == "jsonl":
local_dataset_path = os.path.join(temp_dir, "metrics_table.jsonl")
df.to_json(path_or_buf=local_dataset_path, orient="records", lines=True)
else:
raise ValueError(
f"Unsupported file type: {file_type} from {upload_gcs_path}."
" Please provide a valid GCS path with `jsonl` or `csv` suffix."
)

storage_client = storage.Client(
project=initializer.global_config.project,
credentials=initializer.global_config.credentials,
)
storage.Blob.from_string(
uri=upload_gcs_path, client=storage_client
).upload_from_filename(filename=local_dataset_path)


def upload_evaluation_results(
dataset: "pd.DataFrame", destination_uri_prefix: str, file_name: str
):
"""Uploads eval results to GCS CSV destination."""
supported_file_types = ["csv"]
) -> None:
"""Uploads eval results to GCS destination.
Args:
dataset: Pandas dataframe to upload.
destination_uri_prefix: GCS folder to store the data.
file_name: File name to store the data.
"""
if not destination_uri_prefix:
return
if destination_uri_prefix.startswith(_GCS_PREFIX):
_, extension = os.path.splitext(file_name)
file_type = extension.lower()[1:]
if file_type in supported_file_types:
output_path = destination_uri_prefix + "/" + file_name
utils.gcs_utils._upload_pandas_df_to_gcs(dataset, output_path)
else:
raise ValueError(
"Unsupported file type in the GCS destination URI:"
f" {file_name}, please provide a valid GCS"
f" file name with a file type in {supported_file_types}."
)
output_path = destination_uri_prefix + "/" + file_name
_upload_pandas_df_to_gcs(dataset, output_path, file_type)
else:
raise ValueError(
f"Unsupported destination URI: {destination_uri_prefix}."
Expand Down

0 comments on commit 905c766

Please sign in to comment.