Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trino support server-cert #16346

Merged
merged 1 commit into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from urllib import parse

import simplejson as json
Expand All @@ -24,6 +24,9 @@
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils

if TYPE_CHECKING:
from superset.models.core import Database


class TrinoEngineSpec(BaseEngineSpec):
engine = "trino"
Expand Down Expand Up @@ -81,7 +84,6 @@ def update_impersonation_config(
that can set the correct properties for impersonating users
:param connect_args: config to be updated
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
:return: None
"""
Expand Down Expand Up @@ -116,9 +118,7 @@ def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
Run a SQL query that estimates the cost of a given statement.

:param statement: A single SQL statement
:param database: Database instance
:param cursor: Cursor instance
:param username: Effective username
:return: JSON response from Trino
"""
sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}"
Expand Down Expand Up @@ -183,3 +183,22 @@ def humanize(value: Any, suffix: str) -> str:
cost.append(statement_cost)

return cost

@staticmethod
def get_extra_params(database: "Database") -> Dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.

:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})

if database.server_cert:
connect_args["http_scheme"] = "https"
connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert)

return extra
35 changes: 35 additions & 0 deletions tests/integration_tests/db_engine_specs/trino_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
from unittest.mock import Mock, patch

from sqlalchemy.engine.url import URL

from superset.db_engine_specs.trino import TrinoEngineSpec
Expand Down Expand Up @@ -52,3 +55,35 @@ def test_adjust_database_uri_when_selected_schema_is_none(self):
url.database = "hive/default"
TrinoEngineSpec.adjust_database_uri(url, selected_schema=None)
self.assertEqual(url.database, "hive/default")

def test_get_extra_params(self):
database = Mock()

database.extra = json.dumps({})
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
expected = {"engine_params": {"connect_args": {}}}
self.assertEqual(extra, expected)

expected = {
"first": 1,
"engine_params": {"second": "two", "connect_args": {"third": "three"}},
}
database.extra = json.dumps(expected)
database.server_cert = None
extra = TrinoEngineSpec.get_extra_params(database)
self.assertEqual(extra, expected)

@patch("superset.utils.core.create_ssl_cert_file")
def test_get_extra_params_with_server_cert(self, create_ssl_cert_file_func: Mock):
database = Mock()

database.extra = json.dumps({})
database.server_cert = "TEST_CERT"
create_ssl_cert_file_func.return_value = "/path/to/tls.crt"
extra = TrinoEngineSpec.get_extra_params(database)

connect_args = extra.get("engine_params", {}).get("connect_args", {})
self.assertEqual(connect_args.get("http_scheme"), "https")
self.assertEqual(connect_args.get("verify"), "/path/to/tls.crt")
create_ssl_cert_file_func.assert_called_once_with(database.server_cert)