diff --git a/pyproject.toml b/pyproject.toml index 45ef703f..587add6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ test = [ "twine", "types-Flask", ] +snowflake = ["snowflake-cli"] [project.urls] Repository = "http://github.com/posit-dev/rsconnect-python" diff --git a/rsconnect/actions.py b/rsconnect/actions.py index 2c57b664..917588e6 100644 --- a/rsconnect/actions.py +++ b/rsconnect/actions.py @@ -172,6 +172,14 @@ def test_rstudio_server(server: api.PositServer): raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc)) +def test_spcs_server(server: api.SPCSConnectServer): + with api.RSConnectClient(server) as client: + try: + client.me() + except RSConnectException as exc: + raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc)) + + def test_api_key(connect_server: api.RSConnectServer) -> str: """ Test that an API Key may be used to authenticate with the given Posit Connect server. diff --git a/rsconnect/api.py b/rsconnect/api.py index 4d455876..df8ba85a 100644 --- a/rsconnect/api.py +++ b/rsconnect/api.py @@ -235,9 +235,33 @@ def __init__( self.ca_data = ca_data # This is specifically not None. self.cookie_jar = CookieJar() + # 🤡 + self.snowflake_connection_name = None -TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer] +class SPCSConnectServer(AbstractRemoteServer): + """ + A class encapsulating the information required to interact with Connect in SPCS. + """ + + def __init__( + self, + url: str, + snowflake_connection_name: Optional[str], + insecure: bool = False, + ca_data: Optional[str | bytes] = None, + ): + super().__init__(url, "Posit Connect") + self.insecure = insecure + self.ca_data = ca_data + self.snowflake_connection_name = snowflake_connection_name + # for compatibility with RSConnectClient + self.cookie_jar = CookieJar() + self.api_key = None + self.bootstrap_jwt = None + + +TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer, SPCSConnectServer] class S3Server(AbstractRemoteServer): @@ -254,7 +278,7 @@ class RSConnectClientDeployResult(TypedDict): class RSConnectClient(HTTPServer): - def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None): + def __init__(self, server: RSConnectServer | SPCSConnectServer, cookies: Optional[CookieJar] = None): if cookies is None: cookies = server.cookie_jar super().__init__( @@ -271,6 +295,14 @@ def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None) if server.bootstrap_jwt: self.bootstrap_authorization(server.bootstrap_jwt) + if server.snowflake_connection_name: + from .snowflake import SnowflakeExchangeClient, get_token_endpoint + + token_endpoint = get_token_endpoint(server.snowflake_connection_name) + snowflake_client = SnowflakeExchangeClient(token_endpoint) + token = snowflake_client.exchange_token(server.url, server.snowflake_connection_name) + self.snowflake_authorization(token) + def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse: return ( response.json_data @@ -555,6 +587,7 @@ def __init__( name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: bool = False, cacert: Optional[str] = None, ca_data: Optional[str | bytes] = None, @@ -604,6 +637,7 @@ def __init__( name=name, url=url or server, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, ca_data=ca_data, @@ -689,6 +723,7 @@ def setup_remote_server( name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: bool = False, cacert: Optional[str] = None, ca_data: Optional[str | bytes] = None, @@ -700,6 +735,7 @@ def setup_remote_server( ctx=ctx, url=url, api_key=api_key, + snowflake_connection_name=snowflake_connection_name, insecure=insecure, cacert=cacert, account_name=account_name, @@ -741,12 +777,16 @@ def setup_remote_server( account_name = server_data.account_name or account_name token = server_data.token or token secret = server_data.secret or secret + snowflake_connection_name = server_data.snowflake_connection_name or snowflake_connection_name self.is_server_from_store = server_data.from_store if api_key: url = cast(str, url) self.remote_server = RSConnectServer(url, api_key, insecure, ca_data) + elif snowflake_connection_name: + url = cast(str, url) + self.remote_server = SPCSConnectServer(url, snowflake_connection_name) elif token and secret: if url and ("rstudio.cloud" in url or "posit.cloud" in url): account_name = cast(str, account_name) @@ -761,6 +801,8 @@ def setup_remote_server( def setup_client(self, cookies: Optional[CookieJar] = None): if isinstance(self.remote_server, RSConnectServer): self.client = RSConnectClient(self.remote_server, cookies) + elif isinstance(self.remote_server, SPCSConnectServer): + self.client = RSConnectClient(self.remote_server) elif isinstance(self.remote_server, PositServer): self.client = PositClient(self.remote_server) else: @@ -774,7 +816,9 @@ def validate_server(self): """ Validate that there is enough information to talk to shinyapps.io or a Connect server. """ - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, SPCSConnectServer): + self.validate_spcs_server() + elif isinstance(self.remote_server, RSConnectServer): self.validate_connect_server() elif isinstance(self.remote_server, PositServer): self.validate_posit_server() @@ -815,6 +859,23 @@ def validate_connect_server(self): return self + def validate_spcs_server(self): + if not isinstance(self.remote_server, SPCSConnectServer): + raise RSConnectException("remote_server must be a Connect server in SPCS") + + url = self.remote_server.url + snowflake_connection_name = self.remote_server.snowflake_connection_name + server = SPCSConnectServer(url, snowflake_connection_name) + + with RSConnectClient(server) as client: + try: + result = client.me() + result = server.handle_bad_response(result) + except RSConnectException as exc: + raise RSConnectException(f"Failed to verify with {server.remote_name} ({exc})") + + return self + def validate_posit_server(self): if not isinstance(self.remote_server, PositServer): raise RSConnectException("remote_server is not a Posit server.") @@ -885,7 +946,7 @@ def deploy_bundle(self): if self.bundle is None: raise RSConnectException("A bundle must be created before deploying it.") - if isinstance(self.remote_server, RSConnectServer): + if isinstance(self.remote_server, RSConnectServer | SPCSConnectServer): if not isinstance(self.client, RSConnectClient): raise RSConnectException("client must be an RSConnectClient.") result = self.client.deploy( diff --git a/rsconnect/http_support.py b/rsconnect/http_support.py index 28d424a0..2127a1cc 100644 --- a/rsconnect/http_support.py +++ b/rsconnect/http_support.py @@ -199,7 +199,11 @@ def __init__( and self.response_body is not None and len(self.response_body) > 0 ): - self.json_data = json.loads(self.response_body) + try: + self.json_data = json.loads(self.response_body) + # snowflake crudo + except json.decoder.JSONDecodeError: + self.response_body class HTTPServer(object): @@ -256,6 +260,9 @@ def key_authorization(self, key: str): def bootstrap_authorization(self, key: str): self.authorization("Connect-Bootstrap %s" % key) + def snowflake_authorization(self, token: str): + self.authorization('Snowflake Token="%s"' % token) + def _get_full_path(self, path: str): return append_to_path(self._url.path, path) diff --git a/rsconnect/main.py b/rsconnect/main.py index a979c0e4..f4891073 100644 --- a/rsconnect/main.py +++ b/rsconnect/main.py @@ -34,6 +34,7 @@ test_api_key, test_rstudio_server, test_server, + test_spcs_server, validate_quarto_engines, which_quarto, ) @@ -48,8 +49,7 @@ get_content, search_content, ) -from .environment import Environment, fake_module_file_from_directory -from .api import RSConnectClient, RSConnectExecutor, RSConnectServer +from .api import RSConnectClient, RSConnectExecutor, RSConnectServer, SPCSConnectServer from .bundle import ( default_title_from_manifest, make_api_bundle, @@ -57,8 +57,8 @@ make_manifest_bundle, make_notebook_html_bundle, make_notebook_source_bundle, - make_voila_bundle, make_tensorflow_bundle, + make_voila_bundle, read_manifest_app_mode, validate_entry_point, validate_extra_files, @@ -71,6 +71,7 @@ write_tensorflow_manifest_json, write_voila_manifest_json, ) +from .environment import Environment, fake_module_file_from_directory from .exception import RSConnectException from .json_web_token import ( TokenGenerator, @@ -178,6 +179,15 @@ def wrapper(*args: P.args, **kwargs: P.kwargs): return wrapper +def spcs_args(func: Callable[P, T]) -> Callable[P, T]: + @click.option("--snowflake-connection-name", help="The name of the Snowflake connection in the configuration file") + @functools.wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs): + return func(*args, **kwargs) + + return wrapper + + def cloud_shinyapps_args(func: Callable[P, T]) -> Callable[P, T]: @click.option( "--account", @@ -403,6 +413,11 @@ def _test_rstudio_creds(server: api.PositServer): test_rstudio_server(server) +def _test_spcs_creds(server: api.SPCSConnectServer): + with cli_feedback(f"Checking {server.remote_name} credential"): + test_spcs_server(server) + + @cli.command( short_help="Create an initial admin user to bootstrap a Connect instance.", help="Creates an initial admin user to bootstrap a Connect instance. Returns the provisionend API key.", @@ -491,37 +506,8 @@ def bootstrap( ), no_args_is_help=True, ) -@click.option("--name", "-n", required=True, help="The nickname of the Posit Connect server to deploy to.") -@click.option( - "--server", - "-s", - envvar="CONNECT_SERVER", - help="The URL for the Posit Connect server to deploy to, OR \ -rstudio.cloud OR shinyapps.io. (Also settable via CONNECT_SERVER \ -environment variable.)", -) -@click.option( - "--api-key", - "-k", - envvar="CONNECT_API_KEY", - help="The API key to use to authenticate with Posit Connect. \ -(Also settable via CONNECT_API_KEY environment variable.)", -) -@click.option( - "--insecure", - "-i", - envvar="CONNECT_INSECURE", - is_flag=True, - help="Disable TLS certification/host validation. (Also settable via CONNECT_INSECURE environment variable.)", -) -@click.option( - "--cacert", - "-c", - envvar="CONNECT_CA_CERTIFICATE", - type=click.Path(exists=True, file_okay=True, dir_okay=False), - help="The path to trusted TLS CA certificates. (Also settable via CONNECT_CA_CERTIFICATE environment variable.)", -) -@click.option("--verbose", "-v", count=True, help="Enable verbose output. Use -vv for very verbose (debug) output.") +@server_args +@spcs_args @cloud_shinyapps_args @click.pass_context def add( @@ -529,6 +515,7 @@ def add( name: str, server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], account: Optional[str], @@ -548,6 +535,7 @@ def add( account_name=account, token=token, secret=secret, + snowflake_connection_name=snowflake_connection_name, ) # The validation.validate_connection_options() function ensures that certain # combinations of arguments are present; the cast() calls inside of the @@ -580,24 +568,39 @@ def add( else: click.echo('Added {} credential "{}".'.format(real_server.remote_name, name)) else: - server = cast(str, server) - api_key = cast(str, api_key) - # If we're in this code path - # Server must be pingable and the API key must work to be added. - real_server_rsc, _ = _test_server_and_api(server, api_key, insecure, cacert) - server_store.set( - name, - real_server_rsc.url, - real_server_rsc.api_key, - real_server_rsc.insecure, - real_server_rsc.ca_data, - ) + if server and ("snowflakecomputing.app" in server or snowflake_connection_name): + + real_server_spcs = api.SPCSConnectServer(server, snowflake_connection_name) + + _test_spcs_creds(real_server_spcs) + + server_store.set(name, server, snowflake_connection_name=snowflake_connection_name) + if old_server: + click.echo('Updated {} credential "{}".'.format(real_server_spcs.remote_name, name)) + else: + click.echo('Added {} credential "{}".'.format(real_server_spcs.remote_name, name)) - if old_server: - click.echo('Updated Connect server "%s" with URL %s' % (name, real_server_rsc.url)) else: - click.echo('Added Connect server "%s" with URL %s' % (name, real_server_rsc.url)) + + server = cast(str, server) + api_key = cast(str, api_key) + # If we're in this code path + # Server must be pingable and the API key must work to be added. + real_server_rsc, _ = _test_server_and_api(server, api_key, insecure, cacert) + + server_store.set( + name, + real_server_rsc.url, + api_key=real_server_rsc.api_key, + insecure=real_server_rsc.insecure, + ca_data=real_server_rsc.ca_data, + ) + + if old_server: + click.echo('Updated Connect server "%s" with URL %s' % (name, real_server_rsc.url)) + else: + click.echo('Added Connect server "%s" with URL %s' % (name, real_server_rsc.url)) @cli.command( @@ -641,6 +644,7 @@ def list_servers(verbose: int): no_args_is_help=True, ) @server_args +@spcs_args @cli_exception_handler @click.pass_context def details( @@ -648,19 +652,20 @@ def details( name: Optional[str], server: Optional[str], api_key: Optional[str], + snowflake_connection_name: Optional[str], insecure: bool, cacert: Optional[str], verbose: int, ): set_verbosity(verbose) + ce = RSConnectExecutor(ctx, name, server, api_key, snowflake_connection_name, insecure, cacert).validate_server() - ce = RSConnectExecutor(ctx, name, server, api_key, insecure, cacert).validate_server() - if not isinstance(ce.remote_server, RSConnectServer): + if not isinstance(ce.remote_server, (RSConnectServer, SPCSConnectServer)): raise RSConnectException("`rsconnect details` requires a Posit Connect server.") click.echo(" Posit Connect URL: %s" % ce.remote_server.url) - if not ce.remote_server.api_key: + if not (ce.remote_server.api_key or ce.remote_server.snowflake_connection_name): return with cli_feedback("Gathering details"): diff --git a/rsconnect/metadata.py b/rsconnect/metadata.py index b1b04780..d47d11b5 100644 --- a/rsconnect/metadata.py +++ b/rsconnect/metadata.py @@ -244,6 +244,7 @@ class ServerDataDict(TypedDict): name: str url: str api_key: NotRequired[str] + snowflake_connection_name: NotRequired[str] insecure: NotRequired[bool] ca_cert: NotRequired[str] account_name: NotRequired[str] @@ -263,6 +264,7 @@ def __init__( url: str, from_store: bool, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: Optional[bool] = None, ca_data: Optional[str] = None, account_name: Optional[str] = None, @@ -273,6 +275,7 @@ def __init__( self.url = url self.from_store = from_store self.api_key = api_key + self.snowflake_connection_name = snowflake_connection_name self.insecure = insecure self.ca_data = ca_data self.account_name = account_name @@ -320,6 +323,7 @@ def set( name: str, url: str, api_key: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, insecure: Optional[bool] = False, ca_data: Optional[str] = None, account_name: Optional[str] = None, @@ -332,6 +336,7 @@ def set( :param name: the nickname for the Connect server. :param url: the full URL for the Connect server. :param api_key: the API key to use to authenticate with the Connect server. + :param snowflake_connection_name: the snowflake connection name :param insecure: a flag to disable TLS verification. :param ca_data: client side certificate data to use for TLS. :param account_name: shinyapps.io account name. @@ -344,6 +349,8 @@ def set( } if api_key: target_data = dict(api_key=api_key, insecure=insecure, ca_cert=ca_data) + elif snowflake_connection_name: + target_data = dict(snowflake_connection_name=snowflake_connection_name) elif account_name: target_data = dict(account_name=account_name, token=token, secret=secret) else: @@ -406,6 +413,7 @@ def resolve(self, name: Optional[str], url: Optional[str]) -> ServerData: name, entry["url"], True, + snowflake_connection_name=entry.get("snowflake_connection_name"), insecure=entry.get("insecure"), ca_data=entry.get("ca_cert"), api_key=entry.get("api_key"), diff --git a/rsconnect/snowflake.py b/rsconnect/snowflake.py new file mode 100644 index 00000000..e6c9c30e --- /dev/null +++ b/rsconnect/snowflake.py @@ -0,0 +1,132 @@ +import json +import subprocess +from typing import Any, Dict, Optional, cast +from urllib.parse import urlencode, urlparse + +from .exception import RSConnectException +from .http_support import HTTPResponse, HTTPServer + + +def is_snow_installed() -> bool: + try: + import snowflake.cli # noqa + + return True + except ImportError: + try: + subprocess.run(["snow", "--help"], capture_output=True) + return True + except OSError: + return False + + +def list_connections(): + + if not is_snow_installed(): + raise RSConnectException( + "The snowflake-cli is required but not installed." + "Install it with 'pip install rsconnect-python[snowflake]'" + ) + snow_cx_list = subprocess.run( + ["snow", "connection", "list", "--format", "json"], + capture_output=True, + text=True, + check=True, + ) + connection_list = json.loads(snow_cx_list.stdout) + return connection_list + + +def get_connection(name: Optional[str] = None) -> Optional[Dict[str, Any]]: + connection_list = list_connections() + + if not name: + return next((x["parameters"] for x in connection_list if x.get("is_default")), None) + else: + return next((x["parameters"] for x in connection_list if x.get("connection_name") == name), None) + + +def get_jwt(snowflake_connection_name: Optional[str] = None): + connection_name = "" if snowflake_connection_name is None else snowflake_connection_name + snow_cx_jwt = subprocess.run( + args=["snow", "connection", "generate-jwt", "--connection", connection_name, "--format", "json"], + capture_output=True, + text=True, + check=True, + ) + output = json.loads(snow_cx_jwt.stdout) + jwt = output.get("message") + return jwt + + +def get_token_endpoint(snowflake_connection_name: Optional[str] = None) -> str: + cx = get_connection(snowflake_connection_name) + if cx is None: + raise RSConnectException("No Snowflake connection found") + + return "https://{}.snowflakecomputing.com/".format(cx["account"]) + + +class SnowflakeExchangeClient(HTTPServer): + + def fmt_payload(self, spcs_endpoint: str, snowflake_connection_name: Optional[str] = None): + cx = get_connection(snowflake_connection_name) + if cx is None: + raise RSConnectException("No Snowflake connection found") + spcs_url = urlparse(spcs_endpoint) + + scope = "session:role:{} {}".format(cx["role"], spcs_url.netloc) + jwt = get_jwt(snowflake_connection_name) + grant_type = "urn:ietf:params:oauth:grant-type:jwt-bearer" + + payload = {"scope": scope, "assertion": jwt, "grant_type": grant_type} + payload = urlencode(payload) + return payload + + def exchange_token(self, spcs_endpoint: str, snowflake_connection_name: Optional[str] = None) -> str: + """ + Exchange Snowflake JWT for an OAuth token. + + Args: + spcs_endpoint: The SPCS endpoint URL + snowflake_connection_name: Optional name of the Snowflake connection + + Returns: + The OAuth token response or None if the exchange fails + + Raises: + RSConnectException: If the token exchange fails + """ + try: + payload = self.fmt_payload(spcs_endpoint, snowflake_connection_name) + + response = self.request( + method="POST", + path="/oauth/token", + body=payload, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + response = cast(HTTPResponse, response) + + # borrowed from AbstractRemoteServer.handle_bad_response + # since we don't want to pick up its json decoding assumptions + if response.status < 200 or response.status > 299: + raise RSConnectException( + "Received an unexpected response from %s (calling %s): %s %s" + % ( + self._url, + response.full_uri, + response.status, + response.reason, + ) + ) + + # Validate response body exists + if not response.response_body: + raise RSConnectException("Token exchange returned empty response") + + return response.response_body + + except RSConnectException as e: + raise RSConnectException(f"Failed to exchange Snowflake token: {str(e)}") from e diff --git a/rsconnect/validation.py b/rsconnect/validation.py index 8c5f455d..bb622fa8 100644 --- a/rsconnect/validation.py +++ b/rsconnect/validation.py @@ -45,6 +45,7 @@ def validate_connection_options( token: Optional[str], secret: Optional[str], name: Optional[str] = None, + snowflake_connection_name: Optional[str] = None, ): """ Validates provided Connect or shinyapps.io connection options and returns which target to use given the provided @@ -63,6 +64,7 @@ def validate_connection_options( -T/--token or SHINYAPPS_TOKEN or RSCLOUD_TOKEN -S/--secret or SHINYAPPS_SECRET or RSCLOUD_SECRET -A/--account or SHINYAPPS_ACCOUNT + --snowflake-connection-name FAILURE if any of: -k/--api-key or CONNECT_API_KEY @@ -72,6 +74,7 @@ def validate_connection_options( -T/--token or SHINYAPPS_TOKEN or RSCLOUD_TOKEN -S/--secret or SHINYAPPS_SECRET or RSCLOUD_SECRET -A/--account or SHINYAPPS_ACCOUNT + --snowflake-connection-name FAILURE if specify -s/--server or CONNECT_SERVER and it includes "posit.cloud" or "rstudio.cloud" and not specified all of following: @@ -82,10 +85,15 @@ def validate_connection_options( -T/--token or SHINYAPPS_TOKEN or RSCLOUD_TOKEN -S/--secret or SHINYAPPS_SECRET or RSCLOUD_SECRET -A/--account or SHINYAPPS_ACCOUNT + + FAILURE if -s/--server or CONNECT_SERVER include "snowflakecomputing.app" + and not + --snowflake-connection-name """ connect_options = {"-k/--api-key": api_key, "-i/--insecure": insecure, "-c/--cacert": cacert} shinyapps_options = {"-T/--token": token, "-S/--secret": secret, "-A/--account": account_name} cloud_options = {"-T/--token": token, "-S/--secret": secret} + spcs_options = {"--snowflake-connection-name": snowflake_connection_name} options_mutually_exclusive_with_name = {"-s/--server": url, **shinyapps_options} present_options_mutually_exclusive_with_name = _get_present_options(options_mutually_exclusive_with_name, ctx) @@ -105,11 +113,25 @@ def validate_connection_options( present_connect_options = _get_present_options(connect_options, ctx) present_shinyapps_options = _get_present_options(shinyapps_options, ctx) present_cloud_options = _get_present_options(cloud_options, ctx) + present_spcs_options = _get_present_options(spcs_options, ctx) if present_connect_options and present_shinyapps_options: raise RSConnectException( f"Connect options ({', '.join(present_connect_options)}) may not be passed \ alongside shinyapps.io or Posit Cloud options ({', '.join(present_shinyapps_options)}). \ +See command help for further details." + ) + + if snowflake_connection_name and not url: + raise RSConnectException( + "--snowflake-connection-name requires -s/--server to be specified. \ +See command help for further details." + ) + + if present_shinyapps_options and present_spcs_options: + raise RSConnectException( + f"Shinyapps.io/Cloud options ({', '.join(present_shinyapps_options)}) may not be passed \ +alongside SPCS options ({', '.join(present_spcs_options)}). \ See command help for further details." )