diff --git a/src/snowcli/cli/common/sql_execution.py b/src/snowcli/cli/common/sql_execution.py index e84bc772c6..ad616e7d92 100644 --- a/src/snowcli/cli/common/sql_execution.py +++ b/src/snowcli/cli/common/sql_execution.py @@ -1,6 +1,7 @@ from __future__ import annotations from functools import cached_property +from textwrap import dedent from snowcli.cli.common.snow_cli_global_context import snow_cli_global_context_manager @@ -17,6 +18,6 @@ def _execute_template(self, template_name: str, payload: dict): return self._conn.run_sql(template_name, payload) def _execute_query(self, query: str): - results = self._conn.ctx.execute_string(query) + results = self._conn.ctx.execute_string(dedent(query)) *_, last_result = results return last_result diff --git a/src/snowcli/cli/snowpark/__init__.py b/src/snowcli/cli/snowpark/__init__.py index 359438e8d7..01b21cf3b7 100644 --- a/src/snowcli/cli/snowpark/__init__.py +++ b/src/snowcli/cli/snowpark/__init__.py @@ -4,9 +4,12 @@ from snowcli.cli.snowpark.function import app as function_app from snowcli.cli.snowpark.package import app as package_app from snowcli.cli.snowpark.procedure import app as procedure_app -from snowcli.cli.snowpark.cp import app as compute_pools_app, app_cp as cp_app -from snowcli.cli.snowpark.services import app as services_app -from snowcli.cli.snowpark.jobs import app as jobs_app +from snowcli.cli.snowpark.compute_pool.commands import ( + app as compute_pools_app, + app_cp as cp_app, +) +from snowcli.cli.snowpark.services.commands import app as services_app +from snowcli.cli.snowpark.jobs.commands import app as jobs_app from snowcli.cli.snowpark.registry import app as registry_app app = typer.Typer( diff --git a/src/snowcli/cli/snowpark/common.py b/src/snowcli/cli/snowpark/common.py new file mode 100644 index 0000000000..adb1134070 --- /dev/null +++ b/src/snowcli/cli/snowpark/common.py @@ -0,0 +1,36 @@ +import sys +from typing import TextIO + + +if not sys.stdout.closed and sys.stdout.isatty(): + GREEN = "\033[32m" + BLUE = "\033[34m" + ORANGE = "\033[38:2:238:76:44m" + GRAY = "\033[2m" + ENDC = "\033[0m" +else: + GREEN = "" + ORANGE = "" + BLUE = "" + GRAY = "" + ENDC = "" + + +def _prefix_line(prefix: str, line: str) -> str: + """ + _prefix_line ensure the prefix is still present even when dealing with return characters + """ + if "\r" in line: + line = line.replace("\r", f"\r{prefix}") + if "\n" in line[:-1]: + line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:] + if not line.startswith("\r"): + line = f"{prefix}{line}" + return line + + +def print_log_lines(file: TextIO, name, id, logs): + prefix = f"{GREEN}{name}/{id}{ENDC} " + logs = logs[0:-1] + for log in logs: + print(_prefix_line(prefix, log + "\n"), file=file, end="", flush=True) diff --git a/src/snowcli/cli/snowpark/compute_pool/__init__.py b/src/snowcli/cli/snowpark/compute_pool/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowcli/cli/snowpark/compute_pool/commands.py b/src/snowcli/cli/snowpark/compute_pool/commands.py new file mode 100644 index 0000000000..aa139fe5c9 --- /dev/null +++ b/src/snowcli/cli/snowpark/compute_pool/commands.py @@ -0,0 +1,68 @@ +import typer + +from snowcli.cli.common.alias import build_alias +from snowcli.cli.common.decorators import global_options +from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS +from snowcli.cli.snowpark.compute_pool.manager import ComputePoolManager +from snowcli.output.decorators import with_output + + +app = typer.Typer( + context_settings=DEFAULT_CONTEXT_SETTINGS, + name="compute-pool", + help="Manage compute pools. You can also use cp as alias for this command", +) + + +@app.command() +@with_output +@global_options +def create( + name: str = typer.Option(..., "--name", "-n", help="Compute pool name"), + num_instances: int = typer.Option(..., "--num", "-d", help="Number of instances"), + instance_family: str = typer.Option(..., "--family", "-f", help="Instance family"), + **options, +): + """ + Create compute pool + """ + return ComputePoolManager().create( + pool_name=name, num_instances=num_instances, instance_family=instance_family + ) + + +@app.command() +@with_output +@global_options +def list(**options): + """ + List compute pools + """ + return ComputePoolManager().show() + + +@app.command() +@with_output +@global_options +def drop(name: str = typer.Argument(..., help="Compute Pool Name"), **options): + """ + Drop compute pool + """ + return ComputePoolManager().drop(pool_name=name) + + +@app.command() +@with_output +@global_options +def stop(name: str = typer.Argument(..., help="Compute Pool Name"), **options): + """ + Stop and delete all services running on Compute Pool + """ + return ComputePoolManager().stop(pool_name=name) + + +app_cp = build_alias( + app, + name="cp", + help_str="Manage compute pools. This is alias for compute-pool command", +) diff --git a/src/snowcli/cli/snowpark/compute_pool/manager.py b/src/snowcli/cli/snowpark/compute_pool/manager.py new file mode 100644 index 0000000000..66cf5d2733 --- /dev/null +++ b/src/snowcli/cli/snowpark/compute_pool/manager.py @@ -0,0 +1,25 @@ +from snowcli.cli.common.sql_execution import SqlExecutionMixin + + +class ComputePoolManager(SqlExecutionMixin): + def create(self, pool_name: str, num_instances: int, instance_family: str): + return self._execute_query( + f"""\ + CREATE COMPUTE POOL {pool_name} + MIN_NODES = {num_instances} + MAX_NODES = {num_instances} + INSTANCE_FAMILY = {instance_family}; + """ + ) + + def show(self): + return self._execute_query("show compute pools;") + + def drop( + self, + pool_name: str, + ): + return self._execute_query(f"drop compute pool {pool_name};") + + def stop(self, pool_name: str): + return self._execute_query(f"alter compute pool {pool_name} stop all services;") diff --git a/src/snowcli/cli/snowpark/cp.py b/src/snowcli/cli/snowpark/cp.py deleted file mode 100644 index c8d95bea71..0000000000 --- a/src/snowcli/cli/snowpark/cp.py +++ /dev/null @@ -1,99 +0,0 @@ -import typer - -from snowcli.cli.common.alias import build_alias -from snowcli.cli.common.flags import ConnectionOption, DEFAULT_CONTEXT_SETTINGS -from snowcli.snow_connector import connect_to_snowflake -from snowcli.output.printing import print_db_cursor - -app = typer.Typer( - context_settings=DEFAULT_CONTEXT_SETTINGS, - name="compute-pool", - help="Manage compute pools. You can also use cp as alias for this command", -) - - -@app.command() -def create( - environment: str = ConnectionOption, - name: str = typer.Option(..., "--name", "-n", help="Compute pool name"), - num_instances: int = typer.Option(..., "--num", "-d", help="Number of instances"), - instance_family: str = typer.Option(..., "--family", "-f", help="Instance family"), -): - """ - Create compute pool - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.create_compute_pool( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - num_instances=num_instances, - instance_family=instance_family, - ) - print_db_cursor(results) - - -@app.command() -def list(environment: str = ConnectionOption): - """ - List compute pools - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.list_compute_pools( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - ) - print_db_cursor(results) - - -@app.command() -def drop( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Compute Pool Name"), -): - """ - Drop compute pool - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.drop_compute_pool( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - ) - print_db_cursor(results) - - -@app.command() -def stop( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Compute Pool Name"), -): - """ - Stop and delete all services running on Compute Pool - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.stop_compute_pool( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - ) - print_db_cursor(results) - - -app_cp = build_alias( - app, - name="cp", - help_str="Manage compute pools. This is alias for compute-pool command", -) diff --git a/src/snowcli/cli/snowpark/jobs.py b/src/snowcli/cli/snowpark/jobs.py deleted file mode 100644 index bbb9651e60..0000000000 --- a/src/snowcli/cli/snowpark/jobs.py +++ /dev/null @@ -1,155 +0,0 @@ -import sys -import typer -from typing import TextIO - - -from snowcli.cli.common.flags import ConnectionOption, DEFAULT_CONTEXT_SETTINGS -from snowcli.snow_connector import connect_to_snowflake -from snowcli.output.printing import print_db_cursor - -app = typer.Typer( - context_settings=DEFAULT_CONTEXT_SETTINGS, name="jobs", help="Manage jobs" -) - -if not sys.stdout.closed and sys.stdout.isatty(): - GREEN = "\033[32m" - BLUE = "\033[34m" - ORANGE = "\033[38:2:238:76:44m" - GRAY = "\033[2m" - ENDC = "\033[0m" -else: - GREEN = "" - ORANGE = "" - BLUE = "" - GRAY = "" - ENDC = "" - - -@app.command() -def create( - environment: str = ConnectionOption, - compute_pool: str = typer.Option(..., "--compute_pool", "-c", help="Compute Pool"), - spec_path: str = typer.Option(..., "--spec_path", "-s", help="Spec.yaml file path"), - stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), -): - """ - Create Job - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.create_job( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - compute_pool=compute_pool, - spec_path=spec_path, - stage=stage, - ) - print_db_cursor(results) - - -@app.command() -def desc( - environment: str = ConnectionOption, - id: str = typer.Argument(..., help="Job id"), -): - """ - Desc Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.desc_job( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - id=id, - ) - print_db_cursor(results) - - -def _prefix_line(prefix: str, line: str) -> str: - """ - _prefix_line ensure the prefix is still present even when dealing with return characters - """ - if "\r" in line: - line = line.replace("\r", f"\r{prefix}") - if "\n" in line[:-1]: - line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:] - if not line.startswith("\r"): - line = f"{prefix}{line}" - return line - - -def print_log_lines(file: TextIO, name, id, logs): - prefix = f"{GREEN}{name}/{id}{ENDC} " - logs = logs[0:-1] - for log in logs: - print(_prefix_line(prefix, log + "\n"), file=file, end="", flush=True) - - -@app.command() -def logs( - environment: str = ConnectionOption, - id: str = typer.Argument(..., help="Job id"), - container_name: str = typer.Option( - ..., "--container-name", "-c", help="Container Name" - ), -): - """ - Logs Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.logs_job( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - id=id, - container_name=container_name, - ) - cursor = results.fetchone() - logs = next(iter(cursor)).split("\n") - print_log_lines(sys.stdout, id, "0", logs) - - -@app.command() -def status( - environment: str = ConnectionOption, - id: str = typer.Argument(..., help="Job id"), -): - """ - Logs Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.status_job( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - id=id, - ) - print_db_cursor(results) - - -@app.command() -def drop( - environment: str = ConnectionOption, - id: str = typer.Argument(..., help="Job id"), -): - """ - Drop Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.drop_job( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - id=id, - ) - print_db_cursor(results) diff --git a/src/snowcli/cli/snowpark/jobs/__init__.py b/src/snowcli/cli/snowpark/jobs/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowcli/cli/snowpark/jobs/commands.py b/src/snowcli/cli/snowpark/jobs/commands.py new file mode 100644 index 0000000000..bf2409d1b9 --- /dev/null +++ b/src/snowcli/cli/snowpark/jobs/commands.py @@ -0,0 +1,92 @@ +import sys +from pathlib import Path + +import typer + +from snowcli.cli.common.decorators import global_options +from snowcli.cli.common.flags import DEFAULT_CONTEXT_SETTINGS +from snowcli.cli.snowpark.common import print_log_lines +from snowcli.cli.snowpark.jobs.manager import JobManager +from snowcli.cli.stage.manager import StageManager +from snowcli.output.decorators import with_output + +app = typer.Typer( + context_settings=DEFAULT_CONTEXT_SETTINGS, name="jobs", help="Manage jobs" +) + + +@app.command() +@with_output +@global_options +def create( + compute_pool: str = typer.Option(..., "--compute-pool", "-c", help="Compute Pool"), + spec_path: Path = typer.Option( + ..., + "--spec-path", + "-s", + help="Spec.yaml file path", + file_okay=True, + dir_okay=False, + exists=True, + ), + stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), + **options, +): + """ + Create Job + """ + stage_manager = StageManager() + stage_manager.create(stage_name=stage) + stage_manager.put(local_path=str(spec_path), stage_name=stage, overwrite=True) + + return JobManager().create( + compute_pool=compute_pool, spec_path=spec_path, stage=stage + ) + + +@app.command() +@with_output +@global_options +def desc(id: str = typer.Argument(..., help="Job id"), **options): + """ + Desc Service + """ + return JobManager().desc(job_name=id) + + +@app.command() +@global_options +def logs( + id: str = typer.Argument(..., help="Job id"), + container_name: str = typer.Option( + ..., "--container-name", "-c", help="Container Name" + ), + **options, +): + """ + Logs Service + """ + results = JobManager().logs(job_name=id, container_name=container_name) + cursor = results.fetchone() + logs = next(iter(cursor)).split("\n") + print_log_lines(sys.stdout, id, "0", logs) + + +@app.command() +@with_output +@global_options +def status(id: str = typer.Argument(..., help="Job id"), **options): + """ + Returns status of a job. + """ + return JobManager().status(job_name=id) + + +@app.command() +@with_output +@global_options +def drop(id: str = typer.Argument(..., help="Job id"), **options): + """ + Drop Service + """ + return JobManager().drop(job_name=id) diff --git a/src/snowcli/cli/snowpark/jobs/manager.py b/src/snowcli/cli/snowpark/jobs/manager.py new file mode 100644 index 0000000000..8d8dd449e0 --- /dev/null +++ b/src/snowcli/cli/snowpark/jobs/manager.py @@ -0,0 +1,33 @@ +import hashlib +import os +from pathlib import Path + +from snowcli.cli.common.sql_execution import SqlExecutionMixin + + +class JobManager(SqlExecutionMixin): + def create(self, compute_pool: str, spec_path: Path, stage: str): + spec_filename = os.path.basename(spec_path) + file_hash = hashlib.md5(open(spec_path, "rb").read()).hexdigest() + stage_dir = os.path.join("jobs", file_hash) + return self._execute_query( + f"""\ + EXECUTE SERVICE + COMPUTE_POOL = {compute_pool} + spec=@{stage}/{stage_dir}/{spec_filename}; + """ + ) + + def desc(self, job_name: str): + return self._execute_query(f"desc service {job_name}") + + def status(self, job_name: str): + return self._execute_query(f"CALL SYSTEM$GET_JOB_STATUS('{job_name}')") + + def drop(self, job_name: str): + return self._execute_query(f"CALL SYSTEM$CANCEL_JOB('{job_name}')") + + def logs(self, job_name: str, container_name: str): + return self._execute_query( + f"call SYSTEM$GET_JOB_LOGS('{job_name}', '{container_name}')" + ) diff --git a/src/snowcli/cli/snowpark/services.py b/src/snowcli/cli/snowpark/services.py deleted file mode 100644 index 30886e47c2..0000000000 --- a/src/snowcli/cli/snowpark/services.py +++ /dev/null @@ -1,179 +0,0 @@ -import sys - -import typer -from typing import TextIO -from typing_extensions import Annotated - -from snowcli.cli.common.flags import ConnectionOption, DEFAULT_CONTEXT_SETTINGS -from snowcli.snow_connector import connect_to_snowflake -from snowcli.output.printing import print_db_cursor - -app = typer.Typer( - context_settings=DEFAULT_CONTEXT_SETTINGS, name="services", help="Manage services" -) - -if not sys.stdout.closed and sys.stdout.isatty(): - GREEN = "\033[32m" - BLUE = "\033[34m" - ORANGE = "\033[38:2:238:76:44m" - GRAY = "\033[2m" - ENDC = "\033[0m" -else: - GREEN = "" - ORANGE = "" - BLUE = "" - GRAY = "" - ENDC = "" - - -@app.command() -def create( - environment: str = ConnectionOption, - name: str = typer.Option(..., "--name", "-n", help="Job Name"), - compute_pool: str = typer.Option(..., "--compute_pool", "-c", help="Compute Pool"), - spec_path: str = typer.Option(..., "--spec_path", "-s", help="Spec Path"), - num_instances: Annotated[ - int, typer.Option("--num_instances", "-num", help="Number of instances") - ] = 1, - stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), -): - """ - Create service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.create_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - compute_pool=compute_pool, - num_instances=num_instances, - spec_path=spec_path, - stage=stage, - ) - print_db_cursor(results) - - -@app.command() -def desc( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Service Name"), -): - """ - Desc Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.desc_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - ) - print_db_cursor(results) - - -def _prefix_line(prefix: str, line: str) -> str: - """ - _prefix_line ensure the prefix is still present even when dealing with return characters - """ - if "\r" in line: - line = line.replace("\r", f"\r{prefix}") - if "\n" in line[:-1]: - line = line[:-1].replace("\n", f"\n{prefix}") + line[-1:] - if not line.startswith("\r"): - line = f"{prefix}{line}" - return line - - -def print_log_lines(file: TextIO, name, id, logs): - prefix = f"{GREEN}{name}/{id}{ENDC} " - logs = logs[0:-1] - for log in logs: - print(_prefix_line(prefix, log + "\n"), file=file, end="", flush=True) - - -@app.command() -def logs( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Service Name"), - container_name: str = typer.Option( - ..., "--container_name", "-c", help="Container Name" - ), -): - """ - Logs Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.logs_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - instance_id="0", - container_name=container_name, - ) - cursor = results.fetchone() - logs = next(iter(cursor)).split("\n") - print_log_lines(sys.stdout, name, "0", logs) - - -@app.command() -def status( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Service Name"), -): - """ - Logs Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.status_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - ) - print_db_cursor(results) - - -@app.command() -def list(environment: str = ConnectionOption): - """ - List Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.list_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - ) - print_db_cursor(results) - - -@app.command() -def drop( - environment: str = ConnectionOption, - name: str = typer.Argument(..., help="Service Name"), -): - """ - Drop Service - """ - conn = connect_to_snowflake(connection_name=environment) - - results = conn.drop_service( - database=conn.ctx.database, - schema=conn.ctx.schema, - role=conn.ctx.role, - warehouse=conn.ctx.warehouse, - name=name, - ) - print_db_cursor(results) diff --git a/src/snowcli/cli/snowpark/services/__init__.py b/src/snowcli/cli/snowpark/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/snowcli/cli/snowpark/services/commands.py b/src/snowcli/cli/snowpark/services/commands.py new file mode 100644 index 0000000000..138beb77de --- /dev/null +++ b/src/snowcli/cli/snowpark/services/commands.py @@ -0,0 +1,111 @@ +import sys +from pathlib import Path + +import typer + +from snowcli.cli.common.decorators import global_options +from snowcli.cli.common.flags import ConnectionOption, DEFAULT_CONTEXT_SETTINGS +from snowcli.cli.snowpark.common import print_log_lines +from snowcli.cli.snowpark.services.manager import ServiceManager +from snowcli.cli.stage.manager import StageManager +from snowcli.output.decorators import with_output + +app = typer.Typer( + context_settings=DEFAULT_CONTEXT_SETTINGS, name="services", help="Manage services" +) + + +@app.command() +@with_output +@global_options +def create( + name: str = typer.Option(..., "--name", "-n", help="Job Name"), + compute_pool: str = typer.Option(..., "--compute_pool", "-c", help="Compute Pool"), + spec_path: Path = typer.Option( + ..., + "--spec_path", + "-s", + help="Spec Path", + file_okay=True, + dir_okay=False, + exists=True, + ), + num_instances: int = typer.Option( + 1, "--num_instances", "-num", help="Number of instances" + ), + stage: str = typer.Option("SOURCE_STAGE", "--stage", "-l", help="Stage name"), + **options, +): + """ + Create service + """ + stage_manager = StageManager() + stage_manager.create(stage_name=stage) + stage_manager.put(local_path=str(spec_path), stage_name=stage, overwrite=True) + + return ServiceManager().create( + service_name=name, + num_instances=num_instances, + compute_pool=compute_pool, + spec_path=spec_path, + stage=stage, + ) + + +@app.command() +def desc( + environment: str = ConnectionOption, + name: str = typer.Argument(..., help="Service Name"), +): + """ + Desc Service + """ + return ServiceManager().desc(service_name=name) + + +@app.command() +@with_output +@global_options +def status(name: str = typer.Argument(..., help="Service Name"), **options): + """ + Logs Service + """ + return ServiceManager().status(service_name=name) + + +@app.command() +@with_output +@global_options +def list(**options): + """ + List Service + """ + return ServiceManager().show() + + +@app.command() +@with_output +@global_options +def drop(name: str = typer.Argument(..., help="Service Name"), **options): + """ + Drop Service + """ + return ServiceManager().drop(service_name=name) + + +@app.command() +@global_options +def logs( + name: str = typer.Argument(..., help="Service Name"), + container_name: str = typer.Option( + ..., "--container_name", "-c", help="Container Name" + ), + **options, +): + """ + Logs Service + """ + results = ServiceManager().logs(service_name=name, container_name=container_name) + cursor = results.fetchone() + logs = next(iter(cursor)).split("\n") + print_log_lines(sys.stdout, name, "0", logs) diff --git a/src/snowcli/cli/snowpark/services/manager.py b/src/snowcli/cli/snowpark/services/manager.py new file mode 100644 index 0000000000..93631fc892 --- /dev/null +++ b/src/snowcli/cli/snowpark/services/manager.py @@ -0,0 +1,45 @@ +import hashlib +import os +from pathlib import Path + +from snowcli.cli.common.sql_execution import SqlExecutionMixin + + +class ServiceManager(SqlExecutionMixin): + def create( + self, + service_name: str, + compute_pool: str, + spec_path: Path, + num_instances: int, + stage: str, + ): + spec_filename = os.path.basename(spec_path) + file_hash = hashlib.md5(open(spec_path, "rb").read()).hexdigest() + stage_dir = os.path.join("jobs", file_hash) + return self._execute_query( + f"""\ + CREATE SERVICE IF NOT EXISTS {service_name} + MIN_INSTANCES = {num_instances} + MAX_INSTANCES = {num_instances} + COMPUTE_POOL = {compute_pool} + spec=@{stage}/{stage_dir}/{spec_filename}; + """ + ) + + def desc(self, service_name: str): + return self._execute_query(f"desc service {service_name}") + + def show(self): + return self._execute_query("show services") + + def status(self, service_name: str): + return self._execute_query(f"CALL SYSTEM$GET_SERVICE_STATUS(('{service_name}')") + + def drop(self, service_name: str): + return self._execute_query(f"drop service {service_name}") + + def logs(self, service_name: str, container_name: str): + return self._execute_query( + f"call SYSTEM$GET_SERVICE_LOGS('{service_name}', '0', '{container_name}');" + ) diff --git a/src/snowcli/cli/stage/manager.py b/src/snowcli/cli/stage/manager.py index e596482f2a..edef20c30a 100644 --- a/src/snowcli/cli/stage/manager.py +++ b/src/snowcli/cli/stage/manager.py @@ -22,7 +22,13 @@ def get(self, stage_name: str, dest_path: Path): stage_name = self.get_standard_stage_name(stage_name) return self._execute_query(f"get {stage_name} file://{dest_path}/") - def put(self, local_path: str, stage_name: str, parallel: int, overwrite: bool): + def put( + self, + local_path: str, + stage_name: str, + parallel: int = 4, + overwrite: bool = False, + ): stage_name = self.get_standard_stage_name(stage_name) return self._execute_query( f"put file://{local_path} {stage_name} " diff --git a/src/snowcli/snow_connector.py b/src/snowcli/snow_connector.py index 3bd73d269b..4ded027e6d 100644 --- a/src/snowcli/snow_connector.py +++ b/src/snowcli/snow_connector.py @@ -509,272 +509,6 @@ def describe_streamlit(self, name, database, schema, role, warehouse): ) return (description, url) - def create_service( - self, - name: str, - compute_pool: str, - spec_path: str, - role: str, - warehouse: str, - database: str, - num_instances: int, - schema: str, - stage: str, - ) -> SnowflakeCursor: - spec_filename = os.path.basename(spec_path) - file_hash = hashlib.md5(open(spec_path, "rb").read()).hexdigest() - stage_dir = os.path.join("services", file_hash) - return self.run_sql( - "snowservices/services/create_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - "num_instances": num_instances, - "compute_pool": compute_pool, - "spec_path": spec_path, - "stage_dir": stage_dir, - "stage_filename": spec_filename, - "stage": stage, - }, - ) - - def desc_service( - self, name: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/services/desc_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - }, - ) - - def status_service( - self, name: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/services/status_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - }, - ) - - def list_service( - self, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/services/list_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - }, - ) - - def drop_service( - self, name: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/services/drop_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - }, - ) - - def logs_service( - self, - name: str, - instance_id: str, - container_name: str, - role: str, - warehouse: str, - database: str, - schema: str, - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/services/logs_service", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - "instance_id": instance_id, - "container_name": container_name, - }, - ) - - def create_job( - self, - compute_pool: str, - spec_path: str, - role: str, - warehouse: str, - database: str, - schema: str, - stage: str, - ) -> SnowflakeCursor: - spec_filename = os.path.basename(spec_path) - file_hash = hashlib.md5(open(spec_path, "rb").read()).hexdigest() - stage_dir = os.path.join("jobs", file_hash) - return self.run_sql( - "snowservices/jobs/create_job", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "compute_pool": compute_pool, - "spec_path": spec_path, - "stage_dir": stage_dir, - "stage_filename": spec_filename, - "stage": stage, - }, - ) - - def desc_job( - self, id: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/jobs/desc_job", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "id": id, - }, - ) - - def status_job( - self, id: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/jobs/status_job", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "id": id, - }, - ) - - def drop_job( - self, id: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/jobs/drop_job", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "id": id, - }, - ) - - def logs_job( - self, - id: str, - container_name: str, - role: str, - warehouse: str, - database: str, - schema: str, - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/jobs/logs_job", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "id": id, - "container_name": container_name, - }, - ) - - def create_compute_pool( - self, - name: str, - num_instances: int, - instance_family: str, - role: str, - warehouse: str, - database: str, - schema: str, - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/compute_pool/create_compute_pool", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - "min_node": num_instances, - "max_node": num_instances, - "instance_family": instance_family, - }, - ) - - def stop_compute_pool( - self, role: str, warehouse: str, database: str, schema: str, name: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/compute_pool/stop_compute_pools", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - }, - ) - - def drop_compute_pool( - self, name: str, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/compute_pool/drop_compute_pool", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - "name": name, - }, - ) - - def list_compute_pools( - self, role: str, warehouse: str, database: str, schema: str - ) -> SnowflakeCursor: - return self.run_sql( - "snowservices/compute_pool/list_compute_pools", - { - "database": database, - "schema": schema, - "role": role, - "warehouse": warehouse, - }, - ) - def run_sql( self, command, diff --git a/src/snowcli/sql/snowservices/compute_pool/create_compute_pool.sql b/src/snowcli/sql/snowservices/compute_pool/create_compute_pool.sql deleted file mode 100644 index f74e8080d3..0000000000 --- a/src/snowcli/sql/snowservices/compute_pool/create_compute_pool.sql +++ /dev/null @@ -1,6 +0,0 @@ -{% include "set_env.sql" %} - -CREATE COMPUTE POOL {{ name }} - MIN_NODES = {{ min_node }} - MAX_NODES = {{ max_node }} - INSTANCE_FAMILY = {{ instance_family }}; diff --git a/src/snowcli/sql/snowservices/compute_pool/drop_compute_pool.sql b/src/snowcli/sql/snowservices/compute_pool/drop_compute_pool.sql deleted file mode 100644 index 43d7b5cff3..0000000000 --- a/src/snowcli/sql/snowservices/compute_pool/drop_compute_pool.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -drop compute pool {{ name }}; diff --git a/src/snowcli/sql/snowservices/compute_pool/list_compute_pools.sql b/src/snowcli/sql/snowservices/compute_pool/list_compute_pools.sql deleted file mode 100644 index bc65df368a..0000000000 --- a/src/snowcli/sql/snowservices/compute_pool/list_compute_pools.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -show compute pools; diff --git a/src/snowcli/sql/snowservices/compute_pool/stop_compute_pools.sql b/src/snowcli/sql/snowservices/compute_pool/stop_compute_pools.sql deleted file mode 100644 index 5007864b3f..0000000000 --- a/src/snowcli/sql/snowservices/compute_pool/stop_compute_pools.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -alter compute pool {{ name }} stop all services; diff --git a/src/snowcli/sql/snowservices/jobs/create_job.sql b/src/snowcli/sql/snowservices/jobs/create_job.sql deleted file mode 100644 index cd8eca6b5c..0000000000 --- a/src/snowcli/sql/snowservices/jobs/create_job.sql +++ /dev/null @@ -1,9 +0,0 @@ -{% include "set_env.sql" %} - -CREATE STAGE IF NOT EXISTS {{ stage }}; - -put file://{{ spec_path }} @{{ stage }}/{{ stage_dir }} auto_compress=false OVERWRITE = TRUE; - -EXECUTE SERVICE - COMPUTE_POOL = {{ compute_pool }} - spec=@{{ stage }}/{{ stage_dir }}/{{ stage_filename }}; diff --git a/src/snowcli/sql/snowservices/jobs/desc_job.sql b/src/snowcli/sql/snowservices/jobs/desc_job.sql deleted file mode 100644 index e8549c59a3..0000000000 --- a/src/snowcli/sql/snowservices/jobs/desc_job.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -desc service {{ id }}; diff --git a/src/snowcli/sql/snowservices/jobs/drop_job.sql b/src/snowcli/sql/snowservices/jobs/drop_job.sql deleted file mode 100644 index 69f7a91737..0000000000 --- a/src/snowcli/sql/snowservices/jobs/drop_job.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -call SYSTEM$CANCEL_JOB('{{ id }}'); diff --git a/src/snowcli/sql/snowservices/jobs/logs_job.sql b/src/snowcli/sql/snowservices/jobs/logs_job.sql deleted file mode 100644 index 9318381729..0000000000 --- a/src/snowcli/sql/snowservices/jobs/logs_job.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -call SYSTEM$GET_JOB_LOGS('{{ id }}', '{{ container_name }}'); diff --git a/src/snowcli/sql/snowservices/jobs/status_job.sql b/src/snowcli/sql/snowservices/jobs/status_job.sql deleted file mode 100644 index 4ab2ad67c4..0000000000 --- a/src/snowcli/sql/snowservices/jobs/status_job.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -CALL SYSTEM$GET_JOB_STATUS('{{ id }}'); diff --git a/src/snowcli/sql/snowservices/services/create_service.sql b/src/snowcli/sql/snowservices/services/create_service.sql deleted file mode 100644 index d7fc274815..0000000000 --- a/src/snowcli/sql/snowservices/services/create_service.sql +++ /dev/null @@ -1,11 +0,0 @@ -{% include "set_env.sql" %} - -CREATE STAGE IF NOT EXISTS {{ stage }}; - -put file://{{ spec_path }} @{{ stage }}/{{ stage_dir }} auto_compress=false OVERWRITE = TRUE; - -CREATE SERVICE IF NOT EXISTS {{ name }} - MIN_INSTANCES = {{ num_instances }} - MAX_INSTANCES = {{ num_instances }} - COMPUTE_POOL = {{ compute_pool }} - spec=@{{ stage }}/{{ stage_dir }}/{{ stage_filename }}; diff --git a/src/snowcli/sql/snowservices/services/desc_service.sql b/src/snowcli/sql/snowservices/services/desc_service.sql deleted file mode 100644 index 3269d11a67..0000000000 --- a/src/snowcli/sql/snowservices/services/desc_service.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -desc service {{ name }}; diff --git a/src/snowcli/sql/snowservices/services/drop_service.sql b/src/snowcli/sql/snowservices/services/drop_service.sql deleted file mode 100644 index 6229044517..0000000000 --- a/src/snowcli/sql/snowservices/services/drop_service.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -DROP SERVICE {{ name }}; diff --git a/src/snowcli/sql/snowservices/services/list_service.sql b/src/snowcli/sql/snowservices/services/list_service.sql deleted file mode 100644 index ece16bfe48..0000000000 --- a/src/snowcli/sql/snowservices/services/list_service.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -show services; diff --git a/src/snowcli/sql/snowservices/services/logs_service.sql b/src/snowcli/sql/snowservices/services/logs_service.sql deleted file mode 100644 index 80c4b4912b..0000000000 --- a/src/snowcli/sql/snowservices/services/logs_service.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -call SYSTEM$GET_SERVICE_LOGS('{{ name }}', '{{ instance_id }}', '{{ container_name }}'); diff --git a/src/snowcli/sql/snowservices/services/status_service.sql b/src/snowcli/sql/snowservices/services/status_service.sql deleted file mode 100644 index 692b2011f7..0000000000 --- a/src/snowcli/sql/snowservices/services/status_service.sql +++ /dev/null @@ -1,3 +0,0 @@ -{% include "set_env.sql" %} - -CALL SYSTEM$GET_SERVICE_STATUS('{{ name }}'); diff --git a/tests/__snapshots__/test_snow_connector.ambr b/tests/__snapshots__/test_snow_connector.ambr index 5ed2a66c9e..b45d8cae4a 100644 --- a/tests/__snapshots__/test_snow_connector.ambr +++ b/tests/__snapshots__/test_snow_connector.ambr @@ -1,17 +1,4 @@ # serializer version: 1 -# name: test_create_cp - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - CREATE COMPUTE POOL nameValue - MIN_NODES = 42 - MAX_NODES = 42 - INSTANCE_FAMILY = instance_familyValue; - ''' -# --- # name: test_create_function ''' use role roleValue; @@ -47,24 +34,6 @@ describe PROCEDURE nameValue(a, b); ''' # --- -# name: test_create_service - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - CREATE STAGE IF NOT EXISTS stageValue; - - put file://test_spec.yaml @stageValue/services/4231 auto_compress=false OVERWRITE = TRUE; - - CREATE SERVICE IF NOT EXISTS nameValue - MIN_INSTANCES = 42 - MAX_INSTANCES = 42 - COMPUTE_POOL = compute_poolValue - spec=@stageValue/services/4231/test_spec.yaml; - ''' -# --- # name: test_create_streamlit ''' use role roleValue; @@ -109,26 +78,6 @@ CALL SYSTEM$GENERATE_STREAMLIT_URL_FROM_NAME('nameValue'); ''' # --- -# name: test_desc_job - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - desc service idValue; - ''' -# --- -# name: test_desc_services - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - desc service nameValue; - ''' -# --- # name: test_describe_function ''' use role roleValue; @@ -157,16 +106,6 @@ CALL SYSTEM$GENERATE_STREAMLIT_URL_FROM_NAME('nameValue'); ''' # --- -# name: test_drop_cp - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - drop compute pool nameValue; - ''' -# --- # name: test_drop_function ''' use role roleValue; @@ -176,16 +115,6 @@ drop function signatureValue; ''' # --- -# name: test_drop_job - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - call SYSTEM$CANCEL_JOB('idValue'); - ''' -# --- # name: test_drop_procedure ''' use role roleValue; @@ -195,16 +124,6 @@ drop procedure signatureValue; ''' # --- -# name: test_drop_services - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - DROP SERVICE nameValue; - ''' -# --- # name: test_drop_streamlit ''' use role roleValue; @@ -232,32 +151,6 @@ call procedureValue; ''' # --- -# name: test_job_service - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - CREATE STAGE IF NOT EXISTS stageValue; - - put file://test_spec.yaml @stageValue/jobs/4231 auto_compress=false OVERWRITE = TRUE; - - EXECUTE SERVICE - COMPUTE_POOL = compute_poolValue - spec=@stageValue/jobs/4231/test_spec.yaml; - ''' -# --- -# name: test_list_cp - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - show compute pools; - ''' -# --- # name: test_list_functions ''' use role roleValue; @@ -278,16 +171,6 @@ select "name", "created_on", "arguments" from table(result_scan(last_query_id())); ''' # --- -# name: test_list_services - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - show services; - ''' -# --- # name: test_list_streamlits ''' use role roleValue; @@ -297,26 +180,6 @@ show streamlits; ''' # --- -# name: test_logs_job - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - call SYSTEM$GET_JOB_LOGS('idValue', 'container_nameValue'); - ''' -# --- -# name: test_logs_services - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - call SYSTEM$GET_SERVICE_LOGS('nameValue', '0', 'container_nameValue'); - ''' -# --- # name: test_set_procedure_comment ''' use role roleValue; @@ -336,36 +199,6 @@ grant usage on streamlit nameValue to role to_roleValue; ''' # --- -# name: test_status_job - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - CALL SYSTEM$GET_JOB_STATUS('idValue'); - ''' -# --- -# name: test_status_services - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - CALL SYSTEM$GET_SERVICE_STATUS('nameValue'); - ''' -# --- -# name: test_stop_cp - ''' - use role roleValue; - use warehouse warehouseValue; - use database databaseValue; - use schema schemaValue; - - alter compute pool nameValue stop all services; - ''' -# --- # name: test_upload_file_to_stage[namedStageValue-False] ''' use role roleValue; diff --git a/tests/conftest.py b/tests/conftest.py index b8be549320..3b9a8fa108 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,11 @@ from __future__ import annotations import functools +from io import StringIO from pathlib import Path from tempfile import NamedTemporaryFile from typing import List, NamedTuple +from unittest import mock import pytest @@ -30,6 +32,7 @@ def __init__(self, app: Typer, test_snowcli_config: str): @functools.wraps(CliRunner.invoke) def invoke(self, *a, **kw): + kw.update(catch_exceptions=False) return super().invoke(self.app, *a, **kw) def invoke_with_config(self, *args, **kwargs): @@ -56,6 +59,9 @@ def __init__(self, rows: List[tuple], columns: List[str]): self._rows = rows self._columns = [MockResultMetadata(c) for c in columns] + def fetchone(self): + return self.fetchall() + def fetchall(self): yield from self._rows @@ -68,3 +74,39 @@ def from_input(cls, rows, columns): return cls(rows, columns) return _MockCursor.from_input + + +@pytest.fixture() +def mock_ctx(mock_cursor): + class _MockConnectionCtx(mock.MagicMock): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.queries: List[str] = [] + + def get_query(self): + return "\n".join(self.queries) + + @property + def warehouse(self): + return "MockWarehouse" + + @property + def database(self): + return "MockDatabase" + + @property + def schema(self): + return "MockSchema" + + @property + def role(self): + return "mockRole" + + def execute_string(self, query: str): + self.queries.append(query) + return (mock_cursor(["row"], []),) + + def execute_stream(self, query: StringIO): + return self.execute_string(query.read()) + + return _MockConnectionCtx() diff --git a/tests/snowpark/test_compute_pool.py b/tests/snowpark/test_compute_pool.py new file mode 100644 index 0000000000..c10b42bf8a --- /dev/null +++ b/tests/snowpark/test_compute_pool.py @@ -0,0 +1,65 @@ +from unittest import mock + + +@mock.patch("snowflake.connector.connect") +def test_create_cp(mock_connector, runner, mock_ctx, snapshot): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke( + [ + "snowpark", + "cp", + "create", + "--name", + "cpName", + "--num", + "42", + "--family", + "familyValue", + ] + ) + + assert result.exit_code == 0 + assert ( + ctx.get_query() + == """\ +CREATE COMPUTE POOL cpName +MIN_NODES = 42 +MAX_NODES = 42 +INSTANCE_FAMILY = familyValue; +""" + ) + + +@mock.patch("snowflake.connector.connect") +def test_list_cp(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "cp", "list"]) + + assert result.exit_code == 0 + assert ctx.get_query() == "show compute pools;" + + +@mock.patch("snowflake.connector.connect") +def test_drop_cp(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "cp", "drop", "cpNameToDrop"]) + + assert result.exit_code == 0 + assert ctx.get_query() == "drop compute pool cpNameToDrop;" + + +@mock.patch("snowflake.connector.connect") +def test_stop_cp(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "cp", "stop", "cpNameToStop"]) + + assert result.exit_code == 0 + assert ctx.get_query() == "alter compute pool cpNameToStop stop all services;" diff --git a/tests/snowpark/test_jobs.py b/tests/snowpark/test_jobs.py new file mode 100644 index 0000000000..43e387556d --- /dev/null +++ b/tests/snowpark/test_jobs.py @@ -0,0 +1,83 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile +from unittest import mock + + +@mock.patch("snowflake.connector.connect") +def test_create_job(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + with NamedTemporaryFile( + prefix="spec", suffix="yaml", dir=Path(__file__).parent + ) as fh: + name = fh.name + result = runner.invoke( + [ + "snowpark", + "jobs", + "create", + "--compute-pool", + "jobName", + "--spec-path", + fh.name, + "--stage", + "stageValue", + ] + ) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == ( + "create stage if not exists stageValue\n" + f"put file://{name} " + "@stageValue auto_compress=false parallel=4 overwrite=True\n" + "EXECUTE SERVICE\n" + "COMPUTE_POOL = jobName\n" + f"spec=@stageValue/jobs/d41d8cd98f00b204e9800998ecf8427e/{Path(name).stem};\n" + ) + + +@mock.patch("snowflake.connector.connect") +def test_desc_job(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "jobs", "desc", "jobName"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "desc service jobName" + + +@mock.patch("snowflake.connector.connect") +def test_job_status(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "jobs", "status", "jobName"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "CALL SYSTEM$GET_JOB_STATUS('jobName')" + + +@mock.patch("snowflake.connector.connect") +def test_job_logs(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke( + ["snowpark", "jobs", "logs", "--container-name", "containerName", "jobName"] + ) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "call SYSTEM$GET_JOB_LOGS('jobName', 'containerName')" + + +@mock.patch("snowflake.connector.connect") +def test_drop_job(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "jobs", "drop", "cpNameToDrop"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "CALL SYSTEM$CANCEL_JOB('cpNameToDrop')" diff --git a/tests/snowpark/test_services.py b/tests/snowpark/test_services.py new file mode 100644 index 0000000000..4161195866 --- /dev/null +++ b/tests/snowpark/test_services.py @@ -0,0 +1,97 @@ +from pathlib import Path +from tempfile import NamedTemporaryFile +from unittest import mock + + +@mock.patch("snowflake.connector.connect") +def test_create_service(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + with NamedTemporaryFile() as fh: + result = runner.invoke( + [ + "snowpark", + "services", + "create", + "--name", + "serviceName", + "--compute_pool", + "computePoolValue", + "--spec_path", + fh.name, + "--num_instances", + 42, + "--stage", + "stagName", + ] + ) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == ( + ( + "create stage if not exists stagName\n" + f"put file://{fh.name} " + "@stagName auto_compress=false parallel=4 overwrite=True\n" + "CREATE SERVICE IF NOT EXISTS serviceName\n" + "MIN_INSTANCES = 42\n" + "MAX_INSTANCES = 42\n" + "COMPUTE_POOL = computePoolValue\n" + f"spec=@stagName/jobs/d41d8cd98f00b204e9800998ecf8427e/{Path(fh.name).stem};\n" + ) + ) + + +@mock.patch("snowflake.connector.connect") +def test_list_service(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "services", "list"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "show services" + + +@mock.patch("snowflake.connector.connect") +def test_drop_service(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "services", "drop", "serviceName"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "drop service serviceName" + + +@mock.patch("snowflake.connector.connect") +def test_service_status(mock_connector, runner, mock_ctx): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke(["snowpark", "services", "status", "serviceName"]) + + assert result.exit_code == 0, result.output + assert ctx.get_query() == "CALL SYSTEM$GET_SERVICE_STATUS(('serviceName')" + + +@mock.patch("snowflake.connector.connect") +def test_service_logs(mock_connector, runner, mock_ctx, snapshot): + ctx = mock_ctx() + mock_connector.return_value = ctx + + result = runner.invoke( + [ + "snowpark", + "services", + "logs", + "--container_name", + "containerName", + "serviceName", + ] + ) + + assert result.exit_code == 0, result.output + assert ( + ctx.get_query() + == "call SYSTEM$GET_SERVICE_LOGS('serviceName', '0', 'containerName');" + ) diff --git a/tests/test_snow_connector.py b/tests/test_snow_connector.py index ac7b595f17..c7c82fd01c 100644 --- a/tests/test_snow_connector.py +++ b/tests/test_snow_connector.py @@ -385,157 +385,6 @@ def test_describe_streamlit(_, snapshot): assert query.getvalue() == snapshot -@mock.patch("snowflake.connector") -def test_create_cp(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.create_compute_pool( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - num_instances=42, - instance_family="instance_familyValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_list_cp(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.list_compute_pools( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_drop_cp(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.drop_compute_pool( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_stop_cp(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.stop_compute_pool( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowcli.snow_connector.hashlib.md5") -@mock.patch("snowcli.snow_connector.open") -@mock.patch("snowflake.connector") -def test_job_service(_, __, mock_md5, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - mock_md5.return_value.hexdigest.return_value = "4231" - - connector.create_job( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - compute_pool="compute_poolValue", - spec_path="test_spec.yaml", - stage="stageValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_desc_job(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.desc_job( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - id="idValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_logs_job(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.logs_job( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - id="idValue", - container_name="container_nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_status_job(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.status_job( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - id="idValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_drop_job(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.drop_job( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - id="idValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - @mock.patch("snowcli.cli.snowpark.registry.connect_to_snowflake") def test_registry_get_token(mock_conn, runner): mock_conn.return_value.ctx._rest._token_request.return_value = { @@ -549,109 +398,6 @@ def test_registry_get_token(mock_conn, runner): assert result.stdout == '{"token": "token1234", "expires_in": 42}' -@mock.patch("snowcli.snow_connector.hashlib.md5") -@mock.patch("snowcli.snow_connector.open") -@mock.patch("snowflake.connector") -def test_create_service(_, __, mock_md5, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - mock_md5.return_value.hexdigest.return_value = "4231" - connector.create_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - compute_pool="compute_poolValue", - num_instances=42, - spec_path="test_spec.yaml", - stage="stageValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_desc_services(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.desc_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_logs_services(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.logs_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - instance_id="0", - container_name="container_nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_status_services(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.status_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_list_services(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.list_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - -@mock.patch("snowflake.connector") -def test_drop_services(_, snapshot): - connector = SnowflakeConnector(connection_parameters=MOCK_CONNECTION) - connector.ctx.execute_stream.return_value = (None, None) - - connector.drop_service( - database="databaseValue", - schema="schemaValue", - role="roleValue", - warehouse="warehouseValue", - name="nameValue", - ) - query, *_ = connector.ctx.execute_stream.call_args.args - assert query.getvalue() == snapshot - - @mock.patch.dict(os.environ, {}, clear=True) def test_returns_nice_error_in_case_of_connectivity_error(runner): result = runner.invoke_with_config(["sql", "-q", "select 1"])