Skip to content

Commit

Permalink
Escape url parameters in sqlalchemy connection strings
Browse files Browse the repository at this point in the history
  • Loading branch information
mdesmet committed Sep 26, 2022
1 parent cd614ff commit 54d8603
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 29 deletions.
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,13 @@ NOTE: `password` and `schema` are optional
from sqlalchemy import create_engine
from sqlalchemy.schema import Table, MetaData
from sqlalchemy.sql.expression import select, text
from trino.sqlalchemy import URL

engine = create_engine('trino://user@localhost:8080/system')
engine = create_engine(URL(
host="localhost",
port=8080,
catalog="system"
))
connection = engine.connect()

rows = connection.execute(text("SELECT * FROM runtime.nodes")).fetchall()
Expand All @@ -93,6 +98,7 @@ Attributes can also be passed in the connection string.

```python
from sqlalchemy import create_engine
from trino.sqlalchemy import URL

engine = create_engine(
'trino://user@localhost:8080/system',
Expand All @@ -110,6 +116,14 @@ engine = create_engine(
'&client_tags=["tag1", "tag2"]'
'&experimental_python_types=true',
)

# or using the URL factory method
engine = create_engine(URL(
host="localhost",
port=8080,
client_tags=["tag1", "tag2"],
experimental_python_types=True
))
```

## Authentication mechanisms
Expand Down
150 changes: 135 additions & 15 deletions tests/unit/sqlalchemy/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,43 @@
from trino.dbapi import Connection
from trino.sqlalchemy.dialect import CertificateAuthentication, JWTAuthentication, TrinoDialect
from trino.transaction import IsolationLevel
from trino.sqlalchemy import URL as trino_url


class TestTrinoDialect:
def setup(self):
self.dialect = TrinoDialect()

@pytest.mark.parametrize(
"url, expected_args, expected_kwargs",
"url, generated_url, expected_args, expected_kwargs",
[
(
make_url("trino://user@localhost"),
make_url(trino_url(
user="user",
host="localhost",
)),
'trino://user@localhost:8080?source=trino-sqlalchemy',
list(),
dict(host="localhost", catalog="system", user="user", source="trino-sqlalchemy"),
dict(host="localhost", catalog="system", user="user", port=8080, source="trino-sqlalchemy"),
),
(
make_url("trino://user@localhost:8080"),
make_url(trino_url(
user="user",
host="localhost",
port=443,
)),
'trino://user@localhost:443?source=trino-sqlalchemy',
list(),
dict(host="localhost", port=8080, catalog="system", user="user", source="trino-sqlalchemy"),
dict(host="localhost", port=443, catalog="system", user="user", source="trino-sqlalchemy"),
),
(
make_url("trino://user:pass@localhost:8080?source=trino-rulez"),
make_url(trino_url(
user="user",
password="pass",
host="localhost",
source="trino-rulez",
)),
'trino://user:***@localhost:8080?source=trino-rulez',
list(),
dict(
host="localhost",
Expand All @@ -42,13 +58,64 @@ def setup(self):
),
),
(
make_url(
'trino://user@localhost:8080?'
'session_properties={"query_max_run_time": "1d"}'
'&http_headers={"trino": 1}'
'&extra_credential=[("a", "b"), ("c", "d")]'
'&client_tags=[1, "sql"]'
'&experimental_python_types=true'),
make_url(trino_url(
user="user",
host="localhost",
cert="/my/path/to/cert",
key="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'?cert=%2Fmy%2Fpath%2Fto%2Fcert'
'&key=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
auth=CertificateAuthentication("/my/path/to/cert", "afdlsdfk%4#'"),
http_scheme="https",
source="trino-sqlalchemy"
),
),
(
make_url(trino_url(
user="user",
host="localhost",
access_token="afdlsdfk%4#'",
)),
'trino://user@localhost:8080'
'?access_token=afdlsdfk%254%23%27'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user",
auth=JWTAuthentication("afdlsdfk%4#'"),
http_scheme="https",
source="trino-sqlalchemy"
),
),
(
make_url(trino_url(
user="user",
host="localhost",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[("a", "b"), ("c", "d")],
client_tags=["1", "sql"],
experimental_python_types=True,
)),
'trino://user@localhost:8080'
'?client_tags=%5B%221%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%28%27a%27%2C+%27b%27%29%2C+%28%27c%27%2C+%27d%27%29%5D'
'&http_headers=%7B%22trino%22%3A+1%7D'
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
'&source=trino-sqlalchemy',
list(),
dict(
host="localhost",
Expand All @@ -59,13 +126,66 @@ def setup(self):
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[("a", "b"), ("c", "d")],
client_tags=[1, "sql"],
client_tags=["1", "sql"],
experimental_python_types=True,
),
),
# url encoding
(
make_url(trino_url(
user="user@test.org/my_role",
password="pass /*&",
host="localhost",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[
("user1@test.org/my_role", "user2@test.org/my_role"),
("user3@test.org/my_role", "user36@test.org/my_role")],
experimental_python_types=True,
client_tags=["1 @& /\"", "sql"],
verify=False,
)),
'trino://user%40test.org%2Fmy_role:***@localhost:8080'
'?client_tags=%5B%221+%40%26+%2F%5C%22%22%2C+%22sql%22%5D'
'&experimental_python_types=true'
'&extra_credential=%5B%28%27user1%40test.org%2Fmy_role%27%2C'
'+%27user2%40test.org%2Fmy_role%27%29%2C'
'+%28%27user3%40test.org%2Fmy_role%27%2C'
'+%27user36%40test.org%2Fmy_role%27%29%5D'
'&http_headers=%7B%22trino%22%3A+1%7D'
'&session_properties=%7B%22query_max_run_time%22%3A+%221d%22%7D'
'&source=trino-sqlalchemy'
'&verify=false',
list(),
dict(
host="localhost",
port=8080,
catalog="system",
user="user@test.org/my_role",
auth=BasicAuthentication("user@test.org/my_role", "pass /*&"),
http_scheme="https",
source="trino-sqlalchemy",
session_properties={"query_max_run_time": "1d"},
http_headers={"trino": 1},
extra_credential=[
("user1@test.org/my_role", "user2@test.org/my_role"),
("user3@test.org/my_role", "user36@test.org/my_role")],
experimental_python_types=True,
client_tags=["1 @& /\"", "sql"],
verify=False,
),
),
],
)
def test_create_connect_args(self, url: URL, expected_args: List[Any], expected_kwargs: Dict[str, Any]):
def test_create_connect_args(
self,
url: URL,
generated_url: str,
expected_args: List[Any],
expected_kwargs: Dict[str, Any]
):
assert repr(url) == generated_url

