Skip to content

Commit

Permalink
Differentiate test behavior depending on env
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Oct 9, 2024
1 parent 9c6821a commit a252f86
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 38 deletions.
41 changes: 27 additions & 14 deletions src/snowflake/cli/_plugins/snowpark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ def _check_if_replace_is_required(
)
return True

if (
resource_json["handler"].lower() != entity.handler.lower()
or resource_json["returns"].lower()
!= user_to_sql_type_mapper(entity.returns).lower()
if resource_json["handler"].lower() != entity.handler.lower() or not same_type(
resource_json["returns"], entity.returns
):
log.info(
"Return type or handler types do not match. Replacing the %s.", object_type
Expand Down Expand Up @@ -216,16 +214,31 @@ def _standardize(packages: List[str]) -> Set[str]:
return _standardize(old_dependencies) != _standardize(new_dependencies)


def user_to_sql_type_mapper(user_provided_type: str) -> str:
def _cast(user_type: str, sql_type: str, default: str) -> str | None:
if user_type == sql_type:
# TEXT -> VARCHAR(16777216)
return default
if user_type.startswith(sql_type):
# TEXT(30) -> VARCHAR(30)
return user_type.replace(sql_type, "VARCHAR")
return None
def same_type(sf_type: str, local_type: str) -> bool:
sf_type, local_type = sf_type.upper(), local_type.upper()

# 1. Types are equal out of the box
if sf_type == local_type:
return True

# 2. Local type is alias for Snowflake type
local_type = user_to_sql_type_mapper(local_type).upper()
if sf_type == local_type:
return True

# 3. Local type is a subset of Snowflake type, e.g. VARCHAR(N) == VARCHAR
# We solved for local VARCHAR(N) in point 1 & 2 as those are explicit types
if sf_type.startswith(local_type):
return True

# 4. Snowflake types is subset of local type
if local_type.startswith(sf_type):
return True

return False


def user_to_sql_type_mapper(user_provided_type: str) -> str:
mapping = {
("VARCHAR", "(16777216)"): ("CHAR", "TEXT", "STRING"),
("BINARY", "(8388608)"): ("BINARY", "VARBINARY"),
Expand Down Expand Up @@ -258,7 +271,7 @@ def _cast(user_type: str, sql_type: str, default: str) -> str | None:
for type_ in matching_types:
if user_provided_type == type_:
# TEXT -> VARCHAR(16777216)
return default
return cast_type + default
if user_provided_type.startswith(type_):
# TEXT(30) -> VARCHAR(30)
return user_provided_type.replace(type_, cast_type + default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def find_packages_available_in_snowflake_anaconda(self) -> AnacondaPackages:

def _query_snowflake_for_available_packages(self) -> dict[str, AvailablePackage]:
cursor = self._execute_query(
"select package_name, version from information_schema.packages where language = 'python'",
"select package_name, version from snowflake.information_schema.packages where language = 'python'",
cursor_class=DictCursor,
)
if cursor.rowcount is None or cursor.rowcount == 0:
Expand Down
47 changes: 29 additions & 18 deletions tests/snowpark/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@

from __future__ import annotations

from typing import Tuple

import pytest
from snowflake.cli._plugins.snowpark.common import (
_check_if_replace_is_required,
_convert_resource_details_to_dict,
_snowflake_dependencies_differ,
_sql_to_python_return_type_mapper,
is_name_a_templated_one,
same_type,
)
from snowflake.cli._plugins.snowpark.snowpark_entity_model import (
ProcedureEntityModel,
Expand Down Expand Up @@ -63,21 +61,6 @@ def test_convert_resource_details_to_dict():
}


@pytest.mark.parametrize(
"argument",
[
("NUMBER(38,0)", "int"),
("TIMESTAMP_NTZ(9)", "datetime"),
("TIMESTAMP_TZ(9)", "datetime"),
("VARCHAR(16777216)", "string"),
("FLOAT", "float"),
("ARRAY", "array"),
],
)
def test_sql_to_python_return_type_mapper(argument: Tuple[str, str]):
assert _sql_to_python_return_type_mapper(argument[0]) == argument[1]


@pytest.mark.parametrize(
"arguments, expected",
[
Expand Down Expand Up @@ -183,3 +166,31 @@ def test_check_if_replace_is_required_file_changes(
)
def test_is_name_is_templated_one(name: str, expected: bool):
assert is_name_a_templated_one(name) == expected


@pytest.mark.parametrize(
"sf_type, local_type",
[
("VARCHAR", "STRING"),
("VARCHAR(16777216)", "STRING"),
("VARCHAR(16777216)", "VARCHAR"),
("VARCHAR(16777216)", "VARCHAR(16777216)"),
("NUMBER(38,0)", "int"),
("TIMESTAMP_NTZ", "datetime"),
("FLOAT", "float"),
("ARRAY", "array"),
],
)
def test_the_same_type(sf_type, local_type):
assert same_type(sf_type, local_type)


@pytest.mark.parametrize(
"sf_type, local_type",
[
("VARCHAR(25)", "STRING"),
("VARCHAR(25)", "VARCHAR(16777216)"),
],
)
def test_is_not_the_same_type(sf_type, local_type):
assert not same_type(sf_type, local_type)
2 changes: 2 additions & 0 deletions tests_integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import functools
import json
import os
import shutil
import tempfile
from contextlib import contextmanager
Expand Down Expand Up @@ -55,6 +56,7 @@
TEST_DIR = Path(__file__).parent
DEFAULT_TEST_CONFIG = "connection_configs.toml"
WORLD_READABLE_CONFIG = "world_readable.toml"
IS_QA = "qa" in os.getenv("SNOWFLAKE_CONNECTIONS_INTEGRATION_HOST", "").lower()


@dataclass
Expand Down
11 changes: 6 additions & 5 deletions tests_integration/test_snowpark.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest


from tests_integration.conftest import IS_QA
from tests_integration.testing_utils import (
SnowparkTestSteps,
)
Expand All @@ -32,6 +32,7 @@


STAGE_NAME = "dev_deployment"
RETURN_TYPE = "VARCHAR" if IS_QA else "VARCHAR(16777216)"


@pytest.mark.integration
Expand Down Expand Up @@ -92,14 +93,14 @@ def test_snowpark_flow(
object_type="procedure",
identifier="hello_procedure(VARCHAR)",
signature="(NAME VARCHAR)",
returns="VARCHAR",
returns=RETURN_TYPE,
)

_test_steps.object_describe_should_return_entity_description(
object_type="function",
identifier="hello_function(VARCHAR)",
signature="(NAME VARCHAR)",
returns="VARCHAR",
returns=RETURN_TYPE,
)

# Grants are given correctly
Expand Down Expand Up @@ -520,13 +521,13 @@ def test_snowpark_default_arguments(
object_type="function",
identifier="WHOLE_NEW_WORD(VARCHAR, NUMBER, VARCHAR)",
signature="(BASE VARCHAR, MULT NUMBER, SUFFIX VARCHAR)",
returns="VARCHAR",
returns=RETURN_TYPE,
)
_test_steps.object_describe_should_return_entity_description(
object_type="procedure",
identifier="WHOLE_NEW_WORD_PROCEDURE(VARCHAR, NUMBER, VARCHAR)",
signature="(BASE VARCHAR, MULT NUMBER, SUFFIX VARCHAR)",
returns="VARCHAR",
returns=RETURN_TYPE,
)

# execute with default arguments
Expand Down

0 comments on commit a252f86

Please sign in to comment.