Skip to content

Commit

Permalink
Add global connection flags (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-turbaszek authored Aug 2, 2023
1 parent f0104d9 commit 8d03fdb
Show file tree
Hide file tree
Showing 7 changed files with 253 additions and 71 deletions.
108 changes: 108 additions & 0 deletions src/snowcli/cli/common/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from __future__ import annotations

import inspect
from typing import Callable, Optional, get_type_hints

from snowcli.cli.common.flags import (
ConnectionOption,
AccountOption,
UserOption,
DatabaseOption,
SchemaOption,
RoleOption,
WarehouseOption,
PasswordOption,
)


def global_options(func: Callable):
"""
Decorator providing default flags for overriding global parameters. Values are
updated in global SnowCLI state.
To use this decorator your command needs to accept **kwargs as last argument.
"""

def wrapper(**kwargs):
return func(**kwargs)

wrapper.__signature__ = _extend_signature_with_global_options(func) # type: ignore
return wrapper


GLOBAL_CONNECTION_OPTIONS = [
inspect.Parameter(
"connection",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=ConnectionOption,
),
inspect.Parameter(
"account",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=AccountOption,
),
inspect.Parameter(
"user",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=UserOption,
),
inspect.Parameter(
"password",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=PasswordOption,
),
inspect.Parameter(
"database",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=DatabaseOption,
),
inspect.Parameter(
"schema",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=SchemaOption,
),
inspect.Parameter(
"role",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=RoleOption,
),
inspect.Parameter(
"warehouse",
inspect.Parameter.KEYWORD_ONLY,
annotation=Optional[str],
default=WarehouseOption,
),
]


def _extend_signature_with_global_options(func):
"""Extends function signature with global options"""
sig = inspect.signature(func)

# Remove **kwargs from signature
existing_parameters = tuple(sig.parameters.values())[:-1]

type_hints = get_type_hints(func)
existing_parameters_with_evaluated_types = [
inspect.Parameter(
name=p.name,
kind=p.kind,
annotation=type_hints[p.name],
default=p.default,
)
for p in existing_parameters
]
sig = sig.replace(
parameters=[
*existing_parameters_with_evaluated_types,
*GLOBAL_CONNECTION_OPTIONS,
]
)
return sig
25 changes: 23 additions & 2 deletions src/snowcli/cli/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import typer


from snowcli.cli.common.snow_cli_global_context import ConnectionDetails
from snowcli.utils import check_for_connection

DEFAULT_CONTEXT_SETTINGS = {"help_option_names": ["--help", "-h"]}
Expand All @@ -22,6 +24,7 @@
"--accountname",
"--account",
help="Name assigned to your Snowflake account.",
callback=ConnectionDetails.update_callback("account"),
)

UserOption = typer.Option(
Expand All @@ -30,6 +33,7 @@
"--username",
"--user",
help="Username to connect to Snowflake.",
callback=ConnectionDetails.update_callback("user"),
)

PasswordOption = typer.Option(
Expand All @@ -38,6 +42,7 @@
"--password",
help="Snowflake password.",
hide_input=True,
callback=ConnectionDetails.update_callback("password"),
)

DatabaseOption = typer.Option(
Expand All @@ -46,6 +51,7 @@
"--dbname",
"--database",
help="Database to use.",
callback=ConnectionDetails.update_callback("database"),
)

SchemaOption = typer.Option(
Expand All @@ -54,8 +60,23 @@
"--schemaname",
"--schema",
help=" Schema in the database to use.",
callback=ConnectionDetails.update_callback("schema"),
)

RoleOption = typer.Option(None, "-r", "--rolename", "--role", help="Role to be used.")

WarehouseOption = typer.Option(None, "-w", "--warehouse", help="Warehouse to use.")
RoleOption = typer.Option(
None,
"-r",
"--rolename",
"--role",
help="Role to be used.",
callback=ConnectionDetails.update_callback("role"),
)

WarehouseOption = typer.Option(
None,
"-w",
"--warehouse",
help="Warehouse to use.",
callback=ConnectionDetails.update_callback("warehouse"),
)
63 changes: 61 additions & 2 deletions src/snowcli/cli/common/snow_cli_global_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,53 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Optional

from snowcli.config import cli_config, get_default_connection
from snowcli.snow_connector import connect_to_snowflake


@dataclass
class ConnectionDetails:
_connection: Optional[str] = None
account: Optional[str] = None
database: Optional[str] = None
role: Optional[str] = None
schema: Optional[str] = None
user: Optional[str] = None
warehouse: Optional[str] = None

@property
def connection(self):
self._connection = get_default_connection()
return self._connection

def connection_params(self):
from snowcli.cli.common.decorators import GLOBAL_CONNECTION_OPTIONS

params = cli_config.get_connection(self.connection)
for option in GLOBAL_CONNECTION_OPTIONS:
override = option.name
if override == "connection":
continue
override_value = getattr(self, override)
if override_value is not None:
params[override] = override_value
return params

