From c41bb594c32a38de152b716de23adc5829453f21 Mon Sep 17 00:00:00 2001 From: David Kurokawa Date: Mon, 25 Nov 2024 10:55:07 -0800 Subject: [PATCH] Fix Cortex provider for tests. (#1666) --- Makefile | 12 +++++++ .../trulens/providers/cortex/provider.py | 31 ++++++++++++++++--- tests/e2e/data/staged_packages.ipynb | 2 +- tests/util/snowflake_test_case.py | 2 ++ 4 files changed, 41 insertions(+), 6 deletions(-) diff --git a/Makefile b/Makefile index 25f4ebcf3..f48a785a9 100644 --- a/Makefile +++ b/Makefile @@ -168,6 +168,18 @@ write-golden-%: tests/e2e/test_%.py WRITE_GOLDEN=1 $(PYTEST) tests/e2e/test_$*.py || true write-golden: write-golden-dummy write-golden-serial +# Snowflake-specific tests. +test-snowflake: + rm -rf ./dist \ + && rm -rf ./src/core/trulens/data/snowflake_stage_zips \ + && make build \ + && make zip-wheels \ + && make build \ + && make env \ + && TEST_OPTIONAL=1 ALLOW_OPTIONALS=1 $(PYTEST) \ + ./tests/e2e/test_context_variables.py \ + ./tests/e2e/test_snowflake_* + # Run the tests for a specific file. test-file-%: tests/e2e/% $(PYTEST) tests/e2e/$* diff --git a/src/providers/cortex/trulens/providers/cortex/provider.py b/src/providers/cortex/trulens/providers/cortex/provider.py index 5d7235d7f..789161dd6 100644 --- a/src/providers/cortex/trulens/providers/cortex/provider.py +++ b/src/providers/cortex/trulens/providers/cortex/provider.py @@ -8,6 +8,8 @@ from snowflake.cortex import Complete from snowflake.snowpark import Session from snowflake.snowpark import context +from snowflake.snowpark.exceptions import SnowparkSessionException +from trulens.core.utils import pyschema as pyschema_utils from trulens.feedback import llm_provider from trulens.feedback import prompts as feedback_prompts from trulens.providers.cortex import endpoint as cortex_endpoint @@ -18,6 +20,7 @@ class Cortex( ): # require `pip install snowflake-snowpark-python snowflake-ml-python>=1.7.1` and a active Snowflake account with proper privileges # https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability + DEFAULT_SNOWPARK_SESSION: Optional[Session] = None DEFAULT_MODEL_ENGINE: ClassVar[str] = "llama3.1-8b" model_engine: str @@ -100,12 +103,30 @@ def __init__( *args, **kwargs ) - self_kwargs["snowpark_session"] = ( + if snowpark_session is None or pyschema_utils.is_noserio( snowpark_session - if snowpark_session is not None - else context.get_active_session() # context.get_active_session() will fail if there is no or more than one active session. This should be fine - # for server side eval in the warehouse as there should only be one active session in the execution context. - ) + ): + if ( + hasattr(self, "DEFAULT_SNOWPARK_SESSION") + and self.DEFAULT_SNOWPARK_SESSION is not None + ): + snowpark_session = self.DEFAULT_SNOWPARK_SESSION + else: + # context.get_active_session() will fail if there is no or more + # than one active session. This should be fine for server side + # eval in the warehouse as there should only be one active + # session in the execution context. + try: + snowpark_session = context.get_active_session() + except SnowparkSessionException: + class_name = ( + f"{self.__module__}.{self.__class__.__qualname__}" + ) + raise ValueError( + "Cannot infer snowpark session to use! Try setting " + f"`{class_name}.DEFAULT_SNOWPARK_SESSION`." + ) + self_kwargs["snowpark_session"] = snowpark_session super().__init__(**self_kwargs) diff --git a/tests/e2e/data/staged_packages.ipynb b/tests/e2e/data/staged_packages.ipynb index 09a3874c8..2cae82798 100644 --- a/tests/e2e/data/staged_packages.ipynb +++ b/tests/e2e/data/staged_packages.ipynb @@ -164,7 +164,7 @@ "tru_session = TruSession(connector=connector)\n", "\n", "# Set up feedback functions.\n", - "relevance = Cortex(snowpark_session.connection).relevance\n", + "relevance = Cortex(snowpark_session).relevance\n", "f_regular = Feedback(relevance).on_input_output()\n", "f_snowflake = SnowflakeFeedback(relevance).on_input_output()\n", "feedbacks = [\n", diff --git a/tests/util/snowflake_test_case.py b/tests/util/snowflake_test_case.py index da1905185..95a82876d 100644 --- a/tests/util/snowflake_test_case.py +++ b/tests/util/snowflake_test_case.py @@ -13,6 +13,7 @@ from snowflake.snowpark.row import Row from trulens.connectors import snowflake as snowflake_connector from trulens.core import session as core_session +from trulens.providers.cortex.provider import Cortex class SnowflakeTestCase(TestCase): @@ -31,6 +32,7 @@ def setUp(self): self._snowflake_connection_parameters ).create() self._snowflake_schemas_to_delete = set() + Cortex.DEFAULT_SNOWPARK_SESSION = self._snowpark_session def tearDown(self): # [HACK!] Clean up any instances of `TruSession` so tests don't interfere with each other.