From b83d1d59e3667b2cabdcdd231068b8f81df96d13 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Thu, 4 May 2023 22:50:07 +0300 Subject: [PATCH 1/6] Feat: refactor hex/bit literals so that their text is preserved --- sqlglot/dialects/dialect.py | 4 ++-- sqlglot/dialects/oracle.py | 2 ++ sqlglot/tokens.py | 17 ++++++++++------- tests/dialects/test_mysql.py | 11 +++++++++++ 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 3b0f9dcf26..1dea96ab8a 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -77,7 +77,7 @@ def __new__(cls, clsname, bases, attrs): bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.BitString - ] = lambda self, e: f"{bs_start}{int(self.sql(e, 'this')):b}{bs_end}" + ] = lambda self, e: f"{bs_start}{self.sql(e, 'this')}{bs_end}" if ( klass.tokenizer_class._HEX_STRINGS and exp.HexString not in klass.generator_class.TRANSFORMS @@ -85,7 +85,7 @@ def __new__(cls, clsname, bases, attrs): hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] klass.generator_class.TRANSFORMS[ exp.HexString - ] = lambda self, e: f"{hs_start}{int(self.sql(e, 'this')):X}{hs_end}" + ] = lambda self, e: f"{hs_start}{self.sql(e, 'this')}{hs_end}" if ( klass.tokenizer_class._BYTE_STRINGS and exp.ByteString not in klass.generator_class.TRANSFORMS diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index c8af1c6c08..2347308c59 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -123,10 +123,12 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore + exp.BitString: lambda self, e: f"{int(e.this, 2)}", exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), exp.Group: transforms.preprocess([transforms.unalias_group]), + exp.HexString: lambda self, e: f"{int(e.this, 16)}", exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, exp.IfNull: rename_func("NVL"), diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 64c1f92349..990cbb5c49 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -1033,7 +1033,9 @@ def _scan_bits(self) -> None: self._advance() value = self._extract_value() try: - self._add(TokenType.BIT_STRING, f"{int(value, 2)}") + # If `value` can't be converted to a binary, fallback to tokenizing it as an identifier + int(value, 2) + self._add(TokenType.BIT_STRING, value[2:]) # Drop the 0b except ValueError: self._add(TokenType.IDENTIFIER) @@ -1041,7 +1043,9 @@ def _scan_hex(self) -> None: self._advance() value = self._extract_value() try: - self._add(TokenType.HEX_STRING, f"{int(value, 16)}") + # If `value` can't be converted to a hex, fallback to tokenizing it as an identifier + int(value, 16) + self._add(TokenType.HEX_STRING, value[2:]) # Drop the 0x except ValueError: self._add(TokenType.IDENTIFIER) @@ -1066,7 +1070,7 @@ def _scan_string(self, quote: str) -> bool: self._add(TokenType.NATIONAL if quote[0].upper() == "N" else TokenType.STRING, text) return True - # X'1234, b'0110', E'\\\\\' etc. + # X'1234', b'0110', E'\\\\\' etc. def _scan_formatted_string(self, string_start: str) -> bool: if string_start in self._HEX_STRINGS: delimiters = self._HEX_STRINGS @@ -1087,16 +1091,15 @@ def _scan_formatted_string(self, string_start: str) -> bool: string_end = delimiters[string_start] text = self._extract_string(string_end) - if base is None: - self._add(token_type, text) - else: + if base: try: - self._add(token_type, f"{int(text, base)}") + int(text, base) except: raise RuntimeError( f"Numeric string contains invalid characters from {self._line}:{self._start}" ) + self._add(token_type, text) return True def _scan_identifier(self, identifier_end: str) -> None: diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 524d95e5c0..2ee23ca7e8 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -161,6 +161,17 @@ def test_hexadecimal_literal(self): "oracle": "SELECT 204", }, ) + self.validate_all( + "SELECT 0x0000CC", + write={ + "mysql": "SELECT x'0000CC'", + "sqlite": "SELECT x'0000CC'", + "spark": "SELECT X'0000CC'", + "trino": "SELECT X'0000CC'", + "bigquery": "SELECT 0x0000CC", + "oracle": "SELECT 204", + }, + ) self.validate_all( "SELECT X'1A'", write={ From 272dfb3ed6ed96f08e5fdfcabbdc00fcd5aab1ec Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 5 May 2023 18:11:28 +0300 Subject: [PATCH 2/6] Convert hex/bin values by default, make meta class attr setting more robust --- sqlglot/dialects/bigquery.py | 1 + sqlglot/dialects/clickhouse.py | 2 + sqlglot/dialects/databricks.py | 2 + sqlglot/dialects/dialect.py | 56 ++++++++------- sqlglot/dialects/oracle.py | 2 - sqlglot/dialects/redshift.py | 2 + sqlglot/dialects/snowflake.py | 1 + sqlglot/dialects/tsql.py | 2 +- sqlglot/generator.py | 9 --- tests/dialects/test_bigquery.py | 7 +- tests/dialects/test_mysql.py | 122 +++++++++++++++++++------------- 11 files changed, 117 insertions(+), 89 deletions(-) diff --git a/sqlglot/dialects/bigquery.py b/sqlglot/dialects/bigquery.py index dcd1326034..7f6ad8f7c8 100644 --- a/sqlglot/dialects/bigquery.py +++ b/sqlglot/dialects/bigquery.py @@ -128,6 +128,7 @@ class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ["`"] STRING_ESCAPES = ["\\"] HEX_STRINGS = [("0x", ""), ("0X", "")] + BYTE_STRINGS = [("b'", "'"), ("B'", "'")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/clickhouse.py b/sqlglot/dialects/clickhouse.py index e91b0bf7a6..bb6a4a4fcd 100644 --- a/sqlglot/dialects/clickhouse.py +++ b/sqlglot/dialects/clickhouse.py @@ -22,6 +22,8 @@ class ClickHouse(Dialect): class Tokenizer(tokens.Tokenizer): COMMENTS = ["--", "#", "#!", ("/*", "*/")] IDENTIFIERS = ['"', "`"] + BIT_STRINGS = [("0b", "")] + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/databricks.py b/sqlglot/dialects/databricks.py index 62664e1ba5..51112a0f0e 100644 --- a/sqlglot/dialects/databricks.py +++ b/sqlglot/dialects/databricks.py @@ -41,6 +41,8 @@ class Generator(Spark.Generator): PARAMETER_TOKEN = "$" class Tokenizer(Spark.Tokenizer): + HEX_STRINGS = [] + SINGLE_TOKENS = { **Spark.Tokenizer.SINGLE_TOKENS, "$": TokenType.PARAMETER, diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 1dea96ab8a..b776c2574b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -70,30 +70,38 @@ def __new__(cls, clsname, bases, attrs): klass.tokenizer_class._IDENTIFIERS.items() )[0] - if ( - klass.tokenizer_class._BIT_STRINGS - and exp.BitString not in klass.generator_class.TRANSFORMS - ): - bs_start, bs_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.BitString - ] = lambda self, e: f"{bs_start}{self.sql(e, 'this')}{bs_end}" - if ( - klass.tokenizer_class._HEX_STRINGS - and exp.HexString not in klass.generator_class.TRANSFORMS - ): - hs_start, hs_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.HexString - ] = lambda self, e: f"{hs_start}{self.sql(e, 'this')}{hs_end}" - if ( - klass.tokenizer_class._BYTE_STRINGS - and exp.ByteString not in klass.generator_class.TRANSFORMS - ): - be_start, be_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] - klass.generator_class.TRANSFORMS[ - exp.ByteString - ] = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" + tokenizer_dict = klass.tokenizer_class.__dict__ + tokenizer_transforms = klass.generator_class.TRANSFORMS + + if exp.BitString not in tokenizer_transforms and "bitstring_sql" not in tokenizer_dict: + bit_strings = tokenizer_dict.get("_BIT_STRINGS") + if bit_strings: + bs_start, bs_end = list(bit_strings.items())[0] + bitstring_sql = lambda self, e: f"{bs_start}{self.sql(e, 'this')}{bs_end}" + else: + bitstring_sql = lambda self, e: f"{int(self.sql(e, 'this'), 2)}" + + setattr(klass.generator_class, "bitstring_sql", bitstring_sql) + + if exp.HexString not in tokenizer_transforms and "hexstring_sql" not in tokenizer_dict: + hex_strings = tokenizer_dict.get("_HEX_STRINGS") + if hex_strings: + hs_start, hs_end = list(hex_strings.items())[0] + hexstring_sql = lambda self, e: f"{hs_start}{self.sql(e, 'this')}{hs_end}" + else: + hexstring_sql = lambda self, e: f"{int(self.sql(e, 'this'), 16)}" + + setattr(klass.generator_class, "hexstring_sql", hexstring_sql) + + if exp.ByteString not in tokenizer_transforms and "bytestring_sql" not in tokenizer_dict: + byte_strings = tokenizer_dict.get("_BYTE_STRINGS") + if byte_strings: + be_start, be_end = list(byte_strings.items())[0] + bytestring_sql = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" + else: + bytestring_sql = lambda self, e: self.sql(e, "this") + + setattr(klass.generator_class, "bytestring_sql", bytestring_sql) return klass diff --git a/sqlglot/dialects/oracle.py b/sqlglot/dialects/oracle.py index 2347308c59..c8af1c6c08 100644 --- a/sqlglot/dialects/oracle.py +++ b/sqlglot/dialects/oracle.py @@ -123,12 +123,10 @@ class Generator(generator.Generator): TRANSFORMS = { **generator.Generator.TRANSFORMS, # type: ignore - exp.BitString: lambda self, e: f"{int(e.this, 2)}", exp.DateStrToDate: lambda self, e: self.func( "TO_DATE", e.this, exp.Literal.string("YYYY-MM-DD") ), exp.Group: transforms.preprocess([transforms.unalias_group]), - exp.HexString: lambda self, e: f"{int(e.this, 16)}", exp.Hint: lambda self, e: f" /*+ {self.expressions(e).strip()} */", exp.ILike: no_ilike_sql, exp.IfNull: rename_func("NVL"), diff --git a/sqlglot/dialects/redshift.py b/sqlglot/dialects/redshift.py index 4568a412bf..1b7cf3175a 100644 --- a/sqlglot/dialects/redshift.py +++ b/sqlglot/dialects/redshift.py @@ -52,6 +52,8 @@ def _parse_types(self, check_func: bool = False) -> t.Optional[exp.Expression]: return this class Tokenizer(Postgres.Tokenizer): + BIT_STRINGS = [] + HEX_STRINGS = [] STRING_ESCAPES = ["\\"] KEYWORDS = { diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 406aa23e4f..70dcaa90f2 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -252,6 +252,7 @@ def _parse_alter_table_set_tag(self, unset: bool = False) -> exp.Expression: class Tokenizer(tokens.Tokenizer): QUOTES = ["'", "$$"] STRING_ESCAPES = ["\\", "'"] + HEX_STRINGS = [("x'", "'"), ("X'", "'")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index cc0a196a3e..03de99cd06 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -259,8 +259,8 @@ class TSQL(Dialect): class Tokenizer(tokens.Tokenizer): IDENTIFIERS = ['"', ("[", "]")] - QUOTES = ["'", '"'] + HEX_STRINGS = [("0x", ""), ("0X", "")] KEYWORDS = { **tokens.Tokenizer.KEYWORDS, diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 3657bbee4f..a714e1d897 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -715,15 +715,6 @@ def tablealias_sql(self, expression: exp.TableAlias) -> str: columns = f"({columns})" if columns else "" return f"{alias}{columns}" - def bitstring_sql(self, expression: exp.BitString) -> str: - return self.sql(expression, "this") - - def hexstring_sql(self, expression: exp.HexString) -> str: - return self.sql(expression, "this") - - def bytestring_sql(self, expression: exp.ByteString) -> str: - return self.sql(expression, "this") - def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 80edcd0a0c..66ecde49f2 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -6,14 +6,15 @@ class TestBigQuery(Validator): dialect = "bigquery" def test_bigquery(self): - self.validate_identity( - """CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)""" - ) + self.validate_identity("SELECT b'abc'") self.validate_identity("""SELECT * FROM UNNEST(ARRAY>[1, 2])""") self.validate_identity("SELECT AS STRUCT 1 AS a, 2 AS b") self.validate_identity("SELECT AS VALUE STRUCT(1 AS a, 2 AS b)") self.validate_identity("SELECT STRUCT>(['2023-01-17'])") self.validate_identity("SELECT * FROM q UNPIVOT(values FOR quarter IN (b, c))") + self.validate_identity( + """CREATE TABLE x (a STRING OPTIONS (description='x')) OPTIONS (table_expiration_days=1)""" + ) self.validate_identity( "SELECT * FROM (SELECT * FROM `t`) AS a UNPIVOT((c) FOR c_name IN (v1, v2))" ) diff --git a/tests/dialects/test_mysql.py b/tests/dialects/test_mysql.py index 2ee23ca7e8..ef647f704f 100644 --- a/tests/dialects/test_mysql.py +++ b/tests/dialects/test_mysql.py @@ -150,58 +150,80 @@ def test_introducers(self): ) def test_hexadecimal_literal(self): - self.validate_all( - "SELECT 0xCC", - write={ - "mysql": "SELECT x'CC'", - "sqlite": "SELECT x'CC'", - "spark": "SELECT X'CC'", - "trino": "SELECT X'CC'", - "bigquery": "SELECT 0xCC", - "oracle": "SELECT 204", - }, - ) - self.validate_all( - "SELECT 0x0000CC", - write={ - "mysql": "SELECT x'0000CC'", - "sqlite": "SELECT x'0000CC'", - "spark": "SELECT X'0000CC'", - "trino": "SELECT X'0000CC'", - "bigquery": "SELECT 0x0000CC", - "oracle": "SELECT 204", - }, - ) - self.validate_all( - "SELECT X'1A'", - write={ - "mysql": "SELECT x'1A'", - }, - ) - self.validate_all( - "SELECT 0xz", - write={ - "mysql": "SELECT `0xz`", - }, - ) + write_CC = { + "bigquery": "SELECT 0xCC", + "clickhouse": "SELECT 0xCC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'CC'", + "spark": "SELECT X'CC'", + "sqlite": "SELECT x'CC'", + "starrocks": "SELECT x'CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'CC'", + "tsql": "SELECT 0xCC", + } + write_CC_with_leading_zeros = { + "bigquery": "SELECT 0x0000CC", + "clickhouse": "SELECT 0x0000CC", + "databricks": "SELECT 204", + "drill": "SELECT 204", + "duckdb": "SELECT 204", + "hive": "SELECT 204", + "mysql": "SELECT x'0000CC'", + "oracle": "SELECT 204", + "postgres": "SELECT x'0000CC'", + "presto": "SELECT 204", + "redshift": "SELECT 204", + "snowflake": "SELECT x'0000CC'", + "spark": "SELECT X'0000CC'", + "sqlite": "SELECT x'0000CC'", + "starrocks": "SELECT x'0000CC'", + "tableau": "SELECT 204", + "teradata": "SELECT 204", + "trino": "SELECT X'0000CC'", + "tsql": "SELECT 0x0000CC", + } + + self.validate_all("SELECT X'1A'", write={"mysql": "SELECT x'1A'"}) + self.validate_all("SELECT 0xz", write={"mysql": "SELECT `0xz`"}) + self.validate_all("SELECT 0xCC", write=write_CC) + self.validate_all("SELECT x'CC'", write=write_CC) + self.validate_all("SELECT 0x0000CC", write=write_CC_with_leading_zeros) + self.validate_all("SELECT x'0000CC'", write=write_CC_with_leading_zeros) def test_bits_literal(self): - self.validate_all( - "SELECT 0b1011", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) - self.validate_all( - "SELECT B'1011'", - write={ - "mysql": "SELECT b'1011'", - "postgres": "SELECT b'1011'", - "oracle": "SELECT 11", - }, - ) + write_1011 = { + "bigquery": "SELECT 11", + "clickhouse": "SELECT 0b1011", + "databricks": "SELECT 11", + "drill": "SELECT 11", + "hive": "SELECT 11", + "mysql": "SELECT b'1011'", + "oracle": "SELECT 11", + "postgres": "SELECT b'1011'", + "presto": "SELECT 11", + "redshift": "SELECT 11", + "snowflake": "SELECT 11", + "spark": "SELECT 11", + "sqlite": "SELECT 11", + "mysql": "SELECT b'1011'", + "tableau": "SELECT 11", + "teradata": "SELECT 11", + "trino": "SELECT 11", + "tsql": "SELECT 11", + } + + self.validate_all("SELECT 0b1011", write=write_1011) + self.validate_all("SELECT b'1011'", write=write_1011) def test_string_literals(self): self.validate_all( From 9a4e226055dacf1eabc885d974196a729b1f3af2 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 5 May 2023 18:30:49 +0300 Subject: [PATCH 3/6] Fix parsing for dialects that don't support HEX/BIT_STINGS --- sqlglot/tokens.py | 4 ++-- tests/dialects/test_duckdb.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 990cbb5c49..0758af270b 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -988,9 +988,9 @@ def _scan_number(self) -> None: if self._char == "0": peek = self._peek.upper() if peek == "B": - return self._scan_bits() + return self._scan_bits() if self._BIT_STRINGS else self._add(TokenType.NUMBER) elif peek == "X": - return self._scan_hex() + return self._scan_hex() if self._HEX_STRINGS else self._add(TokenType.NUMBER) decimal = False scientific = 0 diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index c7e6e85c3a..8c1b748691 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -127,18 +127,20 @@ def test_duckdb(self): self.validate_identity("SELECT {'a': 1} AS x") self.validate_identity("SELECT {'a': {'b': {'c': 1}}, 'd': {'e': 2}} AS x") self.validate_identity("SELECT {'x': 1, 'y': 2, 'z': 3}") - self.validate_identity( - "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" - ) self.validate_identity("SELECT {'key1': 'string', 'key2': 1, 'key3': 12.345}") self.validate_identity("SELECT ROW(x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT (x, x + 1, y) FROM (SELECT 1 AS x, 'a' AS y)") self.validate_identity("SELECT a.x FROM (SELECT {'x': 1, 'y': 2, 'z': 3} AS a)") self.validate_identity("ATTACH DATABASE ':memory:' AS new_database") + self.validate_identity( + "SELECT {'yes': 'duck', 'maybe': 'goose', 'huh': NULL, 'no': 'heron'}" + ) self.validate_identity( "SELECT a['x space'] FROM (SELECT {'x space': 1, 'y': 2, 'z': 3} AS a)" ) + self.validate_all("0b1010", write={"": "0 AS b1010"}) + self.validate_all("0x1010", write={"": "0 AS x1010"}) self.validate_all( """SELECT DATEDIFF('day', t1."A", t1."B") FROM "table" AS t1""", write={ From 36298b4a9320b3fce44e554cd2975df01b861896 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 5 May 2023 20:04:10 +0300 Subject: [PATCH 4/6] Refactor --- sqlglot/dialects/dialect.py | 49 +++++++++++++------------------------ sqlglot/generator.py | 42 +++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index b776c2574b..4b06488df2 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -70,38 +70,17 @@ def __new__(cls, clsname, bases, attrs): klass.tokenizer_class._IDENTIFIERS.items() )[0] - tokenizer_dict = klass.tokenizer_class.__dict__ - tokenizer_transforms = klass.generator_class.TRANSFORMS - - if exp.BitString not in tokenizer_transforms and "bitstring_sql" not in tokenizer_dict: - bit_strings = tokenizer_dict.get("_BIT_STRINGS") - if bit_strings: - bs_start, bs_end = list(bit_strings.items())[0] - bitstring_sql = lambda self, e: f"{bs_start}{self.sql(e, 'this')}{bs_end}" - else: - bitstring_sql = lambda self, e: f"{int(self.sql(e, 'this'), 2)}" - - setattr(klass.generator_class, "bitstring_sql", bitstring_sql) - - if exp.HexString not in tokenizer_transforms and "hexstring_sql" not in tokenizer_dict: - hex_strings = tokenizer_dict.get("_HEX_STRINGS") - if hex_strings: - hs_start, hs_end = list(hex_strings.items())[0] - hexstring_sql = lambda self, e: f"{hs_start}{self.sql(e, 'this')}{hs_end}" - else: - hexstring_sql = lambda self, e: f"{int(self.sql(e, 'this'), 16)}" - - setattr(klass.generator_class, "hexstring_sql", hexstring_sql) - - if exp.ByteString not in tokenizer_transforms and "bytestring_sql" not in tokenizer_dict: - byte_strings = tokenizer_dict.get("_BYTE_STRINGS") - if byte_strings: - be_start, be_end = list(byte_strings.items())[0] - bytestring_sql = lambda self, e: f"{be_start}{self.sql(e, 'this')}{be_end}" - else: - bytestring_sql = lambda self, e: self.sql(e, "this") - - setattr(klass.generator_class, "bytestring_sql", bytestring_sql) + klass.bit_start, klass.bit_end = None, None + if klass.tokenizer_class.__dict__.get("BIT_STRINGS"): + klass.bit_start, klass.bit_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] + + klass.hex_start, klass.hex_end = None, None + if klass.tokenizer_class.__dict__.get("HEX_STRINGS"): + klass.hex_start, klass.hex_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] + + klass.byte_start, klass.byte_end = None, None + if klass.tokenizer_class.__dict__.get("BYTE_STRINGS"): + klass.byte_start, klass.byte_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] return klass @@ -207,6 +186,12 @@ def generator(self, **opts) -> Generator: **{ "quote_start": self.quote_start, "quote_end": self.quote_end, + "bit_start": self.bit_start, + "bit_end": self.bit_end, + "hex_start": self.hex_start, + "hex_end": self.hex_end, + "byte_start": self.byte_start, + "byte_end": self.byte_end, "identifier_start": self.identifier_start, "identifier_end": self.identifier_end, "string_escape": self.tokenizer_class.STRING_ESCAPES[0], diff --git a/sqlglot/generator.py b/sqlglot/generator.py index a714e1d897..996e46280c 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -25,6 +25,12 @@ class Generator: quote_end (str): specifies which ending character to use to delimit quotes. Default: '. identifier_start (str): specifies which starting character to use to delimit identifiers. Default: ". identifier_end (str): specifies which ending character to use to delimit identifiers. Default: ". + bit_start (str): specifies which starting character to use to delimit bit literals. Default: None. + bit_end (str): specifies which ending character to use to delimit bit literals. Default: None. + hex_start (str): specifies which starting character to use to delimit hex literals. Default: None. + hex_end (str): specifies which ending character to use to delimit hex literals. Default: None. + byte_start (str): specifies which starting character to use to delimit byte literals. Default: None. + byte_end (str): specifies which ending character to use to delimit byte literals. Default: None. identify (bool | str): 'always': always quote, 'safe': quote identifiers if they don't contain an upcase, True defaults to always. normalize (bool): if set to True all identifiers will lower cased string_escape (str): specifies a string escape character. Default: '. @@ -227,6 +233,12 @@ class Generator: "quote_end", "identifier_start", "identifier_end", + "bit_start", + "bit_end", + "hex_start", + "hex_end", + "byte_start", + "byte_end", "identify", "normalize", "string_escape", @@ -258,6 +270,12 @@ def __init__( quote_end=None, identifier_start=None, identifier_end=None, + bit_start=None, + bit_end=None, + hex_start=None, + hex_end=None, + byte_start=None, + byte_end=None, identify=False, normalize=False, string_escape=None, @@ -284,6 +302,12 @@ def __init__( self.quote_end = quote_end or "'" self.identifier_start = identifier_start or '"' self.identifier_end = identifier_end or '"' + self.bit_start = bit_start + self.bit_end = bit_end + self.hex_start = hex_start + self.hex_end = hex_end + self.byte_start = byte_start + self.byte_end = byte_end self.identify = identify self.normalize = normalize self.string_escape = string_escape or "'" @@ -715,6 +739,24 @@ def tablealias_sql(self, expression: exp.TableAlias) -> str: columns = f"({columns})" if columns else "" return f"{alias}{columns}" + def bitstring_sql(self, expression: exp.BitString) -> str: + this = self.sql(expression, "this") + if self.bit_start or self.bit_end: + return f"{self.bit_start}{this}{self.bit_end}" + return f"{int(this, 2)}" + + def hexstring_sql(self, expression: exp.HexString) -> str: + this = self.sql(expression, "this") + if self.hex_start or self.hex_end: + return f"{self.hex_start}{this}{self.hex_end}" + return f"{int(this, 16)}" + + def bytestring_sql(self, expression: exp.ByteString) -> str: + this = self.sql(expression, "this") + if self.byte_start or self.byte_end: + return f"{self.byte_start}{this}{self.byte_end}" + return this + def datatype_sql(self, expression: exp.DataType) -> str: type_value = expression.this type_sql = self.TYPE_MAPPING.get(type_value, type_value.value) From fb1e5b87b061b88ba667e7d86c921307f80252a1 Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 5 May 2023 20:14:44 +0300 Subject: [PATCH 5/6] Apply suggestion --- sqlglot/dialects/dialect.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/dialect.py b/sqlglot/dialects/dialect.py index 4b06488df2..71269f256b 100644 --- a/sqlglot/dialects/dialect.py +++ b/sqlglot/dialects/dialect.py @@ -70,17 +70,17 @@ def __new__(cls, clsname, bases, attrs): klass.tokenizer_class._IDENTIFIERS.items() )[0] - klass.bit_start, klass.bit_end = None, None - if klass.tokenizer_class.__dict__.get("BIT_STRINGS"): - klass.bit_start, klass.bit_end = list(klass.tokenizer_class._BIT_STRINGS.items())[0] + klass.bit_start, klass.bit_end = seq_get( + list(klass.tokenizer_class._BIT_STRINGS.items()), 0 + ) or (None, None) - klass.hex_start, klass.hex_end = None, None - if klass.tokenizer_class.__dict__.get("HEX_STRINGS"): - klass.hex_start, klass.hex_end = list(klass.tokenizer_class._HEX_STRINGS.items())[0] + klass.hex_start, klass.hex_end = seq_get( + list(klass.tokenizer_class._HEX_STRINGS.items()), 0 + ) or (None, None) - klass.byte_start, klass.byte_end = None, None - if klass.tokenizer_class.__dict__.get("BYTE_STRINGS"): - klass.byte_start, klass.byte_end = list(klass.tokenizer_class._BYTE_STRINGS.items())[0] + klass.byte_start, klass.byte_end = seq_get( + list(klass.tokenizer_class._BYTE_STRINGS.items()), 0 + ) or (None, None) return klass From fee4213783bfccd9b58368b75fff5e1f9b2b8aef Mon Sep 17 00:00:00 2001 From: George Sittas Date: Fri, 5 May 2023 20:23:25 +0300 Subject: [PATCH 6/6] Simplify generator check --- sqlglot/generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 996e46280c..8cb5a0c4a0 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -741,19 +741,19 @@ def tablealias_sql(self, expression: exp.TableAlias) -> str: def bitstring_sql(self, expression: exp.BitString) -> str: this = self.sql(expression, "this") - if self.bit_start or self.bit_end: + if self.bit_start: return f"{self.bit_start}{this}{self.bit_end}" return f"{int(this, 2)}" def hexstring_sql(self, expression: exp.HexString) -> str: this = self.sql(expression, "this") - if self.hex_start or self.hex_end: + if self.hex_start: return f"{self.hex_start}{this}{self.hex_end}" return f"{int(this, 16)}" def bytestring_sql(self, expression: exp.ByteString) -> str: this = self.sql(expression, "this") - if self.byte_start or self.byte_end: + if self.byte_start: return f"{self.byte_start}{this}{self.byte_end}" return this