Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/amazon/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.13.0",
"apache-airflow-providers-common-sql>=1.27.0",
"apache-airflow-providers-common-sql>=1.27.0", # use next version
"apache-airflow-providers-http",
# We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number
# of candidates to consider. Make sure to configure boto3 version here as well as in all the tools below
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage

if TYPE_CHECKING:
from botocore.paginate import PageIterator
Expand Down Expand Up @@ -126,6 +127,11 @@ def run_query(
response = self.get_conn().start_query_execution(**params)
query_execution_id = response["QueryExecutionId"]
self.log.info("Query execution id: %s", query_execution_id)
send_sql_hook_lineage(
context=self,
sql=query,
job_id=query_execution_id,
)
return query_execution_id

def get_query_info(self, query_execution_id: str, use_cache: bool = False) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage

if TYPE_CHECKING:
from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa: F401
Expand Down Expand Up @@ -154,6 +155,19 @@ def execute_query(

statement_id = resp["Id"]

send_sql_hook_lineage(
context=self,
sql="; ".join(sql) if isinstance(sql, list) else sql,
sql_parameters=parameters or None,
job_id=statement_id,
default_db=database,
extra={
"cluster_identifier": cluster_identifier,
"workgroup_name": workgroup_name,
"session_id": session_id or resp.get("SessionId"),
},
)

if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)

Expand Down
15 changes: 15 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ def test_hook_run_query_no_log_query(self, mock_conn, log):
)
assert athena_hook_no_log_query.log.info.call_count == 1

@mock.patch("airflow.providers.amazon.aws.hooks.athena.send_sql_hook_lineage")
@mock.patch.object(AthenaHook, "get_conn")
def test_run_query_hook_lineage(self, mock_conn, mock_send_lineage):
mock_conn.return_value.start_query_execution.return_value = MOCK_QUERY_EXECUTION
self.athena.run_query(
query=MOCK_DATA["query"],
query_context=mock_query_context,
result_configuration=mock_result_configuration,
)
mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.athena
assert call_kw["sql"] == MOCK_DATA["query"]
assert call_kw["job_id"] == MOCK_DATA["query_execution_id"]

@mock.patch.object(AthenaHook, "get_conn")
def test_hook_get_query_results_with_non_succeeded_query(self, mock_conn):
mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,30 @@ def test_execute_reuse_session(self, mock_conn):
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.send_sql_hook_lineage")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_query_hook_lineage(self, mock_conn, mock_send_lineage):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
cluster_identifier="cluster_identifier",
sql=SQL,
wait_for_completion=False,
)
mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is hook
assert call_kw["sql"] == SQL
assert call_kw["sql_parameters"] is None
assert call_kw["job_id"] == STATEMENT_ID
assert call_kw["default_db"] == DATABASE
assert call_kw["extra"] == {
"cluster_identifier": "cluster_identifier",
"workgroup_name": None,
"session_id": None,
}

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,81 @@ def test_get_openlineage_redshift_authority_part(
assert f"{expected_identity}:{LOGIN_PORT}" == self.db_hook._get_openlineage_redshift_authority_part(
self.connection
)


class TestRedshiftSQLHookLineage:
def setup_method(self):
self.cur = mock.MagicMock(rowcount=0)
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn

class UnitTestRedshiftSQLHook(RedshiftSQLHook):
conn_name_attr = "test_conn_id"

def get_conn(self):
return conn

self.db_hook = UnitTestRedshiftSQLHook()
self.db_hook.get_connection = mock.Mock(
return_value=Connection(
conn_type="redshift",
login=LOGIN_USER,
password=LOGIN_PASSWORD,
host=LOGIN_HOST,
port=LOGIN_PORT,
schema=LOGIN_SCHEMA,
)
)

@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
def test_run_hook_lineage(self, mock_send_lineage):
statement = "SELECT 1"

self.db_hook.run(statement)

mock_send_lineage.assert_called()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == statement
assert call_kw["sql_parameters"] is None
assert call_kw["cur"] is self.cur

@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
def test_insert_rows_hook_lineage(self, mock_send_lineage):
table = "table"
rows = [("hello",), ("world",)]

self.db_hook.insert_rows(table, rows)

mock_send_lineage.assert_called()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == "INSERT INTO table VALUES (%s)"
assert call_kw["row_count"] == 2

@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df")
def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage):
sql = "SELECT 1"
parameters = ("x",)
self.db_hook.get_df(sql, parameters=parameters)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] == parameters

