diff --git a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py index b2d8c5d4476ff..16b313c1d71d8 100644 --- a/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py +++ b/providers/cohere/src/airflow/providers/cohere/hooks/cohere.py @@ -19,7 +19,6 @@ import logging import warnings -from functools import cached_property from typing import TYPE_CHECKING, Any import cohere @@ -65,6 +64,7 @@ def __init__( self.timeout = timeout self.max_retries = max_retries self.request_options = request_options + self._client: cohere.ClientV2 | None = None if self.max_retries: warnings.warn( @@ -77,20 +77,23 @@ def __init__( else: self.request_options.update({"max_retries": self.max_retries}) - @cached_property - def get_conn(self) -> cohere.ClientV2: # type: ignore[override] - conn = self.get_connection(self.conn_id) - return cohere.ClientV2( - api_key=conn.password, - timeout=self.timeout, - base_url=conn.host or None, - ) + def get_conn(self) -> cohere.ClientV2: + """Return a new or cached Cohere client instance.""" + if self._client is None: + # create a new client instance if there is no existing client + conn = self.get_connection(self.conn_id) + self._client = cohere.ClientV2( + api_key=conn.password, + timeout=self.timeout, + base_url=conn.host or None, + ) + return self._client def create_embeddings( self, texts: list[str], model: str = "embed-multilingual-v3.0" ) -> EmbedByTypeResponseEmbeddings: logger.info("Creating embeddings with model: embed-multilingual-v3.0") - response = self.get_conn.embed( + response = self.get_conn().embed( texts=texts, model=model, input_type="search_document", @@ -117,7 +120,7 @@ def test_connection( try: if messages is None: messages = [UserChatMessageV2(role="user", content="hello world!")] - self.get_conn.chat(model=model, messages=messages) + self.get_conn().chat(model=model, messages=messages) return True, "Connection successfully established." except Exception as e: return False, f"Unexpected error: {str(e)}" diff --git a/providers/cohere/tests/unit/cohere/hooks/test_cohere.py b/providers/cohere/tests/unit/cohere/hooks/test_cohere.py index 00c73cd1f2981..f4e7a218b13bc 100644 --- a/providers/cohere/tests/unit/cohere/hooks/test_cohere.py +++ b/providers/cohere/tests/unit/cohere/hooks/test_cohere.py @@ -42,5 +42,5 @@ def test__get_api_key(self): patch("cohere.ClientV2") as client, ): hook = CohereHook(timeout=timeout) - _ = hook.get_conn + _ = hook.get_conn() client.assert_called_once_with(api_key=api_key, timeout=timeout, base_url=base_url)