Skip to content

Commit

Permalink
Merge branch 'master' into check_with
Browse files Browse the repository at this point in the history
  • Loading branch information
macbre authored Jun 10, 2021
2 parents a2ed841 + bbe8d25 commit 7336e45
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 4 deletions.
19 changes: 16 additions & 3 deletions sql_metadata/parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module provides SQL query parsing functions
"""
import logging
import re
from typing import Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -29,6 +30,8 @@ class Parser: # pylint: disable=R0902
"""

def __init__(self, sql: str = "") -> None:
self._logger = logging.getLogger(self.__class__.__name__)

self._raw_query = sql
self._query = self._preprocess_query()
self._query_type = None
Expand Down Expand Up @@ -85,12 +88,22 @@ def query_type(self) -> str:
return self._query_type
if not self._tokens:
_ = self.tokens
if self._tokens[0].normalized in ["CREATE", "ALTER"]:
switch = self._tokens[0].normalized + self._tokens[1].normalized

# remove comment tokens to not confuse the logic below (see #163)
tokens: List[SQLToken] = list(
filter(lambda token: not token.is_comment, self._tokens or [])
)

if not tokens:
raise ValueError("Empty queries are not supported!")

if tokens[0].normalized in ["CREATE", "ALTER"]:
switch = tokens[0].normalized + tokens[1].normalized
else:
switch = self._tokens[0].normalized
switch = tokens[0].normalized
self._query_type = SUPPORTED_QUERY_TYPES.get(switch, "UNSUPPORTED")
if self._query_type == "UNSUPPORTED":
self._logger.error("Not supported query type: %s", self._raw_query)
raise ValueError("Not supported query type!")
return self._query_type

Expand Down
12 changes: 12 additions & 0 deletions test/test_getting_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,15 @@ def test_get_tables_with_leading_digits():
assert ["0020_big_table"] == Parser(
"SELECT t.val as value, count(*) FROM 0020_big_table"
).tables


def test_insert_ignore_with_comments():
queries = [
"INSERT IGNORE /* foo */ INTO bar VALUES (1, '123', '2017-01-01');",
"/* foo */ INSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
"-- foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
"# foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');",
]

for query in queries:
assert ["bar"] == Parser(query).tables
50 changes: 50 additions & 0 deletions test/test_query_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

from sql_metadata import Parser


def test_insert_query():
queries = [
"INSERT IGNORE /* foo */ INTO bar VALUES (1, '123', '2017-01-01');",
"/* foo */ INSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
"-- foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');"
"# foo\nINSERT IGNORE INTO bar VALUES (1, '123', '2017-01-01');",
]

for query in queries:
assert "INSERT" == Parser(query).query_type


def test_select_query():
queries = [
"SELECT /* foo */ foo FROM bar",
"/* foo */ SELECT foo FROM bar"
"-- foo\nSELECT foo FROM bar"
"# foo\nSELECT foo FROM bar",
]

for query in queries:
assert "SELECT" == Parser(query).query_type


def test_unsupported_query():
queries = [
"FOO BAR",
"DO SOMETHING",
]

for query in queries:
with pytest.raises(ValueError) as ex:
_ = Parser(query).query_type

assert "Not supported query type!" in str(ex.value)


def test_empty_query():
queries = ["", "/* empty query */"]

for query in queries:
with pytest.raises(ValueError) as ex:
_ = Parser(query).query_type

assert "Empty queries are not supported!" in str(ex.value)
3 changes: 2 additions & 1 deletion test/test_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def test_getting_values():
}

parser = Parser(
"INSERT IGNORE INTO `0070_insert_ignore_table` VALUES (9, 2.15, '123', '2017-01-01');"
"/* method */ INSERT IGNORE INTO `0070_insert_ignore_table` VALUES (9, 2.15, '123', '2017-01-01');"
)
assert parser.query_type == "INSERT"
assert parser.values == [9, 2.15, "123", "2017-01-01"]
assert parser.values_dict == {
"column_1": 9,
Expand Down

0 comments on commit 7336e45

Please sign in to comment.