Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SET commands #673

Merged
merged 1 commit into from
Nov 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ class Tokenizer(tokens.Tokenizer):
"_UTF32": TokenType.INTRODUCER,
"_UTF8MB3": TokenType.INTRODUCER,
"_UTF8MB4": TokenType.INTRODUCER,
"@@": TokenType.SESSION_PARAMETER,
}

class Parser(parser.Parser):
Expand Down Expand Up @@ -246,6 +247,17 @@ class Parser(parser.Parser):
"WARNINGS": _show_parser("WARNINGS"),
}

SET_PARSERS = {
"GLOBAL": lambda self: self._parse_set_item_kind("GLOBAL"),
"PERSIST": lambda self: self._parse_set_item_kind("PERSIST"),
"PERSIST_ONLY": lambda self: self._parse_set_item_kind("PERSIST_ONLY"),
"SESSION": lambda self: self._parse_set_item_kind("SESSION"),
"LOCAL": lambda self: self._parse_set_item_kind("LOCAL"),
"CHARACTER SET": lambda self: self._parse_set_item_kind("CHARACTER SET"),
"CHARSET": lambda self: self._parse_set_item_kind("CHARACTER SET"),
"NAMES": lambda self: self._parse_set_item_names(),
}

PROFILE_TYPES = {
"ALL",
"BLOCK IO",
Expand Down Expand Up @@ -329,6 +341,28 @@ def _parse_oldstyle_limit(self):
offset = parts[0]
return offset, limit

def _parse_set_item_kind(self, kind):
this = self._parse_statement()

return self.expression(
exp.SetItem,
this=this,
kind=kind,
)

def _parse_set_item_names(self):
charset = self._parse_string() or self._parse_id_var()
if self._match_text("COLLATE"):
collate = self._parse_string()
else:
collate = None
return self.expression(
exp.SetItem,
this=charset,
collate=collate,
kind="NAMES",
)

class Generator(generator.Generator):
NULL_ORDERING_SUPPORTED = False

Expand Down Expand Up @@ -413,3 +447,14 @@ def _oldstyle_limit_sql(self, expression):
limit_offset = f"{offset}, {limit}" if offset else limit
return f" LIMIT {limit_offset}"
return ""

def setitem_sql(self, expression):
kind = self.sql(expression, "kind")
kind = f"{kind} " if kind else ""
this = self.sql(expression, "this")
collate = self.sql(expression, "collate")
collate = f" COLLATE {collate}" if collate else ""
return f"{kind}{this}{collate}"

def set_sql(self, expression):
return f"SET {self.expressions(expression)}"
16 changes: 16 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,18 @@ class Describe(Expression):
pass


class Set(Expression):
arg_types = {"expressions": True}


class SetItem(Expression):
arg_types = {
"this": True,
"kind": False,
"collate": False, # MySQL SET NAMES statement
}


class Show(Expression):
arg_types = {
"this": True,
Expand Down Expand Up @@ -1933,6 +1945,10 @@ class Parameter(Expression):
pass


class SessionParameter(Expression):
arg_types = {"this": True, "kind": False}


class Placeholder(Expression):
arg_types = {"this": False}

Expand Down
13 changes: 13 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,13 @@ def structkwarg_sql(self, expression):
def parameter_sql(self, expression):
return f"@{self.sql(expression, 'this')}"

def sessionparameter_sql(self, expression):
this = self.sql(expression, "this")
kind = expression.text("kind")
if kind:
kind = f"{kind}."
return f"@@{kind}{this}"

def placeholder_sql(self, expression):
return f":{expression.name}" if expression.name else "?"

Expand Down Expand Up @@ -1246,6 +1253,12 @@ def use_sql(self, expression):
def show_sql(self, expression):
return f"SHOW {self.sql(expression, 'this')}"

def setitem_sql(self, expression):
return self.sql(expression, "this")

def set_sql(self, expression):
return f"SET {self.expressions(expression)}"

def binary(self, expression, op):
return f"{self.sql(expression, 'this')} {op} {self.sql(expression, 'expression')}"

Expand Down
47 changes: 42 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class _Parser(type):
def __new__(cls, clsname, bases, attrs):
klass = super().__new__(cls, clsname, bases, attrs)
klass._show_trie = new_trie(key.split(" ") for key in klass.SHOW_PARSERS)
klass._set_trie = new_trie(key.split(" ") for key in klass.SET_PARSERS)
return klass


Expand Down Expand Up @@ -376,6 +377,7 @@ class Parser(metaclass=_Parser):
TokenType.UNCACHE: lambda self: self._parse_uncache(),
TokenType.USE: lambda self: self._parse_use(),
TokenType.SHOW: lambda self: self._parse_show(),
TokenType.SET: lambda self: self._parse_set(),
}

PRIMARY_PARSERS = {
Expand All @@ -394,6 +396,7 @@ class Parser(metaclass=_Parser):
TokenType.HEX_STRING: lambda _, token: exp.HexString(this=token.text),
TokenType.BYTE_STRING: lambda _, token: exp.ByteString(this=token.text),
TokenType.INTRODUCER: lambda self, token: self._parse_introducer(token),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
}

RANGE_PARSERS = {
Expand Down Expand Up @@ -487,6 +490,7 @@ class Parser(metaclass=_Parser):
}

SHOW_PARSERS: t.Dict[str, t.Callable] = {}
SET_PARSERS: t.Dict[str, t.Callable] = {}

MODIFIABLES = (exp.Subquery, exp.Subqueryable, exp.Table)

Expand Down Expand Up @@ -518,6 +522,7 @@ class Parser(metaclass=_Parser):
"_next",
"_prev",
"_show_trie",
"_set_trie",
)

