Skip to content

spcs support #651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ test = [
"twine",
"types-Flask",
]
snowflake = ["snowflake-cli"]

[project.urls]
Repository = "http://github.com/posit-dev/rsconnect-python"
Expand Down
8 changes: 8 additions & 0 deletions rsconnect/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,14 @@ def test_rstudio_server(server: api.PositServer):
raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc))


def test_spcs_server(server: api.SPCSConnectServer):
with api.RSConnectClient(server) as client:
try:
client.me()
except RSConnectException as exc:
raise RSConnectException("Failed to verify with {} ({}).".format(server.remote_name, exc))


def test_api_key(connect_server: api.RSConnectServer) -> str:
"""
Test that an API Key may be used to authenticate with the given Posit Connect server.
Expand Down
69 changes: 65 additions & 4 deletions rsconnect/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,33 @@ def __init__(
self.ca_data = ca_data
# This is specifically not None.
self.cookie_jar = CookieJar()
# 🤡
self.snowflake_connection_name = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious as to why this is needed here but not in the other non SPCS targetable servers.



TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer]
class SPCSConnectServer(AbstractRemoteServer):
"""
A class encapsulating the information required to interact with Connect in SPCS.
"""

def __init__(
self,
url: str,
snowflake_connection_name: Optional[str],
insecure: bool = False,
ca_data: Optional[str | bytes] = None,
):
super().__init__(url, "Posit Connect")
self.insecure = insecure
self.ca_data = ca_data
self.snowflake_connection_name = snowflake_connection_name
# for compatibility with RSConnectClient
self.cookie_jar = CookieJar()
self.api_key = None
self.bootstrap_jwt = None


TargetableServer = typing.Union[ShinyappsServer, RSConnectServer, CloudServer, SPCSConnectServer]


class S3Server(AbstractRemoteServer):
Expand All @@ -254,7 +278,7 @@ class RSConnectClientDeployResult(TypedDict):


class RSConnectClient(HTTPServer):
def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None):
def __init__(self, server: RSConnectServer | SPCSConnectServer, cookies: Optional[CookieJar] = None):
if cookies is None:
cookies = server.cookie_jar
super().__init__(
Expand All @@ -271,6 +295,14 @@ def __init__(self, server: RSConnectServer, cookies: Optional[CookieJar] = None)
if server.bootstrap_jwt:
self.bootstrap_authorization(server.bootstrap_jwt)

if server.snowflake_connection_name:
from .snowflake import SnowflakeExchangeClient, get_token_endpoint

token_endpoint = get_token_endpoint(server.snowflake_connection_name)
snowflake_client = SnowflakeExchangeClient(token_endpoint)
token = snowflake_client.exchange_token(server.url, server.snowflake_connection_name)
self.snowflake_authorization(token)

def _tweak_response(self, response: HTTPResponse) -> JsonData | HTTPResponse:
return (
response.json_data
Expand Down Expand Up @@ -555,6 +587,7 @@ def __init__(
name: Optional[str] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
snowflake_connection_name: Optional[str] = None,
insecure: bool = False,
cacert: Optional[str] = None,
ca_data: Optional[str | bytes] = None,
Expand Down Expand Up @@ -604,6 +637,7 @@ def __init__(
name=name,
url=url or server,
api_key=api_key,
snowflake_connection_name=snowflake_connection_name,
insecure=insecure,
cacert=cacert,
ca_data=ca_data,
Expand Down Expand Up @@ -689,6 +723,7 @@ def setup_remote_server(
name: Optional[str] = None,
url: Optional[str] = None,
api_key: Optional[str] = None,
snowflake_connection_name: Optional[str] = None,
insecure: bool = False,
cacert: Optional[str] = None,
ca_data: Optional[str | bytes] = None,
Expand All @@ -700,6 +735,7 @@ def setup_remote_server(
ctx=ctx,
url=url,
api_key=api_key,
snowflake_connection_name=snowflake_connection_name,
insecure=insecure,
cacert=cacert,
account_name=account_name,
Expand Down Expand Up @@ -741,12 +777,16 @@ def setup_remote_server(
account_name = server_data.account_name or account_name
token = server_data.token or token
secret = server_data.secret or secret
snowflake_connection_name = server_data.snowflake_connection_name or snowflake_connection_name

self.is_server_from_store = server_data.from_store

if api_key:
url = cast(str, url)
self.remote_server = RSConnectServer(url, api_key, insecure, ca_data)
elif snowflake_connection_name:
url = cast(str, url)
self.remote_server = SPCSConnectServer(url, snowflake_connection_name)
elif token and secret:
if url and ("rstudio.cloud" in url or "posit.cloud" in url):
account_name = cast(str, account_name)
Expand All @@ -761,6 +801,8 @@ def setup_remote_server(
def setup_client(self, cookies: Optional[CookieJar] = None):
if isinstance(self.remote_server, RSConnectServer):
self.client = RSConnectClient(self.remote_server, cookies)
elif isinstance(self.remote_server, SPCSConnectServer):
self.client = RSConnectClient(self.remote_server)
elif isinstance(self.remote_server, PositServer):
self.client = PositClient(self.remote_server)
else:
Expand All @@ -774,7 +816,9 @@ def validate_server(self):
"""
Validate that there is enough information to talk to shinyapps.io or a Connect server.
"""
if isinstance(self.remote_server, RSConnectServer):
if isinstance(self.remote_server, SPCSConnectServer):
self.validate_spcs_server()
elif isinstance(self.remote_server, RSConnectServer):
self.validate_connect_server()
elif isinstance(self.remote_server, PositServer):
self.validate_posit_server()
Expand Down Expand Up @@ -815,6 +859,23 @@ def validate_connect_server(self):

return self

def validate_spcs_server(self):
if not isinstance(self.remote_server, SPCSConnectServer):
raise RSConnectException("remote_server must be a Connect server in SPCS")

url = self.remote_server.url
snowflake_connection_name = self.remote_server.snowflake_connection_name
server = SPCSConnectServer(url, snowflake_connection_name)

with RSConnectClient(server) as client:
try:
result = client.me()
result = server.handle_bad_response(result)
except RSConnectException as exc:
raise RSConnectException(f"Failed to verify with {server.remote_name} ({exc})")

return self

def validate_posit_server(self):
if not isinstance(self.remote_server, PositServer):
raise RSConnectException("remote_server is not a Posit server.")
Expand Down Expand Up @@ -885,7 +946,7 @@ def deploy_bundle(self):
if self.bundle is None:
raise RSConnectException("A bundle must be created before deploying it.")

if isinstance(self.remote_server, RSConnectServer):
if isinstance(self.remote_server, RSConnectServer | SPCSConnectServer):
if not isinstance(self.client, RSConnectClient):
raise RSConnectException("client must be an RSConnectClient.")
result = self.client.deploy(
Expand Down
9 changes: 8 additions & 1 deletion rsconnect/http_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,11 @@ def __init__(
and self.response_body is not None
and len(self.response_body) > 0
):
self.json_data = json.loads(self.response_body)
try:
self.json_data = json.loads(self.response_body)
# snowflake crudo
except json.decoder.JSONDecodeError:
self.response_body


class HTTPServer(object):
Expand Down Expand Up @@ -256,6 +260,9 @@ def key_authorization(self, key: str):
def bootstrap_authorization(self, key: str):
self.authorization("Connect-Bootstrap %s" % key)

def snowflake_authorization(self, token: str):
self.authorization('Snowflake Token="%s"' % token)

def _get_full_path(self, path: str):
return append_to_path(self._url.path, path)

Expand Down
Loading
Loading