diff --git a/CHANGES.md b/CHANGES.md index f1b8df5c5b258..520609504c809 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -64,6 +64,7 @@ ## Highlights * Python 3.10 support in Apache Beam ([#21458](https://github.com/apache/beam/issues/21458)). +* An initial implementation of a runner that allows us to run Beam pipelines on Dask. Try it out and give us feedback! (Python) ([#18962](https://github.com/apache/beam/issues/18962)). ## I/Os @@ -81,6 +82,7 @@ * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Dataframe wrapper added in Go SDK via Cross-Language (with automatic expansion service). (Go) ([#23384](https://github.com/apache/beam/issues/23384)). * Name all Java threads to aid in debugging ([#23049](https://github.com/apache/beam/issues/23049)). +* An initial implementation of a runner that allows us to run Beam pipelines on Dask. (Python) ([#18962](https://github.com/apache/beam/issues/18962)). ## Breaking Changes diff --git a/sdks/python/apache_beam/runners/dask/__init__.py b/sdks/python/apache_beam/runners/dask/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/sdks/python/apache_beam/runners/dask/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/sdks/python/apache_beam/runners/dask/dask_runner.py b/sdks/python/apache_beam/runners/dask/dask_runner.py new file mode 100644 index 0000000000000..109c4379b45df --- /dev/null +++ b/sdks/python/apache_beam/runners/dask/dask_runner.py @@ -0,0 +1,182 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""DaskRunner, executing remote jobs on Dask.distributed. + +The DaskRunner is a runner implementation that executes a graph of +transformations across processes and workers via Dask distributed's +scheduler. +""" +import argparse +import dataclasses +import typing as t + +from apache_beam import pvalue +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.pipeline import AppliedPTransform +from apache_beam.pipeline import PipelineVisitor +from apache_beam.runners.dask.overrides import dask_overrides +from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS +from apache_beam.runners.dask.transform_evaluator import NoOp +from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner +from apache_beam.runners.runner import PipelineResult +from apache_beam.runners.runner import PipelineState +from apache_beam.utils.interactive_utils import is_in_notebook + + +class DaskOptions(PipelineOptions): + @staticmethod + def _parse_timeout(candidate): + try: + return int(candidate) + except (TypeError, ValueError): + import dask + return dask.config.no_default + + @classmethod + def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: + parser.add_argument( + '--dask_client_address', + dest='address', + type=str, + default=None, + help='Address of a dask Scheduler server. Will default to a ' + '`dask.LocalCluster()`.') + parser.add_argument( + '--dask_connection_timeout', + dest='timeout', + type=DaskOptions._parse_timeout, + help='Timeout duration for initial connection to the scheduler.') + parser.add_argument( + '--dask_scheduler_file', + dest='scheduler_file', + type=str, + default=None, + help='Path to a file with scheduler information if available.') + # TODO(alxr): Add options for security. + parser.add_argument( + '--dask_client_name', + dest='name', + type=str, + default=None, + help='Gives the client a name that will be included in logs generated ' + 'on the scheduler for matters relating to this client.') + parser.add_argument( + '--dask_connection_limit', + dest='connection_limit', + type=int, + default=512, + help='The number of open comms to maintain at once in the connection ' + 'pool.') + + +@dataclasses.dataclass +class DaskRunnerResult(PipelineResult): + from dask import distributed + + client: distributed.Client + futures: t.Sequence[distributed.Future] + + def __post_init__(self): + super().__init__(PipelineState.RUNNING) + + def wait_until_finish(self, duration=None) -> str: + try: + if duration is not None: + # Convert milliseconds to seconds + duration /= 1000 + self.client.wait_for_workers(timeout=duration) + self.client.gather(self.futures, errors='raise') + self._state = PipelineState.DONE + except: # pylint: disable=broad-except + self._state = PipelineState.FAILED + raise + return self._state + + def cancel(self) -> str: + self._state = PipelineState.CANCELLING + self.client.cancel(self.futures) + self._state = PipelineState.CANCELLED + return self._state + + def metrics(self): + # TODO(alxr): Collect and return metrics... + raise NotImplementedError('collecting metrics will come later!') + + +class DaskRunner(BundleBasedDirectRunner): + """Executes a pipeline on a Dask distributed client.""" + @staticmethod + def to_dask_bag_visitor() -> PipelineVisitor: + from dask import bag as db + + @dataclasses.dataclass + class DaskBagVisitor(PipelineVisitor): + bags: t.Dict[AppliedPTransform, + db.Bag] = dataclasses.field(default_factory=dict) + + def visit_transform(self, transform_node: AppliedPTransform) -> None: + op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) + op = op_class(transform_node) + + inputs = list(transform_node.inputs) + if inputs: + bag_inputs = [] + for input_value in inputs: + if isinstance(input_value, pvalue.PBegin): + bag_inputs.append(None) + + prev_op = input_value.producer + if prev_op in self.bags: + bag_inputs.append(self.bags[prev_op]) + + if len(bag_inputs) == 1: + self.bags[transform_node] = op.apply(bag_inputs[0]) + else: + self.bags[transform_node] = op.apply(bag_inputs) + + else: + self.bags[transform_node] = op.apply(None) + + return DaskBagVisitor() + + @staticmethod + def is_fnapi_compatible(): + return False + + def run_pipeline(self, pipeline, options): + # TODO(alxr): Create interactive notebook support. + if is_in_notebook(): + raise NotImplementedError('interactive support will come later!') + + try: + import dask.distributed as ddist + except ImportError: + raise ImportError( + 'DaskRunner is not available. Please install apache_beam[dask].') + + dask_options = options.view_as(DaskOptions).get_all_options( + drop_default=True) + client = ddist.Client(**dask_options) + + pipeline.replace_all(dask_overrides()) + + dask_visitor = self.to_dask_bag_visitor() + pipeline.visit(dask_visitor) + + futures = client.compute(list(dask_visitor.bags.values())) + return DaskRunnerResult(client, futures) diff --git a/sdks/python/apache_beam/runners/dask/dask_runner_test.py b/sdks/python/apache_beam/runners/dask/dask_runner_test.py new file mode 100644 index 0000000000000..d8b3e17d8a56b --- /dev/null +++ b/sdks/python/apache_beam/runners/dask/dask_runner_test.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import unittest + +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.testing import test_pipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to + +try: + from apache_beam.runners.dask.dask_runner import DaskOptions + from apache_beam.runners.dask.dask_runner import DaskRunner + import dask + import dask.distributed as ddist +except (ImportError, ModuleNotFoundError): + raise unittest.SkipTest('Dask must be installed to run tests.') + + +class DaskOptionsTest(unittest.TestCase): + def test_parses_connection_timeout__defaults_to_none(self): + default_options = PipelineOptions([]) + default_dask_options = default_options.view_as(DaskOptions) + self.assertEqual(None, default_dask_options.timeout) + + def test_parses_connection_timeout__parses_int(self): + conn_options = PipelineOptions('--dask_connection_timeout 12'.split()) + dask_conn_options = conn_options.view_as(DaskOptions) + self.assertEqual(12, dask_conn_options.timeout) + + def test_parses_connection_timeout__handles_bad_input(self): + err_options = PipelineOptions('--dask_connection_timeout foo'.split()) + dask_err_options = err_options.view_as(DaskOptions) + self.assertEqual(dask.config.no_default, dask_err_options.timeout) + + def test_parser_destinations__agree_with_dask_client(self): + options = PipelineOptions( + '--dask_client_address localhost:8080 --dask_connection_timeout 600 ' + '--dask_scheduler_file foobar.cfg --dask_client_name charlie ' + '--dask_connection_limit 1024'.split()) + dask_options = options.view_as(DaskOptions) + + # Get the argument names for the constructor. + client_args = list(inspect.signature(ddist.Client).parameters) + + for opt_name in dask_options.get_all_options(drop_default=True).keys(): + with self.subTest(f'{opt_name} in dask.distributed.Client constructor'): + self.assertIn(opt_name, client_args) + + +class DaskRunnerRunPipelineTest(unittest.TestCase): + """Test class used to introspect the dask runner via a debugger.""" + def setUp(self) -> None: + self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) + + def test_create(self): + with self.pipeline as p: + pcoll = p | beam.Create([1]) + assert_that(pcoll, equal_to([1])) + + def test_create_and_map(self): + def double(x): + return x * 2 + + with self.pipeline as p: + pcoll = p | beam.Create([1]) | beam.Map(double) + assert_that(pcoll, equal_to([2])) + + def test_create_map_and_groupby(self): + def double(x): + return x * 2, x + + with self.pipeline as p: + pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() + assert_that(pcoll, equal_to([(2, [1])])) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/runners/dask/overrides.py b/sdks/python/apache_beam/runners/dask/overrides.py new file mode 100644 index 0000000000000..d07c7cd518afb --- /dev/null +++ b/sdks/python/apache_beam/runners/dask/overrides.py @@ -0,0 +1,145 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import typing as t + +import apache_beam as beam +from apache_beam import typehints +from apache_beam.io.iobase import SourceBase +from apache_beam.pipeline import AppliedPTransform +from apache_beam.pipeline import PTransformOverride +from apache_beam.runners.direct.direct_runner import _GroupAlsoByWindowDoFn +from apache_beam.transforms import ptransform +from apache_beam.transforms.window import GlobalWindows + +K = t.TypeVar("K") +V = t.TypeVar("V") + + +@dataclasses.dataclass +class _Create(beam.PTransform): + values: t.Tuple[t.Any] + + def expand(self, input_or_inputs): + return beam.pvalue.PCollection.from_(input_or_inputs) + + def get_windowing(self, inputs: t.Any) -> beam.Windowing: + return beam.Windowing(GlobalWindows()) + + +@typehints.with_input_types(K) +@typehints.with_output_types(K) +class _Reshuffle(beam.PTransform): + def expand(self, input_or_inputs): + return beam.pvalue.PCollection.from_(input_or_inputs) + + +@dataclasses.dataclass +class _Read(beam.PTransform): + source: SourceBase + + def expand(self, input_or_inputs): + return beam.pvalue.PCollection.from_(input_or_inputs) + + +@typehints.with_input_types(t.Tuple[K, V]) +@typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) +class _GroupByKeyOnly(beam.PTransform): + def expand(self, input_or_inputs): + return beam.pvalue.PCollection.from_(input_or_inputs) + + def infer_output_type(self, input_type): + + key_type, value_type = typehints.trivial_inference.key_value_types( + input_type + ) + return typehints.KV[key_type, typehints.Iterable[value_type]] + + +@typehints.with_input_types(t.Tuple[K, t.Iterable[V]]) +@typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) +class _GroupAlsoByWindow(beam.ParDo): + """Not used yet...""" + def __init__(self, windowing): + super().__init__(_GroupAlsoByWindowDoFn(windowing)) + self.windowing = windowing + + def expand(self, input_or_inputs): + return beam.pvalue.PCollection.from_(input_or_inputs) + + +@typehints.with_input_types(t.Tuple[K, V]) +@typehints.with_output_types(t.Tuple[K, t.Iterable[V]]) +class _GroupByKey(beam.PTransform): + def expand(self, input_or_inputs): + return input_or_inputs | "GroupByKey" >> _GroupByKeyOnly() + + +class _Flatten(beam.PTransform): + def expand(self, input_or_inputs): + is_bounded = all(pcoll.is_bounded for pcoll in input_or_inputs) + return beam.pvalue.PCollection(self.pipeline, is_bounded=is_bounded) + + +def dask_overrides() -> t.List[PTransformOverride]: + class CreateOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: + return applied_ptransform.transform.__class__ == beam.Create + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: + return _Create(t.cast(beam.Create, applied_ptransform.transform).values) + + class ReshuffleOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: + return applied_ptransform.transform.__class__ == beam.Reshuffle + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: + return _Reshuffle() + + class ReadOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: + return applied_ptransform.transform.__class__ == beam.io.Read + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: + return _Read(t.cast(beam.io.Read, applied_ptransform.transform).source) + + class GroupByKeyOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: + return applied_ptransform.transform.__class__ == beam.GroupByKey + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: + return _GroupByKey() + + class FlattenOverride(PTransformOverride): + def matches(self, applied_ptransform: AppliedPTransform) -> bool: + return applied_ptransform.transform.__class__ == beam.Flatten + + def get_replacement_transform_for_applied_ptransform( + self, applied_ptransform: AppliedPTransform) -> ptransform.PTransform: + return _Flatten() + + return [ + CreateOverride(), + ReshuffleOverride(), + ReadOverride(), + GroupByKeyOverride(), + FlattenOverride(), + ] diff --git a/sdks/python/apache_beam/runners/dask/transform_evaluator.py b/sdks/python/apache_beam/runners/dask/transform_evaluator.py new file mode 100644 index 0000000000000..c4aac7f2111f8 --- /dev/null +++ b/sdks/python/apache_beam/runners/dask/transform_evaluator.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transform Beam PTransforms into Dask Bag operations. + +A minimum set of operation substitutions, to adap Beam's PTransform model +to Dask Bag functions. + +TODO(alxr): Translate ops from https://docs.dask.org/en/latest/bag-api.html. +""" +import abc +import dataclasses +import typing as t + +import apache_beam +import dask.bag as db +from apache_beam.pipeline import AppliedPTransform +from apache_beam.runners.dask.overrides import _Create +from apache_beam.runners.dask.overrides import _Flatten +from apache_beam.runners.dask.overrides import _GroupByKeyOnly + +OpInput = t.Union[db.Bag, t.Sequence[db.Bag], None] + + +@dataclasses.dataclass +class DaskBagOp(abc.ABC): + applied: AppliedPTransform + + @property + def transform(self): + return self.applied.transform + + @abc.abstractmethod + def apply(self, input_bag: OpInput) -> db.Bag: + pass + + +class NoOp(DaskBagOp): + def apply(self, input_bag: OpInput) -> db.Bag: + return input_bag + + +class Create(DaskBagOp): + def apply(self, input_bag: OpInput) -> db.Bag: + assert input_bag is None, 'Create expects no input!' + original_transform = t.cast(_Create, self.transform) + items = original_transform.values + return db.from_sequence(items) + + +class ParDo(DaskBagOp): + def apply(self, input_bag: db.Bag) -> db.Bag: + transform = t.cast(apache_beam.ParDo, self.transform) + return input_bag.map( + transform.fn.process, *transform.args, **transform.kwargs).flatten() + + +class Map(DaskBagOp): + def apply(self, input_bag: db.Bag) -> db.Bag: + transform = t.cast(apache_beam.Map, self.transform) + return input_bag.map( + transform.fn.process, *transform.args, **transform.kwargs) + + +class GroupByKey(DaskBagOp): + def apply(self, input_bag: db.Bag) -> db.Bag: + def key(item): + return item[0] + + def value(item): + k, v = item + return k, [elm[1] for elm in v] + + return input_bag.groupby(key).map(value) + + +class Flatten(DaskBagOp): + def apply(self, input_bag: OpInput) -> db.Bag: + assert type(input_bag) is list, 'Must take a sequence of bags!' + return db.concat(input_bag) + + +TRANSLATIONS = { + _Create: Create, + apache_beam.ParDo: ParDo, + apache_beam.Map: Map, + _GroupByKeyOnly: GroupByKey, + _Flatten: Flatten, +} diff --git a/sdks/python/mypy.ini b/sdks/python/mypy.ini index 9309120a8cabe..a628036d6682f 100644 --- a/sdks/python/mypy.ini +++ b/sdks/python/mypy.ini @@ -89,6 +89,9 @@ ignore_errors = true [mypy-apache_beam.runners.direct.*] ignore_errors = true +[mypy-apache_beam.runners.dask.*] +ignore_errors = true + [mypy-apache_beam.runners.interactive.*] ignore_errors = true diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 8451fc5964660..61858fa5d978a 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -350,7 +350,11 @@ def get_portability_package_data(): # This can be removed once dill is updated to version > 0.3.5.1 # Issue: https://github.com/apache/beam/issues/23566 'dataframe': ['pandas>=1.0,<1.5;python_version<"3.10"', - 'pandas>=1.4.3,<1.5;python_version>="3.10"'] + 'pandas>=1.4.3,<1.5;python_version>="3.10"'], + 'dask': [ + 'dask >= 2022.6', + 'distributed >= 2022.6', + ], }, zip_safe=False, # PyPI package information. diff --git a/sdks/python/test-suites/tox/common.gradle b/sdks/python/test-suites/tox/common.gradle index 99afc1d725579..61802ac9c45ec 100644 --- a/sdks/python/test-suites/tox/common.gradle +++ b/sdks/python/test-suites/tox/common.gradle @@ -24,6 +24,9 @@ test.dependsOn "testPython${pythonVersionSuffix}" toxTask "testPy${pythonVersionSuffix}Cloud", "py${pythonVersionSuffix}-cloud" test.dependsOn "testPy${pythonVersionSuffix}Cloud" +toxTask "testPy${pythonVersionSuffix}Dask", "py${pythonVersionSuffix}-dask" +test.dependsOn "testPy${pythonVersionSuffix}Dask" + toxTask "testPy${pythonVersionSuffix}Cython", "py${pythonVersionSuffix}-cython" test.dependsOn "testPy${pythonVersionSuffix}Cython" diff --git a/sdks/python/tox.ini b/sdks/python/tox.ini index 138a5410ead0d..11997b55c771a 100644 --- a/sdks/python/tox.ini +++ b/sdks/python/tox.ini @@ -17,7 +17,7 @@ [tox] # new environments will be excluded by default unless explicitly added to envlist. -envlist = py37,py38,py39,py310,py37-{cloud,cython,lint,mypy},py38-{cloud,cython,docs,cloudcoverage},py39-{cloud,cython},py310-{cloud,cython},whitespacelint +envlist = py37,py38,py39,py310,py37-{cloud,cython,lint,mypy,dask},py38-{cloud,cython,docs,cloudcoverage,dask},py39-{cloud,cython},py310-{cloud,cython,dask},whitespacelint toxworkdir = {toxinidir}/target/{env:ENV_NAME:.tox} [pycodestyle] @@ -92,12 +92,16 @@ extras = test,gcp,interactive,dataframe,aws,azure commands = {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" +[testenv:py{37,38,39}-dask] +extras = test,dask +commands = + {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" [testenv:py38-cloudcoverage] deps = codecov pytest-cov==3.0.0 passenv = GIT_* BUILD_* ghprb* CHANGE_ID BRANCH_NAME JENKINS_* CODECOV_* -extras = test,gcp,interactive,dataframe,aws +extras = test,gcp,interactive,dataframe,aws,dask commands = -rm .coverage {toxinidir}/scripts/run_pytest.sh {envname} "{posargs}" "--cov-report=xml --cov=. --cov-append" @@ -129,6 +133,8 @@ commands = deps = -r build-requirements.txt mypy==0.782 + dask==2022.01.0 + distributed==2022.01.0 # make extras available in case any of these libs are typed extras = gcp @@ -136,8 +142,9 @@ commands = mypy --version python setup.py mypy + [testenv:py38-docs] -extras = test,gcp,docs,interactive,dataframe +extras = test,gcp,docs,interactive,dataframe,dask deps = Sphinx==1.8.5 sphinx_rtd_theme==0.4.3