From 8d03fdbc80ce98f01a16ef71badacd3c3cccae94 Mon Sep 17 00:00:00 2001 From: Tomasz Urbaszek Date: Wed, 2 Aug 2023 10:26:02 +0200 Subject: [PATCH] Add global connection flags (#267) --- src/snowcli/cli/common/decorators.py | 108 ++++++++++++++++++ src/snowcli/cli/common/flags.py | 25 +++- .../cli/common/snow_cli_global_context.py | 63 +++++++++- src/snowcli/cli/snowpark/package.py | 78 +++++++------ src/snowcli/cli/sql.py | 31 +---- tests/test_sql.py | 17 +-- tox.ini | 2 +- 7 files changed, 253 insertions(+), 71 deletions(-) create mode 100644 src/snowcli/cli/common/decorators.py diff --git a/src/snowcli/cli/common/decorators.py b/src/snowcli/cli/common/decorators.py new file mode 100644 index 0000000000..219c3373d0 --- /dev/null +++ b/src/snowcli/cli/common/decorators.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import inspect +from typing import Callable, Optional, get_type_hints + +from snowcli.cli.common.flags import ( + ConnectionOption, + AccountOption, + UserOption, + DatabaseOption, + SchemaOption, + RoleOption, + WarehouseOption, + PasswordOption, +) + + +def global_options(func: Callable): + """ + Decorator providing default flags for overriding global parameters. Values are + updated in global SnowCLI state. + + To use this decorator your command needs to accept **kwargs as last argument. + """ + + def wrapper(**kwargs): + return func(**kwargs) + + wrapper.__signature__ = _extend_signature_with_global_options(func) # type: ignore + return wrapper + + +GLOBAL_CONNECTION_OPTIONS = [ + inspect.Parameter( + "connection", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=ConnectionOption, + ), + inspect.Parameter( + "account", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=AccountOption, + ), + inspect.Parameter( + "user", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=UserOption, + ), + inspect.Parameter( + "password", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=PasswordOption, + ), + inspect.Parameter( + "database", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=DatabaseOption, + ), + inspect.Parameter( + "schema", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=SchemaOption, + ), + inspect.Parameter( + "role", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=RoleOption, + ), + inspect.Parameter( + "warehouse", + inspect.Parameter.KEYWORD_ONLY, + annotation=Optional[str], + default=WarehouseOption, + ), +] + + +def _extend_signature_with_global_options(func): + """Extends function signature with global options""" + sig = inspect.signature(func) + + # Remove **kwargs from signature + existing_parameters = tuple(sig.parameters.values())[:-1] + + type_hints = get_type_hints(func) + existing_parameters_with_evaluated_types = [ + inspect.Parameter( + name=p.name, + kind=p.kind, + annotation=type_hints[p.name], + default=p.default, + ) + for p in existing_parameters + ] + sig = sig.replace( + parameters=[ + *existing_parameters_with_evaluated_types, + *GLOBAL_CONNECTION_OPTIONS, + ] + ) + return sig diff --git a/src/snowcli/cli/common/flags.py b/src/snowcli/cli/common/flags.py index 661bee40ab..4018732e3e 100644 --- a/src/snowcli/cli/common/flags.py +++ b/src/snowcli/cli/common/flags.py @@ -2,6 +2,8 @@ import typer + +from snowcli.cli.common.snow_cli_global_context import ConnectionDetails from snowcli.utils import check_for_connection DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]} @@ -22,6 +24,7 @@ "--accountname", "--account", help="Name assigned to your Snowflake account.", + callback=ConnectionDetails.update_callback("account"), ) UserOption = typer.Option( @@ -30,6 +33,7 @@ "--username", "--user", help="Username to connect to Snowflake.", + callback=ConnectionDetails.update_callback("user"), ) PasswordOption = typer.Option( @@ -38,6 +42,7 @@ "--password", help="Snowflake password.", hide_input=True, + callback=ConnectionDetails.update_callback("password"), ) DatabaseOption = typer.Option( @@ -46,6 +51,7 @@ "--dbname", "--database", help="Database to use.", + callback=ConnectionDetails.update_callback("database"), ) SchemaOption = typer.Option( @@ -54,8 +60,23 @@ "--schemaname", "--schema", help=" Schema in the database to use.", + callback=ConnectionDetails.update_callback("schema"), ) -RoleOption = typer.Option(None, "-r", "--rolename", "--role", help="Role to be used.") -WarehouseOption = typer.Option(None, "-w", "--warehouse", help="Warehouse to use.") +RoleOption = typer.Option( + None, + "-r", + "--rolename", + "--role", + help="Role to be used.", + callback=ConnectionDetails.update_callback("role"), +) + +WarehouseOption = typer.Option( + None, + "-w", + "--warehouse", + help="Warehouse to use.", + callback=ConnectionDetails.update_callback("warehouse"), +) diff --git a/src/snowcli/cli/common/snow_cli_global_context.py b/src/snowcli/cli/common/snow_cli_global_context.py index 89668c7542..89bf6ffe68 100644 --- a/src/snowcli/cli/common/snow_cli_global_context.py +++ b/src/snowcli/cli/common/snow_cli_global_context.py @@ -1,6 +1,53 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Callable +from typing import Callable, Optional + +from snowcli.config import cli_config, get_default_connection +from snowcli.snow_connector import connect_to_snowflake + + +@dataclass +class ConnectionDetails: + _connection: Optional[str] = None + account: Optional[str] = None + database: Optional[str] = None + role: Optional[str] = None + schema: Optional[str] = None + user: Optional[str] = None + warehouse: Optional[str] = None + + @property + def connection(self): + self._connection = get_default_connection() + return self._connection + + def connection_params(self): + from snowcli.cli.common.decorators import GLOBAL_CONNECTION_OPTIONS + + params = cli_config.get_connection(self.connection) + for option in GLOBAL_CONNECTION_OPTIONS: + override = option.name + if override == "connection": + continue + override_value = getattr(self, override) + if override_value is not None: + params[override] = override_value + return params + + @staticmethod + def _connection_update(param_name: str, value: str): + def modifications(context: SnowCliGlobalContext) -> SnowCliGlobalContext: + setattr(context.connection, param_name, value) + return context + + snow_cli_global_context_manager.update_global_context(modifications) + return value + + @staticmethod + def update_callback(param_name: str): + return lambda value: ConnectionDetails._connection_update( + param_name=param_name, value=value + ) @dataclass @@ -10,6 +57,7 @@ class SnowCliGlobalContext: """ enable_tracebacks: bool + connection: ConnectionDetails class SnowCliGlobalContextManager: @@ -35,12 +83,23 @@ def update_global_context( """ self._global_context = deepcopy(update(self.get_global_context_copy())) + def get_connection(self): + connection = self.get_global_context_copy().connection + return connect_to_snowflake( + connection_name=connection.connection, **connection.connection_params() + ) + def _create_snow_cli_global_context_manager_with_default_values() -> SnowCliGlobalContextManager: """ Creates a manager with global state filled with default values. """ - return SnowCliGlobalContextManager(SnowCliGlobalContext(enable_tracebacks=True)) + return SnowCliGlobalContextManager( + SnowCliGlobalContext( + enable_tracebacks=True, + connection=ConnectionDetails(), + ) + ) snow_cli_global_context_manager = ( diff --git a/src/snowcli/cli/snowpark/package.py b/src/snowcli/cli/snowpark/package.py index 3747ef25fd..bbd2f3831b 100644 --- a/src/snowcli/cli/snowpark/package.py +++ b/src/snowcli/cli/snowpark/package.py @@ -4,14 +4,26 @@ import tempfile from pathlib import Path from shutil import rmtree +from typing import Optional import click import logging import typer from requirements.requirement import Requirement -from snowcli import config, utils -from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS, ConnectionOption +from snowcli import utils +from snowcli.cli.common.flags import ( + DEFAULT_CONTEXT_SETTINGS, + ConnectionOption, + AccountOption, + UserOption, + DatabaseOption, + SchemaOption, + RoleOption, + WarehouseOption, +) +from snowcli.cli.common.decorators import global_options +from snowcli.cli.common.snow_cli_global_context import snow_cli_global_context_manager from snowcli.snow_connector import connect_to_snowflake app = typer.Typer( @@ -76,34 +88,8 @@ def package_lookup( return packages_string -@app.command("create") -def package_create( - name: str = typer.Argument( - ..., - help="Name of the package", - ), - install_packages: bool = typer.Option( - False, - "--yes", - "-y", - help="Install packages that are not available on the Snowflake anaconda channel", - ), -): - """ - Create a python package as a zip file that can be uploaded to a stage and imported for a Snowpark python app. - """ - results_string = package_lookup(name, install_packages, _run_nested=True) - if os.path.exists(".packages"): - utils.recursive_zip_packages_dir(".packages", name + ".zip") - rmtree(".packages") - log.info( - f"Package {name}.zip created. You can now upload it to a stage (`snow package upload -f {name}.zip -s packages`) and reference it in your procedure or function." - ) - if results_string is not None: - log.info(results_string) - - @app.command("upload") +@global_options def package_upload( file: Path = typer.Option( ..., @@ -124,12 +110,13 @@ def package_upload( "-o", help="Overwrite the file if it already exists", ), - environment: str = ConnectionOption, + **kwargs, ): """ - Upload a python package zip file to a Snowflake stage so it can be referenced in the imports of a procedure or function. + Upload a python package zip file to a Snowflake stage, so it can be referenced in the imports of a procedure or function. """ - conn = connect_to_snowflake(connection_name=environment) + conn = snow_cli_global_context_manager.get_connection() + log.info(f"Uploading {file} to Snowflake @{stage}/{file}...") with tempfile.TemporaryDirectory() as temp_dir: temp_app_zip_path = utils.prepare_app_zip(file, temp_dir) @@ -150,3 +137,30 @@ def package_upload( log.info( "Package already exists on stage. Consider using --overwrite to overwrite the file." ) + + +@app.command("create") +def package_create( + name: str = typer.Argument( + ..., + help="Name of the package", + ), + install_packages: bool = typer.Option( + False, + "--yes", + "-y", + help="Install packages that are not available on the Snowflake anaconda channel", + ), +): + """ + Create a python package as a zip file that can be uploaded to a stage and imported for a Snowpark python app. + """ + results_string = package_lookup(name, install_packages, _run_nested=True) + if os.path.exists(".packages"): + utils.recursive_zip_packages_dir(".packages", name + ".zip") + rmtree(".packages") + log.info( + f"Package {name}.zip created. You can now upload it to a stage (`snow package upload -f {name}.zip -s packages`) and reference it in your procedure or function." + ) + if results_string is not None: + log.info(results_string) diff --git a/src/snowcli/cli/sql.py b/src/snowcli/cli/sql.py index 2fa140328c..b99c56eef0 100644 --- a/src/snowcli/cli/sql.py +++ b/src/snowcli/cli/sql.py @@ -5,19 +5,12 @@ import typer from click import UsageError -from snowcli.snow_connector import connect_to_snowflake -from snowcli.cli.common.flags import ( - ConnectionOption, - AccountOption, - UserOption, - DatabaseOption, - SchemaOption, - RoleOption, - WarehouseOption, -) +from snowcli.cli.common.snow_cli_global_context import snow_cli_global_context_manager +from snowcli.cli.common.decorators import global_options from snowcli.output.printing import print_db_cursor +@global_options def execute_sql( query: Optional[str] = typer.Option( None, @@ -35,13 +28,7 @@ def execute_sql( readable=True, help="File to execute.", ), - connection: Optional[str] = ConnectionOption, - account: Optional[str] = AccountOption, - user: Optional[str] = UserOption, - database: Optional[str] = DatabaseOption, - schema: Optional[str] = SchemaOption, - role: Optional[str] = RoleOption, - warehouse: Optional[str] = WarehouseOption, + **options ): """ Executes Snowflake query. @@ -69,15 +56,7 @@ def execute_sql( else: sql = query if query else file.read_text() # type: ignore - conn = connect_to_snowflake( - connection_name=connection, - account=account, - user=user, - role=role, - warehouse=warehouse, - database=database, - schema=schema, - ) + conn = snow_cli_global_context_manager.get_connection() results = conn.ctx.execute_string( sql_text=sql, diff --git a/tests/test_sql.py b/tests/test_sql.py index 39330d5d90..d806e6f3c8 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -4,7 +4,7 @@ from tests.testing_utils.result_assertions import assert_that_result_is_usage_error -MOCK_CONNECTION = "snowcli.cli.sql.connect_to_snowflake" +MOCK_CONNECTION = "snowcli.cli.sql.snow_cli_global_context_manager.get_connection" @mock.patch(MOCK_CONNECTION) @@ -63,12 +63,9 @@ def test_sql_fails_for_both_query_and_file(runner): assert_that_result_is_usage_error(result, "Both query and file provided") -@mock.patch(MOCK_CONNECTION) -@mock.patch("snowcli.config.cli_config") -def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner): - mock_config.get.return_value = "dev" # mock of get_default_connection - mock_config.get_connection.return_value = {} - result = runner.invoke( +@mock.patch("snowcli.cli.common.snow_cli_global_context.connect_to_snowflake") +def test_sql_overrides_connection_configuration(mock_conn, runner): + result = runner.invoke_with_config( [ "sql", "-q", @@ -85,7 +82,10 @@ def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner): "rolenameValue", "--warehouse", "warehouseValue", - ] + "--password", + "passFromTest", + ], + catch_exceptions=False, ) assert result.exit_code == 0, result.output @@ -97,4 +97,5 @@ def test_sql_overrides_connection_configuration(mock_config, mock_conn, runner): database="dbnameValue", schema="schemanameValue", role="rolenameValue", + password="passFromTest", ) diff --git a/tox.ini b/tox.ini index a473ef4222..1660986aeb 100644 --- a/tox.ini +++ b/tox.ini @@ -10,7 +10,7 @@ deps = requests-mock extras = tests commands: - coverage run --source=snowcli -m pytest + coverage run --source=snowcli -m pytest tests/ coverage report [tox:.package]