Skip to content

Commit

Permalink
fixup! Add default_connection option
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek committed Jul 25, 2023
1 parent aa603ef commit d0c934a
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 118 deletions.
12 changes: 12 additions & 0 deletions src/snowcli/cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def _version_callback(value: bool):
raise typer.Exit()


def _info_callback(value: bool):
if not value:
return


def setup_global_context(debug: bool):
"""
Setup global state (accessible in whole CLI code) using options passed in SNOW CLI invocation.
Expand All @@ -48,6 +53,13 @@ def default(
callback=_version_callback,
is_eager=True,
),
info: bool = typer.Option(
None,
"--info",
help="Prints information about the snowcli",
callback=_info_callback,
is_eager=True,
),
output_format: OutputFormat = typer.Option(
OutputFormat.TABLE.value,
"--format",
Expand Down
3 changes: 1 addition & 2 deletions src/snowcli/cli/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

import typer

from snowcli.config import get_default_connection
from snowcli.utils import check_for_connection

DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]}


ConnectionOption = typer.Option(
get_default_connection(),
None,
"-c",
"--connection",
"--environment",
Expand Down
224 changes: 112 additions & 112 deletions src/snowcli/cli/snowpark_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,142 +150,142 @@ def snowpark_update(
execute_as_caller: bool = False,
install_coverage_wrapper: bool = False,
) -> None:
conn = connect_to_snowflake(connection_name=environment)

if type == "function" and install_coverage_wrapper:
log.error(
"You cannot install a code coverage wrapper on a function, only a procedure."
)
raise typer.Abort()
updated_package_list = []
try:
log.info(f"Updating {type} {name}...")

conn = connect_to_snowflake(connection_name=environment)
updated_package_list = []
try:
log.info(f"Updating {type} {name}...")
if type == "function":
resource_details = conn.describe_function(
name=name,
input_parameters=input_parameters,
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
show_exceptions=False,
)
elif type == "procedure":
resource_details = conn.describe_procedure(
name=name,
input_parameters=input_parameters,
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
show_exceptions=False,
)
log.info("Checking if any new packages to update...")
resource_json = utils.convert_resource_details_to_dict(
resource_details,
) # type: ignore
anaconda_packages = resource_json["packages"]
log.info(
f"Found {len(anaconda_packages)} defined Anaconda "
"packages in deployed {type}..."
)
log.info(
"Checking if any packages defined or missing from "
"requirements.snowflake.txt..."
)
updated_package_list = utils.get_snowflake_packages_delta(
anaconda_packages,
)
if install_coverage_wrapper:
# if we're installing a coverage wrapper, ensure the coverage package included as a dependency
if (
"coverage" not in anaconda_packages
and "coverage" not in updated_package_list
):
updated_package_list.append("coverage")
log.info("Checking if app configuration has changed...")
if (
resource_json["handler"].lower() != handler.lower()
or resource_json["returns"].lower() != return_type.lower()
):
log.info(
"Return type or handler types do not match. Replacing"
"function configuration..."
)
replace = True
except Exception:
log.info(f"Existing {type} not found, creating new {type}...")
replace = True

finally:
deploy_dict = utils.get_deploy_names(
conn.ctx.database,
conn.ctx.schema,
generate_deploy_stage_name(name, input_parameters),
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_app_zip_path = utils.prepare_app_zip(file, temp_dir)
stage_path = deploy_dict["directory"] + "/coverage"
if install_coverage_wrapper:
handler = replace_handler_in_zip(
proc_name=name,
proc_signature=input_parameters,
handler=handler,
coverage_reports_stage=deploy_dict["stage"],
coverage_reports_stage_path=stage_path,
temp_dir=temp_dir,
zip_file_path=temp_app_zip_path,
)
deploy_response = conn.upload_file_to_stage(
file_path=temp_app_zip_path,
destination_stage=deploy_dict["stage"],
path=deploy_dict["directory"],
database=conn.ctx.database,
schema=conn.ctx.schema,
overwrite=True,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
)
log.info(f"{deploy_response[0]} uploaded to stage {deploy_dict['full_path']}")

if updated_package_list or replace:
log.info(f"Replacing {type} with updated values...")
if type == "function":
resource_details = conn.describe_function(
conn.create_function(
name=name,
input_parameters=input_parameters,
return_type=return_type,
handler=handler,
imports=deploy_dict["full_path"],
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
show_exceptions=False,
overwrite=True,
packages=utils.get_snowflake_packages(),
)
elif type == "procedure":
resource_details = conn.describe_procedure(
conn.create_procedure(
name=name,
input_parameters=input_parameters,
return_type=return_type,
handler=handler,
imports=deploy_dict["full_path"],
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
show_exceptions=False,
)
log.info("Checking if any new packages to update...")
resource_json = utils.convert_resource_details_to_dict(
resource_details,
) # type: ignore
anaconda_packages = resource_json["packages"]
log.info(
f"Found {len(anaconda_packages)} defined Anaconda "
"packages in deployed {type}..."
)
log.info(
"Checking if any packages defined or missing from "
"requirements.snowflake.txt..."
)
updated_package_list = utils.get_snowflake_packages_delta(
anaconda_packages,
)
if install_coverage_wrapper:
# if we're installing a coverage wrapper, ensure the coverage package included as a dependency
if (
"coverage" not in anaconda_packages
and "coverage" not in updated_package_list
):
updated_package_list.append("coverage")
log.info("Checking if app configuration has changed...")
if (
resource_json["handler"].lower() != handler.lower()
or resource_json["returns"].lower() != return_type.lower()
):
log.info(
"Return type or handler types do not match. Replacing"
"function configuration..."
)
replace = True
except Exception:
log.info(f"Existing {type} not found, creating new {type}...")
replace = True

finally:
deploy_dict = utils.get_deploy_names(
conn.ctx.database,
conn.ctx.schema,
generate_deploy_stage_name(name, input_parameters),
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_app_zip_path = utils.prepare_app_zip(file, temp_dir)
stage_path = deploy_dict["directory"] + "/coverage"
if install_coverage_wrapper:
handler = replace_handler_in_zip(
proc_name=name,
proc_signature=input_parameters,
handler=handler,
coverage_reports_stage=deploy_dict["stage"],
coverage_reports_stage_path=stage_path,
temp_dir=temp_dir,
zip_file_path=temp_app_zip_path,
)
deploy_response = conn.upload_file_to_stage(
file_path=temp_app_zip_path,
destination_stage=deploy_dict["stage"],
path=deploy_dict["directory"],
database=conn.ctx.database,
schema=conn.ctx.schema,
overwrite=True,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
packages=utils.get_snowflake_packages(),
execute_as_caller=execute_as_caller,
)
log.info(
f"{deploy_response[0]} uploaded to stage {deploy_dict['full_path']}"
f"{type.capitalize()} {name} updated with new packages. "
"Deployment complete!"
)

if updated_package_list or replace:
log.info(f"Replacing {type} with updated values...")
if type == "function":
conn.create_function(
name=name,
input_parameters=input_parameters,
return_type=return_type,
handler=handler,
imports=deploy_dict["full_path"],
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
overwrite=True,
packages=utils.get_snowflake_packages(),
)
elif type == "procedure":
conn.create_procedure(
name=name,
input_parameters=input_parameters,
return_type=return_type,
handler=handler,
imports=deploy_dict["full_path"],
database=conn.ctx.database,
schema=conn.ctx.schema,
role=conn.ctx.role,
warehouse=conn.ctx.warehouse,
overwrite=True,
packages=utils.get_snowflake_packages(),
execute_as_caller=execute_as_caller,
)
log.info(
f"{type.capitalize()} {name} updated with new packages. "
"Deployment complete!"
)
else:
log.info("No packages to update. Deployment complete!")
else:
log.info("No packages to update. Deployment complete!")


def replace_handler_in_zip(
Expand Down
1 change: 0 additions & 1 deletion src/snowcli/snow_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import hashlib
from io import StringIO

import typer
from jinja2 import Environment, FileSystemLoader
from pathlib import Path
from typing import Optional
Expand Down
4 changes: 3 additions & 1 deletion src/snowcli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import typer
from jinja2 import Environment, FileSystemLoader

from snowcli.config import cli_config
from snowcli.config import cli_config, get_default_connection

warnings.filterwarnings("ignore", category=UserWarning)

Expand Down Expand Up @@ -572,6 +572,8 @@ def convert_resource_details_to_dict(function_details: list[tuple]) -> dict:


def check_for_connection(connection_name: str):
if not connection_name:
connection_name = get_default_connection()
cli_config.get_connection(connection_name=connection_name)
return connection_name

Expand Down
6 changes: 5 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def test_streamlit_help(runner):
@mock.patch.dict(os.environ, {}, clear=True)
def test_custom_config_path(mock_conn, runner):
config_file = Path(__file__).parent / "test.toml"
runner.invoke(["--config-file", str(config_file), "warehouse", "status"])
result = runner.invoke(
["--config-file", str(config_file), "warehouse", "status"],
catch_exceptions=False,
)
assert result.exit_code == 0, result.output
mock_conn.assert_called_once_with(
connection_parameters={
"database": "db_for_test",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_sql_fails_for_both_query_and_file(runner):
@mock.patch(MOCK_CONNECTION)
@mock.patch("snowcli.config.cli_config")
def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner):
mock_config.get.return_value = "dev" # mock of get_default_connection
mock_config.get_connection.return_value = {}
result = runner.invoke(
[
Expand All @@ -87,7 +88,7 @@ def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner):
]
)

assert result.exit_code == 0
assert result.exit_code == 0, result.output
mock_conn.assert_called_once_with(
connection_name="dev",
account="accountnameValue",
Expand Down

0 comments on commit d0c934a

Please sign in to comment.