Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add alisa config
Browse files Browse the repository at this point in the history
lhw362950217 committed Oct 28, 2020
1 parent a11f794 commit 4621bb7
Showing 2 changed files with 150 additions and 0 deletions.
84 changes: 84 additions & 0 deletions python/runtime/dbapi/pyalisa/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2020 The SQLFlow Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import base64
import json
import re
import urllib
from collections import OrderedDict

from six.moves.urllib.parse import parse_qs, urlencode, urlparse


class Config(object):
"""Alisa config object, this can be parsed from an alisa dsn
Args:
url(string): a connection url like "alisa://user:pwd@host/path?env=AAB&with=XSE".
There are three required params in the url: current_project, env and with.
The env and with params are maps, which is dumpped to json and then encoded
in base64 format, that is: env=base64(json.dumps({"a":1, "b":2}))
"""
def __init__(self, url):
urlpts = urlparse(url)
kvs = parse_qs(urlpts.query)
required = ["env", "with", "curr_project"]
for k in required:
if k not in kvs:
raise ValueError("Given dsn does not contain: %s" % k)
# extract the param if it's only has one element
for k, v in kvs.items():
if len(v) == 1:
kvs[k] = v[0]

self.pop_access_id = urlpts.username
self.pop_access_secret = urlpts.password
self.pop_url = urlpts.hostname + urlpts.path
self.pop_scheme = urlpts.scheme

self.env = Config._decode_json_base64(kvs["env"])
self.withs = Config._decode_json_base64(kvs["with"])
self.scheme = kvs["scheme"] or "http"
self.verbose = kvs["verbose"] == "true"
self.curr_project = kvs["curr_project"]

@staticmethod
def _encode_json_base64(env):
# We sort the env params to ensure the consistent encoding
jstr = json.dumps(OrderedDict(env))
b64 = base64.urlsafe_b64encode(jstr.encode("utf8")).decode("utf8")
return b64.rstrip("=")

@staticmethod
def _decode_json_base64(b64env):
padded = b64env + "=" * (len(b64env) % 4)
jstr = base64.urlsafe_b64decode(padded).decode("utf8")
return json.loads(jstr)

def to_url(self):
"""Serialize a config to connection url
Returns:
(string) a connection url build from this config
"""
parts = (
self.pop_access_id,
self.pop_access_secret,
self.pop_url,
self.scheme,
"true" if self.verbose else "false",
self.curr_project,
Config._encode_json_base64(self.env),
Config._encode_json_base64(self.withs),
)
return "alisa://%s:%s@%s?scheme=%s&verbose=%s&curr_project=%s&env=%s&with=%s" % parts
66 changes: 66 additions & 0 deletions python/runtime/dbapi/pyalisa/config_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2020 The SQLFlow Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import unittest

from runtime.dbapi.pyalisa.config import Config

test_url = ("alisa://pid:psc@dw.a.hk/?scheme=http&verbose=true&"
"curr_project=jtest_env&env=eyJTS1lORVRfT05EVVRZIjog"
"IlNLWSIsICJTS1lORVRfQUNDRVNTSUQiOiAiU0tZIiwgIlNLWU5"
"FVF9TWVNURU1JRCI6ICJTS1kiLCAiQUxJU0FfVEFTS19JRCI6IC"
"JBTEkiLCAiU0tZTkVUX0VORFBPSU5UIjogIlNLWSIsICJTS1lOR"
"VRfU1lTVEVNX0VOViI6ICJTS1kiLCAiU0tZTkVUX0JJWkRBVEUi"
"OiAiU0tZIiwgIlNLWU5FVF9BQ0NFU1NLRVkiOiAiU0tZIiwgIlNL"
"WU5FVF9QQUNLQUdFSUQiOiAiU0tZIiwgIkFMSVNBX1RBU0tfRVhF"
"Q19UQVJHRVQiOiAiQUxJIn0&with=eyJFeGVjIjogIndlYy5zaCI"
"sICJQbHVnaW5OYW1lIjogIndwZSIsICJDdXN0b21lcklkIjogIndjZCJ9")


class TestConfig(unittest.TestCase):
def test_encode_json_base64(self):
params = dict()
params["key1"] = "val1"
params["key2"] = "val2"
b64 = Config._encode_json_base64(params)
self.assertEqual("eyJrZXkxIjogInZhbDEiLCAia2V5MiI6ICJ2YWwyIn0", b64)

params = Config._decode_json_base64(b64)
self.assertEqual(2, len(params))
self.assertEqual("val1", params["key1"])
self.assertEqual("val2", params["key2"])

def test_dsn_parsing(self):
cfg = Config(test_url)
self.assertEqual("pid", cfg.pop_access_id)
self.assertEqual("psc", cfg.pop_access_secret)
self.assertEqual("jtest_env", cfg.curr_project)
self.assertEqual("http", cfg.scheme)
self.assertEqual("wcd", cfg.withs["CustomerId"])
self.assertEqual("wpe", cfg.withs["PluginName"])
self.assertEqual("wec.sh", cfg.withs["Exec"])
self.assertEqual("SKY", cfg.env["SKYNET_ACCESSKEY"])

def test_to_dsn(self):
cfg = Config(test_url)
url = cfg.to_url()
self.assertEqual(test_url, url)

def test_parse_error(self):
# no env and with
dsn = "alisa://pid:psc@dw.a.hk/?scheme=http&verbose=true"
self.assertRaises(ValueError, lambda: Config(dsn))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4621bb7

Please sign in to comment.