Skip to content

Commit

Permalink
Pass context in entity init instead of only to actions
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-fcampbell committed Oct 1, 2024
1 parent 56041f1 commit 65bd816
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 35 deletions.
5 changes: 2 additions & 3 deletions src/snowflake/cli/_plugins/nativeapp/entities/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
SameAccountInstallMethod,
)
from snowflake.cli._plugins.nativeapp.utils import needs_confirmation
from snowflake.cli._plugins.workspace.action_context import ActionContext
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.console.abc import AbstractConsole
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor
Expand Down Expand Up @@ -118,7 +117,6 @@ class ApplicationEntity(EntityBase[ApplicationEntityModel]):

def action_deploy(
self,
ctx: ActionContext,
from_release_directive: bool,
prune: bool,
recursive: bool,
Expand All @@ -133,6 +131,7 @@ def action_deploy(
**kwargs,
):
model = self._entity_model
ctx = self._action_ctx
app_name = model.fqn.identifier
debug_mode = model.debug
if model.meta:
Expand Down Expand Up @@ -202,14 +201,14 @@ def deploy_package():

def action_drop(
self,
ctx: ActionContext,
interactive: bool,
force_drop: bool = False,
cascade: Optional[bool] = None,
*args,
**kwargs,
):
model = self._entity_model
ctx = self._action_ctx
app_name = model.fqn.identifier
if model.meta and model.meta.role:
app_role = model.meta.role
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
from snowflake.cli._plugins.nativeapp.utils import needs_confirmation
from snowflake.cli._plugins.stage.diff import DiffResult
from snowflake.cli._plugins.stage.manager import StageManager
from snowflake.cli._plugins.workspace.action_context import ActionContext
from snowflake.cli.api.cli_global_context import get_cli_context
from snowflake.cli.api.console.abc import AbstractConsole
from snowflake.cli.api.entities.common import EntityBase, get_sql_executor
Expand Down Expand Up @@ -150,8 +149,9 @@ class ApplicationPackageEntity(EntityBase[ApplicationPackageEntityModel]):
A Native App application package.
"""

def action_bundle(self, ctx: ActionContext, *args, **kwargs):
def action_bundle(self, *args, **kwargs):
model = self._entity_model
ctx = self._action_ctx
return self.bundle(
project_root=ctx.project_root,
deploy_root=Path(model.deploy_root),
Expand All @@ -163,7 +163,6 @@ def action_bundle(self, ctx: ActionContext, *args, **kwargs):

def action_deploy(
self,
ctx: ActionContext,
prune: bool,
recursive: bool,
paths: List[Path],
Expand All @@ -175,6 +174,7 @@ def action_deploy(
**kwargs,
):
model = self._entity_model
ctx = self._action_ctx
package_name = model.fqn.identifier

if force:
Expand Down Expand Up @@ -209,8 +209,9 @@ def action_deploy(
policy=policy,
)

def action_drop(self, ctx: ActionContext, force_drop: bool, *args, **kwargs):
def action_drop(self, force_drop: bool, *args, **kwargs):
model = self._entity_model
ctx = self._action_ctx
package_name = model.fqn.identifier
if model.meta and model.meta.role:
package_role = model.meta.role
Expand All @@ -224,10 +225,9 @@ def action_drop(self, ctx: ActionContext, force_drop: bool, *args, **kwargs):
force_drop=force_drop,
)

def action_validate(
self, ctx: ActionContext, interactive: bool, force: bool, *args, **kwargs
):
def action_validate(self, interactive: bool, force: bool, *args, **kwargs):
model = self._entity_model
ctx = self._action_ctx
package_name = model.fqn.identifier
if force:
policy = AllowAlwaysPolicy()
Expand Down Expand Up @@ -261,18 +261,16 @@ def action_validate(
)
ctx.console.message("Setup script is valid")

def action_version_list(
self, ctx: ActionContext, *args, **kwargs
) -> SnowflakeCursor:
def action_version_list(self, *args, **kwargs) -> SnowflakeCursor:
model = self._entity_model
ctx = self._action_ctx
return self.version_list(
package_name=model.fqn.identifier,
package_role=(model.meta and model.meta.role) or ctx.default_role,
)

def action_version_create(
self,
ctx: ActionContext,
version: Optional[str],
patch: Optional[int],
skip_git_check: bool,
Expand All @@ -282,6 +280,7 @@ def action_version_create(
**kwargs,
):
model = self._entity_model
ctx = self._action_ctx
package_name = model.fqn.identifier
return self.version_create(
console=ctx.console,
Expand Down Expand Up @@ -313,14 +312,14 @@ def action_version_create(

def action_version_drop(
self,
ctx: ActionContext,
version: Optional[str],
interactive: bool,
force: bool,
*args,
**kwargs,
):
model = self._entity_model
ctx = self._action_ctx
package_name = model.fqn.identifier
return self.version_drop(
console=ctx.console,
Expand Down
18 changes: 9 additions & 9 deletions src/snowflake/cli/_plugins/workspace/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ def get_entity(self, entity_id: str):
raise ValueError(f"No such entity ID: {entity_id}")
entity_model_cls = entity_model.__class__
entity_cls = v2_entity_model_to_entity_map[entity_model_cls]
self._entities_cache[entity_id] = entity_cls(entity_model)
action_ctx = ActionContext(
console=cc,
project_root=self.project_root(),
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
get_entity=self.get_entity,
)
self._entities_cache[entity_id] = entity_cls(entity_model, action_ctx)
return self._entities_cache[entity_id]

def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs):
Expand All @@ -52,14 +59,7 @@ def perform_action(self, entity_id: str, action: EntityActions, *args, **kwargs)
"""
entity = self.get_entity(entity_id)
if entity.supports(action):
action_ctx = ActionContext(
console=cc,
project_root=self.project_root(),
get_default_role=_get_default_role,
get_default_warehouse=_get_default_warehouse,
get_entity=self.get_entity,
)
return entity.perform(action, action_ctx, *args, **kwargs)
return entity.perform(action, *args, **kwargs)
else:
raise ValueError(f'This entity type does not support "{action.value}"')

Expand Down
9 changes: 4 additions & 5 deletions src/snowflake/cli/api/entities/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ class EntityBase(Generic[T]):
Base class for the fully-featured entity classes.
"""

def __init__(self, entity_model: T):
def __init__(self, entity_model: T, action_ctx: ActionContext):
self._entity_model = entity_model
self._action_ctx = action_ctx

@classmethod
def get_entity_model_type(cls) -> Type[T]:
Expand All @@ -42,13 +43,11 @@ def supports(self, action: EntityActions) -> bool:
"""
return callable(getattr(self, action, None))

def perform(
self, action: EntityActions, action_ctx: ActionContext, *args, **kwargs
):
def perform(self, action: EntityActions, *args, **kwargs):
"""
Performs the requested action.
"""
return getattr(self, action)(action_ctx, *args, **kwargs)
return getattr(self, action)(*args, **kwargs)


def get_sql_executor() -> SqlExecutor:
Expand Down
11 changes: 5 additions & 6 deletions tests/workspace/test_application_package_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def _get_app_pkg_entity(project_directory):
get_default_warehouse=lambda: "wh",
get_entity=lambda *args: None,
)
return ApplicationPackageEntity(model), action_ctx, mock_console
return ApplicationPackageEntity(model, action_ctx), mock_console