def __init__(
Expand Down Expand Up @@ -1978,6 +1983,18 @@ def _parse_introducer(self, token):

return self.expression(exp.Identifier, this=token.text)

def _parse_session_parameter(self):
kind = None
this = self._parse_id_var() or self._parse_primary()
if self._match(TokenType.DOT):
kind = this.name
this = self._parse_var() or self._parse_primary()
return self.expression(
exp.SessionParameter,
this=this,
kind=kind,
)

def _parse_udf_kwarg(self):
this = self._parse_id_var()
kind = self._parse_types()
Expand Down Expand Up @@ -2536,8 +2553,28 @@ def _parse_use(self):
return self.expression(exp.Use, this=self._parse_id_var())

def _parse_show(self):
parser = self._find_parser(self.SHOW_PARSERS, self._show_trie)
if parser:
return parser(self)
self._advance()
return self.expression(exp.Show, this=self._prev.text.upper())

def _default_parse_set_item(self):
return self.expression(
exp.SetItem,
this=self._parse_statement(),
)

def _parse_set_item(self):
parser = self._find_parser(self.SET_PARSERS, self._set_trie)
return parser(self) if parser else self._default_parse_set_item()

def _parse_set(self):
return self.expression(exp.Set, expressions=self._parse_csv(self._parse_set_item))

def _find_parser(self, parsers, trie):
index = self._index
this = []
trie = self._show_trie
while True:
# The current token might be multiple words
key = self._curr.text.split(" ")
Expand All @@ -2547,10 +2584,10 @@ def _parse_show(self):
if result == 0:
break
if result == 2:
subparser = self.SHOW_PARSERS[" ".join(this)]
return subparser(self)

return self.expression(exp.Show, this=" ".join(this))
subparser = parsers[" ".join(this)]
return subparser
self._retreat(index)
return None

def _match(self, token_type):
if not self._curr:
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TokenType(AutoName):
ANNOTATION = auto()
DOLLAR = auto()
PARAMETER = auto()
SESSION_PARAMETER = auto()

SPACE = auto()
BREAK = auto()
Expand Down Expand Up @@ -674,7 +675,6 @@ class Tokenizer(metaclass=_Tokenizer):
TokenType.COMMIT,
TokenType.EXPLAIN,
TokenType.OPTIMIZE,
TokenType.SET,
TokenType.TRUNCATE,
TokenType.VACUUM,
TokenType.ROLLBACK,
Expand Down
9 changes: 6 additions & 3 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
class Validator(unittest.TestCase):
dialect = None

def parse_one(self, sql):
return parse_one(sql, read=self.dialect)

def validate_identity(self, sql, write_sql=None):
expression = parse_one(sql, read=self.dialect)
self.assertEqual(expression.sql(dialect=self.dialect), write_sql or sql)
expression = self.parse_one(sql)
self.assertEqual(write_sql or sql, expression.sql(dialect=self.dialect))
return expression

def validate_all(self, sql, read=None, write=None, pretty=False):
Expand All @@ -23,7 +26,7 @@ def validate_all(self, sql, read=None, write=None, pretty=False):
read (dict): Mapping of dialect -> SQL
write (dict): Mapping of dialect -> SQL
"""
expression = parse_one(sql, read=self.dialect)
expression = self.parse_one(sql)

for read_dialect, read_sql in (read or {}).items():
with self.subTest(f"{read_dialect} -> {sql}"):
Expand Down
63 changes: 63 additions & 0 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,40 @@ def test_identity(self):
self.validate_identity("SELECT TRIM(TRAILING 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM(BOTH 'bla' FROM ' XXX ')")
self.validate_identity("SELECT TRIM('bla' FROM ' XXX ')")
self.validate_identity("@@GLOBAL.max_connections")

# SET Commands
self.validate_identity("SET @var_name = expr")
self.validate_identity("SET @name = 43")
self.validate_identity("SET @total_tax = (SELECT SUM(tax) FROM taxable_transactions)")
self.validate_identity("SET GLOBAL max_connections = 1000")
self.validate_identity("SET @@GLOBAL.max_connections = 1000")
self.validate_identity("SET SESSION sql_mode = 'TRADITIONAL'")
self.validate_identity("SET LOCAL sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@SESSION.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@LOCAL.sql_mode = 'TRADITIONAL'")
self.validate_identity("SET @@sql_mode = 'TRADITIONAL'")
self.validate_identity("SET sql_mode = 'TRADITIONAL'")
self.validate_identity("SET PERSIST max_connections = 1000")
self.validate_identity("SET @@PERSIST.max_connections = 1000")
self.validate_identity("SET PERSIST_ONLY back_log = 100")
self.validate_identity("SET @@PERSIST_ONLY.back_log = 100")
self.validate_identity("SET @@SESSION.max_join_size = DEFAULT")
self.validate_identity("SET @@SESSION.max_join_size = @@GLOBAL.max_join_size")
self.validate_identity("SET @x = 1, SESSION sql_mode = ''")
self.validate_identity(
"SET GLOBAL sort_buffer_size = 1000000, SESSION sort_buffer_size = 1000000"
)
self.validate_identity(
"SET @@GLOBAL.sort_buffer_size = 1000000, @@LOCAL.sort_buffer_size = 1000000"
)
self.validate_identity("SET GLOBAL max_connections = 1000, sort_buffer_size = 1000000")
self.validate_identity("SET @@GLOBAL.sort_buffer_size = 50000, sort_buffer_size = 1000000")
self.validate_identity("SET CHARACTER SET 'utf8'")
self.validate_identity("SET CHARACTER SET DEFAULT")
self.validate_identity("SET NAMES 'utf8'")
self.validate_identity("SET NAMES DEFAULT")
self.validate_identity("SET NAMES 'utf8' COLLATE 'utf8_unicode_ci'")

def test_escape(self):
self.validate_all(
Expand Down Expand Up @@ -393,3 +427,32 @@ def test_show_tables(self):
self.assertEqual(show.text("db"), "db_name")
self.assertIsInstance(show.args["like"], exp.Literal)
self.assertEqual(show.text("like"), "%foo%")

def test_set_variable(self):
cmd = self.parse_one("SET SESSION x = 1")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "SESSION")
self.assertIsInstance(item.this, exp.EQ)
self.assertEqual(item.this.left.name, "x")
self.assertEqual(item.this.right.name, "1")

cmd = self.parse_one("SET @@GLOBAL.x = @@GLOBAL.y")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "")
self.assertIsInstance(item.this, exp.EQ)
self.assertIsInstance(item.this.left, exp.SessionParameter)
self.assertIsInstance(item.this.right, exp.SessionParameter)

cmd = self.parse_one("SET NAMES 'charset_name' COLLATE 'collation_name'")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "NAMES")
self.assertEqual(item.name, "charset_name")
self.assertEqual(item.text("collate"), "collation_name")

cmd = self.parse_one("SET CHARSET DEFAULT")
item = cmd.expressions[0]
self.assertEqual(item.text("kind"), "CHARACTER SET")
self.assertEqual(item.this.name, "DEFAULT")

cmd = self.parse_one("SET x = 1, y = 2")
self.assertEqual(len(cmd.expressions), 2)