Skip to content

Commit

Permalink
Refactor validation decorators (pinterest#1354)
Browse files Browse the repository at this point in the history
* Refactor validation decorators
  • Loading branch information
kgopal492 authored and aidenprice committed Jan 3, 2024
1 parent 1c9aa1d commit 974abff
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 102 deletions.
11 changes: 9 additions & 2 deletions querybook/server/lib/elasticsearch/search_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def get_column_name_suggestion(
return get_matching_objects(search_query, ES_CONFIG["tables"]["index_name"], True)


def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]:
def get_table_name_suggestion(
fuzzy_table_name: str, metastore_id: int
) -> Tuple[Dict, int]:
"""Given an invalid table name use fuzzy search to search the correctly-spelled table name"""

schema_name, fuzzy_name = None, fuzzy_table_name
Expand All @@ -229,7 +231,12 @@ def get_table_name_suggestion(fuzzy_table_name: str) -> Tuple[Dict, int]:
{
"match": {
"name": {"query": fuzzy_name, "fuzziness": "AUTO"},
}
},
},
{
"match": {
"metastore_id": metastore_id,
},
},
]
if schema_name:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Any, Dict, List, Tuple
from typing import List, Tuple
from sqlglot import Tokenizer
from sqlglot.tokens import Token

Expand All @@ -8,13 +8,12 @@
QueryValidationResultObjectType,
QueryValidationSeverity,
)
from lib.query_analysis.validation.base_query_validator import BaseQueryValidator

from lib.query_analysis.validation.decorators.base_validation_decorator import (
BaseValidationDecorator,
)

class BaseSQLGlotValidator(BaseQueryValidator):
def __init__(self, name: str = "", config: Dict[str, Any] = {}):
super(BaseSQLGlotValidator, self).__init__(name, config)

class BaseSQLGlotValidationDecorator(BaseValidationDecorator):
@property
@abstractmethod
def message(self) -> str:
Expand Down Expand Up @@ -65,7 +64,6 @@ def _get_query_validation_result(
suggestion=suggestion,
)

