diff --git a/providers/apache/spark/docs/decorators/pyspark.rst b/providers/apache/spark/docs/decorators/pyspark.rst index 9306f42fce37b..c83b5980cbd6e 100644 --- a/providers/apache/spark/docs/decorators/pyspark.rst +++ b/providers/apache/spark/docs/decorators/pyspark.rst @@ -42,7 +42,7 @@ Example ------- The following example shows how to use the ``@task.pyspark`` decorator. Note -that the ``spark`` and ``sc`` objects are injected into the function. +that the ``spark`` object is injected into the function. .. exampleinclude:: /../tests/system/apache/spark/example_pyspark.py :language: python diff --git a/providers/apache/spark/docs/operators.rst b/providers/apache/spark/docs/operators.rst index d60967e8f130e..2e1aad8fb3842 100644 --- a/providers/apache/spark/docs/operators.rst +++ b/providers/apache/spark/docs/operators.rst @@ -29,6 +29,8 @@ Prerequisite and :doc:`JDBC connection `. * :class:`~airflow.providers.apache.spark.operators.spark_sql.SparkSqlOperator` gets all the configurations from operator parameters. +* To use :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator` + you can configure :doc:`SparkConnect Connection `. .. _howto/operator:SparkJDBCOperator: @@ -56,6 +58,29 @@ Reference For further information, look at `Apache Spark DataFrameWriter documentation `_. +.. _howto/operator:PySparkOperator: + +PySparkOperator +---------------- + +Launches applications on a Apache Spark Connect server or directly in a standalone mode + +For parameter definition take a look at :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`. + +Using the operator +"""""""""""""""""" + +.. exampleinclude:: /../tests/system/apache/spark/example_spark_dag.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_spark_pyspark] + :end-before: [END howto_operator_spark_pyspark] + +Reference +""""""""" + +For further information, look at `Running the Spark Connect Python `_. + .. _howto/operator:SparkSqlOperator: SparkSqlOperator diff --git a/providers/apache/spark/provider.yaml b/providers/apache/spark/provider.yaml index c1eed1b11bd3d..4d4de5c3473a8 100644 --- a/providers/apache/spark/provider.yaml +++ b/providers/apache/spark/provider.yaml @@ -94,6 +94,7 @@ operators: - airflow.providers.apache.spark.operators.spark_jdbc - airflow.providers.apache.spark.operators.spark_sql - airflow.providers.apache.spark.operators.spark_submit + - airflow.providers.apache.spark.operators.spark_pyspark hooks: - integration-name: Apache Spark diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py b/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py index e5080c151b549..0aac4fc94f699 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/decorators/pyspark.py @@ -19,38 +19,33 @@ import inspect from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any -from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook +from airflow.providers.apache.spark.operators.spark_pyspark import SPARK_CONTEXT_KEYS, PySparkOperator from airflow.providers.common.compat.sdk import ( - BaseHook, DecoratedOperator, TaskDecorator, task_decorator_factory, ) -from airflow.providers.common.compat.standard.operators import PythonOperator -if TYPE_CHECKING: - from airflow.providers.common.compat.sdk import Context -SPARK_CONTEXT_KEYS = ["spark", "sc"] - -class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator): +class _PySparkDecoratedOperator(DecoratedOperator, PySparkOperator): custom_operator_name = "@task.pyspark" - template_fields: Sequence[str] = ("op_args", "op_kwargs") - def __init__( self, + *, python_callable: Callable, - op_args: Sequence | None = None, - op_kwargs: dict | None = None, conn_id: str | None = None, config_kwargs: dict | None = None, + op_args: Sequence | None = None, + op_kwargs: dict | None = None, **kwargs, - ): - self.conn_id = conn_id - self.config_kwargs = config_kwargs or {} + ) -> None: + kwargs_to_upstream = { + "python_callable": python_callable, + "op_args": op_args, + "op_kwargs": op_kwargs, + } signature = inspect.signature(python_callable) parameters = [ @@ -61,65 +56,16 @@ def __init__( # see https://github.com/python/mypy/issues/12472 python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined] - kwargs_to_upstream = { - "python_callable": python_callable, - "op_args": op_args, - "op_kwargs": op_kwargs, - } super().__init__( kwargs_to_upstream=kwargs_to_upstream, python_callable=python_callable, + config_kwargs=config_kwargs, + conn_id=conn_id, op_args=op_args, op_kwargs=op_kwargs, **kwargs, ) - def execute(self, context: Context): - from pyspark import SparkConf - from pyspark.sql import SparkSession - - conf = SparkConf() - conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}") - - url = "local[*]" - if self.conn_id: - # we handle both spark connect and spark standalone - conn = BaseHook.get_connection(self.conn_id) - if conn.conn_type == SparkConnectHook.conn_type: - url = SparkConnectHook(self.conn_id).get_connection_url() - elif conn.port: - url = f"{conn.host}:{conn.port}" - elif conn.host: - url = conn.host - - for key, value in conn.extra_dejson.items(): - conf.set(key, value) - - # you cannot have both remote and master - if url.startswith("sc://"): - conf.set("spark.remote", url) - - # task can override connection config - for key, value in self.config_kwargs.items(): - conf.set(key, value) - - if not conf.get("spark.remote") and not conf.get("spark.master"): - conf.set("spark.master", url) - - spark = SparkSession.builder.config(conf=conf).getOrCreate() - - if not self.op_kwargs: - self.op_kwargs = {} - - op_kwargs: dict[str, Any] = dict(self.op_kwargs) - op_kwargs["spark"] = spark - - # spark context is not available when using spark connect - op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote") else None - - self.op_kwargs = op_kwargs - return super().execute(context) - def pyspark_task( python_callable: Callable | None = None, diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py index 55fe6d725910d..4d96c86f923a6 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py @@ -42,6 +42,7 @@ def get_provider_info(): "airflow.providers.apache.spark.operators.spark_jdbc", "airflow.providers.apache.spark.operators.spark_sql", "airflow.providers.apache.spark.operators.spark_submit", + "airflow.providers.apache.spark.operators.spark_pyspark", ], } ], diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pyspark.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pyspark.py new file mode 100644 index 0000000000000..c295cf5b54abb --- /dev/null +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pyspark.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import inspect +from collections.abc import Callable, Sequence + +from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook +from airflow.providers.common.compat.sdk import BaseHook +from airflow.providers.common.compat.standard.operators import PythonOperator + +SPARK_CONTEXT_KEYS = ["spark", "sc"] + + +class PySparkOperator(PythonOperator): + """Submit the run of a pyspark job to an external spark-connect service or directly run the pyspark job in a standalone mode.""" + + template_fields: Sequence[str] = ("conn_id", "config_kwargs", *PythonOperator.template_fields) + + def __init__( + self, + python_callable: Callable, + conn_id: str | None = None, + config_kwargs: dict | None = None, + **kwargs, + ): + self.conn_id = conn_id + self.config_kwargs = config_kwargs or {} + + signature = inspect.signature(python_callable) + parameters = [ + param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS else param + for param in signature.parameters.values() + ] + # mypy does not understand __signature__ attribute + # see https://github.com/python/mypy/issues/12472 + python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined] + + super().__init__( + python_callable=python_callable, + **kwargs, + ) + + def execute_callable(self): + from pyspark import SparkConf + from pyspark.sql import SparkSession + + conf = SparkConf() + conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}") + + url = "local[*]" + if self.conn_id: + # we handle both spark connect and spark standalone + conn = BaseHook.get_connection(self.conn_id) + if conn.conn_type == SparkConnectHook.conn_type: + url = SparkConnectHook(self.conn_id).get_connection_url() + elif conn.port: + url = f"{conn.host}:{conn.port}" + elif conn.host: + url = conn.host + + for key, value in conn.extra_dejson.items(): + conf.set(key, value) + + # you cannot have both remote and master + if url.startswith("sc://"): + conf.set("spark.remote", url) + + # task can override connection config + for key, value in self.config_kwargs.items(): + conf.set(key, value) + + if not conf.get("spark.remote") and not conf.get("spark.master"): + conf.set("spark.master", url) + + spark_session = SparkSession.builder.config(conf=conf).getOrCreate() + + try: + self.op_kwargs = {**self.op_kwargs, "spark": spark_session} + return super().execute_callable() + finally: + spark_session.stop() diff --git a/providers/apache/spark/tests/system/apache/spark/example_spark_dag.py b/providers/apache/spark/tests/system/apache/spark/example_spark_dag.py index 6680450ae987a..8e8b0f959aa71 100644 --- a/providers/apache/spark/tests/system/apache/spark/example_spark_dag.py +++ b/providers/apache/spark/tests/system/apache/spark/example_spark_dag.py @@ -27,6 +27,7 @@ from airflow.models import DAG from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator +from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator @@ -75,6 +76,16 @@ ) # [END howto_operator_spark_sql] + # [START howto_operator_spark_pyspark] + def my_pyspark_job(spark): + df = spark.range(100).filter("id % 2 = 0") + print(df.count()) + + spark_pyspark_job = PySparkOperator( + python_callable=my_pyspark_job, conn_id="spark_connect", task_id="spark_pyspark_job" + ) + # [END howto_operator_spark_pyspark] + from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) diff --git a/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py b/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py index b056c14b824c0..c662fde33f886 100644 --- a/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py +++ b/providers/apache/spark/tests/unit/apache/spark/decorators/test_pyspark.py @@ -100,11 +100,10 @@ def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_make conf_mock.return_value = config @task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"}) - def f(spark, sc): + def f(spark): import random assert spark is not None - assert sc is not None return [random.random() for _ in range(100)] with dag_maker(): @@ -129,7 +128,7 @@ def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker): e = 2 @task.pyspark - def f(): + def f(spark): return e with dag_maker(): @@ -148,9 +147,8 @@ def test_spark_connect(self, spark_mock, conf_mock, dag_maker): conf_mock.return_value = config @task.pyspark(conn_id="spark-connect") - def f(spark, sc): + def f(spark): assert spark is not None - assert sc is None return True @@ -172,9 +170,8 @@ def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker): conf_mock.return_value = config @task.pyspark(conn_id="spark-connect-auth") - def f(spark, sc): + def f(spark): assert spark is not None - assert sc is None return True diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pyspark.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pyspark.py new file mode 100644 index 0000000000000..ef7b961754424 --- /dev/null +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pyspark.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.models.dag import DAG +from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator +from airflow.utils import timezone + +DEFAULT_DATE = timezone.datetime(2024, 2, 1, tzinfo=timezone.utc) + + +class TestSparkPySparkOperator: + _config = { + "conn_id": "spark_special_conn_id", + } + + def setup_method(self): + args = {"owner": "airflow", "start_date": DEFAULT_DATE} + self.dag = DAG("test_dag_id", schedule=None, default_args=args) + + def test_execute(self): + def my_spark_fn(spark): + pass + + operator = PySparkOperator( + task_id="spark_pyspark_job", python_callable=my_spark_fn, dag=self.dag, **self._config + ) + + assert self._config["conn_id"] == operator.conn_id