Skip to content

Commit

Permalink
fixes error when running :NUMBER in duckdb (#910)
Browse files Browse the repository at this point in the history
* basic bug fix

* seperate function, add tests

* edit changelog

* fix spelling, test error

* remove print

* last spelling fix

* add test compare duckdb result
  • Loading branch information
bryannho authored Oct 10, 2023
1 parent 2ce74f4 commit 433178c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

* [Fix] Remove force deleted snippets from dependent snippet's `with` (#717)
* [Fix] Comments added in SQL query to be stripped before saved as snippet (#886)
* [Fix] Fixed bug passing :NUMBER while string slicing in query (#901)

## 0.10.2 (2023-09-22)

Expand Down
30 changes: 29 additions & 1 deletion src/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def escape_string_literals_with_colon_prefix(query):
"""
Given a query, replaces all occurrences of ':variable' with '\:variable' and
":variable" with "\:variable" so that the query can be passed to sqlalchemy.text
without the literals being interpreted as bind parameters. It doesn't replace
without the literals being interpreted as bind parameters. Also calls
escape_string_slicing_with_colon_prefix(). It doesn't replace
the occurrences of :variable (without quotes)
""" # noqa

Expand All @@ -245,9 +246,36 @@ def escape_string_literals_with_colon_prefix(query):
double_found = re.findall(double_quoted_variable_pattern, query)
single_found = re.findall(single_quoted_variable_pattern, query)

# Escape occurrences of : for string slicing
query_quoted, _ = escape_string_slicing_notation(query_quoted)

return query_quoted, double_found + single_found


def escape_string_slicing_notation(query):
"""
Given a query, replaces all occurrences of 'example'[x:y] with 'example'[x\:y].
Escaping the colon using \ ensures correct string slicing behavior rather
than being interpreted as a bind parameter.
Parameters
----------
query: str
query to be parsed and cleaned
""" # noqa
identifier_pattern = r"\b[0-9_]*\b"

# Define the regular expression pattern for matching [x:y]
string_slicing_pattern = r"(?<!\\):(" + identifier_pattern + r")(?<!\\)\]"

# Replace [x:y] with [x\:y]
query_escaped = re.sub(string_slicing_pattern, r"\\:\1]", query)

occurences_found = re.findall(string_slicing_pattern, query)

return query_escaped, occurences_found


def find_named_parameters(input_string):
# Define the regular expression pattern for valid Python identifiers
identifier_pattern = r"\b[a-zA-Z_][a-zA-Z0-9_]*\b"
Expand Down
81 changes: 81 additions & 0 deletions src/tests/test_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
without_sql_comment,
magic_args,
escape_string_literals_with_colon_prefix,
escape_string_slicing_notation,
find_named_parameters,
_connection_string,
ConnectionsFile,
Expand Down Expand Up @@ -424,3 +425,83 @@ def test_connections_file_get_default_connection_url(tmp_empty, content, expecte

cf = ConnectionsFile(path_to_file="conns.ini")
assert cf.get_default_connection_url() == expected


@pytest.mark.parametrize(
"query_jupysql, expected_duckdb",
[
(
"select 'hello'[:2]",
"he",
),
(
"select 'hello'[2:]",
"ello",
),
(
"select 'hello'[2:4]",
"ell",
),
(
"select 'hello'[:-1]",
"hell",
),
],
)
def test_slicing_jupysql_matches_duckdb_expected(
ip_empty, query_jupysql, expected_duckdb
):
ip_empty.run_cell("%load_ext sql")
ip_empty.run_cell("%sql duckdb://")
raw_result = ip_empty.run_line_magic("sql", query_jupysql)
result_jupysql = list(raw_result.dict().values())[0][0]
assert result_jupysql == expected_duckdb


@pytest.mark.parametrize(
"query, expected_escaped, expected_found",
[
(
"SELECT 'hello'",
"SELECT 'hello'",
[],
),
(
"SELECT 'hello'[:]",
"SELECT 'hello'[:]",
[],
),
(
"SELECT 'hello'[:2]",
"SELECT 'hello'[\\:2]",
["2"],
),
(
"SELECT 'hello'[1:5]",
"SELECT 'hello'[1\\:5]",
["5"],
),
(
"SELECT 'hello'[1:99]",
"SELECT 'hello'[1\\:99]",
["99"],
),
(
"SELECT 'hello'[:123456789]",
"SELECT 'hello'[\\:123456789]",
["123456789"],
),
],
ids=[
"no-slicing",
"slicing-empty",
"end-index-only",
"begin-end-index",
"end-index-two-digit",
"end-index-many-digit",
],
)
def test_escape_string_slicing_notation(query, expected_escaped, expected_found):
escaped, found = escape_string_slicing_notation(query)
assert escaped == expected_escaped
assert found == expected_found

0 comments on commit 433178c

Please sign in to comment.