@abstractmethod
def validate(
self,
query: str,
Expand All @@ -74,20 +72,8 @@ def validate(
raw_tokens: List[Token] = None,
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()


class BaseSQLGlotDecorator(BaseSQLGlotValidator):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

def validate(
self,
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
**kwargs,
):
"""Override this method to add suggestions to validation results"""
return self._validator.validate(query, uid, engine_id, **kwargs)
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
return super(BaseSQLGlotValidationDecorator, self).validate(
query, uid, engine_id, raw_tokens=raw_tokens, **kwargs
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABCMeta, abstractmethod
from typing import List

from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
)
from lib.query_analysis.validation.base_query_validator import BaseQueryValidator


class BaseValidationDecorator(metaclass=ABCMeta):
def __init__(self, validator: BaseQueryValidator):
self._validator = validator

@abstractmethod
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
**kwargs,
) -> List[QueryValidationResult]:
raise NotImplementedError()

def validate(
self,
query: str,
uid: int,
engine_id: int,
**kwargs,
) -> List[QueryValidationResult]:
validation_results = self._validator.validate(query, uid, engine_id, **kwargs)
return self.decorate_validation_results(
validation_results, query, uid, engine_id, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,14 @@
from lib.query_analysis.lineage import process_query
from lib.query_analysis.validation.base_query_validator import (
QueryValidationResult,
QueryValidationSeverity,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import (
BaseValidationDecorator,
)
from logic.admin import get_query_engine_by_id
from logic import admin as admin_logic


class BaseColumnNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

class BaseColumnNameSuggester(BaseValidationDecorator):
@abstractmethod
def get_column_name_from_error(
self, validation_result: QueryValidationResult
Expand All @@ -32,7 +23,7 @@ def get_column_name_from_error(
raise NotImplementedError()

def _get_tables_in_query(self, query: str, engine_id: int) -> List[str]:
engine = get_query_engine_by_id(engine_id)
engine = admin_logic.get_query_engine_by_id(engine_id)
tables_per_statement, _ = process_query(query, language=engine.language)
return list(chain.from_iterable(tables_per_statement))

Expand Down Expand Up @@ -69,49 +60,43 @@ def _suggest_column_name_if_needed(
validation_result.start_ch + len(fuzzy_column_name) - 1
)

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
tables_in_query = self._get_tables_in_query(query, engine_id)
for result in validation_results:
self._suggest_column_name_if_needed(result, tables_in_query)
return validation_results


class BaseTableNameSuggester(BaseSQLGlotDecorator):
@property
def severity(self):
return QueryValidationSeverity.WARNING # Unused, severity is not changed

@property
def message(self):
return "" # Unused, message is not changed

class BaseTableNameSuggester(BaseValidationDecorator):
@abstractmethod
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
"""Returns invalid table name if the validation result is a table name error, otherwise
returns None"""
raise NotImplementedError()

def _suggest_table_name_if_needed(
self, validation_result: QueryValidationResult
self,
validation_result: QueryValidationResult,
engine_id: int,
) -> Optional[str]:
"""Takes validation result and tables in query to update validation result to provide table
name suggestion"""
fuzzy_table_name = self.get_full_table_name_from_error(validation_result)
if not fuzzy_table_name:
return
results, count = search_table.get_table_name_suggestion(fuzzy_table_name)
metastore_id = admin_logic.get_query_metastore_id_by_engine_id(engine_id)
if metastore_id is None:
return
results, count = search_table.get_table_name_suggestion(
fuzzy_table_name, metastore_id
)
if count > 0:
table_result = results[0] # Get top match
table_suggestion = f"{table_result['schema']}.{table_result['name']}"
Expand All @@ -121,19 +106,14 @@ def _suggest_table_name_if_needed(
validation_result.start_ch + len(fuzzy_table_name) - 1
)

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[QueryValidationResult] = None,
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for result in validation_results:
self._suggest_table_name_if_needed(result)
self._suggest_table_name_if_needed(result, engine_id)
return validation_results
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,16 @@
from lib.query_analysis.validation.validators.presto_explain_validator import (
PrestoExplainValidator,
)
from lib.query_analysis.validation.validators.base_sqlglot_validator import (
BaseSQLGlotDecorator,
from lib.query_analysis.validation.decorators.base_sqlglot_validation_decorator import (
BaseSQLGlotValidationDecorator,
)
from lib.query_analysis.validation.validators.metadata_suggesters import (
from lib.query_analysis.validation.decorators.metadata_decorators import (
BaseColumnNameSuggester,
BaseTableNameSuggester,
)


class BasePrestoSQLGlotDecorator(BaseSQLGlotDecorator):
def languages(self):
return ["presto", "trino"]

class BasePrestoSQLGlotDecorator(BaseSQLGlotValidationDecorator):
@property
def tokenizer(self) -> Tokenizer:
return Trino.Tokenizer()
Expand All @@ -39,19 +36,15 @@ def message(self):
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)
validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for i, token in enumerate(raw_tokens):
if token.token_type == TokenType.UNION:
if (
Expand All @@ -77,20 +70,15 @@ def message(self):
def severity(self) -> str:
return QueryValidationSeverity.WARNING

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)
for i, token in enumerate(raw_tokens):
if (
i < len(raw_tokens) - 2
Expand Down Expand Up @@ -125,21 +113,15 @@ def _get_regexp_like_suggestion(self, column_name: str, like_strings: List[str])
]
return f"REGEXP_LIKE({column_name}, '{'|'.join(sanitized_like_strings)}')"

def validate(
def decorate_validation_results(
self,
validation_results: List[QueryValidationResult],
query: str,
uid: int,
engine_id: int,
raw_tokens: List[Token] = None,
raw_tokens: List[Token] = [],
**kwargs,
) -> List[QueryValidationResult]:
if raw_tokens is None:
raw_tokens = self._tokenize_query(query)

validation_results = self._validator.validate(
query, uid, engine_id, raw_tokens=raw_tokens
)

start_column_token = None
like_strings = []
token_idx = 0
Expand Down Expand Up @@ -203,15 +185,15 @@ def validate(
return validation_results


class PrestoColumnNameSuggester(BasePrestoSQLGlotDecorator, BaseColumnNameSuggester):
class PrestoColumnNameSuggester(BaseColumnNameSuggester):
def get_column_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(
r"line \d+:\d+: Column '(.*)' cannot be resolved", validation_result.message
)
return regex_result.groups()[0] if regex_result else None


class PrestoTableNameSuggester(BasePrestoSQLGlotDecorator, BaseTableNameSuggester):
class PrestoTableNameSuggester(BaseTableNameSuggester):
def get_full_table_name_from_error(self, validation_result: QueryValidationResult):
regex_result = re.match(
r"line \d+:\d+: Table '(.*)' does not exist", validation_result.message
Expand Down
6 changes: 6 additions & 0 deletions querybook/server/logic/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def get_query_metastore_by_name(name, session=None):
return session.query(QueryMetastore).filter(QueryMetastore.name == name).first()


@with_session
def get_query_metastore_id_by_engine_id(engine_id: int, session=None):
query_engine = get_query_engine_by_id(engine_id, session=session)
return query_engine.metastore_id if query_engine else None


@with_session
def get_all_query_metastore(session=None):
return session.query(QueryMetastore).all()
Expand Down
Loading

0 comments on commit 974abff

Please sign in to comment.