From a1aa9146df8d60c9c8761b323eb2c00fc2338e31 Mon Sep 17 00:00:00 2001 From: vatsrahul1001 Date: Wed, 4 Jun 2025 16:29:15 +0530 Subject: [PATCH 1/3] make cohere provider AF3 compatible --- .../cohere/src/airflow/providers/cohere/hooks/cohere.py | 9 +++++---- .../src/airflow/providers/cohere/operators/embedding.py | 8 +++++--- .../tests/system/pinecone/example_pinecone_cohere.py | 2 +- .../tests/system/weaviate/example_weaviate_cohere.py | 1 - 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py index 16b313c1d71d8..bfaba2d658a58 100644 --- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py +++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from cohere.core.request_options import RequestOptions - from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings + from cohere.types import ChatMessages logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ def get_conn(self) -> cohere.ClientV2: def create_embeddings( self, texts: list[str], model: str = "embed-multilingual-v3.0" - ) -> EmbedByTypeResponseEmbeddings: + ) -> list[list[float]]: logger.info("Creating embeddings with model: embed-multilingual-v3.0") response = self.get_conn().embed( texts=texts, @@ -100,8 +100,9 @@ def create_embeddings( embedding_types=["float"], request_options=self.request_options, ) - embeddings = response.embeddings - return embeddings + if response.embeddings.float_ is None: + raise ValueError("Embeddings response is missing float_ field") + return response.embeddings.float_ @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: diff --git a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py index b06f13ab02194..1504858fc217e 100644 --- a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py +++ b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from cohere.core.request_options import RequestOptions - from cohere.types import EmbedByTypeResponseEmbeddings try: from airflow.sdk.definitions.context import Context @@ -91,6 +90,9 @@ def hook(self) -> CohereHook: request_options=self.request_options, ) - def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings: + def execute(self, context: Context) -> list[list[float]]: """Embed texts using Cohere embed services.""" - return self.hook.create_embeddings(self.input_text) + embedding_response = self.hook.create_embeddings(self.input_text) + + # Extract just the embeddings list, which is serializable + return embedding_response diff --git a/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py b/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py index 496c0a40154bf..2e74fe870a07d 100644 --- a/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py +++ b/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py @@ -54,7 +54,7 @@ def create_index(): @task def transform_output(embedding_output) -> list[dict]: # Convert each embedding to a map with an ID and the embedding vector - return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output.float_)] + return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)] transformed_output = transform_output(embed_task.output) diff --git a/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py b/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py index 9376a9235152f..11b45c2cb45be 100644 --- a/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py +++ b/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py @@ -72,7 +72,6 @@ def update_vector_data_in_json(**kwargs): data = json.load(Path("jeopardy_data_without_vectors.json").open()) embedded_data = ti.xcom_pull(task_ids="embedding_using_xcom_data", key="return_value") for i, vector in enumerate(embedded_data): - vector = vector.float_ data[i]["Vector"] = vector[0] return data From 8d9db65c573b4906bcab9f1726201c10aa2e1495 Mon Sep 17 00:00:00 2001 From: vatsrahul1001 Date: Thu, 5 Jun 2025 12:04:12 +0530 Subject: [PATCH 2/3] fix unit test --- .../tests/unit/cohere/operators/test_embedding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/providers/cohere/tests/unit/cohere/operators/test_embedding.py b/providers/cohere/tests/unit/cohere/operators/test_embedding.py index 640690f1f1d64..1b6bf810e93d7 100644 --- a/providers/cohere/tests/unit/cohere/operators/test_embedding.py +++ b/providers/cohere/tests/unit/cohere/operators/test_embedding.py @@ -27,12 +27,12 @@ def test_cohere_embedding_operator(cohere_client, get_connection): """ Test Cohere client is getting called with the correct key and that - the execute methods returns expected response. + the execute method returns expected response. """ - embedded_obj = [1, 2, 3] + embedded_obj = [[1.0, 2.0, 3.0]] - class resp: - embeddings = embedded_obj + mock_response = MagicMock() + mock_response.embeddings.float_ = embedded_obj api_key = "test" base_url = "http://some_host.com" @@ -43,7 +43,7 @@ class resp: get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url) client_obj = MagicMock() cohere_client.return_value = client_obj - client_obj.embed.return_value = resp + client_obj.embed.return_value = mock_response op = CohereEmbeddingOperator( task_id="embed", From 27acf604aee25279a11ee8f83b346e24381660f1 Mon Sep 17 00:00:00 2001 From: vatsrahul1001 Date: Fri, 6 Jun 2025 08:31:54 +0530 Subject: [PATCH 3/3] removing changes for pinecone and weaviate --- .../pinecone/tests/system/pinecone/example_pinecone_cohere.py | 2 +- .../weaviate/tests/system/weaviate/example_weaviate_cohere.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py b/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py index 2e74fe870a07d..496c0a40154bf 100644 --- a/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py +++ b/providers/pinecone/tests/system/pinecone/example_pinecone_cohere.py @@ -54,7 +54,7 @@ def create_index(): @task def transform_output(embedding_output) -> list[dict]: # Convert each embedding to a map with an ID and the embedding vector - return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)] + return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output.float_)] transformed_output = transform_output(embed_task.output) diff --git a/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py b/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py index 11b45c2cb45be..9376a9235152f 100644 --- a/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py +++ b/providers/weaviate/tests/system/weaviate/example_weaviate_cohere.py @@ -72,6 +72,7 @@ def update_vector_data_in_json(**kwargs): data = json.load(Path("jeopardy_data_without_vectors.json").open()) embedded_data = ti.xcom_pull(task_ids="embedding_using_xcom_data", key="return_value") for i, vector in enumerate(embedded_data): + vector = vector.float_ data[i]["Vector"] = vector[0] return data