def test_bundle(project_directory):
app_pkg, bundle_ctx, mock_console = _get_app_pkg_entity(project_directory)
app_pkg, mock_console = _get_app_pkg_entity(project_directory)

bundle_result = app_pkg.action_bundle(bundle_ctx)
bundle_result = app_pkg.action_bundle()

deploy_root = bundle_result.deploy_root()
assert (deploy_root / "README.md").exists()
Expand Down Expand Up @@ -130,10 +130,9 @@ def test_deploy(
)
mock_execute.side_effect = side_effects

app_pkg, bundle_ctx, mock_console = _get_app_pkg_entity(project_directory)
app_pkg, mock_console = _get_app_pkg_entity(project_directory)

app_pkg.action_deploy(
bundle_ctx,
prune=False,
recursive=False,
paths=["a/b", "c"],
Expand All @@ -158,7 +157,7 @@ def test_deploy(
mock_validate.assert_called_once()
mock_execute_post_deploy_hooks.assert_called_once_with(
console=mock_console,
project_root=bundle_ctx.project_root,
project_root=app_pkg._action_ctx.project_root, # noqa SLF001
post_deploy_hooks=[
SqlScriptHookType(sql_script="scripts/package_post_deploy1.sql"),
SqlScriptHookType(sql_script="scripts/package_post_deploy2.sql"),
Expand Down

0 comments on commit 65bd816

Please sign in to comment.