Skip to content

Commit

Permalink
Run Black Formatting on Ocient DB Engine Spec (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexclavel-ocient authored Jan 30, 2023
1 parent 6c1d956 commit 6979aed
Showing 1 changed file with 68 additions and 57 deletions.
125 changes: 68 additions & 57 deletions superset/db_engine_specs/ocient.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,28 +31,27 @@
import threading

from superset.models.sql_lab import Query

# Ensure pyocient inherits Superset's logging level
superset_log_level = app.config['LOG_LEVEL']
superset_log_level = app.config["LOG_LEVEL"]
pyocient.logger.setLevel(superset_log_level)


# Regular expressions to catch custom errors

CONNECTION_INVALID_USERNAME_REGEX = re.compile(
"The referenced user does not exist \(User \'(?P<username>.*?)\' not found\)"
"The referenced user does not exist \(User '(?P<username>.*?)' not found\)"
)
CONNECTION_INVALID_PASSWORD_REGEX = re.compile(
'The userid/password combination was not valid \(Incorrect password for user\)'
"The userid/password combination was not valid \(Incorrect password for user\)"
)
CONNECTION_INVALID_HOSTNAME_REGEX = re.compile(
r"Unable to connect to (?P<host>.*?):(?P<port>.*?):"
r"Unable to connect to (?P<host>.*?):(?P<port>.*?):"
)
CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile(
"No database named '(?P<database>.*?)' exists"
)
CONNECTION_INVALID_PORT_ERROR = re.compile(
"Port out of range 0-65535"
)
CONNECTION_INVALID_PORT_ERROR = re.compile("Port out of range 0-65535")
INVALID_CONNECTION_STRING_REGEX = re.compile(
"An invalid connection string attribute was specified \(failed to decrypt cipher text\)"
)
Expand All @@ -67,9 +66,9 @@
)



# Custom datatype conversion functions


def _to_hex(data: bytes) -> str:
"""
Converts the bytes object into a string of hexadecimal digits.
Expand All @@ -79,28 +78,31 @@ def _to_hex(data: bytes) -> str:
"""
return data.hex()


def _polygon_to_json(polygon: _STPolygon) -> str:
"""
Converts the _STPolygon object into its JSON representation.
:param data: the polygon object
:returns: JSON representation of the polygon
"""
json_value = f'{str([[p.long, p.lat] for p in polygon.exterior])}'
json_value = f"{str([[p.long, p.lat] for p in polygon.exterior])}"
if polygon.holes:
for hole in polygon.holes:
json_value += f', {str([[p.long, p.lat] for p in hole])}'
json_value = f'[{json_value}]'
json_value += f", {str([[p.long, p.lat] for p in hole])}"
json_value = f"[{json_value}]"
return json_value


def _linestring_to_json(linestring: _STLinestring) -> str:
"""
Converts the _STLinestring object into its JSON representation.
:param data: the linestring object
:returns: JSON representation of the linestring
"""
return f'{str([[p.long, p.lat] for p in linestring.points])}'
return f"{str([[p.long, p.lat] for p in linestring.points])}"


def _point_to_comma_delimited(point: _STPoint) -> str:
"""
Expand All @@ -109,12 +111,13 @@ def _point_to_comma_delimited(point: _STPoint) -> str:
:param data: the point object
:returns: the x and y coordinates as a comma delimited string
"""
return f'{point.long}, {point.lat}'
return f"{point.long}, {point.lat}"


# Sanitization function for column values
SanitizeFunc = Callable[[Any], Any]

# Represents a pair of a column index and the sanitization function
# Represents a pair of a column index and the sanitization function
# to apply to its values.
PlacedSanitizeFunc = NamedTuple(
"PlacedSanitizeFunc",
Expand All @@ -124,11 +127,11 @@ def _point_to_comma_delimited(point: _STPoint) -> str:
],
)

# This map contains functions used to sanitize values for column types
# This map contains functions used to sanitize values for column types
# that cannot be processed natively by Superset.
#
# Superset serializes temporal objects using a custom serializer
# defined in superset/utils/core.py (#json_int_dttm_ser(...)). Other
#
# Superset serializes temporal objects using a custom serializer
# defined in superset/utils/core.py (#json_int_dttm_ser(...)). Other
# are serialized by the default JSON encoder.
_sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = {
TypeCodes.BINARY: _to_hex,
Expand All @@ -138,11 +141,12 @@ def _point_to_comma_delimited(point: _STPoint) -> str:
TypeCodes.ST_LINESTRING: _linestring_to_json,
TypeCodes.ST_POLYGON: _polygon_to_json,
}



def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]:
"""
Cleans the column value for consumption by Superset.
:param cursor: the result set cursor
:returns: the list of tuples consisting of the column index and sanitization function
"""
Expand All @@ -152,67 +156,71 @@ def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]:
if cursor.description[i][1] in _sanitized_ocient_type_codes
]


