From d0c934a18a6580079c1be0916ad1def5ca363013 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Tue, 25 Jul 2023 10:37:54 +0200 Subject: [PATCH] fixup! Add default_connection option --- src/snowcli/cli/app.py | 12 ++ src/snowcli/cli/common/flags.py | 3 +- src/snowcli/cli/snowpark_shared.py | 224 ++++++++++++++--------------- src/snowcli/snow_connector.py | 1 - src/snowcli/utils.py | 4 +- tests/test_main.py | 6 +- tests/test_sql.py | 3 +- 7 files changed, 135 insertions(+), 118 deletions(-) diff --git a/src/snowcli/cli/app.py b/src/snowcli/cli/app.py index 17edc10da6..cf43bb5c1f 100644 --- a/src/snowcli/cli/app.py +++ b/src/snowcli/cli/app.py @@ -27,6 +27,11 @@ def _version_callback(value: bool): raise typer.Exit() +def _info_callback(value: bool): + if not value: + return + + def setup_global_context(debug: bool): """ Setup global state (accessible in whole CLI code) using options passed in SNOW CLI invocation. @@ -48,6 +53,13 @@ def default( callback=_version_callback, is_eager=True, ), + info: bool = typer.Option( + None, + "--info", + help="Prints information about the snowcli", + callback=_info_callback, + is_eager=True, + ), output_format: OutputFormat = typer.Option( OutputFormat.TABLE.value, "--format", diff --git a/src/snowcli/cli/common/flags.py b/src/snowcli/cli/common/flags.py index 2cee3357e7..661bee40ab 100644 --- a/src/snowcli/cli/common/flags.py +++ b/src/snowcli/cli/common/flags.py @@ -2,14 +2,13 @@ import typer -from snowcli.config import get_default_connection from snowcli.utils import check_for_connection DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]} ConnectionOption = typer.Option( - get_default_connection(), + None, "-c", "--connection", "--environment", diff --git a/src/snowcli/cli/snowpark_shared.py b/src/snowcli/cli/snowpark_shared.py index df5045a661..9fc5fdd083 100644 --- a/src/snowcli/cli/snowpark_shared.py +++ b/src/snowcli/cli/snowpark_shared.py @@ -150,142 +150,142 @@ def snowpark_update( execute_as_caller: bool = False, install_coverage_wrapper: bool = False, ) -> None: - conn = connect_to_snowflake(connection_name=environment) + if type == "function" and install_coverage_wrapper: log.error( "You cannot install a code coverage wrapper on a function, only a procedure." ) raise typer.Abort() - updated_package_list = [] - try: - log.info(f"Updating {type} {name}...") + + conn = connect_to_snowflake(connection_name=environment) + updated_package_list = [] + try: + log.info(f"Updating {type} {name}...") + if type == "function": + resource_details = conn.describe_function( + name=name, + input_parameters=input_parameters, + database=conn.ctx.database, + schema=conn.ctx.schema, + role=conn.ctx.role, + warehouse=conn.ctx.warehouse, + show_exceptions=False, + ) + elif type == "procedure": + resource_details = conn.describe_procedure( + name=name, + input_parameters=input_parameters, + database=conn.ctx.database, + schema=conn.ctx.schema, + role=conn.ctx.role, + warehouse=conn.ctx.warehouse, + show_exceptions=False, + ) + log.info("Checking if any new packages to update...") + resource_json = utils.convert_resource_details_to_dict( + resource_details, + ) # type: ignore + anaconda_packages = resource_json["packages"] + log.info( + f"Found {len(anaconda_packages)} defined Anaconda " + "packages in deployed {type}..." + ) + log.info( + "Checking if any packages defined or missing from " + "requirements.snowflake.txt..." + ) + updated_package_list = utils.get_snowflake_packages_delta( + anaconda_packages, + ) + if install_coverage_wrapper: + # if we're installing a coverage wrapper, ensure the coverage package included as a dependency + if ( + "coverage" not in anaconda_packages + and "coverage" not in updated_package_list + ): + updated_package_list.append("coverage") + log.info("Checking if app configuration has changed...") + if ( + resource_json["handler"].lower() != handler.lower() + or resource_json["returns"].lower() != return_type.lower() + ): + log.info( + "Return type or handler types do not match. Replacing" + "function configuration..." + ) + replace = True + except Exception: + log.info(f"Existing {type} not found, creating new {type}...") + replace = True + + finally: + deploy_dict = utils.get_deploy_names( + conn.ctx.database, + conn.ctx.schema, + generate_deploy_stage_name(name, input_parameters), + ) + with tempfile.TemporaryDirectory() as temp_dir: + temp_app_zip_path = utils.prepare_app_zip(file, temp_dir) + stage_path = deploy_dict["directory"] + "/coverage" + if install_coverage_wrapper: + handler = replace_handler_in_zip( + proc_name=name, + proc_signature=input_parameters, + handler=handler, + coverage_reports_stage=deploy_dict["stage"], + coverage_reports_stage_path=stage_path, + temp_dir=temp_dir, + zip_file_path=temp_app_zip_path, + ) + deploy_response = conn.upload_file_to_stage( + file_path=temp_app_zip_path, + destination_stage=deploy_dict["stage"], + path=deploy_dict["directory"], + database=conn.ctx.database, + schema=conn.ctx.schema, + overwrite=True, + role=conn.ctx.role, + warehouse=conn.ctx.warehouse, + ) + log.info(f"{deploy_response[0]} uploaded to stage {deploy_dict['full_path']}") + + if updated_package_list or replace: + log.info(f"Replacing {type} with updated values...") if type == "function": - resource_details = conn.describe_function( + conn.create_function( name=name, input_parameters=input_parameters, + return_type=return_type, + handler=handler, + imports=deploy_dict["full_path"], database=conn.ctx.database, schema=conn.ctx.schema, role=conn.ctx.role, warehouse=conn.ctx.warehouse, - show_exceptions=False, + overwrite=True, + packages=utils.get_snowflake_packages(), ) elif type == "procedure": - resource_details = conn.describe_procedure( + conn.create_procedure( name=name, input_parameters=input_parameters, + return_type=return_type, + handler=handler, + imports=deploy_dict["full_path"], database=conn.ctx.database, schema=conn.ctx.schema, role=conn.ctx.role, warehouse=conn.ctx.warehouse, - show_exceptions=False, - ) - log.info("Checking if any new packages to update...") - resource_json = utils.convert_resource_details_to_dict( - resource_details, - ) # type: ignore - anaconda_packages = resource_json["packages"] - log.info( - f"Found {len(anaconda_packages)} defined Anaconda " - "packages in deployed {type}..." - ) - log.info( - "Checking if any packages defined or missing from " - "requirements.snowflake.txt..." - ) - updated_package_list = utils.get_snowflake_packages_delta( - anaconda_packages, - ) - if install_coverage_wrapper: - # if we're installing a coverage wrapper, ensure the coverage package included as a dependency - if ( - "coverage" not in anaconda_packages - and "coverage" not in updated_package_list - ): - updated_package_list.append("coverage") - log.info("Checking if app configuration has changed...") - if ( - resource_json["handler"].lower() != handler.lower() - or resource_json["returns"].lower() != return_type.lower() - ): - log.info( - "Return type or handler types do not match. Replacing" - "function configuration..." - ) - replace = True - except Exception: - log.info(f"Existing {type} not found, creating new {type}...") - replace = True - - finally: - deploy_dict = utils.get_deploy_names( - conn.ctx.database, - conn.ctx.schema, - generate_deploy_stage_name(name, input_parameters), - ) - with tempfile.TemporaryDirectory() as temp_dir: - temp_app_zip_path = utils.prepare_app_zip(file, temp_dir) - stage_path = deploy_dict["directory"] + "/coverage" - if install_coverage_wrapper: - handler = replace_handler_in_zip( - proc_name=name, - proc_signature=input_parameters, - handler=handler, - coverage_reports_stage=deploy_dict["stage"], - coverage_reports_stage_path=stage_path, - temp_dir=temp_dir, - zip_file_path=temp_app_zip_path, - ) - deploy_response = conn.upload_file_to_stage( - file_path=temp_app_zip_path, - destination_stage=deploy_dict["stage"], - path=deploy_dict["directory"], - database=conn.ctx.database, - schema=conn.ctx.schema, overwrite=True, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, + packages=utils.get_snowflake_packages(), + execute_as_caller=execute_as_caller, ) log.info( - f"{deploy_response[0]} uploaded to stage {deploy_dict['full_path']}" + f"{type.capitalize()} {name} updated with new packages. " + "Deployment complete!" ) - - if updated_package_list or replace: - log.info(f"Replacing {type} with updated values...") - if type == "function": - conn.create_function( - name=name, - input_parameters=input_parameters, - return_type=return_type, - handler=handler, - imports=deploy_dict["full_path"], - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - overwrite=True, - packages=utils.get_snowflake_packages(), - ) - elif type == "procedure": - conn.create_procedure( - name=name, - input_parameters=input_parameters, - return_type=return_type, - handler=handler, - imports=deploy_dict["full_path"], - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - overwrite=True, - packages=utils.get_snowflake_packages(), - execute_as_caller=execute_as_caller, - ) - log.info( - f"{type.capitalize()} {name} updated with new packages. " - "Deployment complete!" - ) - else: - log.info("No packages to update. Deployment complete!") + else: + log.info("No packages to update. Deployment complete!") def replace_handler_in_zip( diff --git a/src/snowcli/snow_connector.py b/src/snowcli/snow_connector.py index e4beebfb48..6273f14f99 100644 --- a/src/snowcli/snow_connector.py +++ b/src/snowcli/snow_connector.py @@ -7,7 +7,6 @@ import hashlib from io import StringIO -import typer from jinja2 import Environment, FileSystemLoader from pathlib import Path from typing import Optional diff --git a/src/snowcli/utils.py b/src/snowcli/utils.py index 3af91a7926..ac0b0939f6 100644 --- a/src/snowcli/utils.py +++ b/src/snowcli/utils.py @@ -20,7 +20,7 @@ import typer from jinja2 import Environment, FileSystemLoader -from snowcli.config import cli_config +from snowcli.config import cli_config, get_default_connection warnings.filterwarnings("ignore", category=UserWarning) @@ -572,6 +572,8 @@ def convert_resource_details_to_dict(function_details: list[tuple]) -> dict: def check_for_connection(connection_name: str): + if not connection_name: + connection_name = get_default_connection() cli_config.get_connection(connection_name=connection_name) return connection_name diff --git a/tests/test_main.py b/tests/test_main.py index 001239c0bb..6c0eed953b 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -20,7 +20,11 @@ def test_streamlit_help(runner): @mock.patch.dict(os.environ, {}, clear=True) def test_custom_config_path(mock_conn, runner): config_file = Path(__file__).parent / "test.toml" - runner.invoke(["--config-file", str(config_file), "warehouse", "status"]) + result = runner.invoke( + ["--config-file", str(config_file), "warehouse", "status"], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.output mock_conn.assert_called_once_with( connection_parameters={ "database": "db_for_test", diff --git a/tests/test_sql.py b/tests/test_sql.py index 18949395a1..39330d5d90 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -66,6 +66,7 @@ def test_sql_fails_for_both_query_and_file(runner): @mock.patch(MOCK_CONNECTION) @mock.patch("snowcli.config.cli_config") def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner): + mock_config.get.return_value = "dev" # mock of get_default_connection mock_config.get_connection.return_value = {} result = runner.invoke( [ @@ -87,7 +88,7 @@ def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner): ] ) - assert result.exit_code == 0 + assert result.exit_code == 0, result.output mock_conn.assert_called_once_with( connection_name="dev", account="accountnameValue",