-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial DaskRunner for Beam (#22421)
* WIP: Created a skeleton dask runner implementation. * WIP: Idea for a translation evaluator. * Added overrides and a visitor that translates operations. * Fixed a dataclass typo. * Expanded translations. * Core idea seems to be kinda working... * First iteration on DaskRunnerResult (keep track of pipeline state). * Added minimal set of DaskRunner options. * WIP: Alllmost got asserts to work! The current status is: - CoGroupByKey is broken due to how tags are used with GroupByKey - GroupByKey should output `[('0', None), ('1', 1)]`, however it actually outputs: [(None, ('1', 1)), (None, ('0', None))] - Once that is fixed, we may have test pipelines work on Dask. * With a great 1-liner from @pabloem, groupby is fixed! Now, all three initial tests pass. * Self-review: Cleaned up dask runner impl. * Self-review: Remove TODOs, delete commented out code, other cleanup. * First pass at linting rules. * WIP, include dask dependencies + test setup. * WIP: maybe better dask deps? * Skip dask tests depending on successful import. * Fixed setup.py (missing `,`). * Added an additional comma. * Moved skipping logic to be above dask import. * Fix lint issues with dask runner tests. * Adding destination for client address. * Changing to async produces a timeout error instead of stuck in infinite loop. * Close client during `wait_until_finish`; rm async. * Supporting side-inputs for ParDo. * Revert "Close client during `wait_until_finish`; rm async." This reverts commit 09365f6. * Revert "Changing to async produces a timeout error instead of stuck in infinite loop." This reverts commit 676d752. * Adding -dask tox targets onto the gradle build * wip - added print stmt. * wip - prove side inputs is set. * wip - prove side inputs is set in Pardo. * wip - rm asserts, add print * wip - adding named inputs... * Experiments: non-named side inputs + del `None` in named inputs. * None --> 'None' * No default side input. * Pass along args + kwargs. * Applied yapf to dask sources. * Dask sources passing pylint. * Added dask extra to docs gen tox env. * Applied yapf from tox. * Include dask in mypy checks. * Upgrading mypy support to python 3.8 since py37 support is deprecated in dask. * Manually installing an old version of dask before 3.7 support was dropped. * fix lint: line too long. * Fixed type errors with DaskRunnerResult. Disabled mypy type checking in dask. * Fix pytype errors (in transform_evaluator). * Ran isort. * Ran yapf again. * Fix imports (one per line) * isort -- alphabetical. * Added feature to CHANGES.md. * ran yapf via tox on linux machine * Change an import to pass CI. * Skip isort error; needed to get CI to pass. * Skip test logic may favor better with isort. * (Maybe) the last isort fix. * Tested pipeline options (added one fix). * Improve formatting of test. * Self-review: removing side inputs. In addition, adding a more helpful property to the base DaskBagOp (tranform). * add dask to coverage suite in tox. * Capture value error in assert. * Change timeout value to 600 seconds. * ignoring broken test * Update CHANGES.md * Using reflection to test the Dask client constructor. * Better method of inspecting the constructor parameters (thanks @TomAugspurger!). Co-authored-by: Pablo E <pabloem@apache.org> Co-authored-by: Pablo <pabloem@users.noreply.github.com>
- Loading branch information
1 parent
220902c
commit 76761db
Showing
10 changed files
with
563 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
"""DaskRunner, executing remote jobs on Dask.distributed. | ||
The DaskRunner is a runner implementation that executes a graph of | ||
transformations across processes and workers via Dask distributed's | ||
scheduler. | ||
""" | ||
import argparse | ||
import dataclasses | ||
import typing as t | ||
|
||
from apache_beam import pvalue | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.pipeline import AppliedPTransform | ||
from apache_beam.pipeline import PipelineVisitor | ||
from apache_beam.runners.dask.overrides import dask_overrides | ||
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS | ||
from apache_beam.runners.dask.transform_evaluator import NoOp | ||
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner | ||
from apache_beam.runners.runner import PipelineResult | ||
from apache_beam.runners.runner import PipelineState | ||
from apache_beam.utils.interactive_utils import is_in_notebook | ||
|
||
|
||
class DaskOptions(PipelineOptions): | ||
@staticmethod | ||
def _parse_timeout(candidate): | ||
try: | ||
return int(candidate) | ||
except (TypeError, ValueError): | ||
import dask | ||
return dask.config.no_default | ||
|
||
@classmethod | ||
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
'--dask_client_address', | ||
dest='address', | ||
type=str, | ||
default=None, | ||
help='Address of a dask Scheduler server. Will default to a ' | ||
'`dask.LocalCluster()`.') | ||
parser.add_argument( | ||
'--dask_connection_timeout', | ||
dest='timeout', | ||
type=DaskOptions._parse_timeout, | ||
help='Timeout duration for initial connection to the scheduler.') | ||
parser.add_argument( | ||
'--dask_scheduler_file', | ||
dest='scheduler_file', | ||
type=str, | ||
default=None, | ||
help='Path to a file with scheduler information if available.') | ||
# TODO(alxr): Add options for security. | ||
parser.add_argument( | ||
'--dask_client_name', | ||
dest='name', | ||
type=str, | ||
default=None, | ||
help='Gives the client a name that will be included in logs generated ' | ||
'on the scheduler for matters relating to this client.') | ||
parser.add_argument( | ||
'--dask_connection_limit', | ||
dest='connection_limit', | ||
type=int, | ||
default=512, | ||
help='The number of open comms to maintain at once in the connection ' | ||
'pool.') | ||
|
||
|
||
@dataclasses.dataclass | ||
class DaskRunnerResult(PipelineResult): | ||
from dask import distributed | ||
|
||
client: distributed.Client | ||
futures: t.Sequence[distributed.Future] | ||
|
||
def __post_init__(self): | ||
super().__init__(PipelineState.RUNNING) | ||
|
||
def wait_until_finish(self, duration=None) -> str: | ||
try: | ||
if duration is not None: | ||
# Convert milliseconds to seconds | ||
duration /= 1000 | ||
self.client.wait_for_workers(timeout=duration) | ||
self.client.gather(self.futures, errors='raise') | ||
self._state = PipelineState.DONE | ||
except: # pylint: disable=broad-except | ||
self._state = PipelineState.FAILED | ||
raise | ||
return self._state | ||
|
||
def cancel(self) -> str: | ||
self._state = PipelineState.CANCELLING | ||
self.client.cancel(self.futures) | ||
self._state = PipelineState.CANCELLED | ||
return self._state | ||
|
||
def metrics(self): | ||
# TODO(alxr): Collect and return metrics... | ||
raise NotImplementedError('collecting metrics will come later!') | ||
|
||
|
||
class DaskRunner(BundleBasedDirectRunner): | ||
"""Executes a pipeline on a Dask distributed client.""" | ||
@staticmethod | ||
def to_dask_bag_visitor() -> PipelineVisitor: | ||
from dask import bag as db | ||
|
||
@dataclasses.dataclass | ||
class DaskBagVisitor(PipelineVisitor): | ||
bags: t.Dict[AppliedPTransform, | ||
db.Bag] = dataclasses.field(default_factory=dict) | ||
|
||
def visit_transform(self, transform_node: AppliedPTransform) -> None: | ||
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) | ||
op = op_class(transform_node) | ||
|
||
inputs = list(transform_node.inputs) | ||
if inputs: | ||
bag_inputs = [] | ||
for input_value in inputs: | ||
if isinstance(input_value, pvalue.PBegin): | ||
bag_inputs.append(None) | ||
|
||
prev_op = input_value.producer | ||
if prev_op in self.bags: | ||
bag_inputs.append(self.bags[prev_op]) | ||
|
||
if len(bag_inputs) == 1: | ||
self.bags[transform_node] = op.apply(bag_inputs[0]) | ||
else: | ||
self.bags[transform_node] = op.apply(bag_inputs) | ||
|
||
else: | ||
self.bags[transform_node] = op.apply(None) | ||
|
||
return DaskBagVisitor() | ||
|
||
@staticmethod | ||
def is_fnapi_compatible(): | ||
return False | ||
|
||
def run_pipeline(self, pipeline, options): | ||
# TODO(alxr): Create interactive notebook support. | ||
if is_in_notebook(): | ||
raise NotImplementedError('interactive support will come later!') | ||
|
||
try: | ||
import dask.distributed as ddist | ||
except ImportError: | ||
raise ImportError( | ||
'DaskRunner is not available. Please install apache_beam[dask].') | ||
|
||
dask_options = options.view_as(DaskOptions).get_all_options( | ||
drop_default=True) | ||
client = ddist.Client(**dask_options) | ||
|
||
pipeline.replace_all(dask_overrides()) | ||
|
||
dask_visitor = self.to_dask_bag_visitor() | ||
pipeline.visit(dask_visitor) | ||
|
||
futures = client.compute(list(dask_visitor.bags.values())) | ||
return DaskRunnerResult(client, futures) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import inspect | ||
import unittest | ||
|
||
import apache_beam as beam | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.testing import test_pipeline | ||
from apache_beam.testing.util import assert_that | ||
from apache_beam.testing.util import equal_to | ||
|
||
try: | ||
from apache_beam.runners.dask.dask_runner import DaskOptions | ||
from apache_beam.runners.dask.dask_runner import DaskRunner | ||
import dask | ||
import dask.distributed as ddist | ||
except (ImportError, ModuleNotFoundError): | ||
raise unittest.SkipTest('Dask must be installed to run tests.') | ||
|
||
|
||
class DaskOptionsTest(unittest.TestCase): | ||
def test_parses_connection_timeout__defaults_to_none(self): | ||
default_options = PipelineOptions([]) | ||
default_dask_options = default_options.view_as(DaskOptions) | ||
self.assertEqual(None, default_dask_options.timeout) | ||
|
||
def test_parses_connection_timeout__parses_int(self): | ||
conn_options = PipelineOptions('--dask_connection_timeout 12'.split()) | ||
dask_conn_options = conn_options.view_as(DaskOptions) | ||
self.assertEqual(12, dask_conn_options.timeout) | ||
|
||
def test_parses_connection_timeout__handles_bad_input(self): | ||
err_options = PipelineOptions('--dask_connection_timeout foo'.split()) | ||
dask_err_options = err_options.view_as(DaskOptions) | ||
self.assertEqual(dask.config.no_default, dask_err_options.timeout) | ||
|
||
def test_parser_destinations__agree_with_dask_client(self): | ||
options = PipelineOptions( | ||
'--dask_client_address localhost:8080 --dask_connection_timeout 600 ' | ||
'--dask_scheduler_file foobar.cfg --dask_client_name charlie ' | ||
'--dask_connection_limit 1024'.split()) | ||
dask_options = options.view_as(DaskOptions) | ||
|
||
# Get the argument names for the constructor. | ||
client_args = list(inspect.signature(ddist.Client).parameters) | ||
|
||
for opt_name in dask_options.get_all_options(drop_default=True).keys(): | ||
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'): | ||
self.assertIn(opt_name, client_args) | ||
|
||
|
||
class DaskRunnerRunPipelineTest(unittest.TestCase): | ||
"""Test class used to introspect the dask runner via a debugger.""" | ||
def setUp(self) -> None: | ||
self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) | ||
|
||
def test_create(self): | ||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | ||
assert_that(pcoll, equal_to([1])) | ||
|
||
def test_create_and_map(self): | ||
def double(x): | ||
return x * 2 | ||
|
||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | beam.Map(double) | ||
assert_that(pcoll, equal_to([2])) | ||
|
||
def test_create_map_and_groupby(self): | ||
def double(x): | ||
return x * 2, x | ||
|
||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() | ||
assert_that(pcoll, equal_to([(2, [1])])) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.