Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,12 @@ def execute_on_dataflow(self, context: Context):
)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(context=context, region=location)
DataflowJobLink.persist(
context=context,
region=self.dataflow_config.location,
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
)

if self.deferrable:
trigger_args = {
Expand Down Expand Up @@ -648,7 +653,12 @@ def execute_on_dataflow(self, context: Context):
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)
if self.dataflow_job_name and self.dataflow_config.location:
DataflowJobLink.persist(context=context)
DataflowJobLink.persist(
context=context,
region=self.dataflow_config.location,
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
)
if self.deferrable:
trigger_args = {
"job_id": self.dataflow_job_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,25 @@ def file_has_gcs_path(file_path: str):
@staticmethod
async def provide_gcs_tempfile(gcs_file, gcp_conn_id):
try:
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.hooks.gcs import GCSAsyncHook
except ImportError:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(
"Failed to import GCSHook. To use the GCSHook functionality, please install the "
"Failed to import GCSAsyncHook. To use the GCSAsyncHook functionality, please install the "
"apache-airflow-google-provider."
)

gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)
async_gcs_hook = GCSAsyncHook(gcp_conn_id=gcp_conn_id)
sync_gcs_hook = await async_gcs_hook.get_sync_hook()

loop = asyncio.get_running_loop()

# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=gcs_file)
create_tmp_file_call = sync_gcs_hook.provide_file(object_url=gcs_file)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,12 @@ def test_exec_dataflow_runner(
}
gcs_provide_file.assert_any_call(object_url=PY_FILE)
gcs_provide_file.assert_any_call(object_url=REQURIEMENTS_FILE)
persist_link_mock.assert_called_once_with(context={}, region="us-central1")
persist_link_mock.assert_called_once_with(
context={},
region="us-central1",
job_id=None,
project_id=dataflow_hook_mock.return_value.project_id,
)
beam_hook_mock.return_value.start_python_pipeline.assert_called_once_with(
variables=expected_options,
py_file=gcs_provide_file.return_value.__enter__.return_value.name,
Expand Down Expand Up @@ -468,7 +473,12 @@ def test_exec_dataflow_runner(self, gcs_hook, dataflow_hook_mock, beam_hook_mock
"output": "gs://test/output",
"impersonateServiceAccount": TEST_IMPERSONATION_ACCOUNT,
}
persist_link_mock.assert_called_once_with(context={})
persist_link_mock.assert_called_once_with(
context={},
region="us-central1",
job_id=None,
project_id=dataflow_hook_mock.return_value.project_id,
)
beam_hook_mock.return_value.start_java_pipeline.assert_called_once_with(
variables=expected_options,
jar=gcs_provide_file.return_value.__enter__.return_value.name,
Expand Down
77 changes: 61 additions & 16 deletions providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import asyncio
from unittest import mock

import pytest
Expand Down Expand Up @@ -134,17 +135,41 @@ async def test_beam_trigger_exception_should_execute_successfully(
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, python_trigger):
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(
self, python_trigger, monkeypatch
):
"""
Test that BeamPythonPipelineTrigger downloads GCS provide file correct.
Test that BeamPythonPipelineTrigger downloads GCS provide file correctly with GCSAsyncHook.
"""
TEST_GCS_PY_FILE = "gs://bucket/path/file.py"
python_trigger.py_file = TEST_GCS_PY_FILE
with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook:
mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file"
generator = python_trigger.run()
await generator.asend(None)
mock_gcs_hook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id)
mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE)

with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSAsyncHook") as MockAsyncHook:
async_hook_instance = MockAsyncHook.return_value

class DummyCM:
def __enter__(self):
return "mocked_temp_file"

def __exit__(self, exc_type, exc, tb):
return False

sync_hook = mock.Mock(name="SyncGCSHook")
sync_hook.provide_file.return_value = DummyCM()

async_hook_instance.get_sync_hook = mock.AsyncMock(return_value=sync_hook)

fake_loop = mock.Mock()
fake_loop.run_in_executor = mock.AsyncMock(return_value="mocked_temp_file")
monkeypatch.setattr(asyncio, "get_running_loop", lambda: fake_loop)

gen = python_trigger.run()
await gen.asend(None)

MockAsyncHook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id)
async_hook_instance.get_sync_hook.assert_awaited_once()
sync_hook.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE)
fake_loop.run_in_executor.assert_awaited_once()


class TestBeamJavaPipelineTrigger:
Expand Down Expand Up @@ -211,15 +236,35 @@ async def test_beam_trigger_exception_should_execute_successfully(
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger):
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger, monkeypatch):
"""
Test that BeamJavaPipelineTrigger downloads GCS provide file correct.
Test that BeamJavaPipelineTrigger downloads GCS provide file correctly with GCSAsyncHook.
"""
java_trigger.jar = TEST_GCS_JAR_FILE

with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook:
mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file"
generator = java_trigger.run()
await generator.asend(None)
mock_gcs_hook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id)
mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)
with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSAsyncHook") as MockAsyncHook:
async_hook_instance = MockAsyncHook.return_value

class DummyCM:
def __enter__(self):
return "mocked_temp_file"

def __exit__(self, exc_type, exc, tb):
return False

sync_hook = mock.Mock(name="SyncGCSHook")
sync_hook.provide_file.return_value = DummyCM()

async_hook_instance.get_sync_hook = mock.AsyncMock(return_value=sync_hook)

fake_loop = mock.Mock()
fake_loop.run_in_executor = mock.AsyncMock(return_value="mocked_temp_file")
monkeypatch.setattr(asyncio, "get_running_loop", lambda: fake_loop)

gen = java_trigger.run()
await gen.asend(None)

MockAsyncHook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id)
async_hook_instance.get_sync_hook.assert_awaited_once()
sync_hook.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)
fake_loop.run_in_executor.assert_awaited_once()
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
# [START howto_operator_start_java_job_local_jar]
start_java_job_direct = BeamRunJavaPipelineOperator(
task_id="start_java_job_direct",
jar=LOCAL_JAR,
jar=GCS_JAR,
pipeline_options={
"output": GCS_OUTPUT,
},
Expand All @@ -102,7 +102,7 @@

start_java_job_direct_deferrable = BeamRunJavaPipelineOperator(
task_id="start_java_job_direct_deferrable",
jar=GCS_JAR,
jar=LOCAL_JAR,
pipeline_options={
"output": GCS_OUTPUT,
},
Expand Down
Loading