diff --git a/pyproject.toml b/pyproject.toml index 45ef703f..587add6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ test = [ "twine", "types-Flask", ] +snowflake = ["snowflake-cli"] [project.urls] Repository = "http://github.com/posit-dev/rsconnect-python" diff --git a/rsconnect/actions.py b/rsconnect/actions.py index 2c57b664..917588e6 100644 --- a/rsconnect/actions.py +++ b/rsconnect/actions.py @@ -172,6 +172,14 @@ def test_rstudio_server(server: api.PositServer): raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc)) +def test_spcs_server(server: api.SPCSConnectServer): + with api.RSConnectClient(server) as client: + try: + client.me() + except RSConnectException as exc: + raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc)) + + def test_api_key(connect_server: api.RSConnectServer) -> str: """ Test that an API Key may be used to authenticate with the given Posit Connect server. diff --git a/rsconnect/actions_content.py b/rsconnect/actions_content.py index 562c09a7..928e55ee 100644 --- a/rsconnect/actions_content.py +++ b/rsconnect/actions_content.py @@ -13,7 +13,7 @@ import semver -from .api import RSConnectClient, RSConnectServer, emit_task_log +from .api import PositConnectServer, RSConnectClient, emit_task_log from .exception import RSConnectException from .log import logger from .metadata import ContentBuildStore, ContentItemWithBuildState @@ -33,7 +33,7 @@ def content_build_store() -> ContentBuildStore: return _content_build_store -def ensure_content_build_store(connect_server: RSConnectServer) -> ContentBuildStore: +def ensure_content_build_store(connect_server: PositConnectServer) -> ContentBuildStore: global _content_build_store if not _content_build_store: logger.info("Initializing ContentBuildStore for %s" % connect_server.url) @@ -42,7 +42,7 @@ def ensure_content_build_store(connect_server: RSConnectServer) -> ContentBuildS def build_add_content( - connect_server: RSConnectServer, + connect_server: PositConnectServer, content_guids_with_bundle: Sequence[ContentGuidWithBundle], ): """ @@ -85,7 +85,7 @@ def _validate_build_rm_args(guid: Optional[str], all: bool, purge: bool): def build_remove_content( - connect_server: RSConnectServer, + connect_server: PositConnectServer, guid: Optional[str], all: bool, purge: bool, @@ -109,7 +109,7 @@ def build_remove_content( return guids -def build_list_content(connect_server: RSConnectServer, guid: str, status: Optional[str]): +def build_list_content(connect_server: PositConnectServer, guid: str, status: Optional[str]): build_store = ensure_content_build_store(connect_server) if guid: return [build_store.get_content_item(g) for g in guid] @@ -117,12 +117,12 @@ def build_list_content(connect_server: RSConnectServer, guid: str, status: Optio return build_store.get_content_items(status=status) -def build_history(connect_server: RSConnectServer, guid: str): +def build_history(connect_server: PositConnectServer, guid: str): return ensure_content_build_store(connect_server).get_build_history(guid) def build_start( - connect_server: RSConnectServer, + connect_server: PositConnectServer, parallelism: int, aborted: bool = False, error: bool = False, @@ -251,7 +251,7 @@ def build_start( build_monitor.shutdown() -def _monitor_build(connect_server: RSConnectServer, content_items: list[ContentItemWithBuildState]): +def _monitor_build(connect_server: PositConnectServer, content_items: list[ContentItemWithBuildState]): """ :return bool: True if the build completed without errors, False otherwise """ @@ -296,7 +296,7 @@ def _monitor_build(connect_server: RSConnectServer, content_items: list[ContentI return True -def _build_content_item(connect_server: RSConnectServer, content: ContentItemWithBuildState, poll_wait: int): +def _build_content_item(connect_server: PositConnectServer, content: ContentItemWithBuildState, poll_wait: int): build_store = ensure_content_build_store(connect_server) with RSConnectClient(connect_server) as client: # Pending futures will still try to execute when ThreadPoolExecutor.shutdown() is called @@ -351,7 +351,7 @@ def write_log(line: str): def emit_build_log( - connect_server: RSConnectServer, + connect_server: PositConnectServer, guid: str, format: str, task_id: Optional[str] = None, @@ -369,7 +369,7 @@ def emit_build_log( raise RSConnectException("Log file not found for content: %s" % guid) -def download_bundle(connect_server: RSConnectServer, guid_with_bundle: ContentGuidWithBundle): +def download_bundle(connect_server: PositConnectServer, guid_with_bundle: ContentGuidWithBundle): """ :param guid_with_bundle: models.ContentGuidWithBundle """ @@ -387,7 +387,7 @@ def download_bundle(connect_server: RSConnectServer, guid_with_bundle: ContentGu return client.download_bundle(guid_with_bundle.guid, guid_with_bundle.bundle_id) -def get_content(connect_server: RSConnectServer, guid: str | list[str]): +def get_content(connect_server: PositConnectServer, guid: str | list[str]): """ :param guid: a single guid as a string or list of guids. :return: a list of content items. @@ -401,7 +401,7 @@ def get_content(connect_server: RSConnectServer, guid: str | list[str]): def search_content( - connect_server: RSConnectServer, + connect_server: PositConnectServer, published: bool, unpublished: bool, content_type: Sequence[str], diff --git a/rsconnect/api.py b/rsconnect/api.py index 4d455876..eb3dd237 100644 --- a/rsconnect/api.py +++ b/rsconnect/api.py @@ -32,7 +32,7 @@ overload, ) from urllib import parse -from urllib.parse import urlparse +from urllib.parse import urlencode, urlparse from warnings import warn import click @@ -52,8 +52,8 @@ from . import validation from .bundle import _default_title -from .environment import fake_module_file_from_directory from .certificates import read_certificate_file +from .environment import fake_module_file_from_directory from .exception import DeploymentFailedException, RSConnectException from .http_support import CookieJar, HTTPResponse, HTTPServer, JsonData, append_to_path from .log import cls_logged, connect_logger, console_logger, logger @@ -76,6 +76,7 @@ TaskStatusV1, UserRecord, ) +from .snowflake import generate_jwt, get_connection_parameters from .timeouts import get_task_timeout, get_task_timeout_help_message if TYPE_CHECKING: @@ -235,9 +236,94 @@ def __init__( self.ca_data = ca_data # This is specifically not None. self.cookie_jar = CookieJar() + # for compatibility with RSconnectClient + self.snowflake_connection_name = None + + +class SPCSConnectServer(AbstractRemoteServer): + """ """ + + def __init__( + self, + url: str, + snowflake_connection_name: Optional[str], + insecure: bool = False, + ca_data: Optional[str | bytes] = None, + ): + super().__init__(url, "Posit Connect (SPCS)") + self.snowflake_connection_name = snowflake_connection_name + self.insecure = insecure + self.ca_data = ca_data + # for compatibility with RSConnectClient + self.cookie_jar = CookieJar() + self.api_key = None + self.bootstrap_jwt = None + + def token_endpoint(self) -> str: + params = get_connection_parameters(self.snowflake_connection_name) + + if params is None: + raise RSConnectException("No Snowflake connection found.") + + return "https://{}.snowflakecomputing.com/".format(params["account"]) + + def fmt_payload(self) -> str: + params = get_connection_parameters(self.snowflake_connection_name) + + if params is None: + raise RSConnectException("No Snowflake connection found.") + + spcs_url = urlparse(self.url) + scope = "session:role:{} {}".format(params["role"], spcs_url.netloc) + jwt = generate_jwt(self.snowflake_connection_name) + grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type} + payload = urlencode(payload) + return payload + + def exchange_token(self) -> str: + try: + server = HTTPServer(url=self.token_endpoint()) + payload = self.fmt_payload() + + response = server.request( + method="POST", + path="/oauth/token", + body=payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + response = cast(HTTPResponse, response) + # borrowed from AbstractRemoteServer.handle_bad_response + # since we don't want to pick up its json decoding assumptions + if response.status < 200 or response.status > 299: + raise RSConnectException( + "Received an unexpected response from %s (calling %s): %s %s" + % ( + self.url, + response.full_uri, + response.status, + response.reason, + ) + ) + + # Validate response body exists + if not response.response_body: + raise RSConnectException("Token exchange returned empty response") + + # Ensure we return a string + if isinstance(response.response_body, bytes): + return response.response_body.decode("utf-8") + return response.response_body + + except RSConnectException as e: + raise RSConnectException(f"Failed to exchange Snowflake token: {str(e)}") from e -TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer] + +TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer, SPCSConnectServer] +PositConnectServer = typing.Union[RSConnectServer, SPCSConnectServer] class S3Server(AbstractRemoteServer): @@ -254,7 +340,7 @@ class RSConnectClientDeployResult(TypedDict): class RSConnectClient(HTTPServer): - def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None): + def __init__(self, server: PositConnectServer, cookies: Optional[CookieJar] = None): if cookies is None: cookies = server.cookie_jar super().__init__( @@ -271,6 +357,10 @@ def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None) if server.bootstrap_jwt: self.bootstrap_authorization(server.bootstrap_jwt) + if server.snowflake_connection_name and isinstance(server, SPCSConnectServer): + token = server.exchange_token() + self.snowflake_authorization(token) + def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse: return ( response.json_data @@ -555,6 +645,7 @@ def __init__( name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: bool = False, cacert: Optional[str] = None, ca_data: Optional[str | bytes] = None, @@ -604,6 +695,7 @@ def __init__( name=name, url=url or server, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, ca_data=ca_data, @@ -689,6 +781,7 @@ def setup_remote_server( name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: bool = False, cacert: Optional[str] = None, ca_data: Optional[str | bytes] = None, @@ -700,6 +793,7 @@ def setup_remote_server( ctx=ctx, url=url, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, account_name=account_name, @@ -721,6 +815,8 @@ def setup_remote_server( if self.logger: if server_data.api_key and api_key: header_output = self.output_overlap_details("api-key", header_output) + if server_data.snowflake_connection_name and snowflake_connection_name: + header_output = self.output_overlap_details("snowflake_connection_name", header_output) if server_data.insecure and insecure: header_output = self.output_overlap_details("insecure", header_output) if server_data.ca_data and ca_data: @@ -736,6 +832,7 @@ def setup_remote_server( # TODO: Is this logic backward? Seems like the provided value should override the stored value. api_key = server_data.api_key or api_key + snowflake_connection_name = server_data.snowflake_connection_name or snowflake_connection_name insecure = server_data.insecure or insecure ca_data = server_data.ca_data or ca_data account_name = server_data.account_name or account_name @@ -747,6 +844,9 @@ def setup_remote_server( if api_key: url = cast(str, url) self.remote_server = RSConnectServer(url, api_key, insecure, ca_data) + elif snowflake_connection_name: + url = cast(str, url) + self.remote_server = SPCSConnectServer(url, snowflake_connection_name) elif token and secret: if url and ("rstudio.cloud" in url or "posit.cloud" in url): account_name = cast(str, account_name) @@ -761,6 +861,8 @@ def setup_remote_server( def setup_client(self, cookies: Optional[CookieJar] = None): if isinstance(self.remote_server, RSConnectServer): self.client = RSConnectClient(self.remote_server, cookies) + elif isinstance(self.remote_server, SPCSConnectServer): + self.client = RSConnectClient(self.remote_server) elif isinstance(self.remote_server, PositServer): self.client = PositClient(self.remote_server) else: @@ -774,8 +876,11 @@ def validate_server(self): """ Validate that there is enough information to talk to shinyapps.io or a Connect server. """ - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, SPCSConnectServer): + self.validate_spcs_server() + elif isinstance(self.remote_server, RSConnectServer): self.validate_connect_server() + elif isinstance(self.remote_server, PositServer): self.validate_posit_server() else: @@ -815,6 +920,23 @@ def validate_connect_server(self): return self + def validate_spcs_server(self): + if not isinstance(self.remote_server, SPCSConnectServer): + raise RSConnectException("remote_server must be a Connect server in SPCS") + + url = self.remote_server.url + snowflake_connection_name = self.remote_server.snowflake_connection_name + server = SPCSConnectServer(url, snowflake_connection_name) + + with RSConnectClient(server) as client: + try: + result = client.me() + result = server.handle_bad_response(result) + except RSConnectException as exc: + raise RSConnectException(f"Failed to verify with {server.remote_name} ({exc})") + + return self + def validate_posit_server(self): if not isinstance(self.remote_server, PositServer): raise RSConnectException("remote_server is not a Posit server.") @@ -885,7 +1007,7 @@ def deploy_bundle(self): if self.bundle is None: raise RSConnectException("A bundle must be created before deploying it.") - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, PositConnectServer): if not isinstance(self.client, RSConnectClient): raise RSConnectException("client must be an RSConnectClient.") result = self.client.deploy( @@ -964,7 +1086,7 @@ def emit_task_log( :param raise_on_error: whether to raise an exception when a task is failed, otherwise we return the task_result so we can record the exit code. """ - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, PositConnectServer): if not isinstance(self.client, RSConnectClient): raise RSConnectException("To emit task log, client must be a RSConnectClient.") @@ -1006,7 +1128,7 @@ def save_deployed_info(self): @cls_logged("Verifying deployed content...") def verify_deployment(self): - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, PositConnectServer): if not isinstance(self.client, RSConnectClient): raise RSConnectException("To verify deployment, client must be a RSConnectClient.") deployed_info = self.deployed_info @@ -1761,7 +1883,7 @@ def verify_api_key(connect_server: RSConnectServer) -> str: return result["username"] -def get_python_info(connect_server: RSConnectServer): +def get_python_info(connect_server: PositConnectServer): """ Return information about versions of Python that are installed on the indicated Connect server. @@ -1775,7 +1897,7 @@ def get_python_info(connect_server: RSConnectServer): return result -def get_app_info(connect_server: RSConnectServer, app_id: str): +def get_app_info(connect_server: PositConnectServer, app_id: str): """ Return information about an application that has been created in Connect. @@ -1796,7 +1918,7 @@ def get_posit_app_info(server: PositServer, app_id: str): return response["source"] -def get_app_config(connect_server: RSConnectServer, app_id: str): +def get_app_config(connect_server: PositConnectServer, app_id: str): """ Return the configuration information for an application that has been created in Connect. @@ -1812,7 +1934,7 @@ def get_app_config(connect_server: RSConnectServer, app_id: str): def emit_task_log( - connect_server: RSConnectServer, + connect_server: PositConnectServer, app_id: str, task_id: str, log_callback: Optional[Callable[[str], None]], @@ -1848,7 +1970,7 @@ def emit_task_log( def retrieve_matching_apps( - connect_server: RSConnectServer, + connect_server: PositConnectServer, filters: Optional[dict[str, str | int]] = None, limit: Optional[int] = None, mapping_function: Optional[Callable[[RSConnectClient, ContentItemV0], AbbreviatedAppItem | None]] = None, @@ -1924,7 +2046,7 @@ class AbbreviatedAppItem(TypedDict): config_url: str -def override_title_search(connect_server: RSConnectServer, app_id: str, app_title: str): +def override_title_search(connect_server: PositConnectServer, app_id: str, app_title: str): """ Returns a list of abbreviated app data that contains apps with a title that matches the given one and/or the specific app noted by its ID. @@ -2005,7 +2127,7 @@ def find_unique_name(remote_server: TargetableServer, name: str): :param name: the default name for an app. :return: the name, potentially with a suffixed number to guarantee uniqueness. """ - if isinstance(remote_server, RSConnectServer): + if isinstance(remote_server, PositConnectServer): existing_names = retrieve_matching_apps( remote_server, filters={"search": name}, diff --git a/rsconnect/http_support.py b/rsconnect/http_support.py index 28d424a0..2bf6bd47 100644 --- a/rsconnect/http_support.py +++ b/rsconnect/http_support.py @@ -199,7 +199,12 @@ def __init__( and self.response_body is not None and len(self.response_body) > 0 ): - self.json_data = json.loads(self.response_body) + try: + self.json_data = json.loads(self.response_body) + # if non-empty response body is described by response headers as JSON but JSON decoding fails + # return the response body + except json.decoder.JSONDecodeError: + self.response_body class HTTPServer(object): @@ -256,6 +261,9 @@ def key_authorization(self, key: str): def bootstrap_authorization(self, key: str): self.authorization("Connect-Bootstrap %s" % key) + def snowflake_authorization(self, token: str): + self.authorization('Snowflake Token="%s"' % token) + def _get_full_path(self, path: str): return append_to_path(self._url.path, path) diff --git a/rsconnect/main.py b/rsconnect/main.py index a979c0e4..02e72e7c 100644 --- a/rsconnect/main.py +++ b/rsconnect/main.py @@ -34,6 +34,7 @@ test_api_key, test_rstudio_server, test_server, + test_spcs_server, validate_quarto_engines, which_quarto, ) @@ -48,8 +49,13 @@ get_content, search_content, ) -from .environment import Environment, fake_module_file_from_directory -from .api import RSConnectClient, RSConnectExecutor, RSConnectServer +from .api import ( + PositConnectServer, + RSConnectClient, + RSConnectExecutor, + RSConnectServer, + SPCSConnectServer, +) from .bundle import ( default_title_from_manifest, make_api_bundle, @@ -57,8 +63,8 @@ make_manifest_bundle, make_notebook_html_bundle, make_notebook_source_bundle, - make_voila_bundle, make_tensorflow_bundle, + make_voila_bundle, read_manifest_app_mode, validate_entry_point, validate_extra_files, @@ -71,6 +77,7 @@ write_tensorflow_manifest_json, write_voila_manifest_json, ) +from .environment import Environment, fake_module_file_from_directory from .exception import RSConnectException from .json_web_token import ( TokenGenerator, @@ -178,6 +185,15 @@ def wrapper(*args: P.args, **kwargs: P.kwargs): return wrapper +def spcs_args(func: Callable[P, T]) -> Callable[P, T]: + @click.option("--snowflake-connection-name", help="The name of the Snowflake connection in the configuration file") + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs): + return func(*args, **kwargs) + + return wrapper + + def cloud_shinyapps_args(func: Callable[P, T]) -> Callable[P, T]: @click.option( "--account", @@ -403,6 +419,11 @@ def _test_rstudio_creds(server: api.PositServer): test_rstudio_server(server) +def _test_spcs_creds(server: SPCSConnectServer): + with cli_feedback(f"Checking {server.remote_name} credential"): + test_spcs_server(server) + + @cli.command( short_help="Create an initial admin user to bootstrap a Connect instance.", help="Creates an initial admin user to bootstrap a Connect instance. Returns the provisionend API key.", @@ -491,37 +512,8 @@ def bootstrap( ), no_args_is_help=True, ) -@click.option("--name", "-n", required=True, help="The nickname of the Posit Connect server to deploy to.") -@click.option( - "--server", - "-s", - envvar="CONNECT_SERVER", - help="The URL for the Posit Connect server to deploy to, OR \ -rstudio.cloud OR shinyapps.io. (Also settable via CONNECT_SERVER \ -environment variable.)", -) -@click.option( - "--api-key", - "-k", - envvar="CONNECT_API_KEY", - help="The API key to use to authenticate with Posit Connect. \ -(Also settable via CONNECT_API_KEY environment variable.)", -) -@click.option( - "--insecure", - "-i", - envvar="CONNECT_INSECURE", - is_flag=True, - help="Disable TLS certification/host validation. (Also settable via CONNECT_INSECURE environment variable.)", -) -@click.option( - "--cacert", - "-c", - envvar="CONNECT_CA_CERTIFICATE", - type=click.Path(exists=True, file_okay=True, dir_okay=False), - help="The path to trusted TLS CA certificates. (Also settable via CONNECT_CA_CERTIFICATE environment variable.)", -) -@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.") +@server_args +@spcs_args @cloud_shinyapps_args @click.pass_context def add( @@ -529,6 +521,7 @@ def add( name: str, server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], account: Optional[str], @@ -548,6 +541,7 @@ def add( account_name=account, token=token, secret=secret, + snowflake_connection_name=snowflake_connection_name, ) # The validation.validate_connection_options() function ensures that certain # combinations of arguments are present; the cast() calls inside of the @@ -580,24 +574,39 @@ def add( else: click.echo('Added {} credential "{}".'.format(real_server.remote_name, name)) else: - server = cast(str, server) - api_key = cast(str, api_key) - # If we're in this code path - # Server must be pingable and the API key must work to be added. - real_server_rsc, _ = _test_server_and_api(server, api_key, insecure, cacert) - server_store.set( - name, - real_server_rsc.url, - real_server_rsc.api_key, - real_server_rsc.insecure, - real_server_rsc.ca_data, - ) + if server and ("snowflakecomputing.app" in server or snowflake_connection_name): + + real_server_spcs = api.SPCSConnectServer(server, snowflake_connection_name) + + _test_spcs_creds(real_server_spcs) + + server_store.set(name, server, snowflake_connection_name=snowflake_connection_name) + if old_server: + click.echo('Updated {} credential "{}".'.format(real_server_spcs.remote_name, name)) + else: + click.echo('Added {} credential "{}".'.format(real_server_spcs.remote_name, name)) - if old_server: - click.echo('Updated Connect server "%s" with URL %s' % (name, real_server_rsc.url)) else: - click.echo('Added Connect server "%s" with URL %s' % (name, real_server_rsc.url)) + + server = cast(str, server) + api_key = cast(str, api_key) + # If we're in this code path + # Server must be pingable and the API key must work to be added. + real_server_rsc, _ = _test_server_and_api(server, api_key, insecure, cacert) + + server_store.set( + name, + real_server_rsc.url, + api_key=real_server_rsc.api_key, + insecure=real_server_rsc.insecure, + ca_data=real_server_rsc.ca_data, + ) + + if old_server: + click.echo('Updated Connect server "%s" with URL %s' % (name, real_server_rsc.url)) + else: + click.echo('Added Connect server "%s" with URL %s' % (name, real_server_rsc.url)) @cli.command( @@ -626,6 +635,10 @@ def list_servers(verbose: int): click.echo(" Insecure mode (TLS host/certificate validation disabled)") if server.get("ca_cert"): click.echo(" Client TLS certificate data provided") + if server.get("snowflake_connection_name"): + snowflake_connection_name = server.get("snowflake_connection_name") + if snowflake_connection_name: + click.echo(' Snowflake Connection Name: "%s"' % snowflake_connection_name) click.echo() @@ -641,6 +654,7 @@ def list_servers(verbose: int): no_args_is_help=True, ) @server_args +@spcs_args @cli_exception_handler @click.pass_context def details( @@ -648,19 +662,20 @@ def details( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], verbose: int, ): set_verbosity(verbose) - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor(ctx, name, server, api_key, snowflake_connection_name, insecure, cacert).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect details` requires a Posit Connect server.") click.echo(" Posit Connect URL: %s" % ce.remote_server.url) - if not ce.remote_server.api_key: + if not (ce.remote_server.api_key or ce.remote_server.snowflake_connection_name): return with cli_feedback("Gathering details"): @@ -834,6 +849,7 @@ def _warn_on_ignored_requirements(directory: str, requirements_file_name: str): no_args_is_help=True, ) @server_args +@spcs_args @content_args @runtime_environment_args @click.option( @@ -883,6 +899,7 @@ def deploy_notebook( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], static: bool, @@ -928,6 +945,7 @@ def deploy_notebook( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, path=file, @@ -973,6 +991,7 @@ def deploy_notebook( no_args_is_help=True, ) @server_args +@spcs_args @content_args @runtime_environment_args @click.option( @@ -1045,6 +1064,7 @@ def deploy_voila( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], multi_notebook: bool, @@ -1063,6 +1083,7 @@ def deploy_voila( path=path, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, server=server, @@ -1103,6 +1124,7 @@ def deploy_voila( no_args_is_help=True, ) @server_args +@spcs_args @content_args @cloud_shinyapps_args @click.argument("file", type=click.Path(exists=True, dir_okay=True, file_okay=True)) @@ -1114,6 +1136,7 @@ def deploy_manifest( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], account: Optional[str], @@ -1139,6 +1162,7 @@ def deploy_manifest( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, account=account, @@ -1181,6 +1205,7 @@ def deploy_manifest( no_args_is_help=True, ) @server_args +@spcs_args @content_args @runtime_environment_args @click.option( @@ -1232,6 +1257,7 @@ def deploy_quarto( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], new: bool, @@ -1279,6 +1305,7 @@ def deploy_quarto( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, path=file_or_directory, @@ -1325,6 +1352,7 @@ def deploy_quarto( no_args_is_help=True, ) @server_args +@spcs_args @content_args @click.option( "--image", @@ -1355,6 +1383,7 @@ def deploy_tensorflow( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], new: bool, @@ -1377,6 +1406,7 @@ def deploy_tensorflow( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, path=directory, @@ -1413,6 +1443,7 @@ def deploy_tensorflow( no_args_is_help=True, ) @server_args +@spcs_args @content_args @cloud_shinyapps_args @click.option( @@ -1452,6 +1483,7 @@ def deploy_html( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], account: Optional[str], @@ -1483,6 +1515,7 @@ def deploy_html( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, account=account, @@ -1533,6 +1566,7 @@ def generate_deploy_python(app_mode: AppMode, alias: str, min_version: str, desc no_args_is_help=True, ) @server_args + @spcs_args @content_args @cloud_shinyapps_args @runtime_environment_args @@ -1587,6 +1621,7 @@ def deploy_app( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], entrypoint: Optional[str], @@ -1626,6 +1661,7 @@ def deploy_app( ctx=ctx, name=name, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, account=account, @@ -2317,6 +2353,7 @@ def content(): short_help="Search for content on Posit Connect.", ) @server_args +@spcs_args @click.option( "--published", is_flag=True, @@ -2360,6 +2397,7 @@ def content_search( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], published: bool, @@ -2374,8 +2412,17 @@ def content_search( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content search` requires a Posit Connect server.") result = search_content( ce.remote_server, published, unpublished, content_type, r_version, py_version, title_contains, order_by @@ -2389,6 +2436,7 @@ def content_search( short_help="Describe a content item on Posit Connect.", ) @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2405,6 +2453,7 @@ def content_describe( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: str, @@ -2413,8 +2462,17 @@ def content_describe( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content describe` requires a Posit Connect server.") result = get_content(ce.remote_server, guid) json.dump(result, sys.stdout, indent=2) @@ -2426,6 +2484,7 @@ def content_describe( short_help="Download a content item's source bundle.", ) @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2452,6 +2511,7 @@ def content_bundle_download( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: ContentGuidWithBundle, @@ -2462,8 +2522,17 @@ def content_bundle_download( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content download-bundle` requires a Posit Connect server.") if exists(output) and not overwrite: raise RSConnectException("The output file already exists: %s" % output) @@ -2485,6 +2554,7 @@ def build(): name="add", short_help="Mark a content item for build. Use `build run` to invoke the build on the Connect server." ) @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2500,6 +2570,7 @@ def add_content_build( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: tuple[ContentGuidWithBundle, ...], @@ -2508,8 +2579,17 @@ def add_content_build( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content build add` requires a Posit Connect server.") build_add_content(ce.remote_server, guid) if len(guid) == 1: @@ -2525,6 +2605,7 @@ def add_content_build( + "Use `build ls` to view the tracked content.", ) @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2550,6 +2631,7 @@ def remove_content_build( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: Optional[str], @@ -2560,7 +2642,16 @@ def remove_content_build( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() if not isinstance(ce.remote_server, RSConnectServer): raise RSConnectException("`rsconnect content build rm` requires a Posit Connect server.") guids = build_remove_content(ce.remote_server, guid, all, purge) @@ -2575,6 +2666,7 @@ def remove_content_build( name="ls", short_help="List the content items that are being tracked for build on a given Connect server." ) @server_args +@spcs_args @click.option( "--status", type=click.Choice(BuildStatus._all), @@ -2596,6 +2688,7 @@ def list_content_build( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], status: Optional[str], @@ -2605,8 +2698,17 @@ def list_content_build( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content build ls` requires a Posit Connect server.") result = build_list_content(ce.remote_server, guid, status) json.dump(result, sys.stdout, indent=2) @@ -2615,6 +2717,7 @@ def list_content_build( # noinspection SpellCheckingInspection,DuplicatedCode @build.command(name="history", short_help="Get the build history for a content item.") @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2630,6 +2733,7 @@ def get_build_history( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: str, @@ -2638,9 +2742,18 @@ def get_build_history( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert) + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ) ce.validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content build history` requires a Posit Connect server.") result = build_history(ce.remote_server, guid) json.dump(result, sys.stdout, indent=2) @@ -2652,6 +2765,7 @@ def get_build_history( short_help="Print the logs for a content build.", ) @server_args +@spcs_args @click.option( "--guid", "-g", @@ -2680,6 +2794,7 @@ def get_build_logs( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], guid: str, @@ -2690,8 +2805,17 @@ def get_build_logs( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("`rsconnect content build logs` requires a Posit Connect server.") for line in emit_build_log(ce.remote_server, guid, format, task_id): sys.stdout.write(line) @@ -2703,6 +2827,7 @@ def get_build_logs( short_help="Start building content on a given Connect server.", ) @server_args +@spcs_args @click.option( "--parallelism", type=click.IntRange(min=1, clamp=True), @@ -2747,6 +2872,7 @@ def start_content_build( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], parallelism: int, @@ -2765,8 +2891,17 @@ def start_content_build( output_params(ctx, locals().items()) logger.set_log_output_format(format) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() + if not isinstance(ce.remote_server, PositConnectServer): raise RSConnectException("rsconnect content build run` requires a Posit Connect server.") build_start(ce.remote_server, parallelism, aborted, error, running, retry, all, poll_wait, debug, force) @@ -2787,17 +2922,28 @@ def caches(): short_help="List runtime caches present on a Posit Connect server.", ) @server_args +@spcs_args def system_caches_list( name: str, server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], verbose: int, ): set_verbosity(verbose) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(None, name, server, api_key, insecure, cacert, logger=None).validate_server() + ce = RSConnectExecutor( + None, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() result = ce.runtime_caches json.dump(result, sys.stdout, indent=2) @@ -2808,6 +2954,7 @@ def system_caches_list( short_help="Delete a runtime cache on a Posit Connect server.", ) @server_args +@spcs_args @click.option( "--language", "-l", @@ -2838,6 +2985,7 @@ def system_caches_delete( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], verbose: int, @@ -2849,10 +2997,20 @@ def system_caches_delete( set_verbosity(verbose) output_params(ctx, locals().items()) with cli_feedback("", stderr=True): - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert, logger=None).validate_server() + ce = RSConnectExecutor( + ctx=ctx, + name=name, + server=server, + api_key=api_key, + snowflake_connection_name=snowflake_connection_name, + insecure=insecure, + cacert=cacert, + logger=None, + ).validate_server() ce.delete_runtime_cache(language, version, image_name, dry_run) if __name__ == "__main__": cli() click.echo() + click.echo() diff --git a/rsconnect/metadata.py b/rsconnect/metadata.py index b1b04780..2f819e61 100644 --- a/rsconnect/metadata.py +++ b/rsconnect/metadata.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: - from .api import RSConnectServer + from .api import PositConnectServer from .exception import RSConnectException from .log import logger @@ -244,6 +244,7 @@ class ServerDataDict(TypedDict): name: str url: str api_key: NotRequired[str] + snowflake_connection_name: NotRequired[str] insecure: NotRequired[bool] ca_cert: NotRequired[str] account_name: NotRequired[str] @@ -263,6 +264,7 @@ def __init__( url: str, from_store: bool, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: Optional[bool] = None, ca_data: Optional[str] = None, account_name: Optional[str] = None, @@ -273,6 +275,7 @@ def __init__( self.url = url self.from_store = from_store self.api_key = api_key + self.snowflake_connection_name = snowflake_connection_name self.insecure = insecure self.ca_data = ca_data self.account_name = account_name @@ -320,6 +323,7 @@ def set( name: str, url: str, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: Optional[bool] = False, ca_data: Optional[str] = None, account_name: Optional[str] = None, @@ -332,6 +336,7 @@ def set( :param name: the nickname for the Connect server. :param url: the full URL for the Connect server. :param api_key: the API key to use to authenticate with the Connect server. + :param snowflake_connection_name: the snowflake connection name :param insecure: a flag to disable TLS verification. :param ca_data: client side certificate data to use for TLS. :param account_name: shinyapps.io account name. @@ -344,6 +349,8 @@ def set( } if api_key: target_data = dict(api_key=api_key, insecure=insecure, ca_cert=ca_data) + elif snowflake_connection_name: + target_data = dict(snowflake_connection_name=snowflake_connection_name) elif account_name: target_data = dict(account_name=account_name, token=token, secret=secret) else: @@ -409,6 +416,7 @@ def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData: insecure=entry.get("insecure"), ca_data=entry.get("ca_cert"), api_key=entry.get("api_key"), + snowflake_connection_name=entry.get("snowflake_connection_name"), account_name=entry.get("account_name"), token=entry.get("token"), secret=entry.get("secret"), @@ -594,7 +602,7 @@ class ContentBuildStore(DataStore[Dict[str, object]]): def __init__( self, - server: RSConnectServer, + server: PositConnectServer, base_dir: str = os.getenv("CONNECT_CONTENT_BUILD_DIR", DEFAULT_BUILD_DIR), ): # This type declaration is a bit of a hack. It is needed because data model used diff --git a/rsconnect/snowflake.py b/rsconnect/snowflake.py new file mode 100644 index 00000000..bc637891 --- /dev/null +++ b/rsconnect/snowflake.py @@ -0,0 +1,75 @@ +# pyright: reportMissingTypeStubs=false, reportUnusedImport=false +import json +from subprocess import CalledProcessError, CompletedProcess, run +from typing import Any, Dict, List, Optional + +from .exception import RSConnectException +from .log import logger + + +def snow(*args: str) -> CompletedProcess[str]: + ensure_snow_installed() + return run(["snow"] + list(args), capture_output=True, text=True, check=True) + + +def ensure_snow_installed() -> None: + try: + import snowflake.cli # noqa: F401 + + logger.debug("snowflake-cli is installed.") + + except ImportError: + logger.warning("snowflake-cli is not installed.") + try: + run(["snow", "--version"], capture_output=True, check=True) + except CalledProcessError: + raise RSConnectException("snow is installed but could not be run.") + except FileNotFoundError: + raise RSConnectException("snow cannot be found.") + + +def list_connections() -> List[Dict[str, Any]]: + + try: + res = snow("connection", "list", "--format", "json") + connection_list = json.loads(res.stdout) + return connection_list + except CalledProcessError: + raise RSConnectException("Could not list snowflake connections.") + + +def get_connection_parameters(name: Optional[str] = None) -> Optional[Dict[str, Any]]: + + connection_list = list_connections() + # return parameters for default connection if configured + # otherwise return named connection + + if not connection_list: + raise RSConnectException("No Snowflake connections found.") + + try: + if not name: + return next((x["parameters"] for x in connection_list if x.get("is_default")), None) + else: + return next((x["parameters"] for x in connection_list if x.get("connection_name") == name)) + except StopIteration: + raise RSConnectException(f"No Snowflake connection found with name '{name}'.") + + +def generate_jwt(name: Optional[str] = None) -> str: + + _ = get_connection_parameters(name) + connection_name = "" if name is None else name + + try: + res = snow("connection", "generate-jwt", "--connection", connection_name, "--format", "json") + try: + output = json.loads(res.stdout) + except json.JSONDecodeError: + raise RSConnectException(f"Failed to parse JSON from snow-cli: {res.stdout}") + jwt = output.get("message") + if jwt is None: + raise RSConnectException(f"Failed to generate JWT: Missing 'message' field in response: {output}") + return jwt + except CalledProcessError as e: + raise RSConnectException(f"Failed to generate JWT for connection '{name}': {e.stderr}") diff --git a/rsconnect/validation.py b/rsconnect/validation.py index 8c5f455d..7976e9d5 100644 --- a/rsconnect/validation.py +++ b/rsconnect/validation.py @@ -45,6 +45,7 @@ def validate_connection_options( token: Optional[str], secret: Optional[str], name: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, ): """ Validates provided Connect or shinyapps.io connection options and returns which target to use given the provided @@ -63,6 +64,7 @@ def validate_connection_options( -T/--token or SHINYAPPS_TOKEN or RSCLOUD_TOKEN -S/--secret or SHINYAPPS_SECRET or RSCLOUD_SECRET -A/--account or SHINYAPPS_ACCOUNT + --snowflake-connection-name FAILURE if any of: -k/--api-key or CONNECT_API_KEY @@ -82,10 +84,16 @@ def validate_connection_options( -T/--token or SHINYAPPS_TOKEN or RSCLOUD_TOKEN -S/--secret or SHINYAPPS_SECRET or RSCLOUD_SECRET -A/--account or SHINYAPPS_ACCOUNT + + + FAILURE if -s/--server or CONNECT_SERVER include "snowflakecomputing.app" + and not + --snowflake-connection-name """ connect_options = {"-k/--api-key": api_key, "-i/--insecure": insecure, "-c/--cacert": cacert} shinyapps_options = {"-T/--token": token, "-S/--secret": secret, "-A/--account": account_name} cloud_options = {"-T/--token": token, "-S/--secret": secret} + spcs_options = {"--snowflake-connection-name": snowflake_connection_name} options_mutually_exclusive_with_name = {"-s/--server": url, **shinyapps_options} present_options_mutually_exclusive_with_name = _get_present_options(options_mutually_exclusive_with_name, ctx) @@ -105,6 +113,7 @@ def validate_connection_options( present_connect_options = _get_present_options(connect_options, ctx) present_shinyapps_options = _get_present_options(shinyapps_options, ctx) present_cloud_options = _get_present_options(cloud_options, ctx) + present_spcs_options = _get_present_options(spcs_options, ctx) if present_connect_options and present_shinyapps_options: raise RSConnectException( @@ -113,6 +122,19 @@ def validate_connection_options( See command help for further details." ) + if snowflake_connection_name and not url: + raise RSConnectException( + "--snowflake-connection-name requires -s/--server to be specified. \ +See command help for further details." + ) + + if present_shinyapps_options and present_spcs_options: + raise RSConnectException( + f"Shinyapps.io/Cloud options ({', '.join(present_shinyapps_options)}) may not be passed \ +alongside SPCS options ({', '.join(present_spcs_options)}). \ + See command help for further details." + ) + if url and ("posit.cloud" in url or "rstudio.cloud" in url): if len(present_cloud_options) != len(cloud_options): raise RSConnectException( diff --git a/tests/test_api.py b/tests/test_api.py index 4eda8c24..910efb09 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -16,6 +16,7 @@ RSConnectServer, ShinyappsServer, ShinyappsService, + SPCSConnectServer, ) from rsconnect.exception import DeploymentFailedException, RSConnectException from rsconnect.models import AppModes @@ -508,3 +509,132 @@ def test_do_deploy_failure(self): self.cloud_client.deploy_application.assert_called_with(bundle_id, app_id) self.cloud_client.wait_until_task_is_successful.assert_called_with(task_id) self.cloud_client.get_task_logs.assert_called_with(task_id) + + +class SPCSConnectServerTestCase(TestCase): + def test_init(self): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + assert server.url == "https://spcs.example.com" + assert server.remote_name == "Posit Connect (SPCS)" + assert server.snowflake_connection_name == "example_connection" + assert server.api_key is None + + @patch("rsconnect.api.SPCSConnectServer.token_endpoint") + def test_token_endpoint(self, mock_token_endpoint): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" + endpoint = server.token_endpoint() + assert endpoint == "https://example.snowflakecomputing.com/" + + @patch("rsconnect.api.get_connection_parameters") + def test_token_endpoint_with_account(self, mock_get_connection_parameters): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + mock_get_connection_parameters.return_value = {"account": "test_account"} + endpoint = server.token_endpoint() + assert endpoint == "https://test_account.snowflakecomputing.com/" + mock_get_connection_parameters.assert_called_once_with("example_connection") + + @patch("rsconnect.api.get_connection_parameters") + def test_token_endpoint_with_none_params(self, mock_get_connection_parameters): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + mock_get_connection_parameters.return_value = None + with pytest.raises(RSConnectException, match="No Snowflake connection found."): + server.token_endpoint() + + @patch("rsconnect.api.get_connection_parameters") + def test_fmt_payload(self, mock_get_connection_parameters): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + mock_get_connection_parameters.return_value = {"account": "test_account", "role": "test_role"} + + with patch("rsconnect.api.generate_jwt") as mock_generate_jwt: + mock_generate_jwt.return_value = "mocked_jwt" + payload = server.fmt_payload() + + assert "scope=session%3Arole%3Atest_role+spcs.example.com" in payload + assert "assertion=mocked_jwt" in payload + assert "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer" in payload + + mock_get_connection_parameters.assert_called_once_with("example_connection") + mock_generate_jwt.assert_called_once_with("example_connection") + + @patch("rsconnect.api.get_connection_parameters") + def test_fmt_payload_with_none_params(self, mock_get_connection_parameters): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + mock_get_connection_parameters.return_value = None + with pytest.raises(RSConnectException, match="No Snowflake connection found."): + server.fmt_payload() + + @patch("rsconnect.api.HTTPServer") + @patch("rsconnect.api.SPCSConnectServer.token_endpoint") + @patch("rsconnect.api.SPCSConnectServer.fmt_payload") + def test_exchange_token_success(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + + # Mock the HTTP request + mock_server_instance = mock_http_server.return_value + mock_response = Mock() + mock_response.status = 200 + mock_response.response_body = "token_data" + mock_server_instance.request.return_value = mock_response + + # Mock the token endpoint and payload + mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" + mock_fmt_payload.return_value = "mocked_payload" + + # Call the method + result = server.exchange_token() + + # Verify the results + assert result == "token_data" + mock_http_server.assert_called_once_with(url="https://example.snowflakecomputing.com/") + mock_server_instance.request.assert_called_once_with( + method="POST", + path="/oauth/token", + body="mocked_payload", + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + @patch("rsconnect.api.HTTPServer") + @patch("rsconnect.api.SPCSConnectServer.token_endpoint") + @patch("rsconnect.api.SPCSConnectServer.fmt_payload") + def test_exchange_token_error_status(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + + # Mock the HTTP request with error status + mock_server_instance = mock_http_server.return_value + mock_response = Mock() + mock_response.status = 401 + mock_response.full_uri = "https://example.snowflakecomputing.com/oauth/token" + mock_response.reason = "Unauthorized" + mock_server_instance.request.return_value = mock_response + + # Mock the token endpoint and payload + mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" + mock_fmt_payload.return_value = "mocked_payload" + + # Call the method and verify it raises the expected exception + with pytest.raises(RSConnectException, match="Failed to exchange Snowflake token"): + server.exchange_token() + + @patch("rsconnect.api.HTTPServer") + @patch("rsconnect.api.SPCSConnectServer.token_endpoint") + @patch("rsconnect.api.SPCSConnectServer.fmt_payload") + def test_exchange_token_empty_response(self, mock_fmt_payload, mock_token_endpoint, mock_http_server): + server = SPCSConnectServer("https://spcs.example.com", "example_connection") + + # Mock the HTTP request with empty response body + mock_server_instance = mock_http_server.return_value + mock_response = Mock() + mock_response.status = 200 + mock_response.response_body = None + mock_server_instance.request.return_value = mock_response + + # Mock the token endpoint and payload + mock_token_endpoint.return_value = "https://example.snowflakecomputing.com/" + mock_fmt_payload.return_value = "mocked_payload" + + # Call the method and verify it raises the expected exception + with pytest.raises( + RSConnectException, match="Failed to exchange Snowflake token: Token exchange returned empty response" + ): + server.exchange_token() diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 6b1e5342..24ae5a84 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -31,7 +31,8 @@ def setUp(self): token="someToken", secret="c29tZVNlY3JldAo=", ) - self.assertEqual(len(self.server_store.get_all_servers()), 3, "Unexpected servers after setup") + self.server_store.set("qux", "https://example.snowflakecomputing.app", snowflake_connection_name="dev") + self.assertEqual(len(self.server_store.get_all_servers()), 4, "Unexpected servers after setup") def tearDown(self): # clean up our temp test directory created with tempfile.mkdtemp() @@ -71,6 +72,11 @@ def test_add(self): ), ) + self.assertEqual( + self.server_store.get_by_name("qux"), + dict(name="qux", url="https://example.snowflakecomputing.app", snowflake_connection_name="dev"), + ) + def test_remove_by_name(self): self.server_store.remove_by_name("foo") self.assertIsNone(self.server_store.get_by_name("foo")) @@ -87,19 +93,21 @@ def test_remove_by_url(self): def test_remove_not_found(self): self.assertFalse(self.server_store.remove_by_name("frazzle")) - self.assertEqual(len(self.server_store.get_all_servers()), 3) + self.assertEqual(len(self.server_store.get_all_servers()), 4) self.assertFalse(self.server_store.remove_by_url("http://frazzle")) - self.assertEqual(len(self.server_store.get_all_servers()), 3) + self.assertEqual(len(self.server_store.get_all_servers()), 4) def test_list(self): servers = self.server_store.get_all_servers() - self.assertEqual(len(servers), 3) + self.assertEqual(len(servers), 4) self.assertEqual(servers[0]["name"], "bar") self.assertEqual(servers[0]["url"], "http://connect.remote") self.assertEqual(servers[1]["name"], "baz") self.assertEqual(servers[1]["url"], "https://shinyapps.io") self.assertEqual(servers[2]["name"], "foo") self.assertEqual(servers[2]["url"], "http://connect.local") + self.assertEqual(servers[3]["name"], "qux") + self.assertEqual(servers[3]["url"], "https://example.snowflakecomputing.app") def check_resolve_call(self, name, server, api_key, insecure, ca_cert, should_be_from_store): server_data = self.server_store.resolve(name, server) @@ -124,6 +132,7 @@ def test_resolve_by_default(self): # with only a single entry, server None will resolve to that entry self.server_store.remove_by_url("http://connect.remote") self.server_store.remove_by_url("https://shinyapps.io") + self.server_store.remove_by_name("qux") self.check_resolve_call(None, None, None, None, None, True) def test_resolve_from_args(self): diff --git a/tests/test_snowflake.py b/tests/test_snowflake.py new file mode 100644 index 00000000..0bbf0267 --- /dev/null +++ b/tests/test_snowflake.py @@ -0,0 +1,330 @@ +import json +import logging +import sys +from subprocess import CalledProcessError + +import pytest +from pytest import LogCaptureFixture, MonkeyPatch + +from rsconnect.exception import RSConnectException +from rsconnect.snowflake import ( + ensure_snow_installed, + generate_jwt, + get_connection_parameters, + list_connections, +) + +SAMPLE_CONNECTIONS = [ + { + "connection_name": "dev", + "parameters": { + "account": "example-dev-acct", + "user": "alice@example.com", + "database": "EXAMPLE_DB", + "warehouse": "DEV_WH", + "role": "ACCOUNTADMIN", + "authenticator": "SNOWFLAKE_JWT", + }, + "is_default": False, + }, + { + "connection_name": "prod", + "parameters": { + "account": "example-prod-acct", + "user": "alice@example.com", + "database": "EXAMPLE_DB_PROD", + "schema": "DATA", + "warehouse": "DEFAULT_WH", + "role": "DEVELOPER", + "authenticator": "SNOWFLAKE_JWT", + "private_key_file": "/home/alice/snowflake/rsa_key.p8", + }, + "is_default": True, + }, +] + + +@pytest.fixture(autouse=True) +def setup_caplog(caplog: LogCaptureFixture): + # Set the log level to debug to capture all logs + caplog.set_level(logging.DEBUG) + + +def test_ensure_snow_installed_success(monkeypatch: MonkeyPatch): + # Test when snowflake-cli is installed - simpler approach + # Just check that the function doesn't raise an exception + + # Let's directly mock snowflake.cli to simulate it being installed + # Create a fake module to return + class MockModule: + pass + + # Create a fake snowflake module with a cli attribute + mock_snowflake = MockModule() + mock_snowflake.cli = MockModule() + + # Add to sys.modules before test + sys.modules["snowflake"] = mock_snowflake + sys.modules["snowflake.cli"] = mock_snowflake.cli + + try: + # Should not raise an exception + ensure_snow_installed() + # If we get here, test passes + assert True + finally: + # Clean up + if "snowflake" in sys.modules: + del sys.modules["snowflake"] + if "snowflake.cli" in sys.modules: + del sys.modules["snowflake.cli"] + + +class MockRunResult: + def __init__(self, returncode: int = 0): + self.returncode = returncode + + +def test_ensure_snow_installed_binary(monkeypatch: MonkeyPatch, caplog: LogCaptureFixture): + # Test when import fails but snow binary is available + + monkeypatch.setattr("builtins.__import__", mock_failed_import) + + # Mock run to return success + def mock_run(cmd: list[str], **kwargs): + assert cmd == ["snow", "--version"] + assert kwargs.get("capture_output") is True + assert kwargs.get("check") is True + return MockRunResult(returncode=0) + + monkeypatch.setattr("rsconnect.snowflake.run", mock_run) + + # Should not raise exception + ensure_snow_installed() + + # Verify log message + assert "snowflake-cli is not installed" in caplog.text + + +def test_ensure_snow_installed_nobinary(monkeypatch: MonkeyPatch, caplog: LogCaptureFixture): + # Test when import fails and snow binary is not found + + # Remove snowflake modules if they exist + monkeypatch.delitem(sys.modules, "snowflake.cli", raising=False) + monkeypatch.delitem(sys.modules, "snowflake", raising=False) + + monkeypatch.setattr("builtins.__import__", mock_failed_import) + + # Mock run to raise FileNotFoundError + def mock_run(cmd: list[str], **kwargs): + if cmd == ["snow", "--version"]: + raise FileNotFoundError("No such file or directory: 'snow'") + return MockRunResult(returncode=0) + + monkeypatch.setattr("rsconnect.snowflake.run", mock_run) + + with pytest.raises(RSConnectException) as excinfo: + ensure_snow_installed() + + assert "snow cannot be found" in str(excinfo.value) + + # Verify log message + assert "snowflake-cli is not installed" in caplog.text + + +def test_ensure_snow_installed_failing_binary(monkeypatch: MonkeyPatch, caplog: LogCaptureFixture): + # Test when import fails and snow binary exits with error + + # Remove snowflake modules if they exist + monkeypatch.delitem(sys.modules, "snowflake.cli", raising=False) + monkeypatch.delitem(sys.modules, "snowflake", raising=False) + + monkeypatch.setattr("builtins.__import__", mock_failed_import) + + # Mock run to raise CalledProcessError + def mock_run(cmd: list[str], **kwargs): + if cmd == ["snow", "--version"]: + raise CalledProcessError(returncode=1, cmd=cmd, output="", stderr="Command failed with exit code 1") + return MockRunResult(returncode=0) + + monkeypatch.setattr("rsconnect.snowflake.run", mock_run) + + with pytest.raises(RSConnectException) as excinfo: + ensure_snow_installed() + + assert "snow is installed but could not be run" in str(excinfo.value) + + # Verify log message + assert "snowflake-cli is not installed" in caplog.text + + +# Patch the import to raise ImportError +original_import = __import__ + + +def mock_failed_import(name: str, *args, **kwargs): + if name.startswith("snowflake"): + raise ImportError(f"No module named '{name}'") + return original_import(name, *args, **kwargs) + + +def test_list_connections(monkeypatch: MonkeyPatch): + + class MockCompletedProcess: + returncode = 0 + stdout = json.dumps(SAMPLE_CONNECTIONS) + + def mock_snow(*args): + assert args == ("connection", "list", "--format", "json") + return MockCompletedProcess() + + monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) + + connections = list_connections() + + assert len(connections) == 2 + assert connections[1]["is_default"] is True + + +def test_get_connection_noname_default(monkeypatch: MonkeyPatch): + # Test that get_connection_parameters() returns parameters from + # the default connection when no name is provided + + monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) + + connection = get_connection_parameters() + + assert connection["account"] == "example-prod-acct" + assert connection["role"] == "DEVELOPER" + + +def test_get_connection_named(monkeypatch: MonkeyPatch): + # Test that get_connection_parameters() returns the specified connection when a name is provided + + monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) + + connection = get_connection_parameters("dev") + + # Should return the connection with the specified name + assert connection["account"] == "example-dev-acct" + assert connection["role"] == "ACCOUNTADMIN" + + +def test_get_connection_errs_if_none(monkeypatch: MonkeyPatch): + # Test that get_connection_parameters() raises an exception when no matching connection is found + + # Test with empty connections list + monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: []) + monkeypatch.setattr("rsconnect.snowflake.ensure_snow_installed", lambda: None) + + with pytest.raises(RSConnectException) as excinfo: + get_connection_parameters() + assert "No Snowflake connections found" in str(excinfo.value) + + # Test with connections but non-existent name + monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + + with pytest.raises(RSConnectException) as excinfo: + get_connection_parameters("nexiste") + assert "No Snowflake connection found with name 'nexiste'" in str(excinfo.value) + + +def test_generate_jwt(monkeypatch: MonkeyPatch): + """Test the JWT generation for Snowflake connections.""" + # Mock the generate_jwt subprocess call + sample_jwt = '{"message": "header.payload.signature"}' + + class MockSnowGenerateJWT: + returncode = 0 + stdout = sample_jwt + + def mock_snow(*args): + assert args[0:3] == ("connection", "generate-jwt", "--connection") + + # Check which connection we're generating a JWT for + conn_name = args[3] + + # Empty string means default connection + if conn_name == "": + return MockSnowGenerateJWT() + elif conn_name == "dev": + return MockSnowGenerateJWT() + elif conn_name == "prod": + return MockSnowGenerateJWT() + else: + raise CalledProcessError( + returncode=1, + cmd=["snow"] + list(args), + output="", + stderr=f"Error: No connection found with name '{conn_name}'", + ) + + monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) + monkeypatch.setattr("rsconnect.snowflake.list_connections", lambda: SAMPLE_CONNECTIONS) + + # Case 1: Test with default connection (no name parameter) + jwt = generate_jwt() + assert jwt == "header.payload.signature" + + # Case 2: Test with a valid connection name + jwt = generate_jwt("dev") + assert jwt == "header.payload.signature" + + # Case 3: Test with an invalid connection name + with pytest.raises(RSConnectException) as excinfo: + generate_jwt("nexiste") + assert "No Snowflake connection found with name 'nexiste'" in str(excinfo.value) + + +def test_generate_jwt_command_failure(monkeypatch: MonkeyPatch): + """Test error handling when snow command fails.""" + + def mock_snow(*args): + raise CalledProcessError( + returncode=1, cmd=["snow"] + list(args), output="", stderr="Error: Authentication failed" + ) + + monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) + monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + + with pytest.raises(RSConnectException) as excinfo: + generate_jwt() + assert "Failed to generate JWT" in str(excinfo.value) + + +def test_generate_jwt_invalid_json(monkeypatch: MonkeyPatch): + """Test handling of invalid JSON output.""" + + class MockProcessInvalidJSON: + returncode = 0 + stdout = "Not a JSON string" + + def mock_snow(*args): + return MockProcessInvalidJSON() + + monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) + monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + + with pytest.raises(RSConnectException) as excinfo: + generate_jwt() + assert "Failed to parse JSON" in str(excinfo.value) + + +def test_generate_jwt_missing_message(monkeypatch: MonkeyPatch): + """Test handling of JSON without the expected message field.""" + + class MockProcessNoMessage: + returncode = 0 + stdout = '{"status": "success", "data": {}}' + + def mock_snow(*args): + return MockProcessNoMessage() + + monkeypatch.setattr("rsconnect.snowflake.snow", mock_snow) + monkeypatch.setattr("rsconnect.snowflake.get_connection_parameters", lambda name=None: {}) + + with pytest.raises(RSConnectException) as excinfo: + generate_jwt() + assert "Failed to generate JWT" in str(excinfo.value)