diff --git a/README.md b/README.md index c6502885..2cc79623 100644 --- a/README.md +++ b/README.md @@ -72,8 +72,13 @@ NOTE: `password` and `schema` are optional from sqlalchemy import create_engine from sqlalchemy.schema import Table, MetaData from sqlalchemy.sql.expression import select, text +from trino.sqlalchemy import URL -engine = create_engine('trino://user@localhost:8080/system') +engine = create_engine(URL( + host="localhost", + port=8080, + catalog="system" +)) connection = engine.connect() rows = connection.execute(text("SELECT * FROM runtime.nodes")).fetchall() @@ -93,6 +98,7 @@ Attributes can also be passed in the connection string. ```python from sqlalchemy import create_engine +from trino.sqlalchemy import URL engine = create_engine( 'trino://user@localhost:8080/system', @@ -110,6 +116,14 @@ engine = create_engine( '&client_tags=["tag1", "tag2"]' '&experimental_python_types=true', ) + +# or using the URL factory method +engine = create_engine(URL( + host="localhost", + port=8080, + client_tags=["tag1", "tag2"], + experimental_python_types=True +)) ``` ## Authentication mechanisms diff --git a/tests/unit/sqlalchemy/test_dialect.py b/tests/unit/sqlalchemy/test_dialect.py index b17f8cfe..490d4792 100644 --- a/tests/unit/sqlalchemy/test_dialect.py +++ b/tests/unit/sqlalchemy/test_dialect.py @@ -9,6 +9,7 @@ from trino.dbapi import Connection from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect from trino.transaction import IsolationLevel +from trino.sqlalchemy import URL as trino_url class TestTrinoDialect: @@ -16,20 +17,35 @@ def setup(self): self.dialect = TrinoDialect() @pytest.mark.parametrize( - "url, expected_args, expected_kwargs", + "url, generated_url, expected_args, expected_kwargs", [ ( - make_url("trino://user@localhost"), + make_url(trino_url( + user="user", + host="localhost", + )), + 'trino://user@localhost:8080?source=trino-sqlalchemy', list(), - dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"), + dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"), ), ( - make_url("trino://user@localhost:8080"), + make_url(trino_url( + user="user", + host="localhost", + port=443, + )), + 'trino://user@localhost:443?source=trino-sqlalchemy', list(), - dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"), + dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"), ), ( - make_url("trino://user:pass@localhost:8080?source=trino-rulez"), + make_url(trino_url( + user="user", + password="pass", + host="localhost", + source="trino-rulez", + )), + 'trino://user:***@localhost:8080?source=trino-rulez', list(), dict( host="localhost", @@ -42,13 +58,64 @@ def setup(self): ), ), ( - make_url( - 'trino://user@localhost:8080?' - 'session_properties={"query_max_run_time": "1d"}' - '&http_headers={"trino": 1}' - '&extra_credential=[("a", "b"), ("c", "d")]' - '&client_tags=[1, "sql"]' - '&experimental_python_types=true'), + make_url(trino_url( + user="user", + host="localhost", + cert="/my/path/to/cert", + key="afdlsdfk%4#'", + )), + 'trino://user@localhost:8080' + '?cert=%2Fmy%2Fpath%2Fto%2Fcert' + '&key=afdlsdfk%254%23%27' + '&source=trino-sqlalchemy', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"), + http_scheme="https", + source="trino-sqlalchemy" + ), + ), + ( + make_url(trino_url( + user="user", + host="localhost", + access_token="afdlsdfk%4#'", + )), + 'trino://user@localhost:8080' + '?access_token=afdlsdfk%254%23%27' + '&source=trino-sqlalchemy', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user", + auth=JWTAuthentication("afdlsdfk%4#'"), + http_scheme="https", + source="trino-sqlalchemy" + ), + ), + ( + make_url(trino_url( + user="user", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[("a", "b"), ("c", "d")], + client_tags=["1", "sql"], + experimental_python_types=True, + )), + 'trino://user@localhost:8080' + '?client_tags=%5B%221%22%2C+%22sql%22%5D' + '&experimental_python_types=true' + '&extra_credential=%5B%28%27a%27%2C+%27b%27%29%2C+%28%27c%27%2C+%27d%27%29%5D' + '&http_headers=%7B%22trino%22%3A+1%7D' + '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' + '&source=trino-sqlalchemy', list(), dict( host="localhost", @@ -59,13 +126,66 @@ def setup(self): session_properties={"query_max_run_time": "1d"}, http_headers={"trino": 1}, extra_credential=[("a", "b"), ("c", "d")], - client_tags=[1, "sql"], + client_tags=["1", "sql"], experimental_python_types=True, ), ), + # url encoding + ( + make_url(trino_url( + user="user@test.org/my_role", + password="pass /*&", + host="localhost", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[ + ("user1@test.org/my_role", "user2@test.org/my_role"), + ("user3@test.org/my_role", "user36@test.org/my_role")], + experimental_python_types=True, + client_tags=["1 @& /\"", "sql"], + verify=False, + )), + 'trino://user%40test.org%2Fmy_role:***@localhost:8080' + '?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D' + '&experimental_python_types=true' + '&extra_credential=%5B%28%27user1%40test.org%2Fmy_role%27%2C' + '+%27user2%40test.org%2Fmy_role%27%29%2C' + '+%28%27user3%40test.org%2Fmy_role%27%2C' + '+%27user36%40test.org%2Fmy_role%27%29%5D' + '&http_headers=%7B%22trino%22%3A+1%7D' + '&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D' + '&source=trino-sqlalchemy' + '&verify=false', + list(), + dict( + host="localhost", + port=8080, + catalog="system", + user="user@test.org/my_role", + auth=BasicAuthentication("user@test.org/my_role", "pass /*&"), + http_scheme="https", + source="trino-sqlalchemy", + session_properties={"query_max_run_time": "1d"}, + http_headers={"trino": 1}, + extra_credential=[ + ("user1@test.org/my_role", "user2@test.org/my_role"), + ("user3@test.org/my_role", "user36@test.org/my_role")], + experimental_python_types=True, + client_tags=["1 @& /\"", "sql"], + verify=False, + ), + ), ], ) - def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]): + def test_create_connect_args( + self, + url: URL, + generated_url: str, + expected_args: List[Any], + expected_kwargs: Dict[str, Any] + ): + assert repr(url) == generated_url + actual_args, actual_kwargs = self.dialect.create_connect_args(url) assert actual_args == expected_args diff --git a/trino/sqlalchemy/__init__.py b/trino/sqlalchemy/__init__.py index 000d3e08..3c10f0b8 100644 --- a/trino/sqlalchemy/__init__.py +++ b/trino/sqlalchemy/__init__.py @@ -10,5 +10,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.dialects import registry +from .util import _url as URL # noqa registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect") diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index e967cb6b..ad0536fd 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -13,6 +13,7 @@ from ast import literal_eval from textwrap import dedent from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple +from urllib.parse import unquote_plus, unquote from sqlalchemy import exc, sql from sqlalchemy.engine.base import Connection @@ -80,49 +81,52 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any db_parts = (url.database or "system").split("/") if len(db_parts) == 1: - kwargs["catalog"] = db_parts[0] + kwargs["catalog"] = unquote_plus(db_parts[0]) elif len(db_parts) == 2: - kwargs["catalog"] = db_parts[0] - kwargs["schema"] = db_parts[1] + kwargs["catalog"] = unquote_plus(db_parts[0]) + kwargs["schema"] = unquote_plus(db_parts[1]) else: raise ValueError(f"Unexpected database format {url.database}") if url.username: - kwargs["user"] = url.username + kwargs["user"] = unquote(url.username) if url.password: if not url.username: raise ValueError("Username is required when specify password in connection URL") kwargs["http_scheme"] = "https" - kwargs["auth"] = BasicAuthentication(url.username, url.password) + kwargs["auth"] = BasicAuthentication(unquote(url.username), unquote(url.password)) if "access_token" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = JWTAuthentication(url.query["access_token"]) + kwargs["auth"] = JWTAuthentication(unquote(url.query["access_token"])) if "cert" and "key" in url.query: kwargs["http_scheme"] = "https" - kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key']) + kwargs["auth"] = CertificateAuthentication(unquote(url.query['cert']), unquote(url.query['key'])) if "source" in url.query: - kwargs["source"] = url.query["source"] + kwargs["source"] = unquote(url.query["source"]) else: kwargs["source"] = "trino-sqlalchemy" if "session_properties" in url.query: - kwargs["session_properties"] = json.loads(url.query["session_properties"]) + kwargs["session_properties"] = json.loads(unquote(url.query["session_properties"])) if "http_headers" in url.query: - kwargs["http_headers"] = json.loads(url.query["http_headers"]) + kwargs["http_headers"] = json.loads(unquote(url.query["http_headers"])) if "extra_credential" in url.query: - kwargs["extra_credential"] = literal_eval(url.query["extra_credential"]) + kwargs["extra_credential"] = literal_eval(unquote(url.query["extra_credential"])) if "client_tags" in url.query: - kwargs["client_tags"] = json.loads(url.query["client_tags"]) + kwargs["client_tags"] = json.loads(unquote(url.query["client_tags"])) if "experimental_python_types" in url.query: - kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"]) + kwargs["experimental_python_types"] = json.loads(unquote(url.query["experimental_python_types"])) + + if "verify" in url.query: + kwargs["verify"] = json.loads(unquote(url.query["verify"])) return args, kwargs diff --git a/trino/sqlalchemy/util.py b/trino/sqlalchemy/util.py new file mode 100644 index 00000000..4a33ba35 --- /dev/null +++ b/trino/sqlalchemy/util.py @@ -0,0 +1,94 @@ +import json +from urllib.parse import quote_plus + +from typing import Optional, Dict, List, Union, Tuple +from sqlalchemy import exc +from sqlalchemy.engine.url import _rfc_1738_quote # noqa + + +def _url( + host: str, + port: Optional[int] = 8080, + user: Optional[str] = None, + password: Optional[str] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + source: Optional[str] = "trino-sqlalchemy", + session_properties: Dict[str, str] = None, + http_headers: Dict[str, Union[str, int]] = None, + extra_credential: Optional[List[Tuple[str, str]]] = None, + client_tags: Optional[List[str]] = None, + experimental_python_types: Optional[bool] = None, + access_token: Optional[str] = None, + cert: Optional[str] = None, + key: Optional[str] = None, + verify: Optional[bool] = None, +) -> str: + """ + Composes a SQLAlchemy connection string from the given database connection + parameters. + Parameters containing special characters (e.g., '@', '%') need to be encoded to be parsed correctly. + """ + + trino_url = "trino://" + + if user is not None: + trino_url += _rfc_1738_quote(user) + + if password is not None: + if user is None: + raise exc.ArgumentError("user must be specified when specifying a password.") + trino_url += f":{_rfc_1738_quote(password)}" + + if user is not None: + trino_url += "@" + + if not host: + raise exc.ArgumentError("host must be specified.") + + trino_url += host + + if not port: + raise exc.ArgumentError("port must be specified.") + + trino_url += f":{port}" + + if catalog is not None: + trino_url += f"/{quote_plus(catalog)}" + + if schema is not None: + if catalog is None: + raise exc.ArgumentError("catalog must be specified when specifying a default schema.") + trino_url += f"/{quote_plus(schema)}" + + assert source + trino_url += f"?source={quote_plus(source)}" + + if session_properties is not None: + trino_url += f"&session_properties={quote_plus(json.dumps(session_properties))}" + + if http_headers is not None: + trino_url += f"&http_headers={quote_plus(json.dumps(http_headers))}" + + if extra_credential is not None: + trino_url += f"&extra_credential={quote_plus(repr(extra_credential))}" + + if client_tags is not None: + trino_url += f"&client_tags={quote_plus(json.dumps(client_tags))}" + + if experimental_python_types is not None: + trino_url += f"&experimental_python_types={json.dumps(experimental_python_types)}" + + if access_token is not None: + trino_url += f"&access_token={quote_plus(access_token)}" + + if cert is not None: + trino_url += f"&cert={quote_plus(cert)}" + + if key is not None: + trino_url += f"&key={quote_plus(key)}" + + if verify is not None: + trino_url += f"&verify={json.dumps(verify)}" + + return trino_url