Skip to content

Commit

Permalink
Fix Cortex provider for tests. (#1666)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-dkurokawa authored Nov 25, 2024
1 parent 1b6b985 commit c41bb59
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ write-golden-%: tests/e2e/test_%.py
WRITE_GOLDEN=1 $(PYTEST) tests/e2e/test_$*.py || true
write-golden: write-golden-dummy write-golden-serial

# Snowflake-specific tests.
test-snowflake:
rm -rf ./dist \
&& rm -rf ./src/core/trulens/data/snowflake_stage_zips \
&& make build \
&& make zip-wheels \
&& make build \
&& make env \
&& TEST_OPTIONAL=1 ALLOW_OPTIONALS=1 $(PYTEST) \
./tests/e2e/test_context_variables.py \
./tests/e2e/test_snowflake_*

# Run the tests for a specific file.
test-file-%: tests/e2e/%
$(PYTEST) tests/e2e/$*
Expand Down
31 changes: 26 additions & 5 deletions src/providers/cortex/trulens/providers/cortex/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from snowflake.cortex import Complete
from snowflake.snowpark import Session
from snowflake.snowpark import context
from snowflake.snowpark.exceptions import SnowparkSessionException
from trulens.core.utils import pyschema as pyschema_utils
from trulens.feedback import llm_provider
from trulens.feedback import prompts as feedback_prompts
from trulens.providers.cortex import endpoint as cortex_endpoint
Expand All @@ -18,6 +20,7 @@ class Cortex(
): # require `pip install snowflake-snowpark-python snowflake-ml-python>=1.7.1` and a active Snowflake account with proper privileges
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions#availability

DEFAULT_SNOWPARK_SESSION: Optional[Session] = None
DEFAULT_MODEL_ENGINE: ClassVar[str] = "llama3.1-8b"

model_engine: str
Expand Down Expand Up @@ -100,12 +103,30 @@ def __init__(
*args, **kwargs
)

self_kwargs["snowpark_session"] = (
if snowpark_session is None or pyschema_utils.is_noserio(
snowpark_session
if snowpark_session is not None
else context.get_active_session() # context.get_active_session() will fail if there is no or more than one active session. This should be fine
# for server side eval in the warehouse as there should only be one active session in the execution context.
)
):
if (
hasattr(self, "DEFAULT_SNOWPARK_SESSION")
and self.DEFAULT_SNOWPARK_SESSION is not None
):
snowpark_session = self.DEFAULT_SNOWPARK_SESSION
else:
# context.get_active_session() will fail if there is no or more
# than one active session. This should be fine for server side
# eval in the warehouse as there should only be one active
# session in the execution context.
try:
snowpark_session = context.get_active_session()
except SnowparkSessionException:
class_name = (
f"{self.__module__}.{self.__class__.__qualname__}"
)
raise ValueError(
"Cannot infer snowpark session to use! Try setting "
f"`{class_name}.DEFAULT_SNOWPARK_SESSION`."
)
self_kwargs["snowpark_session"] = snowpark_session

super().__init__(**self_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/data/staged_packages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@
"tru_session = TruSession(connector=connector)\n",
"\n",
"# Set up feedback functions.\n",
"relevance = Cortex(snowpark_session.connection).relevance\n",
"relevance = Cortex(snowpark_session).relevance\n",
"f_regular = Feedback(relevance).on_input_output()\n",
"f_snowflake = SnowflakeFeedback(relevance).on_input_output()\n",
"feedbacks = [\n",
Expand Down
2 changes: 2 additions & 0 deletions tests/util/snowflake_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from snowflake.snowpark.row import Row
from trulens.connectors import snowflake as snowflake_connector
from trulens.core import session as core_session
from trulens.providers.cortex.provider import Cortex


class SnowflakeTestCase(TestCase):
Expand All @@ -31,6 +32,7 @@ def setUp(self):
self._snowflake_connection_parameters
).create()
self._snowflake_schemas_to_delete = set()
Cortex.DEFAULT_SNOWPARK_SESSION = self._snowpark_session

def tearDown(self):
# [HACK!] Clean up any instances of `TruSession` so tests don't interfere with each other.
Expand Down

0 comments on commit c41bb59

Please sign in to comment.