Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: improve DML check #30417

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ dependencies = [
"slack_sdk>=3.19.0, <4",
"sqlalchemy>=1.4, <2",
"sqlalchemy-utils>=0.38.3, <0.39",
"sqlglot>=23.0.2,<24",
"sqlglot>=25.24.0,<26",
"sqlparse>=0.5.0",
"tabulate>=0.8.9, <0.9",
"typing-extensions>=4, <5",
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ sqlalchemy-utils==0.38.3
# via
# apache-superset
# flask-appbuilder
sqlglot==23.6.3
sqlglot==25.24.0
# via apache-superset
sqlparse==0.5.0
# via apache-superset
Expand Down
2 changes: 1 addition & 1 deletion superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def get_settings(self) -> dict[str, str | bool]:

"""
return {
eq.this.sql(): eq.expression.sql()
eq.this.sql(comments=False): eq.expression.sql(comments=False)
for set_item in self._parsed.find_all(exp.SetItem)
for eq in set_item.find_all(exp.EQ)
}
Expand Down
29 changes: 21 additions & 8 deletions superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
OAuth2RedirectError,
SupersetErrorException,
SupersetErrorsException,
SupersetParseError,
)
from superset.extensions import celery_app, event_logger
from superset.models.core import Database
Expand Down Expand Up @@ -236,15 +237,27 @@ def execute_sql_statement( # pylint: disable=too-many-statements, too-many-loca
# We are testing to see if more rows exist than the limit.
increased_limit = None if query.limit is None else query.limit + 1

parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
if parsed_statement.is_mutating() and not database.allow_dml:
raise SupersetErrorException(
SupersetError(
message=__("Only SELECT statements are allowed against this database."),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
if not database.allow_dml:
try:
parsed_statement = SQLStatement(sql_statement, engine=db_engine_spec.engine)
disallowed = parsed_statement.is_mutating()
except SupersetParseError:
# if we fail to parse teh query, disallow by default
disallowed = True

if disallowed:
raise SupersetErrorException(
SupersetError(
message=__(
"This database does not allow for DDL/DML, and the query "
"could not be parsed to confirm it is a read-only query. Please "
"contact your administrator for more assistance."
),
error_type=SupersetErrorType.DML_NOT_ALLOWED_ERROR,
level=ErrorLevel.ERROR,
)
)
)

if apply_ctas:
if not query.tmp_table_name:
start_dttm = datetime.fromtimestamp(query.start_time)
Expand Down
6 changes: 5 additions & 1 deletion tests/integration_tests/sqllab_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def test_sql_json_dml_disallowed(self):
assert data == {
"errors": [
{
"message": "Only SELECT statements are allowed against this database.",
"message": (
"This database does not allow for DDL/DML, and the query "
"could not be parsed to confirm it is a read-only query. Please "
"contact your administrator for more assistance."
),
"error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR,
"level": ErrorLevel.ERROR,
"extra": {
Expand Down
14 changes: 14 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,3 +918,17 @@ def test_has_mutation(engine: str, sql: str, expected: bool) -> None:
Test the `has_mutation` method.
"""
assert SQLScript(sql, engine).has_mutation() == expected


def test_get_settings() -> None:
"""
Test `get_settings` in some edge cases.
"""
sql = """
set
-- this is a tricky comment
search_path -- another one
= bar;
SELECT * FROM some_table;
"""
assert SQLScript(sql, "postgresql").get_settings() == {"search_path": "bar"}
Loading