@mock.patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@mock.patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks")
def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage):
sql = "SELECT 1"
parameters = ("x",)
self.db_hook.get_df_by_chunks(sql, parameters=parameters, chunksize=1)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is self.db_hook
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] == parameters
2 changes: 1 addition & 1 deletion providers/apache/drill/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-sql>=1.26.0",
"apache-airflow-providers-common-sql>=1.26.0", # use next version
"apache-airflow-providers-common-compat>=1.8.0",
# Workaround until we get https://github.com/JohnOmernik/sqlalchemy-drill/issues/94 fixed.
"sqlalchemy-drill>=1.1.0,!=1.1.6,!=1.1.7",
Expand Down
37 changes: 37 additions & 0 deletions providers/apache/drill/tests/unit/apache/drill/hooks/test_drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,40 @@ def test_insert_rows_raises_not_implemented(self):
db_hook = self.db_hook()
with pytest.raises(NotImplementedError, match=r"There is no INSERT statement in Drill."):
db_hook.insert_rows(table="my_table", rows=[("a",)])

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
def test_run_hook_lineage(self, mock_send_lineage):
sql = "SELECT 1"
self.db_hook().run(sql)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] is None
assert call_kw["cur"] is self.cur

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df")
def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage):
sql = "SELECT 1"
self.db_hook().get_df(sql, df_type="pandas")

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] is None

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks")
def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage):
sql = "SELECT 1"
parameters = ("x",)
self.db_hook().get_df_by_chunks(sql, parameters=parameters, chunksize=1)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] == parameters
2 changes: 1 addition & 1 deletion providers/apache/druid/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ requires-python = ">=3.10"
# After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build``
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-sql>=1.26.0",
"apache-airflow-providers-common-sql>=1.26.0", # use next version
"apache-airflow-providers-common-compat>=1.10.1",
"pydruid>=0.6.6",
]
Expand Down
37 changes: 37 additions & 0 deletions providers/apache/druid/tests/unit/apache/druid/hooks/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,3 +509,40 @@ def test_get_df_polars(self):
assert column == df.columns[0]
assert result_sets[0][0] == df.row(0)[0]
assert result_sets[1][0] == df.row(1)[0]

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
def test_run_hook_lineage(self, mock_send_lineage):
sql = "SELECT 1"
self.db_hook().run(sql)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] is None
assert call_kw["cur"] is self.cur

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df")
def test_get_df_hook_lineage(self, mock_get_pandas_df, mock_send_lineage):
sql = "SELECT 1"
self.db_hook().get_df(sql, df_type="pandas")

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] is None

@patch("airflow.providers.common.sql.hooks.sql.send_sql_hook_lineage")
@patch("airflow.providers.common.sql.hooks.sql.DbApiHook._get_pandas_df_by_chunks")
def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_send_lineage):
sql = "SELECT 1"
parameters = ("x",)
self.db_hook().get_df_by_chunks(sql, parameters=parameters, chunksize=1)

