Skip to content

Commit

Permalink
Defer fetching role and warehouse until they're needed (#1631)
Browse files Browse the repository at this point in the history
We have some commands, like `snow app bundle`/`snow ws bundle` that don't even need a connection, so let's defer connecting to Snowflake until it's actually needed by the particular action being run.

For now, all `snow ws` commands are still marked as requiring a connection, that will be fixed separately.
  • Loading branch information
sfc-gh-fcampbell authored Sep 27, 2024
1 parent 9a234a3 commit d8358a9
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 14 deletions.
15 changes: 12 additions & 3 deletions src/snowflake/cli/_plugins/workspace/action_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Callable, Optional
from typing import Callable

from snowflake.cli.api.console.abc import AbstractConsole

Expand All @@ -13,6 +14,14 @@ class ActionContext:

console: AbstractConsole
project_root: Path
default_role: str
default_warehouse: Optional[str]
get_default_role: Callable[[], str]
get_default_warehouse: Callable[[], str | None]
get_entity: Callable

@cached_property
def default_role(self) -> str:
return self.get_default_role()

@cached_property
def default_warehouse(self) -> str | None:
return self.get_default_warehouse()
25 changes: 16 additions & 9 deletions src/snowflake/cli/_plugins/workspace/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ def __init__(self, project_definition: ProjectDefinition, project_root: Path):
self._entities_cache: Dict[str, Entity] = {}
self._project_definition: DefinitionV20 = project_definition
self._project_root = project_root
self._default_role = default_role()
if self._default_role is None:
self._default_role = get_sql_executor().current_role()
self.default_warehouse = None
cli_context = get_cli_context()
if cli_context.connection.warehouse:
self.default_warehouse = to_identifier(cli_context.connection.warehouse)

def get_entity(self, entity_id: str):
"""
Expand All @@ -62,8 +55,8 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs)
action_ctx = ActionContext(
console=cc,
project_root=self.project_root(),
default_role=self._default_role,
default_warehouse=self.default_warehouse,
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
get_entity=self.get_entity,
)
return entity.perform(action, action_ctx, *args, **kwargs)
Expand All @@ -72,3 +65,17 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs)

def project_root(self) -> Path:
return self._project_root


def _get_default_role() -> str:
role = default_role()
if role is None:
role = get_sql_executor().current_role()
return role


def _get_default_warehouse() -> str | None:
warehouse = get_cli_context().connection.warehouse
if warehouse:
warehouse = to_identifier(warehouse)
return warehouse
4 changes: 2 additions & 2 deletions tests/workspace/test_application_package_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _get_app_pkg_entity(project_directory):
action_ctx = ActionContext(
console=mock_console,
project_root=project_root,
default_role="app_role",
default_warehouse="wh",
get_default_role=lambda: "app_role",
get_default_warehouse=lambda: "wh",
get_entity=lambda *args: None,
)
return ApplicationPackageEntity(model), action_ctx, mock_console
Expand Down

0 comments on commit d8358a9

Please sign in to comment.