Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion providers/apache/spark/docs/decorators/pyspark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Example
-------

The following example shows how to use the ``@task.pyspark`` decorator. Note
that the ``spark`` and ``sc`` objects are injected into the function.
that the ``spark`` object is injected into the function.

.. exampleinclude:: /../tests/system/apache/spark/example_pyspark.py
:language: python
Expand Down
25 changes: 25 additions & 0 deletions providers/apache/spark/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ Prerequisite
and :doc:`JDBC connection <apache-airflow-providers-jdbc:connections/jdbc>`.
* :class:`~airflow.providers.apache.spark.operators.spark_sql.SparkSqlOperator`
gets all the configurations from operator parameters.
* To use :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`
you can configure :doc:`SparkConnect Connection <connections/spark-connect>`.

.. _howto/operator:SparkJDBCOperator:

Expand Down Expand Up @@ -56,6 +58,29 @@ Reference

For further information, look at `Apache Spark DataFrameWriter documentation <https://spark.apache.org/docs/2.4.5/api/scala/index.html#org.apache.spark.sql.DataFrameWriter>`_.

.. _howto/operator:PySparkOperator:

PySparkOperator
----------------

Launches applications on a Apache Spark Connect server or directly in a standalone mode

For parameter definition take a look at :class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`.

Using the operator
""""""""""""""""""

.. exampleinclude:: /../tests/system/apache/spark/example_spark_dag.py
:language: python
:dedent: 4
:start-after: [START howto_operator_spark_pyspark]
:end-before: [END howto_operator_spark_pyspark]

Reference
"""""""""

For further information, look at `Running the Spark Connect Python <https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html>`_.

.. _howto/operator:SparkSqlOperator:

SparkSqlOperator
Expand Down
1 change: 1 addition & 0 deletions providers/apache/spark/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ operators:
- airflow.providers.apache.spark.operators.spark_jdbc
- airflow.providers.apache.spark.operators.spark_sql
- airflow.providers.apache.spark.operators.spark_submit
- airflow.providers.apache.spark.operators.spark_pyspark

hooks:
- integration-name: Apache Spark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,33 @@

import inspect
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any

from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
from airflow.providers.apache.spark.operators.spark_pyspark import SPARK_CONTEXT_KEYS, PySparkOperator
from airflow.providers.common.compat.sdk import (
BaseHook,
DecoratedOperator,
TaskDecorator,
task_decorator_factory,
)
from airflow.providers.common.compat.standard.operators import PythonOperator

if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
SPARK_CONTEXT_KEYS = ["spark", "sc"]


class _PySparkDecoratedOperator(DecoratedOperator, PythonOperator):
class _PySparkDecoratedOperator(DecoratedOperator, PySparkOperator):
custom_operator_name = "@task.pyspark"

template_fields: Sequence[str] = ("op_args", "op_kwargs")

def __init__(
self,
*,
python_callable: Callable,
op_args: Sequence | None = None,
op_kwargs: dict | None = None,
conn_id: str | None = None,
config_kwargs: dict | None = None,
op_args: Sequence | None = None,
op_kwargs: dict | None = None,
**kwargs,
):
self.conn_id = conn_id
self.config_kwargs = config_kwargs or {}
) -> None:
kwargs_to_upstream = {
"python_callable": python_callable,
"op_args": op_args,
"op_kwargs": op_kwargs,
}

signature = inspect.signature(python_callable)
parameters = [
Expand All @@ -61,65 +56,16 @@ def __init__(
# see https://github.com/python/mypy/issues/12472
python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined]

kwargs_to_upstream = {
"python_callable": python_callable,
"op_args": op_args,
"op_kwargs": op_kwargs,
}
super().__init__(
kwargs_to_upstream=kwargs_to_upstream,
python_callable=python_callable,
config_kwargs=config_kwargs,
conn_id=conn_id,
op_args=op_args,
op_kwargs=op_kwargs,
**kwargs,
)

def execute(self, context: Context):
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf()
conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")

url = "local[*]"
if self.conn_id:
# we handle both spark connect and spark standalone
conn = BaseHook.get_connection(self.conn_id)
if conn.conn_type == SparkConnectHook.conn_type:
url = SparkConnectHook(self.conn_id).get_connection_url()
elif conn.port:
url = f"{conn.host}:{conn.port}"
elif conn.host:
url = conn.host

for key, value in conn.extra_dejson.items():
conf.set(key, value)

# you cannot have both remote and master
if url.startswith("sc://"):
conf.set("spark.remote", url)

# task can override connection config
for key, value in self.config_kwargs.items():
conf.set(key, value)

if not conf.get("spark.remote") and not conf.get("spark.master"):
conf.set("spark.master", url)

spark = SparkSession.builder.config(conf=conf).getOrCreate()

if not self.op_kwargs:
self.op_kwargs = {}

op_kwargs: dict[str, Any] = dict(self.op_kwargs)
op_kwargs["spark"] = spark

# spark context is not available when using spark connect
op_kwargs["sc"] = spark.sparkContext if not conf.get("spark.remote") else None

self.op_kwargs = op_kwargs
return super().execute(context)


def pyspark_task(
python_callable: Callable | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def get_provider_info():
"airflow.providers.apache.spark.operators.spark_jdbc",
"airflow.providers.apache.spark.operators.spark_sql",
"airflow.providers.apache.spark.operators.spark_submit",
"airflow.providers.apache.spark.operators.spark_pyspark",
],
}
],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import inspect
from collections.abc import Callable, Sequence

from airflow.providers.apache.spark.hooks.spark_connect import SparkConnectHook
from airflow.providers.common.compat.sdk import BaseHook
from airflow.providers.common.compat.standard.operators import PythonOperator

SPARK_CONTEXT_KEYS = ["spark", "sc"]


class PySparkOperator(PythonOperator):
"""Submit the run of a pyspark job to an external spark-connect service or directly run the pyspark job in a standalone mode."""

template_fields: Sequence[str] = ("conn_id", "config_kwargs", *PythonOperator.template_fields)

def __init__(
self,
python_callable: Callable,
conn_id: str | None = None,
config_kwargs: dict | None = None,
**kwargs,
):
self.conn_id = conn_id
self.config_kwargs = config_kwargs or {}

signature = inspect.signature(python_callable)
parameters = [
param.replace(default=None) if param.name in SPARK_CONTEXT_KEYS else param
for param in signature.parameters.values()
]
# mypy does not understand __signature__ attribute
# see https://github.com/python/mypy/issues/12472
python_callable.__signature__ = signature.replace(parameters=parameters) # type: ignore[attr-defined]

super().__init__(
python_callable=python_callable,
**kwargs,
)

def execute_callable(self):
from pyspark import SparkConf
from pyspark.sql import SparkSession

conf = SparkConf()
conf.set("spark.app.name", f"{self.dag_id}-{self.task_id}")

url = "local[*]"
if self.conn_id:
# we handle both spark connect and spark standalone
conn = BaseHook.get_connection(self.conn_id)
if conn.conn_type == SparkConnectHook.conn_type:
url = SparkConnectHook(self.conn_id).get_connection_url()
elif conn.port:
url = f"{conn.host}:{conn.port}"
elif conn.host:
url = conn.host

for key, value in conn.extra_dejson.items():
conf.set(key, value)

# you cannot have both remote and master
if url.startswith("sc://"):
conf.set("spark.remote", url)

# task can override connection config
for key, value in self.config_kwargs.items():
conf.set(key, value)

if not conf.get("spark.remote") and not conf.get("spark.master"):
conf.set("spark.master", url)

spark_session = SparkSession.builder.config(conf=conf).getOrCreate()

try:
self.op_kwargs = {**self.op_kwargs, "spark": spark_session}
return super().execute_callable()
finally:
spark_session.stop()
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from airflow.models import DAG
from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator
from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator
from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator

Expand Down Expand Up @@ -75,6 +76,16 @@
)
# [END howto_operator_spark_sql]

# [START howto_operator_spark_pyspark]
def my_pyspark_job(spark):
df = spark.range(100).filter("id % 2 = 0")
print(df.count())

spark_pyspark_job = PySparkOperator(
python_callable=my_pyspark_job, conn_id="spark_connect", task_id="spark_pyspark_job"
)
# [END howto_operator_spark_pyspark]

from tests_common.test_utils.system_tests import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,10 @@ def test_pyspark_decorator_with_connection(self, spark_mock, conf_mock, dag_make
conf_mock.return_value = config

@task.pyspark(conn_id="pyspark_local", config_kwargs={"spark.executor.memory": "2g"})
def f(spark, sc):
def f(spark):
import random

assert spark is not None
assert sc is not None
return [random.random() for _ in range(100)]

with dag_maker():
Expand All @@ -129,7 +128,7 @@ def test_simple_pyspark_decorator(self, spark_mock, conf_mock, dag_maker):
e = 2

@task.pyspark
def f():
def f(spark):
return e

with dag_maker():
Expand All @@ -148,9 +147,8 @@ def test_spark_connect(self, spark_mock, conf_mock, dag_maker):
conf_mock.return_value = config

@task.pyspark(conn_id="spark-connect")
def f(spark, sc):
def f(spark):
assert spark is not None
assert sc is None

return True

Expand All @@ -172,9 +170,8 @@ def test_spark_connect_auth(self, spark_mock, conf_mock, dag_maker):
conf_mock.return_value = config

@task.pyspark(conn_id="spark-connect-auth")
def f(spark, sc):
def f(spark):
assert spark is not None
assert sc is None

return True

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.models.dag import DAG
from airflow.providers.apache.spark.operators.spark_pyspark import PySparkOperator
from airflow.utils import timezone

DEFAULT_DATE = timezone.datetime(2024, 2, 1, tzinfo=timezone.utc)


class TestSparkPySparkOperator:
_config = {
"conn_id": "spark_special_conn_id",
}

def setup_method(self):
args = {"owner": "airflow", "start_date": DEFAULT_DATE}
self.dag = DAG("test_dag_id", schedule=None, default_args=args)

def test_execute(self):
def my_spark_fn(spark):
pass

operator = PySparkOperator(
task_id="spark_pyspark_job", python_callable=my_spark_fn, dag=self.dag, **self._config
)

assert self._config["conn_id"] == operator.conn_id