diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py index 16b313c1d71d8..bfaba2d658a58 100644 --- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py +++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from cohere.core.request_options import RequestOptions - from cohere.types import ChatMessages, EmbedByTypeResponseEmbeddings + from cohere.types import ChatMessages logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ def get_conn(self) -> cohere.ClientV2: def create_embeddings( self, texts: list[str], model: str = "embed-multilingual-v3.0" - ) -> EmbedByTypeResponseEmbeddings: + ) -> list[list[float]]: logger.info("Creating embeddings with model: embed-multilingual-v3.0") response = self.get_conn().embed( texts=texts, @@ -100,8 +100,9 @@ def create_embeddings( embedding_types=["float"], request_options=self.request_options, ) - embeddings = response.embeddings - return embeddings + if response.embeddings.float_ is None: + raise ValueError("Embeddings response is missing float_ field") + return response.embeddings.float_ @classmethod def get_ui_field_behaviour(cls) -> dict[str, Any]: diff --git a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py index b06f13ab02194..1504858fc217e 100644 --- a/providers/cohere/src/airflow/providers/cohere/operators/embedding.py +++ b/providers/cohere/src/airflow/providers/cohere/operators/embedding.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from cohere.core.request_options import RequestOptions - from cohere.types import EmbedByTypeResponseEmbeddings try: from airflow.sdk.definitions.context import Context @@ -91,6 +90,9 @@ def hook(self) -> CohereHook: request_options=self.request_options, ) - def execute(self, context: Context) -> EmbedByTypeResponseEmbeddings: + def execute(self, context: Context) -> list[list[float]]: """Embed texts using Cohere embed services.""" - return self.hook.create_embeddings(self.input_text) + embedding_response = self.hook.create_embeddings(self.input_text) + + # Extract just the embeddings list, which is serializable + return embedding_response diff --git a/providers/cohere/tests/unit/cohere/operators/test_embedding.py b/providers/cohere/tests/unit/cohere/operators/test_embedding.py index 640690f1f1d64..1b6bf810e93d7 100644 --- a/providers/cohere/tests/unit/cohere/operators/test_embedding.py +++ b/providers/cohere/tests/unit/cohere/operators/test_embedding.py @@ -27,12 +27,12 @@ def test_cohere_embedding_operator(cohere_client, get_connection): """ Test Cohere client is getting called with the correct key and that - the execute methods returns expected response. + the execute method returns expected response. """ - embedded_obj = [1, 2, 3] + embedded_obj = [[1.0, 2.0, 3.0]] - class resp: - embeddings = embedded_obj + mock_response = MagicMock() + mock_response.embeddings.float_ = embedded_obj api_key = "test" base_url = "http://some_host.com" @@ -43,7 +43,7 @@ class resp: get_connection.return_value = Connection(conn_type="cohere", password=api_key, host=base_url) client_obj = MagicMock() cohere_client.return_value = client_obj - client_obj.embed.return_value = resp + client_obj.embed.return_value = mock_response op = CohereEmbeddingOperator( task_id="embed",