From ccb293a083acbaae8c975059f74cac341c7dbe9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Mon, 29 Aug 2022 14:21:28 +0700 Subject: [PATCH] fix(Trino): create `PrestoBaseEngineSpec` base class to share common code between Trino and Presto (#21066) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore: create `PrestoBaseEngineSpec` class that share common functions between Presto and Trino Signed-off-by: Đặng Minh Dũng * feat(Trino): support CertificateAuthentication * chore(Presto): move `get_function_names` to `PrestoBaseEngineSpec` Signed-off-by: Đặng Minh Dũng * chores(Presto): remove `is_readonly_query` * feat(Trino): implement `extra_table_metadata` * feat(Trino): specify `User-Agent` Signed-off-by: Đặng Minh Dũng * fix: pylint Signed-off-by: Đặng Minh Dũng * chores(Presto): move `PrestoBaseEngineSpec` to `presto.py` Signed-off-by: Đặng Minh Dũng * fix(Presto): typing annotations Signed-off-by: Đặng Minh Dũng Signed-off-by: Đặng Minh Dũng --- docs/docs/databases/trino.mdx | 18 +- superset/db_engine_specs/hive.py | 5 - superset/db_engine_specs/presto.py | 308 +++++++++--------- superset/db_engine_specs/trino.py | 73 ++--- .../db_engine_specs/trino_tests.py | 24 +- tests/integration_tests/model_tests.py | 4 +- .../unit_tests/db_engine_specs/test_presto.py | 6 +- 7 files changed, 233 insertions(+), 205 deletions(-) diff --git a/docs/docs/databases/trino.mdx b/docs/docs/databases/trino.mdx index 50ccf1f27123d..4d6bfcf343205 100644 --- a/docs/docs/databases/trino.mdx +++ b/docs/docs/databases/trino.mdx @@ -56,7 +56,21 @@ In `Secure Extra` field, config as following example: All fields in `auth_params` are passed directly to the [`KerberosAuthentication`](https://github.com/trinodb/trino-python-client/blob/0.306.0/trino/auth.py#L40) class. -#### 3. JWT Authentication +#### 3. Certificate Authentication +In `Secure Extra` field, config as following example: +```json +{ + "auth_method": "certificate", + "auth_params": { + "cert": "/path/to/cert.pem", + "key": "/path/to/key.pem" + } +} +``` + +All fields in `auth_params` are passed directly to the [`CertificateAuthentication`](https://github.com/trinodb/trino-python-client/blob/0.315.0/trino/auth.py#L416) class. + +#### 4. JWT Authentication Config `auth_method` and provide token in `Secure Extra` field ```json { @@ -67,7 +81,7 @@ Config `auth_method` and provide token in `Secure Extra` field } ``` -#### 4. Custom Authentication +#### 5. Custom Authentication To use custom authentication, first you need to add it into `ALLOWED_EXTRA_AUTHENTICATIONS` allow list in Superset config file: ```python diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index df42bf2492606..8ea1bfddae686 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -49,7 +49,6 @@ # prevent circular imports from superset.models.core import Database - logger = logging.getLogger(__name__) @@ -262,10 +261,6 @@ def convert_dttm( .isoformat(sep=" ", timespec="microseconds")}' AS TIMESTAMP)""" return None - @classmethod - def epoch_to_dttm(cls) -> str: - return "from_unixtime({col})" - @classmethod def adjust_database_uri( cls, uri: URL, selected_schema: Optional[str] = None diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index c5fc06b62f332..f0c9a349916f8 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -15,10 +15,13 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=too-many-lines +from __future__ import annotations + import logging import re import textwrap import time +from abc import ABCMeta from collections import defaultdict, deque from contextlib import closing from datetime import datetime @@ -55,7 +58,6 @@ TinyInteger, ) from superset.result_set import destringify -from superset.sql_parse import ParsedQuery from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils from superset.utils.core import ColumnSpec, GenericDataType @@ -148,11 +150,12 @@ def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]: raise Exception(f"Unknown type {type_}!") -class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods - engine = "presto" - engine_name = "Presto" - allows_alias_to_source_column = False +class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): + """ + A base class that share common functions between Presto and Trino + """ + # pylint: disable=line-too-long _time_grain_expressions = { None: "{col}", "PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))", @@ -163,12 +166,146 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho "P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))", "P3M": "date_trunc('quarter', CAST({col} AS TIMESTAMP))", "P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))", - "P1W/1970-01-03T00:00:00Z": "date_add('day', 5, date_trunc('week', " - "date_add('day', 1, CAST({col} AS TIMESTAMP))))", - "1969-12-28T00:00:00Z/P1W": "date_add('day', -1, date_trunc('week', " - "date_add('day', 1, CAST({col} AS TIMESTAMP))))", + # Week starting Sunday + "1969-12-28T00:00:00Z/P1W": "date_trunc('week', CAST({col} AS TIMESTAMP) + interval '1' day) - interval '1' day", # noqa + # Week starting Monday + "1969-12-29T00:00:00Z/P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))", + # Week ending Saturday + "P1W/1970-01-03T00:00:00Z": "date_trunc('week', CAST({col} AS TIMESTAMP) + interval '1' day) + interval '5' day", # noqa + # Week ending Sunday + "P1W/1970-01-04T00:00:00Z": "date_trunc('week', CAST({col} AS TIMESTAMP)) + interval '6' day", # noqa } + @classmethod + def convert_dttm( + cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + ) -> Optional[str]: + """ + Convert a Python `datetime` object to a SQL expression. + :param target_type: The target type of expression + :param dttm: The datetime object + :param db_extra: The database extra object + :return: The SQL expression + Superset only defines time zone naive `datetime` objects, though this method + handles both time zone naive and aware conversions. + """ + tt = target_type.upper() + if tt == utils.TemporalType.DATE: + return f"DATE '{dttm.date().isoformat()}'" + if tt in ( + utils.TemporalType.TIMESTAMP, + utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE, + ): + return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'""" + return None + + @classmethod + def epoch_to_dttm(cls) -> str: + return "from_unixtime({col})" + + @classmethod + def adjust_database_uri( + cls, uri: URL, selected_schema: Optional[str] = None + ) -> URL: + database = uri.database + if selected_schema and database: + selected_schema = parse.quote(selected_schema, safe="") + if "/" in database: + database = database.split("/")[0] + "/" + selected_schema + else: + database += "/" + selected_schema + uri = uri.set(database=database) + + return uri + + @classmethod + def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + """ + Run a SQL query that estimates the cost of a given statement. + :param statement: A single SQL statement + :param cursor: Cursor instance + :return: JSON response from Trino + """ + sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}" + cursor.execute(sql) + + # the output from Trino is a single column and a single row containing + # JSON: + # + # { + # ... + # "estimate" : { + # "outputRowCount" : 8.73265878E8, + # "outputSizeInBytes" : 3.41425774958E11, + # "cpuCost" : 3.41425774958E11, + # "maxMemory" : 0.0, + # "networkCost" : 3.41425774958E11 + # } + # } + result = json.loads(cursor.fetchone()[0]) + return result + + @classmethod + def query_cost_formatter( + cls, raw_cost: List[Dict[str, Any]] + ) -> List[Dict[str, str]]: + """ + Format cost estimate. + :param raw_cost: JSON estimate from Trino + :return: Human readable cost estimate + """ + + def humanize(value: Any, suffix: str) -> str: + try: + value = int(value) + except ValueError: + return str(value) + + prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] + prefix = "" + to_next_prefix = 1000 + while value > to_next_prefix and prefixes: + prefix = prefixes.pop(0) + value //= to_next_prefix + + return f"{value} {prefix}{suffix}" + + cost = [] + columns = [ + ("outputRowCount", "Output count", " rows"), + ("outputSizeInBytes", "Output size", "B"), + ("cpuCost", "CPU cost", ""), + ("maxMemory", "Max memory", "B"), + ("networkCost", "Network cost", ""), + ] + for row in raw_cost: + estimate: Dict[str, float] = row.get("estimate", {}) + statement_cost = {} + for key, label, suffix in columns: + if key in estimate: + statement_cost[label] = humanize(estimate[key], suffix).strip() + cost.append(statement_cost) + + return cost + + @classmethod + @cache_manager.data_cache.memoize() + def get_function_names(cls, database: Database) -> List[str]: + """ + Get a list of function names that are able to be called on the database. + Used for SQL Lab autocomplete. + + :param database: The database to get functions for + :return: A list of function names useable in the database + """ + return database.get_df("SHOW FUNCTIONS")["Function"].tolist() + + +class PrestoEngineSpec(PrestoBaseEngineSpec): # pylint: disable=too-many-public-methods + engine = "presto" + engine_name = "Presto" + allows_alias_to_source_column = False + custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( @@ -255,7 +392,7 @@ def update_impersonation_config( @classmethod def get_table_names( - cls, database: "Database", inspector: Inspector, schema: Optional[str] + cls, database: Database, inspector: Inspector, schema: Optional[str] ) -> List[str]: tables = super().get_table_names(database, inspector, schema) if not is_feature_enabled("PRESTO_SPLIT_VIEWS_FROM_TABLES"): @@ -267,7 +404,7 @@ def get_table_names( @classmethod def get_view_names( - cls, database: "Database", inspector: Inspector, schema: Optional[str] + cls, database: Database, inspector: Inspector, schema: Optional[str] ) -> List[str]: """Returns an empty list @@ -625,7 +762,7 @@ def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database: "Database", + database: Database, table_name: str, engine: Engine, schema: Optional[str] = None, @@ -659,125 +796,9 @@ def select_star( # pylint: disable=too-many-arguments presto_cols, ) - @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: - """ - Run a SQL query that estimates the cost of a given statement. - - :param statement: A single SQL statement - :param cursor: Cursor instance - :return: JSON response from Presto - """ - sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}" - cursor.execute(sql) - - # the output from Presto is a single column and a single row containing - # JSON: - # - # { - # ... - # "estimate" : { - # "outputRowCount" : 8.73265878E8, - # "outputSizeInBytes" : 3.41425774958E11, - # "cpuCost" : 3.41425774958E11, - # "maxMemory" : 0.0, - # "networkCost" : 3.41425774958E11 - # } - # } - result = json.loads(cursor.fetchone()[0]) - return result - - @classmethod - def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: - """ - Format cost estimate. - - :param raw_cost: JSON estimate from Presto - :return: Human readable cost estimate - """ - - def humanize(value: Any, suffix: str) -> str: - try: - value = int(value) - except ValueError: - return str(value) - - prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] - prefix = "" - to_next_prefix = 1000 - while value > to_next_prefix and prefixes: - prefix = prefixes.pop(0) - value //= to_next_prefix - - return f"{value} {prefix}{suffix}" - - cost = [] - columns = [ - ("outputRowCount", "Output count", " rows"), - ("outputSizeInBytes", "Output size", "B"), - ("cpuCost", "CPU cost", ""), - ("maxMemory", "Max memory", "B"), - ("networkCost", "Network cost", ""), - ] - for row in raw_cost: - estimate: Dict[str, float] = row.get("estimate", {}) - statement_cost = {} - for key, label, suffix in columns: - if key in estimate: - statement_cost[label] = humanize(estimate[key], suffix).strip() - cost.append(statement_cost) - - return cost - - @classmethod - def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> URL: - database = uri.database - if selected_schema and database: - selected_schema = parse.quote(selected_schema, safe="") - if "/" in database: - database = database.split("/")[0] + "/" + selected_schema - else: - database += "/" + selected_schema - uri = uri.set(database=database) - - return uri - - @classmethod - def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: - """ - Convert a Python `datetime` object to a SQL expression. - - :param target_type: The target type of expression - :param dttm: The datetime object - :param db_extra: The database extra object - :return: The SQL expression - - Superset only defines time zone naive `datetime` objects, though this method - handles both time zone naive and aware conversions. - """ - tt = target_type.upper() - if tt == utils.TemporalType.DATE: - return f"""DATE '{dttm.date().isoformat()}'""" - if tt in ( - utils.TemporalType.TIMESTAMP, - utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE, - ): - return f"""TIMESTAMP '{dttm.isoformat(timespec="milliseconds", sep=" ")}'""" - return None - - @classmethod - def epoch_to_dttm(cls) -> str: - return "from_unixtime({col})" - @classmethod def get_all_datasource_names( - cls, database: "Database", datasource_type: str + cls, database: Database, datasource_type: str ) -> List[utils.DatasourceName]: datasource_df = database.get_df( "SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S " @@ -906,7 +927,7 @@ def expand_data( # pylint: disable=too-many-locals @classmethod def extra_table_metadata( - cls, database: "Database", table_name: str, schema_name: Optional[str] + cls, database: Database, table_name: str, schema_name: Optional[str] ) -> Dict[str, Any]: metadata = {} @@ -938,7 +959,7 @@ def extra_table_metadata( @classmethod def get_create_view( - cls, database: "Database", schema: Optional[str], table: str + cls, database: Database, schema: Optional[str], table: str ) -> Optional[str]: """ Return a CREATE VIEW statement, or `None` if not a view. @@ -1043,7 +1064,7 @@ def _extract_error_message(cls, ex: Exception) -> str: def _partition_query( # pylint: disable=too-many-arguments,too-many-locals cls, table_name: str, - database: "Database", + database: Database, limit: int = 0, order_by: Optional[List[Tuple[str, bool]]] = None, filters: Optional[Dict[Any, Any]] = None, @@ -1101,7 +1122,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments cls, table_name: str, schema: Optional[str], - database: "Database", + database: Database, query: Select, columns: Optional[List[Dict[str, str]]] = None, ) -> Optional[Select]: @@ -1142,7 +1163,7 @@ def latest_partition( cls, table_name: str, schema: Optional[str], - database: "Database", + database: Database, show_first: bool = False, ) -> Tuple[List[str], Optional[List[str]]]: """Returns col name and the latest (max) partition value for a table @@ -1185,7 +1206,7 @@ def latest_partition( @classmethod def latest_sub_partition( - cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any + cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any ) -> Any: """Returns the latest (max) partition value for a table @@ -1236,23 +1257,6 @@ def latest_sub_partition( return "" return df.to_dict()[field_to_return][0] - @classmethod - @cache_manager.data_cache.memoize() - def get_function_names(cls, database: "Database") -> List[str]: - """ - Get a list of function names that are able to be called on the database. - Used for SQL Lab autocomplete. - - :param database: The database to get functions for - :return: A list of function names useable in the database - """ - return database.get_df("SHOW FUNCTIONS")["Function"].tolist() - - @classmethod - def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool: - """Pessimistic readonly, 100% sure statement won't mutate anything""" - return super().is_readonly_query(parsed_query) or parsed_query.is_show() - @classmethod def get_column_spec( cls, diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 9aa89ce34a06f..f0e15c982900c 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,17 +17,17 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, Dict, Optional, TYPE_CHECKING import simplejson as json from flask import current_app -from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.orm import Session +from superset.constants import USER_AGENT from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec -from superset.db_engine_specs.presto import PrestoEngineSpec +from superset.db_engine_specs.presto import PrestoBaseEngineSpec from superset.models.sql_lab import Query from superset.utils import core as utils @@ -42,11 +42,34 @@ logger = logging.getLogger(__name__) -class TrinoEngineSpec(PrestoEngineSpec): +class TrinoEngineSpec(PrestoBaseEngineSpec): engine = "trino" - engine_aliases = {"trinonative"} # Required for backwards compatibility. engine_name = "Trino" + @classmethod + def extra_table_metadata( + cls, + database: Database, + table_name: str, + schema_name: Optional[str], + ) -> Dict[str, Any]: + metadata = {} + + indexes = database.get_indexes(table_name, schema_name) + if indexes: + partitions_columns = [] + for index in indexes: + if index.get("name") == "partition": + partitions_columns += index.get("column_names", []) + metadata["partitions"] = {"cols": partitions_columns} + + if database.has_view_by_name(table_name, schema_name): + metadata["view"] = database.inspector.get_view_definition( + table_name, schema_name + ) + + return metadata + @classmethod def update_impersonation_config( cls, @@ -89,32 +112,6 @@ def get_url_for_impersonation( def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True - @classmethod - def get_table_names( - cls, - database: Database, - inspector: Inspector, - schema: Optional[str], - ) -> List[str]: - return BaseEngineSpec.get_table_names( - database=database, - inspector=inspector, - schema=schema, - ) - - @classmethod - def get_view_names( - cls, - database: Database, - inspector: Inspector, - schema: Optional[str], - ) -> List[str]: - return BaseEngineSpec.get_view_names( - database=database, - inspector=inspector, - schema=schema, - ) - @classmethod def get_tracking_url(cls, cursor: Cursor) -> Optional[str]: try: @@ -138,11 +135,7 @@ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: query.set_extra_json_key("cancel_query", cursor.stats["queryId"]) session.commit() - BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session) - - @classmethod - def has_implicit_cancel(cls) -> bool: - return False + super().handle_cursor(cursor=cursor, query=query, session=session) @classmethod def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: @@ -166,7 +159,7 @@ def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool: return True @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: Database) -> Dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -178,6 +171,8 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + connect_args.setdefault("source", USER_AGENT) + if database.server_cert: connect_args["http_scheme"] = "https" connect_args["verify"] = utils.create_ssl_cert_file(database.server_cert) @@ -186,7 +181,7 @@ def get_extra_params(database: "Database") -> Dict[str, Any]: @staticmethod def update_encrypted_extra_params( - database: "Database", params: Dict[str, Any] + database: Database, params: Dict[str, Any] ) -> None: if not database.encrypted_extra: return @@ -204,6 +199,8 @@ def update_encrypted_extra_params( from trino.auth import BasicAuthentication as trino_auth # noqa elif auth_method == "kerberos": from trino.auth import KerberosAuthentication as trino_auth # noqa + elif auth_method == "certificate": + from trino.auth import CertificateAuthentication as trino_auth # noqa elif auth_method == "jwt": from trino.auth import JWTAuthentication as trino_auth # noqa else: diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py index fc83b8c64c3d7..7b745e8a1c3d9 100644 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -19,9 +19,9 @@ from unittest.mock import Mock, patch import pytest -from sqlalchemy.engine.url import URL import superset.config +from superset.constants import USER_AGENT from superset.db_engine_specs.trino import TrinoEngineSpec from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -33,12 +33,15 @@ def test_get_extra_params(self): database.extra = json.dumps({}) database.server_cert = None extra = TrinoEngineSpec.get_extra_params(database) - expected = {"engine_params": {"connect_args": {}}} + expected = {"engine_params": {"connect_args": {"source": USER_AGENT}}} self.assertEqual(extra, expected) expected = { "first": 1, - "engine_params": {"second": "two", "connect_args": {"third": "three"}}, + "engine_params": { + "second": "two", + "connect_args": {"source": "foobar", "third": "three"}, + }, } database.extra = json.dumps(expected) database.server_cert = None @@ -93,6 +96,21 @@ def test_auth_kerberos(self, auth: Mock): self.assertEqual(connect_args.get("http_scheme"), "https") auth.assert_called_once_with(**auth_params) + @patch("trino.auth.CertificateAuthentication") + def test_auth_certificate(self, auth: Mock): + database = Mock() + + auth_params = {"cert": "/path/to/cert.pem", "key": "/path/to/key.pem"} + database.encrypted_extra = json.dumps( + {"auth_method": "certificate", "auth_params": auth_params} + ) + + params: Dict[str, Any] = {} + TrinoEngineSpec.update_encrypted_extra_params(database, params) + connect_args = params.setdefault("connect_args", {}) + self.assertEqual(connect_args.get("http_scheme"), "https") + auth.assert_called_once_with(**auth_params) + @patch("trino.auth.JWTAuthentication") def test_auth_jwt(self, auth: Mock): database = Mock() diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 4b1e6e9978047..85014f34d8f2d 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -188,7 +188,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "trino://localhost" - assert call_args[1]["connect_args"] == {"user": "gamma"} + assert call_args[1]["connect_args"]["user"] == "gamma" model = Database( database_name="test_database", @@ -203,7 +203,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): str(call_args[0][0]) == "trino://original_user:original_user_password@localhost" ) - assert call_args[1]["connect_args"] == {"user": "gamma"} + assert call_args[1]["connect_args"]["user"] == "gamma" @mock.patch("superset.models.core.create_engine") def test_impersonate_user_hive(self, mocked_create_engine): diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index 11ab176ff0b20..0f0777d0cb726 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -29,17 +29,17 @@ ( "TIMESTAMP", datetime(2022, 1, 1, 1, 23, 45, 600000), - "TIMESTAMP '2022-01-01 01:23:45.600'", + "TIMESTAMP '2022-01-01 01:23:45.600000'", ), ( "TIMESTAMP WITH TIME ZONE", datetime(2022, 1, 1, 1, 23, 45, 600000), - "TIMESTAMP '2022-01-01 01:23:45.600'", + "TIMESTAMP '2022-01-01 01:23:45.600000'", ), ( "TIMESTAMP WITH TIME ZONE", datetime(2022, 1, 1, 1, 23, 45, 600000, tzinfo=pytz.UTC), - "TIMESTAMP '2022-01-01 01:23:45.600+00:00'", + "TIMESTAMP '2022-01-01 01:23:45.600000+00:00'", ), ], )