@staticmethod
def _connection_update(param_name: str, value: str):
def modifications(context: SnowCliGlobalContext) -> SnowCliGlobalContext:
setattr(context.connection, param_name, value)
return context

snow_cli_global_context_manager.update_global_context(modifications)
return value

@staticmethod
def update_callback(param_name: str):
return lambda value: ConnectionDetails._connection_update(
param_name=param_name, value=value
)


@dataclass
Expand All @@ -10,6 +57,7 @@ class SnowCliGlobalContext:
"""

enable_tracebacks: bool
connection: ConnectionDetails


class SnowCliGlobalContextManager:
Expand All @@ -35,12 +83,23 @@ def update_global_context(
"""
self._global_context = deepcopy(update(self.get_global_context_copy()))

def get_connection(self):
connection = self.get_global_context_copy().connection
return connect_to_snowflake(
connection_name=connection.connection, **connection.connection_params()
)


def _create_snow_cli_global_context_manager_with_default_values() -> SnowCliGlobalContextManager:
"""
Creates a manager with global state filled with default values.
"""
return SnowCliGlobalContextManager(SnowCliGlobalContext(enable_tracebacks=True))
return SnowCliGlobalContextManager(
SnowCliGlobalContext(
enable_tracebacks=True,
connection=ConnectionDetails(),
)
)


snow_cli_global_context_manager = (
Expand Down
78 changes: 46 additions & 32 deletions src/snowcli/cli/snowpark/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,26 @@
import tempfile
from pathlib import Path
from shutil import rmtree
from typing import Optional

import click
import logging
import typer
from requirements.requirement import Requirement

from snowcli import config, utils
from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS, ConnectionOption
from snowcli import utils
from snowcli.cli.common.flags import (
DEFAULT_CONTEXT_SETTINGS,
ConnectionOption,
AccountOption,
UserOption,
DatabaseOption,
SchemaOption,
RoleOption,
WarehouseOption,
)
from snowcli.cli.common.decorators import global_options
from snowcli.cli.common.snow_cli_global_context import snow_cli_global_context_manager
from snowcli.snow_connector import connect_to_snowflake

app = typer.Typer(
Expand Down Expand Up @@ -76,34 +88,8 @@ def package_lookup(
return packages_string


@app.command("create")
def package_create(
name: str = typer.Argument(
...,
help="Name of the package",
),
install_packages: bool = typer.Option(
False,
"--yes",
"-y",
help="Install packages that are not available on the Snowflake anaconda channel",
),
):
"""
Create a python package as a zip file that can be uploaded to a stage and imported for a Snowpark python app.
"""
results_string = package_lookup(name, install_packages, _run_nested=True)
if os.path.exists(".packages"):
utils.recursive_zip_packages_dir(".packages", name + ".zip")
rmtree(".packages")
log.info(
f"Package {name}.zip created. You can now upload it to a stage (`snow package upload -f {name}.zip -s packages`) and reference it in your procedure or function."
)
if results_string is not None:
log.info(results_string)


@app.command("upload")
@global_options
def package_upload(
file: Path = typer.Option(
...,
Expand All @@ -124,12 +110,13 @@ def package_upload(
"-o",
help="Overwrite the file if it already exists",
),
environment: str = ConnectionOption,
**kwargs,
):
"""
Upload a python package zip file to a Snowflake stage so it can be referenced in the imports of a procedure or function.
Upload a python package zip file to a Snowflake stage, so it can be referenced in the imports of a procedure or function.
"""
conn = connect_to_snowflake(connection_name=environment)
conn = snow_cli_global_context_manager.get_connection()

log.info(f"Uploading {file} to Snowflake @{stage}/{file}...")
with tempfile.TemporaryDirectory() as temp_dir:
temp_app_zip_path = utils.prepare_app_zip(file, temp_dir)
Expand All @@ -150,3 +137,30 @@ def package_upload(
log.info(
"Package already exists on stage. Consider using --overwrite to overwrite the file."
)


@app.command("create")
def package_create(
name: str = typer.Argument(
...,
help="Name of the package",
),
install_packages: bool = typer.Option(
False,
"--yes",
"-y",
help="Install packages that are not available on the Snowflake anaconda channel",
),
):
"""
Create a python package as a zip file that can be uploaded to a stage and imported for a Snowpark python app.
"""
results_string = package_lookup(name, install_packages, _run_nested=True)
if os.path.exists(".packages"):
utils.recursive_zip_packages_dir(".packages", name + ".zip")
rmtree(".packages")
log.info(
f"Package {name}.zip created. You can now upload it to a stage (`snow package upload -f {name}.zip -s packages`) and reference it in your procedure or function."
)
if results_string is not None:
log.info(results_string)
Loading

0 comments on commit 8d03fdb

Please sign in to comment.