mock_send_lineage.assert_called_once()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is not None
assert call_kw["sql"] == sql
assert call_kw["sql_parameters"] == parameters
2 changes: 1 addition & 1 deletion providers/apache/hive/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ requires-python = ">=3.10"
dependencies = [
"apache-airflow>=2.11.0",
"apache-airflow-providers-common-compat>=1.12.0",
"apache-airflow-providers-common-sql>=1.26.0",
"apache-airflow-providers-common-sql>=1.26.0", # use next version
"hmsclient>=0.1.0",
'pandas>=2.1.2; python_version <"3.13"',
'pandas>=2.2.3; python_version >="3.13"',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
BaseHook,
conf,
)
from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.security import utils
from airflow.utils.helpers import as_flattened_list
Expand Down Expand Up @@ -332,6 +333,8 @@ def run_cli(
if sub_process.returncode:
raise AirflowException(stdout)

send_sql_hook_lineage(context=self, sql=hql)

return stdout

def test_hql(self, hql: str) -> None:
Expand Down Expand Up @@ -916,6 +919,7 @@ def _get_results(

for statement in sql:
cur.execute(statement)
send_sql_hook_lineage(context=self, sql=statement, cur=cur, default_schema=schema)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
if lowered_statement.startswith(("select", "with", "show")) or (
Expand Down
41 changes: 41 additions & 0 deletions providers/apache/hive/tests/unit/apache/hive/hooks/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,33 @@ def test_run_cli(self, mock_popen, mock_temp_dir):
close_fds=True,
)

@mock.patch("airflow.providers.apache.hive.hooks.hive.send_sql_hook_lineage")
@mock.patch("tempfile.tempdir", "/tmp/")
@mock.patch("tempfile._RandomNameSequence.__next__")
@mock.patch("subprocess.Popen")
def test_run_cli_hook_lineage(self, mock_popen, mock_temp_dir, mock_send_lineage):
mock_subprocess = MockSubProcess()
mock_popen.return_value = mock_subprocess
mock_temp_dir.return_value = "test_run_cli"
hql = "SHOW DATABASES"
envron_name = "AIRFLOW_CTX_LOGICAL_DATE" if AIRFLOW_V_3_0_PLUS else "AIRFLOW_CTX_EXECUTION_DATE"
with mock.patch.dict(
"os.environ",
{
"AIRFLOW_CTX_DAG_ID": "test_dag_id",
"AIRFLOW_CTX_TASK_ID": "test_task_id",
envron_name: "2015-01-01T00:00:00+00:00",
"AIRFLOW_CTX_TRY_NUMBER": "1",
"AIRFLOW_CTX_DAG_RUN_ID": "55",
"AIRFLOW_CTX_DAG_OWNER": "airflow",
"AIRFLOW_CTX_DAG_EMAIL": "test@airflow.com",
},
):
hook = MockHiveCliHook()
hook.run_cli(hql, schema="some_schema")

mock_send_lineage.assert_called_once_with(context=hook, sql=f"USE some_schema;\n{hql}\n")

def test_hive_cli_hook_invalid_schema(self):
hook = InvalidHiveCliHook()
with pytest.raises(RuntimeError) as error:
Expand Down Expand Up @@ -767,6 +794,20 @@ def test_get_results_data(self):

assert results["data"] == [(1, 1), (2, 2)]

@mock.patch("airflow.providers.apache.hive.hooks.hive.send_sql_hook_lineage")
def test_get_results_sends_hook_lineage(self, mock_send_lineage):
hook = MockHiveServer2Hook()

query = f"SELECT * FROM {self.table}"
hook.get_results(query, schema=self.database)

mock_send_lineage.assert_called()
call_kw = mock_send_lineage.call_args.kwargs
assert call_kw["context"] is hook
assert call_kw["sql"] == query
assert call_kw["cur"] is hook.mock_cursor
assert call_kw["default_schema"] == self.database

def test_to_csv(self):
hook = MockHiveServer2Hook()
hook._get_results = mock.MagicMock(
Expand Down
2 changes: 1 addition & 1 deletion providers/apache/impala/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ requires-python = ">=3.10"
dependencies = [
"impyla>=0.22.0,<1.0",
"apache-airflow-providers-common-compat>=1.12.0",
"apache-airflow-providers-common-sql>=1.26.0",
"apache-airflow-providers-common-sql>=1.26.0", # use next version
"apache-airflow>=2.11.0",
]

Expand Down
Loading