-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial DaskRunner for Beam #22421
Merged
Merged
Initial DaskRunner for Beam #22421
Changes from all commits
Commits
Show all changes
69 commits
Select commit
Hold shift + click to select a range
79d4603
WIP: Created a skeleton dask runner implementation.
alxmrs 248ec70
WIP: Idea for a translation evaluator.
alxmrs 42452ca
Added overrides and a visitor that translates operations.
alxmrs 1da2ddd
Fixed a dataclass typo.
alxmrs 14885a3
Expanded translations.
alxmrs fca2420
Core idea seems to be kinda working...
alxmrs 6dd1ada
First iteration on DaskRunnerResult (keep track of pipeline state).
alxmrs 6675687
Added minimal set of DaskRunner options.
alxmrs 88ed36b
WIP: Alllmost got asserts to work! The current status is:
alxmrs 2e3a126
With a great 1-liner from @pabloem, groupby is fixed! Now, all three …
alxmrs 6467b0e
Self-review: Cleaned up dask runner impl.
alxmrs 793ba86
Self-review: Remove TODOs, delete commented out code, other cleanup.
alxmrs e535792
First pass at linting rules.
alxmrs 8e32668
WIP, include dask dependencies + test setup.
alxmrs 318afc2
WIP: maybe better dask deps?
alxmrs b01855f
Skip dask tests depending on successful import.
alxmrs 2c2eb8d
Fixed setup.py (missing `,`).
alxmrs e64e9eb
Added an additional comma.
alxmrs 69b118b
Moved skipping logic to be above dask import.
alxmrs 9ffc8d8
Fix lint issues with dask runner tests.
alxmrs 8a2afb7
Adding destination for client address.
alxmrs 93f02f1
Changing to async produces a timeout error instead of stuck in infini…
alxmrs afdcf1b
Close client during `wait_until_finish`; rm async.
alxmrs 41b5267
Supporting side-inputs for ParDo.
alxmrs e3ac3f8
Revert "Close client during `wait_until_finish`; rm async."
pabloem 3fddc81
Revert "Changing to async produces a timeout error instead of stuck i…
pabloem 9eeb9ea
Adding -dask tox targets onto the gradle build
pabloem b4d0999
wip - added print stmt.
alxmrs 0319ffd
wip - prove side inputs is set.
alxmrs 0b13bb0
wip - prove side inputs is set in Pardo.
alxmrs 1e7052b
wip - rm asserts, add print
alxmrs 292e023
wip - adding named inputs...
alxmrs 31c1e2b
Experiments: non-named side inputs + del `None` in named inputs.
alxmrs f4ecf2f
None --> 'None'
alxmrs 4d24ed9
No default side input.
alxmrs ee62a4a
Pass along args + kwargs.
alxmrs 506c719
Applied yapf to dask sources.
alxmrs cd0ba8b
Dask sources passing pylint.
alxmrs d0a7c63
Added dask extra to docs gen tox env.
alxmrs 775bd07
Applied yapf from tox.
alxmrs efba1c9
Include dask in mypy checks.
alxmrs 741d961
Upgrading mypy support to python 3.8 since py37 support is deprecated…
alxmrs f66458a
Manually installing an old version of dask before 3.7 support was dro…
alxmrs 5dcf969
fix lint: line too long.
alxmrs ec5f613
Fixed type errors with DaskRunnerResult. Disabled mypy type checking …
alxmrs 04b1f1a
Fix pytype errors (in transform_evaluator).
alxmrs 712944b
Ran isort.
alxmrs 567b72b
Ran yapf again.
alxmrs f53c0a4
Fix imports (one per line)
alxmrs fb280ad
isort -- alphabetical.
alxmrs 80ddfec
Added feature to CHANGES.md.
alxmrs 40c4e35
ran yapf via tox on linux machine
alxmrs a70c5f3
Merge branch 'master' into dask-runner-mvp
alxmrs 9fb52e5
Change an import to pass CI.
alxmrs 26c6016
Skip isort error; needed to get CI to pass.
alxmrs aec19bf
Skip test logic may favor better with isort.
alxmrs 0673235
(Maybe) the last isort fix.
alxmrs de03a32
Tested pipeline options (added one fix).
alxmrs 7e0a2c7
Improve formatting of test.
alxmrs 39b1e1c
Self-review: removing side inputs.
alxmrs 6db49fa
add dask to coverage suite in tox.
alxmrs 036561c
Merge branch 'master' into dask-runner-mvp
alxmrs 191580d
Capture value error in assert.
alxmrs 365fc87
Merge branch 'master' of github.com:apache/beam into dask-runner-mvp
alxmrs 085447e
Change timeout value to 600 seconds.
alxmrs 1a60a5e
ignoring broken test
pabloem c1037f8
Update CHANGES.md
pabloem 9e79ffd
Using reflection to test the Dask client constructor.
alxmrs f9cf45a
Better method of inspecting the constructor parameters (thanks @TomAu…
alxmrs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
"""DaskRunner, executing remote jobs on Dask.distributed. | ||
|
||
The DaskRunner is a runner implementation that executes a graph of | ||
transformations across processes and workers via Dask distributed's | ||
scheduler. | ||
""" | ||
import argparse | ||
import dataclasses | ||
import typing as t | ||
|
||
from apache_beam import pvalue | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.pipeline import AppliedPTransform | ||
from apache_beam.pipeline import PipelineVisitor | ||
from apache_beam.runners.dask.overrides import dask_overrides | ||
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS | ||
from apache_beam.runners.dask.transform_evaluator import NoOp | ||
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner | ||
from apache_beam.runners.runner import PipelineResult | ||
from apache_beam.runners.runner import PipelineState | ||
from apache_beam.utils.interactive_utils import is_in_notebook | ||
|
||
|
||
class DaskOptions(PipelineOptions): | ||
@staticmethod | ||
def _parse_timeout(candidate): | ||
try: | ||
return int(candidate) | ||
except (TypeError, ValueError): | ||
import dask | ||
return dask.config.no_default | ||
|
||
@classmethod | ||
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None: | ||
parser.add_argument( | ||
'--dask_client_address', | ||
dest='address', | ||
type=str, | ||
default=None, | ||
help='Address of a dask Scheduler server. Will default to a ' | ||
'`dask.LocalCluster()`.') | ||
parser.add_argument( | ||
'--dask_connection_timeout', | ||
dest='timeout', | ||
type=DaskOptions._parse_timeout, | ||
help='Timeout duration for initial connection to the scheduler.') | ||
parser.add_argument( | ||
'--dask_scheduler_file', | ||
dest='scheduler_file', | ||
type=str, | ||
default=None, | ||
help='Path to a file with scheduler information if available.') | ||
# TODO(alxr): Add options for security. | ||
parser.add_argument( | ||
'--dask_client_name', | ||
dest='name', | ||
type=str, | ||
default=None, | ||
help='Gives the client a name that will be included in logs generated ' | ||
'on the scheduler for matters relating to this client.') | ||
parser.add_argument( | ||
'--dask_connection_limit', | ||
dest='connection_limit', | ||
type=int, | ||
default=512, | ||
help='The number of open comms to maintain at once in the connection ' | ||
'pool.') | ||
|
||
|
||
@dataclasses.dataclass | ||
class DaskRunnerResult(PipelineResult): | ||
from dask import distributed | ||
|
||
client: distributed.Client | ||
futures: t.Sequence[distributed.Future] | ||
|
||
def __post_init__(self): | ||
super().__init__(PipelineState.RUNNING) | ||
|
||
def wait_until_finish(self, duration=None) -> str: | ||
try: | ||
if duration is not None: | ||
# Convert milliseconds to seconds | ||
duration /= 1000 | ||
self.client.wait_for_workers(timeout=duration) | ||
self.client.gather(self.futures, errors='raise') | ||
self._state = PipelineState.DONE | ||
except: # pylint: disable=broad-except | ||
self._state = PipelineState.FAILED | ||
raise | ||
return self._state | ||
|
||
def cancel(self) -> str: | ||
self._state = PipelineState.CANCELLING | ||
self.client.cancel(self.futures) | ||
self._state = PipelineState.CANCELLED | ||
return self._state | ||
|
||
def metrics(self): | ||
# TODO(alxr): Collect and return metrics... | ||
raise NotImplementedError('collecting metrics will come later!') | ||
|
||
|
||
class DaskRunner(BundleBasedDirectRunner): | ||
"""Executes a pipeline on a Dask distributed client.""" | ||
@staticmethod | ||
def to_dask_bag_visitor() -> PipelineVisitor: | ||
from dask import bag as db | ||
|
||
@dataclasses.dataclass | ||
class DaskBagVisitor(PipelineVisitor): | ||
bags: t.Dict[AppliedPTransform, | ||
db.Bag] = dataclasses.field(default_factory=dict) | ||
|
||
def visit_transform(self, transform_node: AppliedPTransform) -> None: | ||
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp) | ||
op = op_class(transform_node) | ||
|
||
inputs = list(transform_node.inputs) | ||
if inputs: | ||
bag_inputs = [] | ||
for input_value in inputs: | ||
if isinstance(input_value, pvalue.PBegin): | ||
bag_inputs.append(None) | ||
|
||
prev_op = input_value.producer | ||
if prev_op in self.bags: | ||
bag_inputs.append(self.bags[prev_op]) | ||
|
||
if len(bag_inputs) == 1: | ||
self.bags[transform_node] = op.apply(bag_inputs[0]) | ||
else: | ||
self.bags[transform_node] = op.apply(bag_inputs) | ||
|
||
else: | ||
self.bags[transform_node] = op.apply(None) | ||
|
||
return DaskBagVisitor() | ||
|
||
@staticmethod | ||
def is_fnapi_compatible(): | ||
return False | ||
|
||
def run_pipeline(self, pipeline, options): | ||
# TODO(alxr): Create interactive notebook support. | ||
if is_in_notebook(): | ||
raise NotImplementedError('interactive support will come later!') | ||
|
||
try: | ||
import dask.distributed as ddist | ||
except ImportError: | ||
raise ImportError( | ||
'DaskRunner is not available. Please install apache_beam[dask].') | ||
|
||
dask_options = options.view_as(DaskOptions).get_all_options( | ||
drop_default=True) | ||
client = ddist.Client(**dask_options) | ||
|
||
pipeline.replace_all(dask_overrides()) | ||
|
||
dask_visitor = self.to_dask_bag_visitor() | ||
pipeline.visit(dask_visitor) | ||
|
||
futures = client.compute(list(dask_visitor.bags.values())) | ||
return DaskRunnerResult(client, futures) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# | ||
# Licensed to the Apache Software Foundation (ASF) under one or more | ||
# contributor license agreements. See the NOTICE file distributed with | ||
# this work for additional information regarding copyright ownership. | ||
# The ASF licenses this file to You under the Apache License, Version 2.0 | ||
# (the "License"); you may not use this file except in compliance with | ||
# the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import inspect | ||
import unittest | ||
|
||
import apache_beam as beam | ||
from apache_beam.options.pipeline_options import PipelineOptions | ||
from apache_beam.testing import test_pipeline | ||
from apache_beam.testing.util import assert_that | ||
from apache_beam.testing.util import equal_to | ||
|
||
try: | ||
from apache_beam.runners.dask.dask_runner import DaskOptions | ||
from apache_beam.runners.dask.dask_runner import DaskRunner | ||
import dask | ||
import dask.distributed as ddist | ||
except (ImportError, ModuleNotFoundError): | ||
raise unittest.SkipTest('Dask must be installed to run tests.') | ||
|
||
|
||
class DaskOptionsTest(unittest.TestCase): | ||
def test_parses_connection_timeout__defaults_to_none(self): | ||
default_options = PipelineOptions([]) | ||
default_dask_options = default_options.view_as(DaskOptions) | ||
self.assertEqual(None, default_dask_options.timeout) | ||
|
||
def test_parses_connection_timeout__parses_int(self): | ||
conn_options = PipelineOptions('--dask_connection_timeout 12'.split()) | ||
dask_conn_options = conn_options.view_as(DaskOptions) | ||
self.assertEqual(12, dask_conn_options.timeout) | ||
|
||
def test_parses_connection_timeout__handles_bad_input(self): | ||
err_options = PipelineOptions('--dask_connection_timeout foo'.split()) | ||
dask_err_options = err_options.view_as(DaskOptions) | ||
self.assertEqual(dask.config.no_default, dask_err_options.timeout) | ||
|
||
def test_parser_destinations__agree_with_dask_client(self): | ||
options = PipelineOptions( | ||
'--dask_client_address localhost:8080 --dask_connection_timeout 600 ' | ||
'--dask_scheduler_file foobar.cfg --dask_client_name charlie ' | ||
'--dask_connection_limit 1024'.split()) | ||
dask_options = options.view_as(DaskOptions) | ||
|
||
# Get the argument names for the constructor. | ||
client_args = list(inspect.signature(ddist.Client).parameters) | ||
|
||
for opt_name in dask_options.get_all_options(drop_default=True).keys(): | ||
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'): | ||
self.assertIn(opt_name, client_args) | ||
|
||
|
||
class DaskRunnerRunPipelineTest(unittest.TestCase): | ||
"""Test class used to introspect the dask runner via a debugger.""" | ||
def setUp(self) -> None: | ||
self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) | ||
|
||
def test_create(self): | ||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | ||
assert_that(pcoll, equal_to([1])) | ||
|
||
def test_create_and_map(self): | ||
def double(x): | ||
return x * 2 | ||
|
||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | beam.Map(double) | ||
assert_that(pcoll, equal_to([2])) | ||
|
||
def test_create_map_and_groupby(self): | ||
def double(x): | ||
return x * 2, x | ||
|
||
with self.pipeline as p: | ||
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() | ||
assert_that(pcoll, equal_to([(2, [1])])) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does Beam typically handle the lifetime of runners? In the tests, I see warnings about re-using port 8787 from Dask, since the
client
(and cluster) aren't being completely cleaned up between tests.Is it more common for beam to create (and clean up) the runner? Or would users typically create it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is my first runner – @pabloem can probably weigh in better than I can wrt your question. However, what makes sense to me is that each Beam runner should clean up its environment between each run, including in tests.
This probably should happen in the
DaskRunnerResult
object. Do you have any recommendations on the best way to clean up dask (distributed)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In a single scope,
But in this case, as you say, you'll need to call it after the results are done. So I think that something like
should do the trick (assuming that beam is the one managing the lifetime of the client.
If you want to rely on the user having a client active, you can call
dask.distributed.get_client()
, which will raise aValueError
if one hasn't already been created.