Skip to content

Commit

Permalink
chore: improve DML check (#30417)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Sep 27, 2024
1 parent 96b0bcf commit cc9fd88
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ dependencies = [
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=23.0.2,<24",
"sqlglot>=25.24.0,<26",
"sqlparse>=0.5.0",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==23.6.3
sqlglot==25.24.0
# via apache-superset
sqlparse==0.5.0
# via apache-superset
Expand Down
2 changes: 1 addition & 1 deletion superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def get_settings(self) -> dict[str, str | bool]:
"""
return {
eq.this.sql(): eq.expression.sql()
eq.this.sql(comments=False): eq.expression.sql(comments=False)
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
Expand Down
29 changes: 21 additions & 8 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
OAuth2RedirectError,
SupersetErrorException,
SupersetErrorsException,
SupersetParseError,
)
from superset.extensions import celery_app, event_logger
from superset.models.core import Database
Expand Down Expand Up @@ -236,15 +237,27 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca
# We are testing to see if more rows exist than the limit.
increased_limit = None if query.limit is None else query.limit + 1

parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
if parsed_statement.is_mutating() and not database.allow_dml:
raise SupersetErrorException(
SupersetError(
message=__("Only SELECT statements are allowed against this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
if not database.allow_dml:
try:
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
disallowed = parsed_statement.is_mutating()
except SupersetParseError:
# if we fail to parse teh query, disallow by default
disallowed = True

if disallowed:
raise SupersetErrorException(
SupersetError(
message=__(
"This database does not allow for DDL/DML, and the query "
"could not be parsed to confirm it is a read-only query. Please "
"contact your administrator for more assistance."
),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
)
)
)

if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
Expand Down
6 changes: 5 additions & 1 deletion tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_sql_json_dml_disallowed(self):
assert data == {
"errors": [
{
"message": "Only SELECT statements are allowed against this database.",
"message": (
"This database does not allow for DDL/DML, and the query "
"could not be parsed to confirm it is a read-only query. Please "
"contact your administrator for more assistance."
),
"error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR,
"level": ErrorLevel.ERROR,
"extra": {
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,17 @@ def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
Test the `has_mutation` method.
"""
assert SQLScript(sql, engine).has_mutation() == expected


def test_get_settings() -> None:
"""
Test `get_settings` in some edge cases.
"""
sql = """
set
-- this is a tricky comment
search_path -- another one
= bar;
SELECT * FROM some_table;
"""
assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}

0 comments on commit cc9fd88

Please sign in to comment.