Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/system/small/ml/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pandas as pd
import pytest

import bigframes
from bigframes import exceptions
from bigframes.ml import core, llm
import bigframes.pandas as bpd
Expand Down Expand Up @@ -260,6 +261,44 @@ def test_text_embedding_generator_multi_cols_predict_success(
assert len(pd_df["ml_generate_embedding_result"][0]) == 768


def test_create_load_multimodal_embedding_generator_model(
dataset_id, session, bq_connection
):
bigframes.options.experiments.blob = True

mm_embedding_model = llm.MultimodalEmbeddingGenerator(
connection_name=bq_connection, session=session
)
assert mm_embedding_model is not None
assert mm_embedding_model._bqml_model is not None

# save, load to ensure configuration was kept
reloaded_model = mm_embedding_model.to_gbq(
f"{dataset_id}.temp_mm_model", replace=True
)
assert f"{dataset_id}.temp_mm_model" == reloaded_model._bqml_model.model_name
assert reloaded_model.connection_name == bq_connection


@pytest.mark.flaky(retries=2)
def test_multimodal_embedding_generator_predict_default_params_success(
images_mm_df, session, bq_connection
):
bigframes.options.experiments.blob = True

text_embedding_model = llm.MultimodalEmbeddingGenerator(
connection_name=bq_connection, session=session
)
df = text_embedding_model.predict(images_mm_df).to_pandas()
utils.check_pandas_df_schema_and_index(
df,
columns=utils.ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT,
index=2,
col_exact=False,
)
assert len(df["ml_generate_embedding_result"][0]) == 1408


@pytest.mark.parametrize(
"model_name",
(
Expand All @@ -273,6 +312,9 @@ def test_text_embedding_generator_multi_cols_predict_success(
"gemini-2.0-flash-exp",
),
)
@pytest.mark.flaky(
retries=2
) # usually create model shouldn't be flaky, but this one due to the limited quota of gemini-2.0-flash-exp.
def test_create_load_gemini_text_generator_model(
dataset_id, model_name, session, bq_connection
):
Expand Down Expand Up @@ -375,6 +417,36 @@ def test_gemini_text_generator_multi_cols_predict_success(
)


@pytest.mark.parametrize(
"model_name",
(
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
"gemini-2.0-flash-exp",
),
)
@pytest.mark.flaky(retries=2)
def test_gemini_text_generator_multimodal_input(
images_mm_df: bpd.DataFrame, model_name, session, bq_connection
):
bigframes.options.experiments.blob = True

gemini_text_generator_model = llm.GeminiTextGenerator(
model_name=model_name, connection_name=bq_connection, session=session
)
pd_df = gemini_text_generator_model.predict(
images_mm_df, prompt=["Describe", images_mm_df["blob_col"]]
).to_pandas()
utils.check_pandas_df_schema_and_index(
pd_df,
columns=utils.ML_GENERATE_TEXT_OUTPUT + ["blob_col"],
index=2,
col_exact=False,
)


# Overrides __eq__ function for comparing as mock.call parameter
class EqCmpAllDataFrame(bpd.DataFrame):
def __eq__(self, other):
Expand Down
8 changes: 8 additions & 0 deletions tests/system/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@
"ml_generate_embedding_status",
"content",
]
ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT = [
"ml_generate_embedding_result",
"ml_generate_embedding_status",
# start and end sec depend on input format. Images and videos input will contain these 2.
"ml_generate_embedding_start_sec",
"ml_generate_embedding_end_sec",
"content",
]


def skip_legacy_pandas(test):
Expand Down