1717import pandas as pd
1818import pytest
1919
20+ import bigframes
2021from bigframes import exceptions
2122from bigframes .ml import core , llm
2223import bigframes .pandas as bpd
@@ -260,6 +261,44 @@ def test_text_embedding_generator_multi_cols_predict_success(
260261 assert len (pd_df ["ml_generate_embedding_result" ][0 ]) == 768
261262
262263
264+ def test_create_load_multimodal_embedding_generator_model (
265+ dataset_id , session , bq_connection
266+ ):
267+ bigframes .options .experiments .blob = True
268+
269+ mm_embedding_model = llm .MultimodalEmbeddingGenerator (
270+ connection_name = bq_connection , session = session
271+ )
272+ assert mm_embedding_model is not None
273+ assert mm_embedding_model ._bqml_model is not None
274+
275+ # save, load to ensure configuration was kept
276+ reloaded_model = mm_embedding_model .to_gbq (
277+ f"{ dataset_id } .temp_mm_model" , replace = True
278+ )
279+ assert f"{ dataset_id } .temp_mm_model" == reloaded_model ._bqml_model .model_name
280+ assert reloaded_model .connection_name == bq_connection
281+
282+
283+ @pytest .mark .flaky (retries = 2 )
284+ def test_multimodal_embedding_generator_predict_default_params_success (
285+ images_mm_df , session , bq_connection
286+ ):
287+ bigframes .options .experiments .blob = True
288+
289+ text_embedding_model = llm .MultimodalEmbeddingGenerator (
290+ connection_name = bq_connection , session = session
291+ )
292+ df = text_embedding_model .predict (images_mm_df ).to_pandas ()
293+ utils .check_pandas_df_schema_and_index (
294+ df ,
295+ columns = utils .ML_MULTIMODAL_GENERATE_EMBEDDING_OUTPUT ,
296+ index = 2 ,
297+ col_exact = False ,
298+ )
299+ assert len (df ["ml_generate_embedding_result" ][0 ]) == 1408
300+
301+
263302@pytest .mark .parametrize (
264303 "model_name" ,
265304 (
@@ -273,6 +312,9 @@ def test_text_embedding_generator_multi_cols_predict_success(
273312 "gemini-2.0-flash-exp" ,
274313 ),
275314)
315+ @pytest .mark .flaky (
316+ retries = 2
317+ ) # usually create model shouldn't be flaky, but this one due to the limited quota of gemini-2.0-flash-exp.
276318def test_create_load_gemini_text_generator_model (
277319 dataset_id , model_name , session , bq_connection
278320):
@@ -375,6 +417,36 @@ def test_gemini_text_generator_multi_cols_predict_success(
375417 )
376418
377419
420+ @pytest .mark .parametrize (
421+ "model_name" ,
422+ (
423+ "gemini-1.5-pro-001" ,
424+ "gemini-1.5-pro-002" ,
425+ "gemini-1.5-flash-001" ,
426+ "gemini-1.5-flash-002" ,
427+ "gemini-2.0-flash-exp" ,
428+ ),
429+ )
430+ @pytest .mark .flaky (retries = 2 )
431+ def test_gemini_text_generator_multimodal_input (
432+ images_mm_df : bpd .DataFrame , model_name , session , bq_connection
433+ ):
434+ bigframes .options .experiments .blob = True
435+
436+ gemini_text_generator_model = llm .GeminiTextGenerator (
437+ model_name = model_name , connection_name = bq_connection , session = session
438+ )
439+ pd_df = gemini_text_generator_model .predict (
440+ images_mm_df , prompt = ["Describe" , images_mm_df ["blob_col" ]]
441+ ).to_pandas ()
442+ utils .check_pandas_df_schema_and_index (
443+ pd_df ,
444+ columns = utils .ML_GENERATE_TEXT_OUTPUT + ["blob_col" ],
445+ index = 2 ,
446+ col_exact = False ,
447+ )
448+
449+
378450# Overrides __eq__ function for comparing as mock.call parameter
379451class EqCmpAllDataFrame (bpd .DataFrame ):
380452 def __eq__ (self , other ):
0 commit comments