diff --git a/airflow/lineage/hook.py b/airflow/lineage/hook.py index 70893516bd9bd..ee12e1624e12d 100644 --- a/airflow/lineage/hook.py +++ b/airflow/lineage/hook.py @@ -139,10 +139,10 @@ class NoOpCollector(HookLineageCollector): It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_): + def add_input_dataset(self, *_, **__): pass - def add_output_dataset(self, *_): + def add_output_dataset(self, *_, **__): pass @property diff --git a/airflow/providers/amazon/aws/datasets/__init__.py b/airflow/providers/amazon/aws/datasets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/amazon/aws/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/amazon/aws/datasets/s3.py b/airflow/providers/amazon/aws/datasets/s3.py new file mode 100644 index 0000000000000..89889efe577b3 --- /dev/null +++ b/airflow/providers/amazon/aws/datasets/s3.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.datasets import Dataset + + +def create_dataset(*, bucket: str, key: str, extra=None) -> Dataset: + return Dataset(uri=f"s3://{bucket}/{key}", extra=extra) diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index 8ca93766e2ed3..5f2c1366404eb 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -41,6 +41,8 @@ from urllib.parse import urlsplit from uuid import uuid4 +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector + if TYPE_CHECKING: from mypy_boto3_s3.service_resource import Bucket as S3Bucket, Object as S3ResourceObject @@ -1111,6 +1113,12 @@ def load_file( client = self.get_conn() client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="file", dataset_kwargs={"path": filename} + ) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) @unify_bucket_name_and_key @provide_bucket_name @@ -1251,6 +1259,10 @@ def _upload_file_obj( ExtraArgs=extra_args, Config=self.transfer_config, ) + # No input because file_obj can be anything - handle in calling function if possible + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) def copy_object( self, @@ -1306,6 +1318,12 @@ def copy_object( response = self.get_conn().copy_object( Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, **kwargs ) + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": source_bucket_name, "key": source_bucket_key} + ) + get_hook_lineage_collector().add_output_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": dest_bucket_name, "key": dest_bucket_key} + ) return response @provide_bucket_name @@ -1425,6 +1443,11 @@ def download_file( file_path.parent.mkdir(exist_ok=True, parents=True) + get_hook_lineage_collector().add_output_dataset( + context=self, + scheme="file", + dataset_kwargs={"path": file_path if file_path.is_absolute() else file_path.absolute()}, + ) file = open(file_path, "wb") else: file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore @@ -1435,7 +1458,9 @@ def download_file( ExtraArgs=self.extra_args, Config=self.transfer_config, ) - + get_hook_lineage_collector().add_input_dataset( + context=self, scheme="s3", dataset_kwargs={"bucket": bucket_name, "key": key} + ) return file.name def generate_presigned_url( diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index a7b4d4272f5fb..9dd76ac9fa3b1 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -91,6 +91,7 @@ dependencies: - apache-airflow>=2.7.0 - apache-airflow-providers-common-sql>=1.3.1 - apache-airflow-providers-http + - apache-airflow-providers-common-compat>=1.1.0 # We should update minimum version of boto3 and here regularly to avoid `pip` backtracking with the number # of candidates to consider. Make sure to configure boto3 version here as well as in all the tools below # in the `devel-dependencies` section to be the same minimum version. @@ -561,6 +562,7 @@ sensors: dataset-uris: - schemes: [s3] handler: null + factory: airflow.providers.amazon.aws.datasets.s3.create_dataset filesystems: - airflow.providers.amazon.aws.fs.s3 diff --git a/airflow/providers/common/compat/lineage/hook.py b/airflow/providers/common/compat/lineage/hook.py index 2115c992e7a41..dbdbc5bf86f4d 100644 --- a/airflow/providers/common/compat/lineage/hook.py +++ b/airflow/providers/common/compat/lineage/hook.py @@ -32,10 +32,10 @@ class NoOpCollector: It is used when you want to disable lineage collection. """ - def add_input_dataset(self, *_): + def add_input_dataset(self, *_, **__): pass - def add_output_dataset(self, *_): + def add_output_dataset(self, *_, **__): pass return NoOpCollector() diff --git a/airflow/providers/common/compat/provider.yaml b/airflow/providers/common/compat/provider.yaml index 27e610e25f4d3..53527f9204ad4 100644 --- a/airflow/providers/common/compat/provider.yaml +++ b/airflow/providers/common/compat/provider.yaml @@ -25,6 +25,7 @@ state: ready source-date-epoch: 1716287191 # note that those versions are maintained by release manager - do not update them manually versions: + - 1.1.0 - 1.0.0 dependencies: diff --git a/airflow/providers/common/io/datasets/file.py b/airflow/providers/common/io/datasets/file.py index 46c7499037e06..1bc4969762b85 100644 --- a/airflow/providers/common/io/datasets/file.py +++ b/airflow/providers/common/io/datasets/file.py @@ -19,6 +19,6 @@ from airflow.datasets import Dataset -def create_dataset(*, path: str) -> Dataset: +def create_dataset(*, path: str, extra=None) -> Dataset: # We assume that we get absolute path starting with / - return Dataset(uri=f"file://{path}") + return Dataset(uri=f"file://{path}", extra=extra) diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 9e9dd4d573ddd..f6d29a51d12ca 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -886,23 +886,23 @@ def _discover_dataset_uri_handlers_and_factories(self) -> None: for provider_package, provider in self._provider_dict.items(): for handler_info in provider.data.get("dataset-uris", []): - try: - schemes = handler_info["schemes"] - handler_path = handler_info["handler"] - except KeyError: + schemes = handler_info.get("schemes") + handler_path = handler_info.get("handler") + factory_path = handler_info.get("factory") + if schemes is None: continue - if handler_path is None: + + if handler_path is not None and ( + handler := _correctness_check(provider_package, handler_path, provider) + ): + pass + else: handler = normalize_noop - elif not (handler := _correctness_check(provider_package, handler_path, provider)): - continue self._dataset_uri_handlers.update((scheme, handler) for scheme in schemes) - factory_path = handler_info.get("factory") - if not ( - factory_path is not None - and (factory := _correctness_check(provider_package, factory_path, provider)) + if factory_path is not None and ( + factory := _correctness_check(provider_package, factory_path, provider) ): - continue - self._dataset_factories.update((scheme, factory) for scheme in schemes) + self._dataset_factories.update((scheme, factory) for scheme in schemes) def _discover_taskflow_decorators(self) -> None: for name, info in self._provider_dict.items(): diff --git a/dev/breeze/tests/test_selective_checks.py b/dev/breeze/tests/test_selective_checks.py index 4c215ca3b0aad..c0c40b9be92b0 100644 --- a/dev/breeze/tests/test_selective_checks.py +++ b/dev/breeze/tests/test_selective_checks.py @@ -569,7 +569,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/__init__.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap microsoft.azure " + "common.compat common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -585,7 +585,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "upgrade-to-newer-dependencies": "false", "run-amazon-tests": "true", "parallel-test-types-list-as-string": "Always Providers[amazon] " - "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http," + "Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http," "imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-folders": "['providers']", @@ -619,7 +619,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): ("airflow/providers/amazon/file.py",), { "affected-providers-list-as-string": "amazon apache.hive cncf.kubernetes " - "common.sql exasol ftp google http imap microsoft.azure " + "common.compat common.sql exasol ftp google http imap microsoft.azure " "mongo mysql openlineage postgres salesforce ssh teradata", "all-python-versions": "['3.8']", "all-python-versions-list-as-string": "3.8", @@ -635,7 +635,7 @@ def assert_outputs_are_printed(expected_outputs: dict[str, str], stderr: str): "run-kubernetes-tests": "false", "upgrade-to-newer-dependencies": "false", "parallel-test-types-list-as-string": "Always Providers[amazon] " - "Providers[apache.hive,cncf.kubernetes,common.sql,exasol,ftp,http," + "Providers[apache.hive,cncf.kubernetes,common.compat,common.sql,exasol,ftp,http," "imap,microsoft.azure,mongo,mysql,openlineage,postgres,salesforce,ssh,teradata] Providers[google]", "needs-mypy": "true", "mypy-folders": "['providers']", diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 9b092cb52f097..48e11a0b8b676 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -28,6 +28,7 @@ "amazon": { "deps": [ "PyAthena>=3.0.10", + "apache-airflow-providers-common-compat>=1.1.0", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow-providers-http", "apache-airflow>=2.7.0", @@ -57,6 +58,7 @@ "cross-providers-deps": [ "apache.hive", "cncf.kubernetes", + "common.compat", "common.sql", "exasol", "ftp", diff --git a/prod_image_installed_providers.txt b/prod_image_installed_providers.txt index c292b7b83d9b6..7340928738c11 100644 --- a/prod_image_installed_providers.txt +++ b/prod_image_installed_providers.txt @@ -2,6 +2,7 @@ amazon celery cncf.kubernetes +common.compat common.io common.sql docker diff --git a/tests/conftest.py b/tests/conftest.py index 9027391575e4c..6cb74446dce8b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1326,6 +1326,16 @@ def airflow_root_path() -> Path: return Path(airflow.__path__[0]).parent +@pytest.fixture +def hook_lineage_collector(): + from airflow.lineage import hook + + hook._hook_lineage_collector = None + hook._hook_lineage_collector = hook.HookLineageCollector() + yield hook.get_hook_lineage_collector() + hook._hook_lineage_collector = None + + # This constant is set to True if tests are run with Airflow installed from Packages rather than running # the tests within Airflow sources. While most tests in CI are run using Airflow sources, there are # also compatibility tests that only use `tests` package and run against installed packages of Airflow in diff --git a/tests/providers/amazon/aws/datasets/__init__.py b/tests/providers/amazon/aws/datasets/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/providers/amazon/aws/datasets/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/amazon/aws/datasets/test_s3.py b/tests/providers/amazon/aws/datasets/test_s3.py new file mode 100644 index 0000000000000..c7ffe252401e7 --- /dev/null +++ b/tests/providers/amazon/aws/datasets/test_s3.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.datasets import Dataset +from airflow.providers.amazon.aws.datasets.s3 import create_dataset + + +def test_create_dataset(): + assert create_dataset(bucket="test-bucket", key="test-path") == Dataset(uri="s3://test-bucket/test-path") + assert create_dataset(bucket="test-bucket", key="test-dir/test-path") == Dataset( + uri="s3://test-bucket/test-dir/test-path" + ) diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py index 6b10173d3c6ed..acedf3d011844 100644 --- a/tests/providers/amazon/aws/hooks/test_s3.py +++ b/tests/providers/amazon/aws/hooks/test_s3.py @@ -31,6 +31,7 @@ from botocore.exceptions import ClientError from moto import mock_aws +from airflow.datasets import Dataset from airflow.exceptions import AirflowException from airflow.models import Connection from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure @@ -41,6 +42,7 @@ unify_bucket_name_and_key, ) from airflow.utils.timezone import datetime +from tests.test_utils.compat import AIRFLOW_V_2_10_PLUS @pytest.fixture @@ -388,6 +390,15 @@ def test_load_string(self, s3_bucket): resource = boto3.resource("s3").Object(s3_bucket, "my_key") assert resource.get()["Body"].read() == b"Cont\xc3\xa9nt" + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + def test_load_string_exposes_lineage(self, s3_bucket, hook_lineage_collector): + hook = S3Hook() + hook.load_string("Contént", "my_key", s3_bucket) + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + def test_load_string_compress(self, s3_bucket): hook = S3Hook() hook.load_string("Contént", "my_key", s3_bucket, compression="gzip") @@ -970,6 +981,17 @@ def test_load_file_gzip(self, s3_bucket, tmp_path): resource = boto3.resource("s3").Object(s3_bucket, "my_key") assert gz.decompress(resource.get()["Body"].read()) == b"Content" + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + def test_load_file_exposes_lineage(self, s3_bucket, tmp_path, hook_lineage_collector): + hook = S3Hook() + path = tmp_path / "testfile" + path.write_text("Content") + hook.load_file(path, "my_key", s3_bucket) + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + def test_load_file_acl(self, s3_bucket, tmp_path): hook = S3Hook() path = tmp_path / "testfile" @@ -1027,6 +1049,26 @@ def test_copy_object_no_acl( ACL="private", ) + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock_aws + def test_copy_object_ol_instrumentation(self, s3_bucket, hook_lineage_collector): + mock_hook = S3Hook() + + with mock.patch.object( + S3Hook, + "get_conn", + ): + mock_hook.copy_object("my_key", "my_key3", s3_bucket, s3_bucket) + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key" + ) + + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"s3://{s3_bucket}/my_key3" + ) + @mock_aws def test_delete_bucket_if_bucket_exist(self, s3_bucket): # assert if the bucket is created @@ -1140,6 +1182,26 @@ def test_download_file(self, mock_temp_file, tmp_path): assert path.name == output_file + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile") + def test_download_file_exposes_lineage(self, mock_temp_file, tmp_path, hook_lineage_collector): + path = tmp_path / "airflow_tmp_test_s3_hook" + mock_temp_file.return_value = path + s3_hook = S3Hook(aws_conn_id="s3_test") + s3_hook.check_for_key = Mock(return_value=True) + s3_obj = Mock() + s3_obj.download_fileobj = Mock(return_value=None) + s3_hook.get_key = Mock(return_value=s3_obj) + key = "test_key" + bucket = "test_bucket" + + s3_hook.download_file(key=key, bucket_name=bucket) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri="s3://test_bucket/test_key" + ) + @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") def test_download_file_with_preserve_name(self, mock_open, tmp_path): path = tmp_path / "test.log" @@ -1152,16 +1214,51 @@ def test_download_file_with_preserve_name(self, mock_open, tmp_path): s3_obj.key = f"s3://{bucket}/{key}" s3_obj.download_fileobj = Mock(return_value=None) s3_hook.get_key = Mock(return_value=s3_obj) + local_path = os.fspath(path.parent) s3_hook.download_file( key=key, bucket_name=bucket, - local_path=os.fspath(path.parent), + local_path=local_path, preserve_file_name=True, use_autogenerated_subdir=False, ) mock_open.assert_called_once_with(path, "wb") + @pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Hook lineage works in Airflow >= 2.10.0") + @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") + def test_download_file_with_preserve_name_exposes_lineage( + self, mock_open, tmp_path, hook_lineage_collector + ): + path = tmp_path / "test.log" + bucket = "test_bucket" + key = f"test_key/{path.name}" + + s3_hook = S3Hook(aws_conn_id="s3_test") + s3_hook.check_for_key = Mock(return_value=True) + s3_obj = Mock() + s3_obj.key = f"s3://{bucket}/{key}" + s3_obj.download_fileobj = Mock(return_value=None) + s3_hook.get_key = Mock(return_value=s3_obj) + local_path = os.fspath(path.parent) + s3_hook.download_file( + key=key, + bucket_name=bucket, + local_path=local_path, + preserve_file_name=True, + use_autogenerated_subdir=False, + ) + + assert len(hook_lineage_collector.collected_datasets.inputs) == 1 + assert hook_lineage_collector.collected_datasets.inputs[0][0] == Dataset( + uri="s3://test_bucket/test_key/test.log" + ) + + assert len(hook_lineage_collector.collected_datasets.outputs) == 1 + assert hook_lineage_collector.collected_datasets.outputs[0][0] == Dataset( + uri=f"file://{local_path}/test.log", + ) + @mock.patch("airflow.providers.amazon.aws.hooks.s3.open") def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open, tmp_path): path = tmp_path / "test.log"