actual_args, actual_kwargs = self.dialect.create_connect_args(url)

assert actual_args == expected_args
Expand Down
1 change: 1 addition & 0 deletions trino/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.dialects import registry
from .util import _url as URL # noqa

registry.register("trino", "trino.sqlalchemy.dialect", "TrinoDialect")
30 changes: 17 additions & 13 deletions trino/sqlalchemy/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ast import literal_eval
from textwrap import dedent
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
from urllib.parse import unquote_plus, unquote

from sqlalchemy import exc, sql
from sqlalchemy.engine.base import Connection
Expand Down Expand Up @@ -80,49 +81,52 @@ def create_connect_args(self, url: URL) -> Tuple[Sequence[Any], Mapping[str, Any

db_parts = (url.database or "system").split("/")
if len(db_parts) == 1:
kwargs["catalog"] = db_parts[0]
kwargs["catalog"] = unquote_plus(db_parts[0])
elif len(db_parts) == 2:
kwargs["catalog"] = db_parts[0]
kwargs["schema"] = db_parts[1]
kwargs["catalog"] = unquote_plus(db_parts[0])
kwargs["schema"] = unquote_plus(db_parts[1])
else:
raise ValueError(f"Unexpected database format {url.database}")

if url.username:
kwargs["user"] = url.username
kwargs["user"] = unquote(url.username)

if url.password:
if not url.username:
raise ValueError("Username is required when specify password in connection URL")
kwargs["http_scheme"] = "https"
kwargs["auth"] = BasicAuthentication(url.username, url.password)
kwargs["auth"] = BasicAuthentication(unquote(url.username), unquote(url.password))

if "access_token" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = JWTAuthentication(url.query["access_token"])
kwargs["auth"] = JWTAuthentication(unquote(url.query["access_token"]))

if "cert" and "key" in url.query:
kwargs["http_scheme"] = "https"
kwargs["auth"] = CertificateAuthentication(url.query['cert'], url.query['key'])
kwargs["auth"] = CertificateAuthentication(unquote(url.query['cert']), unquote(url.query['key']))

if "source" in url.query:
kwargs["source"] = url.query["source"]
kwargs["source"] = unquote(url.query["source"])
else:
kwargs["source"] = "trino-sqlalchemy"

if "session_properties" in url.query:
kwargs["session_properties"] = json.loads(url.query["session_properties"])
kwargs["session_properties"] = json.loads(unquote(url.query["session_properties"]))

if "http_headers" in url.query:
kwargs["http_headers"] = json.loads(url.query["http_headers"])
kwargs["http_headers"] = json.loads(unquote(url.query["http_headers"]))

if "extra_credential" in url.query:
kwargs["extra_credential"] = literal_eval(url.query["extra_credential"])
kwargs["extra_credential"] = literal_eval(unquote(url.query["extra_credential"]))

if "client_tags" in url.query:
kwargs["client_tags"] = json.loads(url.query["client_tags"])
kwargs["client_tags"] = json.loads(unquote(url.query["client_tags"]))

if "experimental_python_types" in url.query:
kwargs["experimental_python_types"] = json.loads(url.query["experimental_python_types"])
kwargs["experimental_python_types"] = json.loads(unquote(url.query["experimental_python_types"]))

if "verify" in url.query:
kwargs["verify"] = json.loads(unquote(url.query["verify"]))

return args, kwargs

Expand Down
Loading

0 comments on commit 54d8603

Please sign in to comment.