Skip to content

Commit

Permalink
Feat: allow marking audits as [non-]blocking at use site
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Jun 25, 2024
1 parent bad0457 commit 6609629
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sqlmesh/core/snapshot/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,10 @@ def _audit(
skipped=True,
)

# Model's "blocking" argument takes precedence over the audit's default setting
blocking = audit_args.pop("blocking", None)
blocking = blocking == exp.true() if blocking else audit.blocking

query = audit.render_query(
snapshot,
start=start,
Expand All @@ -854,10 +858,11 @@ def _audit(
query=query,
adapter_dialect=self.adapter.dialect,
)
if audit.blocking:
if blocking:
raise audit_error
else:
logger.warning(f"{audit_error}\nAudit is warn only so proceeding with execution.")

return AuditResult(
audit=audit,
model=snapshot.model_or_none,
Expand Down
48 changes: 47 additions & 1 deletion tests/core/test_snapshot_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from sqlmesh.core.snapshot.evaluator import CustomMaterialization
from sqlmesh.utils.concurrency import NodeExecutionFailedError
from sqlmesh.utils.date import to_timestamp
from sqlmesh.utils.errors import ConfigError
from sqlmesh.utils.errors import AuditError, ConfigError
from sqlmesh.utils.metaprogramming import Executable


Expand Down Expand Up @@ -2014,6 +2014,52 @@ def test_audit_wap(adapter_mock, make_snapshot):
adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id)


def test_audit_set_blocking_at_use_site(adapter_mock, make_snapshot):
evaluator = SnapshotEvaluator(adapter_mock)

always_failing_audit = ModelAudit(
name="always_fail",
query="SELECT * FROM test_schema.test_table",
)

model = SqlModel(
name="test_schema.test_table",
kind=FullKind(),
query=parse_one("SELECT a::int FROM tbl"),
audits=[
("always_fail", {"blocking": exp.false()}),
],
)
snapshot = make_snapshot(model, audits={always_failing_audit.name: always_failing_audit})
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)

# Return a non-zero count to indicate audit failure
adapter_mock.fetchone.return_value = (1,)

logger = logging.getLogger("sqlmesh.core.snapshot.evaluator")
with patch.object(logger, "warning") as mock_logger:
evaluator.audit(snapshot, snapshots={})
assert "Audit is warn only so proceeding with execution." in mock_logger.call_args[0][0]

model = SqlModel(
name="test_schema.test_table",
kind=FullKind(),
query=parse_one("SELECT a::int FROM tbl"),
audits=[
("always_fail", {"blocking": exp.true()}),
],
)
snapshot = make_snapshot(model, audits={always_failing_audit.name: always_failing_audit})
snapshot.categorize_as(SnapshotChangeCategory.BREAKING)
adapter_mock.fetchone.return_value = (1,)

with pytest.raises(
AuditError,
match="Audit 'always_fail' for model 'test_schema.test_table' failed.",
):
evaluator.audit(snapshot, snapshots={})


def test_create_post_statements_use_deployable_table(
mocker: MockerFixture, adapter_mock, make_snapshot
):
Expand Down

0 comments on commit 6609629

Please sign in to comment.