Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Remove AS Keyword for Subquery Aliases in Oracle SQL #44210

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions providers/src/airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,12 +659,18 @@ class SQLTableCheckOperator(BaseSQLOperator):

template_fields_renderers: ClassVar[dict] = {"sql": "sql"}

sql_check_template = """
sql_check_template_default = """
SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result
FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS {check_name}
FROM {table} {partition_clause}) AS sq
"""

sql_check_template_oracle = """
SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result
FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS {check_name}
FROM {table} {partition_clause}) sq
"""

def __init__(
self,
*,
Expand All @@ -680,7 +686,30 @@ def __init__(
self.table = table
self.checks = checks
self.partition_clause = partition_clause
self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table"

# Determine SQL template and query structure based on database type
self.sql_check_template, self.sql = self._configure_sql_templates()

def _configure_sql_templates(self):
"""
Configures the SQL template and final query based on the database connection type.
"""
db_hook = self.get_db_hook()

# Example: Check the connection type using the hook
conn_type = db_hook.conn_type.lower()
self.log.info(f"Database connection type: {conn_type}")

if conn_type == "oracle":
sql_check_template = self.sql_check_template_oracle
sql_query = f"SELECT check_name, check_result FROM ({self._generate_sql_query(sql_check_template)}) " \
f"check_table"
else:
sql_check_template = self.sql_check_template_default
sql_query = f"SELECT check_name, check_result FROM ({self._generate_sql_query(sql_check_template)}) AS " \
f"check_table"

return sql_check_template, sql_query

def execute(self, context: Context):
hook = self.get_db_hook()
Expand Down Expand Up @@ -709,7 +738,7 @@ def execute(self, context: Context):

self.log.info("All tests have passed")

def _generate_sql_query(self):
def _generate_sql_query(self, sql_check_template):
self.log.debug("Partition clause: %s", self.partition_clause)

def _generate_partition_clause(check_name):
Expand All @@ -723,7 +752,7 @@ def _generate_partition_clause(check_name):
return ""

return "UNION ALL".join(
self.sql_check_template.format(
sql_check_template.format(
check_statement=value["check_statement"],
check_name=check_name,
table=self.table,
Expand Down
Loading