Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.compat.standard.utils import prepare_virtualenv
from airflow.providers.google.go_module_utils import init_module, install_dependencies

if TYPE_CHECKING:
import logging
Expand Down Expand Up @@ -377,6 +376,16 @@ def start_go_pipeline(
"'https://airflow.apache.org/docs/docker-stack/recipes.html'."
)

try:
from airflow.providers.google.go_module_utils import init_module, install_dependencies
except ImportError:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(
"Failed to import apache-airflow-google-provider. To start a go pipeline, please install the"
" google provider."
)

if "labels" in variables:
variables["labels"] = json.dumps(variables["labels"], separators=(",", ":"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,41 @@
from functools import partial
from typing import TYPE_CHECKING, Any, Callable

from packaging.version import parse as parse_version

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowOptionalProviderFeatureException
from airflow.models import BaseOperator
from airflow.providers.apache.beam.hooks.beam import BeamHook, BeamRunnerType
from airflow.providers.apache.beam.triggers.beam import BeamJavaPipelineTrigger, BeamPythonPipelineTrigger
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
DataflowJobStatus,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration
from airflow.providers.google.cloud.triggers.dataflow import (
DataflowJobStatusTrigger,
)
from airflow.providers_manager import ProvidersManager
from airflow.utils.helpers import convert_camel_to_snake, exactly_one
from airflow.version import version

if TYPE_CHECKING:
from airflow.utils.context import Context


try:
from airflow.providers.google.cloud.hooks.dataflow import (
DEFAULT_DATAFLOW_LOCATION,
DataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook, _parse_gcs_url
from airflow.providers.google.cloud.links.dataflow import DataflowJobLink
from airflow.providers.google.cloud.operators.dataflow import CheckJobRunning, DataflowConfiguration
from airflow.providers.google.cloud.triggers.dataflow import (
DataflowJobStateCompleteTrigger,
DataflowJobStatus,
DataflowJobStatusTrigger,
)

GOOGLE_PROVIDER_VERSION = ProvidersManager().providers["apache-airflow-providers-google"].version
except ImportError:
GOOGLE_PROVIDER_VERSION = ""


class BeamDataflowMixin(metaclass=ABCMeta):
"""
Helper class to store common, Dataflow specific logic for both.
Expand All @@ -68,6 +79,13 @@ class BeamDataflowMixin(metaclass=ABCMeta):
gcp_conn_id: str
dataflow_support_impersonation: bool = True

def __init__(self):
if not GOOGLE_PROVIDER_VERSION:
raise AirflowOptionalProviderFeatureException(
"Failed to import apache-airflow-google-provider. To use the dataflow service please install "
"the appropriate version of the google provider."
)

def _set_dataflow(
self,
pipeline_options: dict,
Expand Down Expand Up @@ -319,7 +337,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
"dataflow_config",
)
template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
operator_extra_links = (DataflowJobLink(),)
operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else ()

def __init__(
self,
Expand Down Expand Up @@ -423,22 +441,37 @@ def execute_on_dataflow(self, context: Context):
process_line_callback=self.process_line_callback,
is_dataflow_job_id_exist_callback=self.is_dataflow_job_id_exist_callback,
)

location = self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
location,
self.dataflow_job_id,
)

if self.deferrable:
self.defer(
trigger=DataflowJobStatusTrigger(
job_id=self.dataflow_job_id,
trigger_args = {
"job_id": self.dataflow_job_id,
"project_id": self.dataflow_config.project_id,
"location": location,
"gcp_conn_id": self.gcp_conn_id,
}
trigger: DataflowJobStatusTrigger | DataflowJobStateCompleteTrigger
if parse_version(GOOGLE_PROVIDER_VERSION) < parse_version("16.0.0"):
trigger = DataflowJobStatusTrigger(
expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location or DEFAULT_DATAFLOW_LOCATION,
gcp_conn_id=self.gcp_conn_id,
),
**trigger_args,
)
else:
trigger = DataflowJobStateCompleteTrigger(
wait_until_finished=self.dataflow_config.wait_until_finished,
**trigger_args,
)

self.defer(
trigger=trigger,
method_name="execute_complete",
)
self.dataflow_hook.wait_for_done(
Expand Down Expand Up @@ -498,7 +531,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
ui_color = "#0273d4"

operator_extra_links = (DataflowJobLink(),)
operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else ()

def __init__(
self,
Expand Down Expand Up @@ -601,16 +634,29 @@ def execute_on_dataflow(self, context: Context):
self.dataflow_job_id,
)
if self.deferrable:
self.defer(
trigger=DataflowJobStatusTrigger(
job_id=self.dataflow_job_id,
trigger_args = {
"job_id": self.dataflow_job_id,
"project_id": self.dataflow_config.project_id,
"location": self.dataflow_config.location,
"gcp_conn_id": self.gcp_conn_id,
}
trigger: DataflowJobStatusTrigger | DataflowJobStateCompleteTrigger
if parse_version(GOOGLE_PROVIDER_VERSION) < parse_version("16.0.0"):
trigger = DataflowJobStatusTrigger(
expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
gcp_conn_id=self.gcp_conn_id,
),
**trigger_args,
)
else:
trigger = DataflowJobStateCompleteTrigger(
wait_until_finished=self.dataflow_config.wait_until_finished,
**trigger_args,
)

self.defer(
trigger=trigger,
method_name="execute_complete",
)

multiple_jobs = self.dataflow_config.multiple_jobs or False
self.dataflow_hook.wait_for_done(
job_name=self.dataflow_job_name,
Expand Down Expand Up @@ -676,7 +722,7 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
"dataflow_config",
]
template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
operator_extra_links = (DataflowJobLink(),)
operator_extra_links = (DataflowJobLink(),) if GOOGLE_PROVIDER_VERSION else ()

def __init__(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import IO, Any

from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand All @@ -33,6 +32,37 @@ class BeamPipelineBaseTrigger(BaseTrigger):
def _get_async_hook(*args, **kwargs) -> BeamAsyncHook:
return BeamAsyncHook(*args, **kwargs)

@staticmethod
def file_has_gcs_path(file_path: str):
return file_path.lower().startswith("gs://")

@staticmethod
async def provide_gcs_tempfile(gcs_file, gcp_conn_id):
try:
from airflow.providers.google.cloud.hooks.gcs import GCSHook
except ImportError:
from airflow.exceptions import AirflowOptionalProviderFeatureException

raise AirflowOptionalProviderFeatureException(
"Failed to import GCSHook. To use the GCSHook functionality, please install the "
"apache-airflow-google-provider."
)

gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)
loop = asyncio.get_running_loop()

# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=gcs_file)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
return tmp_gcs_file


class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
"""
Expand Down Expand Up @@ -101,20 +131,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
hook = self._get_async_hook(runner=self.runner)

try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
if self.py_file.lower().startswith("gs://"):
gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.py_file)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
if self.file_has_gcs_path(self.py_file):
tmp_gcs_file = await self.provide_gcs_tempfile(self.py_file, self.gcp_conn_id)
self.py_file = tmp_gcs_file.name

return_code = await hook.start_python_pipeline_async(
Expand Down Expand Up @@ -188,20 +206,8 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
hook = self._get_async_hook(runner=self.runner)
return_code = 0
try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
if self.jar.lower().startswith("gs://"):
gcs_hook = GCSHook(self.gcp_conn_id)
# Running synchronous `enter_context()` method in a separate
# thread using the default executor `None`. The `run_in_executor()` function returns the
# file object, which is created using gcs function `provide_file()`, asynchronously.
# This means we can perform asynchronous operations with this file.
create_tmp_file_call = gcs_hook.provide_file(object_url=self.jar)
tmp_gcs_file: IO[str] = await loop.run_in_executor(
None,
contextlib.ExitStack().enter_context, # type: ignore[arg-type]
create_tmp_file_call,
)
if self.file_has_gcs_path(self.jar):
tmp_gcs_file = await self.provide_gcs_tempfile(self.jar, self.gcp_conn_id)
self.jar = tmp_gcs_file.name

return_code = await hook.start_java_pipeline_async(
Expand Down
27 changes: 15 additions & 12 deletions providers/apache/beam/tests/unit/apache/beam/triggers/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,17 @@ async def test_beam_trigger_exception_should_execute_successfully(
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, python_trigger):
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, python_trigger):
"""
Test that BeamPythonPipelineTrigger downloads GCS provide file correct.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
python_trigger.py_file = TEST_GCS_PY_FILE
generator = python_trigger.run()
await generator.asend(None)
gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE)
with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook:
mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file"
generator = python_trigger.run()
await generator.asend(None)
mock_gcs_hook.assert_called_once_with(gcp_conn_id=python_trigger.gcp_conn_id)
mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_PY_FILE)


class TestBeamJavaPipelineTrigger:
Expand Down Expand Up @@ -210,13 +211,15 @@ async def test_beam_trigger_exception_should_execute_successfully(
assert TriggerEvent({"status": "error", "message": "Test exception"}) == actual

@pytest.mark.asyncio
@mock.patch("airflow.providers.apache.beam.triggers.beam.GCSHook")
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, gcs_hook, java_trigger):
async def test_beam_trigger_gcs_provide_file_should_execute_successfully(self, java_trigger):
"""
Test that BeamJavaPipelineTrigger downloads GCS provide file correct.
"""
gcs_provide_file = gcs_hook.return_value.provide_file
java_trigger.jar = TEST_GCS_JAR_FILE
generator = java_trigger.run()
await generator.asend(None)
gcs_provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)

with mock.patch("airflow.providers.google.cloud.hooks.gcs.GCSHook") as mock_gcs_hook:
mock_gcs_hook.return_value.provide_file.return_value = "mocked_temp_file"
generator = java_trigger.run()
await generator.asend(None)
mock_gcs_hook.assert_called_once_with(gcp_conn_id=java_trigger.gcp_conn_id)
mock_gcs_hook.return_value.provide_file.assert_called_once_with(object_url=TEST_GCS_JAR_FILE)
18 changes: 18 additions & 0 deletions providers/google/docs/operators/cloud/dataflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,15 @@ Here is an example of creating and running a streaming pipeline in Java with jar
:start-after: [START howto_operator_start_java_streaming]
:end-before: [END howto_operator_start_java_streaming]

Here is an Java dataflow streaming pipeline example in deferrable_mode :

.. exampleinclude:: /../../google/tests/system/google/cloud/dataflow/example_dataflow_java_streaming.py
:language: python
:dedent: 4
:start-after: [START howto_operator_start_java_streaming_deferrable]
:end-before: [END howto_operator_start_java_streaming_deferrable]


.. _howto/operator:PythonSDKPipelines:

Python SDK pipelines
Expand Down Expand Up @@ -232,6 +241,15 @@ source, such as Pub/Sub, in your pipeline (for Java).
:start-after: [START howto_operator_start_streaming_python_job]
:end-before: [END howto_operator_start_streaming_python_job]

Deferrable mode:

.. exampleinclude:: /../../google/tests/system/google/cloud/dataflow/example_dataflow_streaming_python.py
:language: python
:dedent: 4
:start-after: [START howto_operator_start_streaming_python_job_deferrable]
:end-before: [END howto_operator_start_streaming_python_job_deferrable]


Setting argument ``drain_pipeline`` to ``True`` allows to stop streaming job by draining it
instead of canceling during killing task instance.

Expand Down
Loading