Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: preserve the full text of hex/bin literals #1552

Merged
merged 6 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ["`"]
STRING_ESCAPES = ["\\"]
HEX_STRINGS = [("0x", ""), ("0X", "")]
BYTE_STRINGS = [("b'", "'"), ("B'", "'")]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ClickHouse(Dialect):
class Tokenizer(tokens.Tokenizer):
COMMENTS = ["--", "#", "#!", ("/*", "*/")]
IDENTIFIERS = ['"', "`"]
BIT_STRINGS = [("0b", "")]
HEX_STRINGS = [("0x", ""), ("0X", "")]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class Generator(Spark.Generator):
PARAMETER_TOKEN = "$"

class Tokenizer(Spark.Tokenizer):
HEX_STRINGS = []

SINGLE_TOKENS = {
**Spark.Tokenizer.SINGLE_TOKENS,
"$": TokenType.PARAMETER,
Expand Down
41 changes: 17 additions & 24 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,17 @@ def __new__(cls, clsname, bases, attrs):
klass.tokenizer_class._IDENTIFIERS.items()
)[0]

if (
klass.tokenizer_class._BIT_STRINGS
and exp.BitString not in klass.generator_class.TRANSFORMS
):
bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.BitString
] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}"
if (
klass.tokenizer_class._HEX_STRINGS
and exp.HexString not in klass.generator_class.TRANSFORMS
):
hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.HexString
] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}"
if (
klass.tokenizer_class._BYTE_STRINGS
and exp.ByteString not in klass.generator_class.TRANSFORMS
):
be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0]
klass.generator_class.TRANSFORMS[
exp.ByteString
] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}"
klass.bit_start, klass.bit_end = seq_get(
list(klass.tokenizer_class._BIT_STRINGS.items()), 0
) or (None, None)

klass.hex_start, klass.hex_end = seq_get(
list(klass.tokenizer_class._HEX_STRINGS.items()), 0
) or (None, None)

klass.byte_start, klass.byte_end = seq_get(
list(klass.tokenizer_class._BYTE_STRINGS.items()), 0
) or (None, None)

return klass

Expand Down Expand Up @@ -199,6 +186,12 @@ def generator(self, **opts) -> Generator:
**{
"quote_start": self.quote_start,
"quote_end": self.quote_end,
"bit_start": self.bit_start,
"bit_end": self.bit_end,
"hex_start": self.hex_start,
"hex_end": self.hex_end,
"byte_start": self.byte_start,
"byte_end": self.byte_end,
"identifier_start": self.identifier_start,
"identifier_end": self.identifier_end,
"string_escape": self.tokenizer_class.STRING_ESCAPES[0],
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]:
return this

class Tokenizer(Postgres.Tokenizer):
BIT_STRINGS = []
HEX_STRINGS = []
STRING_ESCAPES = ["\\"]

KEYWORDS = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression:
class Tokenizer(tokens.Tokenizer):
QUOTES = ["'", "$$"]
STRING_ESCAPES = ["\\", "'"]
HEX_STRINGS = [("x'", "'"), ("X'", "'")]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,8 @@ class TSQL(Dialect):

class Tokenizer(tokens.Tokenizer):
IDENTIFIERS = ['"', ("[", "]")]

QUOTES = ["'", '"']
HEX_STRINGS = [("0x", ""), ("0X", "")]

KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
39 changes: 36 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class Generator:
quote_end (str): specifies which ending character to use to delimit quotes. Default: '.
identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ".
identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ".
bit_start (str): specifies which starting character to use to delimit bit literals. Default: None.
bit_end (str): specifies which ending character to use to delimit bit literals. Default: None.
hex_start (str): specifies which starting character to use to delimit hex literals. Default: None.
hex_end (str): specifies which ending character to use to delimit hex literals. Default: None.
byte_start (str): specifies which starting character to use to delimit byte literals. Default: None.
byte_end (str): specifies which ending character to use to delimit byte literals. Default: None.
identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always.
normalize (bool): if set to True all identifiers will lower cased
string_escape (str): specifies a string escape character. Default: '.
Expand Down Expand Up @@ -227,6 +233,12 @@ class Generator:
"quote_end",
"identifier_start",
"identifier_end",
"bit_start",
"bit_end",
"hex_start",
"hex_end",
"byte_start",
"byte_end",
"identify",
"normalize",
"string_escape",
Expand Down Expand Up @@ -258,6 +270,12 @@ def __init__(
quote_end=None,
identifier_start=None,
identifier_end=None,
bit_start=None,
bit_end=None,
hex_start=None,
hex_end=None,
byte_start=None,
byte_end=None,
identify=False,
normalize=False,
string_escape=None,
Expand All @@ -284,6 +302,12 @@ def __init__(
self.quote_end = quote_end or "'"
self.identifier_start = identifier_start or '"'
self.identifier_end = identifier_end or '"'
self.bit_start = bit_start
self.bit_end = bit_end
self.hex_start = hex_start
self.hex_end = hex_end
self.byte_start = byte_start
self.byte_end = byte_end
self.identify = identify
self.normalize = normalize
self.string_escape = string_escape or "'"
Expand Down Expand Up @@ -716,13 +740,22 @@ def tablealias_sql(self, expression: exp.TableAlias) -> str:
return f"{alias}{columns}"

def bitstring_sql(self, expression: exp.BitString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.bit_start:
return f"{self.bit_start}{this}{self.bit_end}"
return f"{int(this, 2)}"

def hexstring_sql(self, expression: exp.HexString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.hex_start:
return f"{self.hex_start}{this}{self.hex_end}"
return f"{int(this, 16)}"

def bytestring_sql(self, expression: exp.ByteString) -> str:
return self.sql(expression, "this")
this = self.sql(expression, "this")
if self.byte_start:
return f"{self.byte_start}{this}{self.byte_end}"
return this

def datatype_sql(self, expression: exp.DataType) -> str:
type_value = expression.this
Expand Down
21 changes: 12 additions & 9 deletions sqlglot/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,9 +988,9 @@ def _scan_number(self) -> None:
if self._char == "0":
peek = self._peek.upper()
if peek == "B":
return self._scan_bits()
return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER)
elif peek == "X":
return self._scan_hex()
return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER)

decimal = False
scientific = 0
Expand Down Expand Up @@ -1033,15 +1033,19 @@ def _scan_bits(self) -> None:
self._advance()
value = self._extract_value()
try:
self._add(TokenType.BIT_STRING, f"{int(value, 2)}")
# If `value` can't be converted to a binary, fallback to tokenizing it as an identifier
int(value, 2)
self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b
tobymao marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
self._add(TokenType.IDENTIFIER)

def _scan_hex(self) -> None:
self._advance()
value = self._extract_value()
try:
self._add(TokenType.HEX_STRING, f"{int(value, 16)}")
# If `value` can't be converted to a hex, fallback to tokenizing it as an identifier
int(value, 16)
self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
self._add(TokenType.IDENTIFIER)

Expand All @@ -1066,7 +1070,7 @@ def _scan_string(self, quote: str) -> bool:
self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text)
return True

# X'1234, b'0110', E'\\\\\' etc.
# X'1234', b'0110', E'\\\\\' etc.
def _scan_formatted_string(self, string_start: str) -> bool:
if string_start in self._HEX_STRINGS:
delimiters = self._HEX_STRINGS
Expand All @@ -1087,16 +1091,15 @@ def _scan_formatted_string(self, string_start: str) -> bool:
string_end = delimiters[string_start]
text = self._extract_string(string_end)

if base is None:
self._add(token_type, text)
else:
if base:
try:
self._add(token_type, f"{int(text, base)}")
int(text, base)
except:
raise RuntimeError(
f"Numeric string contains invalid characters from {self._line}:{self._start}"
)

self._add(token_type, text)
return True

def _scan_identifier(self, identifier_end: str) -> None:
Expand Down
7 changes: 4 additions & 3 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ class TestBigQuery(Validator):
dialect = "bigquery"

def test_bigquery(self):
self.validate_identity(
"""CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
)
self.validate_identity("SELECT b'abc'")
self.validate_identity("""SELECT * FROM UNNEST(ARRAY<STRUCT<x INT64>>[1, 2])""")
self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b")
self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)")
self.validate_identity("SELECT STRUCT<ARRAY<STRING>>(['2023-01-17'])")
self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))")
self.validate_identity(
"""CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)"""
)
self.validate_identity(
"SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))"
)
Expand Down
8 changes: 5 additions & 3 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,20 @@ def test_duckdb(self):
self.validate_identity("SELECT {'a': 1} AS x")
self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x")
self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}")
self.validate_identity(
"SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
)
self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}")
self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)")
self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)")
self.validate_identity("ATTACH DATABASE ':memory:' AS new_database")
self.validate_identity(
"SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}"
)
self.validate_identity(
"SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)"
)

self.validate_all("0b1010", write={"": "0 AS b1010"})
self.validate_all("0x1010", write={"": "0 AS x1010"})
self.validate_all(
"""SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""",
write={
Expand Down
Loading