diff --git a/src/odm_sharing/private/queries.py b/src/odm_sharing/private/queries.py index bca5b9ba..ba155288 100644 --- a/src/odm_sharing/private/queries.py +++ b/src/odm_sharing/private/queries.py @@ -1,4 +1,3 @@ -import re from collections import defaultdict from dataclasses import dataclass, field from functools import partial @@ -17,6 +16,7 @@ Node, NodeKind, Op, + ParseError, RangeKind, RuleTree, parse_op, @@ -60,12 +60,19 @@ class TableQuery: OrgTableQueries = Dict[OrgName, Dict[TableName, TableQuery]] -INVALID_IDENT_PATTERN = re.compile(r'\W+', re.ASCII) - - def ident(x: str) -> str: - '''sanitize and quote sql identifier''' - return dqt(INVALID_IDENT_PATTERN.sub('', x)) + '''make a sanitized/quoted sql identifier + + :raises ParseError: + ''' + # Double-quotes should be used as the delimiter for column-name + # identifiers. (https://stackoverflow.com/a/2901499) + # + # It should be enough to simply disallow double-quotes in the name. + if '"' in x: + raise ParseError('the following column-name contains double-quotes, ' + + f'which is not allowed: \'{x}\'') + return dqt(x) def convert(val: str) -> str: @@ -90,6 +97,8 @@ def gen_data_sql( to be generated, but it'll also work on any node in the table-subtree. :param args: (output) sql arguments :param rule_queries: (output) see ``gen_data_query`` + + :raises ParseError: ''' def recurse(node: Node) -> str: @@ -130,7 +139,7 @@ def record(node: Node, sql: str, args: SqlArgs) -> None: elif n.kind == NodeKind.FILTER: # filter has op as value, children define field, kind and literals op = parse_op(n.str_val) - key_ident = ident(recurse(n.sons[0])) + key_ident = recurse(n.sons[0]) def gen_range_sql(range_kind: RangeKind, values: List[str]) -> str: '''generates sql for a range of values''' @@ -177,6 +186,8 @@ def gen_data_query( :param rule_queries: (output) partial filter and select queries :return: complete query + + :raises ParseError: ''' args: List[str] = [] sql = gen_data_sql(table_node, args, rule_queries) @@ -226,7 +237,10 @@ def gen_count_query_sql( rule_id: int, filter_query: PartialQuery, ) -> Tuple[RuleId, Query]: - '''generate count query for table from partial filter query''' + '''generate count query for table from partial filter query + + :raises ParseError: + ''' sql = ( f'SELECT COUNT(*) FROM {ident(table)}' + (f' WHERE {filter_query.sql}' if filter_query.sql else '') @@ -235,7 +249,10 @@ def gen_count_query_sql( def gen_table_query(share_node: Node, table_node: Node) -> TableQuery: - '''generates a table-query for a specific table node of a share node''' + '''generates a table-query for a specific table node of a share node + + :raises ParseError: + ''' assert share_node.kind == NodeKind.SHARE assert table_node.kind == NodeKind.TABLE assert table_node in share_node.sons @@ -281,6 +298,8 @@ def generate(rule_tree: RuleTree) -> OrgTableQueries: :param rule_tree: the tree to generate queries from :return: query-objects for each org and table + + :raises ParseError: ''' def gen_table_query_entry(share_node: Node, table_node: Node diff --git a/src/odm_sharing/tools/share.py b/src/odm_sharing/tools/share.py index 1c654673..300856be 100644 --- a/src/odm_sharing/tools/share.py +++ b/src/odm_sharing/tools/share.py @@ -96,8 +96,13 @@ def write_debug( ruleset: Dict[RuleId, Rule] ) -> None: '''write debug output''' + write_line(file, '') write_header(file, 1, f'org {qt(org_name)} - table {qt(table_name)}') + write_header(file, 2, 'data sql') + write_line(file, table_query.data_query.sql) + write_line(file, '') + (select_id, columns) = sh.get_columns(con, table_query) write_header(file, 2, 'columns') for col in columns: diff --git a/tests/api/issue-69/protocols.csv b/tests/api/issue-69/protocols.csv new file mode 100644 index 00000000..d1d9ff97 --- /dev/null +++ b/tests/api/issue-69/protocols.csv @@ -0,0 +1,4 @@ +Protocol.ID +a +b +c diff --git a/tests/api/issue-69/schema-org1-protocols.csv b/tests/api/issue-69/schema-org1-protocols.csv new file mode 100644 index 00000000..d1d9ff97 --- /dev/null +++ b/tests/api/issue-69/schema-org1-protocols.csv @@ -0,0 +1,4 @@ +Protocol.ID +a +b +c diff --git a/tests/api/issue-69/schema.csv b/tests/api/issue-69/schema.csv new file mode 100644 index 00000000..9ee5bc70 --- /dev/null +++ b/tests/api/issue-69/schema.csv @@ -0,0 +1,3 @@ +ruleID,table,mode,key,operator,value,notes +1,protocols,select,,,Protocol.ID, +2,,share,org1,,1, diff --git a/tests/test_api.py b/tests/test_api.py index d4554ade..ae06cc73 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -64,6 +64,14 @@ def test_excel_string_filter(self) -> None: self.assertEqual(df['str1'].to_list(), ['a']) self.assertEqual(df['str2'].to_list(), ['']) + def test_header_with_dot(self) -> None: + HEADER = 'Protocol.ID' + dir = join(self.dir, 'api', 'issue-69') + res = sh.extract(join(dir, 'schema.csv'), join(dir, 'protocols.csv')) + df = res['org1']['protocols'] + self.assertEqual(df.columns.to_list(), [HEADER]) + self.assertEqual(df[HEADER].to_list(), ['a', 'b', 'c']) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_queries.py b/tests/test_queries.py index cef7156c..c9e7c47a 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -92,8 +92,8 @@ def test_share_table_rule_count_queries(self) -> None: self.assertEqual(actual2, expected2) def test_sanitize(self) -> None: - '''special characters are stripped, and parameter values separated, to - prevent injections''' + '''double-quotes are not allowed in identifiers and parameter values + are separated, to prevent injections''' injection = '" OR 1=1 --' ruleset = [ Rule(id=1, table='t', mode=RuleMode.SELECT, value=injection), @@ -102,10 +102,8 @@ def test_sanitize(self) -> None: Rule(id=3, table='', mode=RuleMode.SHARE, key='ohri', value='1;2'), ] ruletree = trees.parse(ruleset) - q = queries.generate(ruletree)['ohri']['t'] - actual = q.data_query.sql - expected = 'SELECT "OR11" FROM "t" WHERE ("OR11" = ?)' - self.assertEqual(actual, expected) + with self.assertRaisesRegex(rules.ParseError, 'quote.*not allowed'): + queries.generate(ruletree) if __name__ == '__main__':