class OcientEngineSpec(BaseEngineSpec):
engine = 'ocient'
engine = "ocient"
engine_name = "Ocient"
#limit_method = LimitMethod.WRAP_SQL
# limit_method = LimitMethod.WRAP_SQL
force_column_alias_quotes = True
max_column_name_length = 30

# Store mapping of superset Query id -> Ocient ID
# These are inserted into the cache when executing the query
# They are then removed, either upon cancellation or query completion
query_id_mapping: Dict[str, str]= dict()
query_id_mapping: Dict[str, str] = dict()
query_id_mapping_lock = threading.Lock()

custom_errors : Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
CONNECTION_INVALID_USERNAME_REGEX: (
__('The username "%(username)s" does not exist.'),
SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
{},
),
CONNECTION_INVALID_PASSWORD_REGEX: (
__('The user/password combination is not valid (Incorrect password for user).'),
__(
"The user/password combination is not valid (Incorrect password for user)."
),
SupersetErrorType.CONNECTION_INVALID_PASSWORD_ERROR,
{},
),
CONNECTION_UNKNOWN_DATABASE_REGEX: (
__('Could not connect to database: "%(database)s"'),
SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR,
{}
{},
),
CONNECTION_INVALID_HOSTNAME_REGEX: (
__('Could not resolve hostname: "%(host)s".'),
SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR,
{}
{},
),
CONNECTION_INVALID_PORT_ERROR: (
__('Port out of range 0-65535'),
__("Port out of range 0-65535"),
SupersetErrorType.CONNECTION_INVALID_PORT_ERROR,
{}
{},
),
INVALID_CONNECTION_STRING_REGEX: (
__('Invalid Connection String: Expecting String of the form \'ocient://user:pass@host:port/database\'.'),
__(
"Invalid Connection String: Expecting String of the form 'ocient://user:pass@host:port/database'."
),
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
{}
),
{},
),
SYNTAX_ERROR_REGEX: (
__('Syntax Error: %(qualifier)s input "%(input)s".'),
SupersetErrorType.SYNTAX_ERROR,
{}
{},
),
TABLE_DOES_NOT_EXIST_REGEX: (
__('Table or View "%(table)s" does not exist.'),
SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR,
{}
{},
),
COLUMN_DOES_NOT_EXIST_REGEX: (
__('Invalid reference to column: "%(column)s"'),
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
{}
{},
),
}
}
_time_grain_expressions = {
None: "{col}",
"PT1S": "ROUND({col}, 'SECOND')",
Expand All @@ -224,70 +232,73 @@ class OcientEngineSpec(BaseEngineSpec):
"P0.25Y": "ROUND({col}, 'QUARTER')",
"P1Y": "ROUND({col}, 'YEAR')",
}


@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
) -> List[str]:
return sorted(inspector.get_table_names(schema))


@classmethod
def fetch_data(cls, cursor, lim=None):
try:
rows = super(OcientEngineSpec, cls).fetch_data(cursor)
except Exception as exception:
with OcientEngineSpec.query_id_mapping_lock:
del OcientEngineSpec.query_id_mapping[getattr(cursor, 'superset_query_id')]
del OcientEngineSpec.query_id_mapping[
getattr(cursor, "superset_query_id")
]
raise exception



if len(rows) > 0 and type(rows[0]).__name__ == rows:
# Peek at the schema to determine which column values, if any,
# require sanitization.
columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize(cursor)
columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize(
cursor
)

if columns_to_sanitize:
# At least 1 column has to be sanitized.
def do_nothing(x):
def do_nothing(x):
return x

sanitization_functions = [do_nothing for _ in range(len(cursor.description))]

sanitization_functions = [
do_nothing for _ in range(len(cursor.description))
]
for info in columns_to_sanitize:
sanitization_functions[info.column_index] = info.sanitize_func

# Rows from pyocient are given as NamedTuple, so we need to recreate the whole table
rows = [[sanitization_functions[i](row[i]) for i in range(len(row))] for row in rows]
rows = [
[sanitization_functions[i](row[i]) for i in range(len(row))]
for row in rows
]
return rows


@classmethod
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
# Return a Non-None value
# If None is returned, Superset will not call cancel_query
return 'DUMMY_VALUE'

return "DUMMY_VALUE"

@classmethod
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
with OcientEngineSpec.query_id_mapping_lock:
OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id
OcientEngineSpec.query_id_mapping[query.id] = cursor.query_id

# Add the query id to the cursor
setattr(cursor, "superset_query_id", query.id)
return super().handle_cursor(cursor, query, session)



@classmethod
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
with OcientEngineSpec.query_id_mapping_lock:
if query.id in OcientEngineSpec.query_id_mapping:
cursor.execute(f'CANCEL {OcientEngineSpec.query_id_mapping[query.id]}')
cursor.execute(f"CANCEL {OcientEngineSpec.query_id_mapping[query.id]}")
# Query has been cancelled, so we can safely remove the cursor from the cache
del OcientEngineSpec.query_id_mapping[query.id]

return True
# If the query is not in the cache, it must have either been cancelled elsewhere or completed
else:
return False
return False

0 comments on commit 6979aed

Please sign in to comment.