Skip to content

Commit

Permalink
feat(docker): Replace use_dill with serializer (apache#41356)
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday authored and Artuz37 committed Aug 19, 2024
1 parent 706e515 commit 54ea557
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 15 deletions.
14 changes: 12 additions & 2 deletions airflow/decorators/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,9 @@ class TaskDecoratorCollection:
self,
*,
multiple_outputs: bool | None = None,
use_dill: bool = False, # Added by _DockerDecoratedOperator.
python_command: str = "python3",
serializer: Literal["pickle", "cloudpickle", "dill"] | None = None,
use_dill: bool = False, # Added by _DockerDecoratedOperator.
# 'command', 'retrieve_output', and 'retrieve_output_path' are filled by
# _DockerDecoratedOperator.
image: str,
Expand Down Expand Up @@ -432,8 +433,17 @@ class TaskDecoratorCollection:
:param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
:param use_dill: Whether to use dill or pickle for serialization
:param python_command: Python command for executing functions, Default: python3
:param serializer: Which serializer use to serialize the args and result. It can be one of the following:
- ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
- ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
this requires to include cloudpickle in your requirements.
- ``"dill"``: Use dill for serialize more complex types,
this requires to include dill in your requirements.
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
:param image: Docker image from which to create the container.
If image tag is omitted, "latest" will be used.
:param api_version: Remote API version. Set to ``auto`` to automatically
Expand Down
88 changes: 79 additions & 9 deletions airflow/providers/docker/decorators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,60 @@

import base64
import os
import pickle
import warnings
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Callable, Sequence

import dill
from typing import TYPE_CHECKING, Any, Callable, Literal, Sequence

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.docker.operators.docker import DockerOperator
from airflow.utils.python_virtualenv import write_python_script

if TYPE_CHECKING:
from airflow.decorators.base import TaskDecorator
from airflow.utils.context import Context

Serializer = Literal["pickle", "dill", "cloudpickle"]

try:
from airflow.operators.python import _SERIALIZERS
except ImportError:
import logging

import lazy_object_proxy

log = logging.getLogger(__name__)

def _load_pickle():
import pickle

return pickle

def _load_dill():
try:
import dill
except ModuleNotFoundError:
log.error("Unable to import `dill` module. Please please make sure that it installed.")
raise
return dill

def _load_cloudpickle():
try:
import cloudpickle
except ModuleNotFoundError:
log.error(
"Unable to import `cloudpickle` module. "
"Please install it with: pip install 'apache-airflow[cloudpickle]'"
)
raise
return cloudpickle

_SERIALIZERS: dict[Serializer, Any] = { # type: ignore[no-redef]
"pickle": lazy_object_proxy.Proxy(_load_pickle),
"dill": lazy_object_proxy.Proxy(_load_dill),
"cloudpickle": lazy_object_proxy.Proxy(_load_cloudpickle),
}


def _generate_decode_command(env_var, file, python_command):
# We don't need `f.close()` as the interpreter is about to exit anyway
Expand All @@ -53,7 +93,6 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator):
:param python_callable: A reference to an object that is callable
:param python: Python binary name to use
:param use_dill: Whether dill should be used to serialize the callable
:param expect_airflow: whether to expect airflow to be installed in the docker environment. if this
one is specified, the script to run callable will attempt to load Airflow macros.
:param op_kwargs: a dictionary of keyword arguments that will get unpacked
Expand All @@ -63,6 +102,16 @@ class _DockerDecoratedOperator(DecoratedOperator, DockerOperator):
:param multiple_outputs: if set, function return value will be
unrolled to multiple XCom values. Dict will unroll to xcom values with keys as keys.
Defaults to False.
:param serializer: Which serializer use to serialize the args and result. It can be one of the following:
- ``"pickle"``: (default) Use pickle for serialization. Included in the Python Standard Library.
- ``"cloudpickle"``: Use cloudpickle for serialize more complex types,
this requires to include cloudpickle in your requirements.
- ``"dill"``: Use dill for serialize more complex types,
this requires to include dill in your requirements.
:param use_dill: Deprecated, use ``serializer`` instead. Whether to use dill to serialize
the args and result (pickle is default). This allows more complex types
but requires you to include dill in your requirements.
"""

custom_operator_name = "@task.docker"
Expand All @@ -74,12 +123,35 @@ def __init__(
use_dill=False,
python_command="python3",
expect_airflow: bool = True,
serializer: Serializer | None = None,
**kwargs,
) -> None:
if use_dill:
warnings.warn(
"`use_dill` is deprecated and will be removed in a future version. "
"Please provide serializer='dill' instead.",
AirflowProviderDeprecationWarning,
stacklevel=3,
)
if serializer:
raise AirflowException(
"Both 'use_dill' and 'serializer' parameters are set. Please set only one of them"
)
serializer = "dill"
serializer = serializer or "pickle"
if serializer not in _SERIALIZERS:
msg = (
f"Unsupported serializer {serializer!r}. "
f"Expected one of {', '.join(map(repr, _SERIALIZERS))}"
)
raise AirflowException(msg)

command = "placeholder command"
self.python_command = python_command
self.expect_airflow = expect_airflow
self.use_dill = use_dill
self.use_dill = serializer == "dill"
self.serializer: Serializer = serializer

super().__init__(
command=command, retrieve_output=True, retrieve_output_path="/tmp/script.out", **kwargs
)
Expand Down Expand Up @@ -128,9 +200,7 @@ def execute(self, context: Context):

@property
def pickling_library(self):
if self.use_dill:
return dill
return pickle
return _SERIALIZERS[self.serializer]


def docker_task(
Expand Down
117 changes: 113 additions & 4 deletions tests/providers/docker/decorators/test_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from __future__ import annotations

import logging
from importlib.util import find_spec
from io import StringIO as StringBuffer

import pytest

from airflow.decorators import setup, task, teardown
from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import TaskInstance
from airflow.models.dag import DAG
from airflow.utils import timezone
Expand All @@ -32,6 +33,10 @@


DEFAULT_DATE = timezone.datetime(2021, 9, 1)
DILL_INSTALLED = find_spec("dill") is not None
DILL_MARKER = pytest.mark.skipif(not DILL_INSTALLED, reason="`dill` is not installed")
CLOUDPICKLE_INSTALLED = find_spec("cloudpickle") is not None
CLOUDPICKLE_MARKER = pytest.mark.skipif(not CLOUDPICKLE_INSTALLED, reason="`cloudpickle` is not installed")


class TestDockerDecorator:
Expand Down Expand Up @@ -207,13 +212,21 @@ def f():
assert teardown_task.is_teardown
assert teardown_task.on_failure_fail_dagrun is on_failure_fail_dagrun

@pytest.mark.parametrize("use_dill", [True, False])
def test_deepcopy_with_python_operator(self, dag_maker, use_dill):
@pytest.mark.parametrize(
"serializer",
[
pytest.param("pickle", id="pickle"),
pytest.param("dill", marks=DILL_MARKER, id="dill"),
pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"),
pytest.param(None, id="default"),
],
)
def test_deepcopy_with_python_operator(self, dag_maker, serializer):
import copy

from airflow.providers.docker.decorators.docker import _DockerDecoratedOperator

@task.docker(image="python:3.9-slim", auto_remove="force", use_dill=use_dill)
@task.docker(image="python:3.9-slim", auto_remove="force", serializer=serializer)
def f():
import logging

Expand Down Expand Up @@ -247,6 +260,7 @@ def g():
assert isinstance(clone_of_docker_operator, _DockerDecoratedOperator)
assert some_task.command == clone_of_docker_operator.command
assert some_task.expect_airflow == clone_of_docker_operator.expect_airflow
assert some_task.serializer == clone_of_docker_operator.serializer
assert some_task.use_dill == clone_of_docker_operator.use_dill
assert some_task.pickling_library is clone_of_docker_operator.pickling_library

Expand Down Expand Up @@ -317,3 +331,98 @@ def f():
assert 'with open(sys.argv[4], "w") as file:' not in log_content
last_line_of_docker_operator_log = log_content.splitlines()[-1]
assert "ValueError: This task is expected to fail" in last_line_of_docker_operator_log

@pytest.mark.parametrize(
"serializer",
[
pytest.param("pickle", id="pickle"),
pytest.param("dill", marks=DILL_MARKER, id="dill"),
pytest.param("cloudpickle", marks=CLOUDPICKLE_MARKER, id="cloudpickle"),
],
)
def test_ambiguous_serializer(self, dag_maker, serializer):
@task.docker(image="python:3.9-slim", auto_remove="force", use_dill=True, serializer=serializer)
def f():
pass

with dag_maker():
with pytest.warns(
AirflowProviderDeprecationWarning, match="`use_dill` is deprecated and will be removed"
):
with pytest.raises(
AirflowException, match="Both 'use_dill' and 'serializer' parameters are set"
):
f()

def test_invalid_serializer(self, dag_maker):
@task.docker(image="python:3.9-slim", auto_remove="force", serializer="airflow")
def f():
"""Ensure dill is correctly installed."""
import dill # noqa: F401

with dag_maker():
with pytest.raises(AirflowException, match="Unsupported serializer 'airflow'"):
f()

@pytest.mark.parametrize(
"serializer",
[
pytest.param(
"dill",
marks=pytest.mark.skipif(
DILL_INSTALLED, reason="For this test case `dill` shouldn't be installed"
),
id="dill",
),
pytest.param(
"cloudpickle",
marks=pytest.mark.skipif(
CLOUDPICKLE_INSTALLED, reason="For this test case `cloudpickle` shouldn't be installed"
),
id="cloudpickle",
),
],
)
def test_advanced_serializer_not_installed(self, dag_maker, serializer, caplog):
"""Test case for check raising an error if dill/cloudpickle is not installed."""

@task.docker(image="python:3.9-slim", auto_remove="force", serializer=serializer)
def f(): ...

with dag_maker():
with pytest.raises(ModuleNotFoundError):
f()
assert f"Unable to import `{serializer}` module." in caplog.text

@CLOUDPICKLE_MARKER
def test_add_cloudpickle(self, dag_maker):
@task.docker(image="python:3.9-slim", auto_remove="force", serializer="cloudpickle")
def f():
"""Ensure cloudpickle is correctly installed."""
import cloudpickle # noqa: F401

with dag_maker():
f()

@DILL_MARKER
def test_add_dill(self, dag_maker):
@task.docker(image="python:3.9-slim", auto_remove="force", serializer="dill")
def f():
"""Ensure dill is correctly installed."""
import dill # noqa: F401

with dag_maker():
f()

@DILL_MARKER
def test_add_dill_use_dill(self, dag_maker):
@task.docker(image="python:3.9-slim", auto_remove="force", use_dill=True)
def f():
"""Ensure dill is correctly installed."""
import dill # noqa: F401

with dag_maker():
with pytest.warns(
AirflowProviderDeprecationWarning, match="`use_dill` is deprecated and will be removed"
):
f()

0 comments on commit 54ea557

Please sign in to comment.