Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial DaskRunner for Beam #22421

Merged
merged 69 commits into from
Oct 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
79d4603
WIP: Created a skeleton dask runner implementation.
alxmrs Jun 22, 2022
248ec70
WIP: Idea for a translation evaluator.
alxmrs Jun 23, 2022
42452ca
Added overrides and a visitor that translates operations.
alxmrs Jul 2, 2022
1da2ddd
Fixed a dataclass typo.
alxmrs Jul 2, 2022
14885a3
Expanded translations.
alxmrs Jul 2, 2022
fca2420
Core idea seems to be kinda working...
alxmrs Jul 2, 2022
6dd1ada
First iteration on DaskRunnerResult (keep track of pipeline state).
alxmrs Jul 3, 2022
6675687
Added minimal set of DaskRunner options.
alxmrs Jul 4, 2022
88ed36b
WIP: Alllmost got asserts to work! The current status is:
alxmrs Jul 8, 2022
2e3a126
With a great 1-liner from @pabloem, groupby is fixed! Now, all three …
alxmrs Jul 8, 2022
6467b0e
Self-review: Cleaned up dask runner impl.
alxmrs Jul 8, 2022
793ba86
Self-review: Remove TODOs, delete commented out code, other cleanup.
alxmrs Jul 8, 2022
e535792
First pass at linting rules.
alxmrs Jul 9, 2022
8e32668
WIP, include dask dependencies + test setup.
alxmrs Jul 9, 2022
318afc2
WIP: maybe better dask deps?
alxmrs Jul 9, 2022
b01855f
Skip dask tests depending on successful import.
alxmrs Jul 10, 2022
2c2eb8d
Fixed setup.py (missing `,`).
alxmrs Jul 11, 2022
e64e9eb
Added an additional comma.
alxmrs Jul 11, 2022
69b118b
Moved skipping logic to be above dask import.
alxmrs Jul 11, 2022
9ffc8d8
Fix lint issues with dask runner tests.
alxmrs Sep 5, 2022
8a2afb7
Adding destination for client address.
alxmrs Sep 20, 2022
93f02f1
Changing to async produces a timeout error instead of stuck in infini…
alxmrs Sep 21, 2022
afdcf1b
Close client during `wait_until_finish`; rm async.
alxmrs Sep 22, 2022
41b5267
Supporting side-inputs for ParDo.
alxmrs Oct 2, 2022
e3ac3f8
Revert "Close client during `wait_until_finish`; rm async."
pabloem Sep 28, 2022
3fddc81
Revert "Changing to async produces a timeout error instead of stuck i…
pabloem Sep 28, 2022
9eeb9ea
Adding -dask tox targets onto the gradle build
pabloem Sep 28, 2022
b4d0999
wip - added print stmt.
alxmrs Oct 2, 2022
0319ffd
wip - prove side inputs is set.
alxmrs Oct 2, 2022
0b13bb0
wip - prove side inputs is set in Pardo.
alxmrs Oct 2, 2022
1e7052b
wip - rm asserts, add print
alxmrs Oct 2, 2022
292e023
wip - adding named inputs...
alxmrs Oct 2, 2022
31c1e2b
Experiments: non-named side inputs + del `None` in named inputs.
alxmrs Oct 2, 2022
f4ecf2f
None --> 'None'
alxmrs Oct 2, 2022
4d24ed9
No default side input.
alxmrs Oct 2, 2022
ee62a4a
Pass along args + kwargs.
alxmrs Oct 2, 2022
506c719
Applied yapf to dask sources.
alxmrs Oct 9, 2022
cd0ba8b
Dask sources passing pylint.
alxmrs Oct 9, 2022
d0a7c63
Added dask extra to docs gen tox env.
alxmrs Oct 9, 2022
775bd07
Applied yapf from tox.
alxmrs Oct 9, 2022
efba1c9
Include dask in mypy checks.
alxmrs Oct 9, 2022
741d961
Upgrading mypy support to python 3.8 since py37 support is deprecated…
alxmrs Oct 9, 2022
f66458a
Manually installing an old version of dask before 3.7 support was dro…
alxmrs Oct 9, 2022
5dcf969
fix lint: line too long.
alxmrs Oct 9, 2022
ec5f613
Fixed type errors with DaskRunnerResult. Disabled mypy type checking …
alxmrs Oct 9, 2022
04b1f1a
Fix pytype errors (in transform_evaluator).
alxmrs Oct 9, 2022
712944b
Ran isort.
alxmrs Oct 9, 2022
567b72b
Ran yapf again.
alxmrs Oct 9, 2022
f53c0a4
Fix imports (one per line)
alxmrs Oct 10, 2022
fb280ad
isort -- alphabetical.
alxmrs Oct 10, 2022
80ddfec
Added feature to CHANGES.md.
alxmrs Oct 10, 2022
40c4e35
ran yapf via tox on linux machine
alxmrs Oct 10, 2022
a70c5f3
Merge branch 'master' into dask-runner-mvp
alxmrs Oct 13, 2022
9fb52e5
Change an import to pass CI.
alxmrs Oct 13, 2022
26c6016
Skip isort error; needed to get CI to pass.
alxmrs Oct 13, 2022
aec19bf
Skip test logic may favor better with isort.
alxmrs Oct 13, 2022
0673235
(Maybe) the last isort fix.
alxmrs Oct 13, 2022
de03a32
Tested pipeline options (added one fix).
alxmrs Oct 14, 2022
7e0a2c7
Improve formatting of test.
alxmrs Oct 14, 2022
39b1e1c
Self-review: removing side inputs.
alxmrs Oct 14, 2022
6db49fa
add dask to coverage suite in tox.
alxmrs Oct 17, 2022
036561c
Merge branch 'master' into dask-runner-mvp
alxmrs Oct 18, 2022
191580d
Capture value error in assert.
alxmrs Oct 18, 2022
365fc87
Merge branch 'master' of github.com:apache/beam into dask-runner-mvp
alxmrs Oct 18, 2022
085447e
Change timeout value to 600 seconds.
alxmrs Oct 18, 2022
1a60a5e
ignoring broken test
pabloem Oct 21, 2022
c1037f8
Update CHANGES.md
pabloem Oct 21, 2022
9e79ffd
Using reflection to test the Dask client constructor.
alxmrs Oct 24, 2022
f9cf45a
Better method of inspecting the constructor parameters (thanks @TomAu…
alxmrs Oct 24, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
## Highlights

* Python 3.10 support in Apache Beam ([#21458](https://github.com/apache/beam/issues/21458)).
* An initial implementation of a runner that allows us to run Beam pipelines on Dask. Try it out and give us feedback! (Python) ([#18962](https://github.com/apache/beam/issues/18962)).


## I/Os
Expand All @@ -71,6 +72,7 @@
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Dataframe wrapper added in Go SDK via Cross-Language (Need to manually start python expansion service). (Go) ([#23384](https://github.com/apache/beam/issues/23384)).
* Name all Java threads to aid in debugging ([#23049](https://github.com/apache/beam/issues/23049)).
* An initial implementation of a runner that allows us to run Beam pipelines on Dask. (Python) ([#18962](https://github.com/apache/beam/issues/18962)).

## Breaking Changes

Expand Down
16 changes: 16 additions & 0 deletions sdks/python/apache_beam/runners/dask/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
182 changes: 182 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""DaskRunner, executing remote jobs on Dask.distributed.

The DaskRunner is a runner implementation that executes a graph of
transformations across processes and workers via Dask distributed's
scheduler.
"""
import argparse
import dataclasses
import typing as t

from apache_beam import pvalue
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.pipeline import AppliedPTransform
from apache_beam.pipeline import PipelineVisitor
from apache_beam.runners.dask.overrides import dask_overrides
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
from apache_beam.runners.dask.transform_evaluator import NoOp
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineState
from apache_beam.utils.interactive_utils import is_in_notebook


class DaskOptions(PipelineOptions):
@staticmethod
def _parse_timeout(candidate):
try:
return int(candidate)
except (TypeError, ValueError):
import dask
return dask.config.no_default

@classmethod
def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
parser.add_argument(
'--dask_client_address',
dest='address',
type=str,
default=None,
help='Address of a dask Scheduler server. Will default to a '
'`dask.LocalCluster()`.')
parser.add_argument(
'--dask_connection_timeout',
dest='timeout',
type=DaskOptions._parse_timeout,
help='Timeout duration for initial connection to the scheduler.')
parser.add_argument(
'--dask_scheduler_file',
dest='scheduler_file',
type=str,
default=None,
help='Path to a file with scheduler information if available.')
# TODO(alxr): Add options for security.
parser.add_argument(
'--dask_client_name',
dest='name',
type=str,
default=None,
help='Gives the client a name that will be included in logs generated '
'on the scheduler for matters relating to this client.')
parser.add_argument(
'--dask_connection_limit',
dest='connection_limit',
type=int,
default=512,
help='The number of open comms to maintain at once in the connection '
'pool.')


@dataclasses.dataclass
class DaskRunnerResult(PipelineResult):
from dask import distributed

client: distributed.Client
futures: t.Sequence[distributed.Future]

def __post_init__(self):
super().__init__(PipelineState.RUNNING)

def wait_until_finish(self, duration=None) -> str:
try:
if duration is not None:
# Convert milliseconds to seconds
duration /= 1000
self.client.wait_for_workers(timeout=duration)
self.client.gather(self.futures, errors='raise')
self._state = PipelineState.DONE
except: # pylint: disable=broad-except
self._state = PipelineState.FAILED
raise
return self._state

def cancel(self) -> str:
self._state = PipelineState.CANCELLING
self.client.cancel(self.futures)
self._state = PipelineState.CANCELLED
return self._state

def metrics(self):
# TODO(alxr): Collect and return metrics...
raise NotImplementedError('collecting metrics will come later!')


class DaskRunner(BundleBasedDirectRunner):
"""Executes a pipeline on a Dask distributed client."""
@staticmethod
def to_dask_bag_visitor() -> PipelineVisitor:
from dask import bag as db

@dataclasses.dataclass
class DaskBagVisitor(PipelineVisitor):
bags: t.Dict[AppliedPTransform,
db.Bag] = dataclasses.field(default_factory=dict)

def visit_transform(self, transform_node: AppliedPTransform) -> None:
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
op = op_class(transform_node)

inputs = list(transform_node.inputs)
if inputs:
bag_inputs = []
for input_value in inputs:
if isinstance(input_value, pvalue.PBegin):
bag_inputs.append(None)

prev_op = input_value.producer
if prev_op in self.bags:
bag_inputs.append(self.bags[prev_op])

if len(bag_inputs) == 1:
self.bags[transform_node] = op.apply(bag_inputs[0])
else:
self.bags[transform_node] = op.apply(bag_inputs)

else:
self.bags[transform_node] = op.apply(None)

return DaskBagVisitor()

@staticmethod
def is_fnapi_compatible():
return False

def run_pipeline(self, pipeline, options):
# TODO(alxr): Create interactive notebook support.
if is_in_notebook():
raise NotImplementedError('interactive support will come later!')

try:
import dask.distributed as ddist
except ImportError:
raise ImportError(
'DaskRunner is not available. Please install apache_beam[dask].')

dask_options = options.view_as(DaskOptions).get_all_options(
drop_default=True)
client = ddist.Client(**dask_options)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does Beam typically handle the lifetime of runners? In the tests, I see warnings about re-using port 8787 from Dask, since the client (and cluster) aren't being completely cleaned up between tests.

Is it more common for beam to create (and clean up) the runner? Or would users typically create it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my first runner – @pabloem can probably weigh in better than I can wrt your question. However, what makes sense to me is that each Beam runner should clean up its environment between each run, including in tests.

This probably should happen in the DaskRunnerResult object. Do you have any recommendations on the best way to clean up dask (distributed)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a single scope,

with distributed.Client(...) as client:
    ...

But in this case, as you say, you'll need to call it after the results are done. So I think that something like

client.close()
client.cluster.close()

should do the trick (assuming that beam is the one managing the lifetime of the client.

If you want to rely on the user having a client active, you can call dask.distributed.get_client(), which will raise a ValueError if one hasn't already been created.


pipeline.replace_all(dask_overrides())

dask_visitor = self.to_dask_bag_visitor()
pipeline.visit(dask_visitor)

futures = client.compute(list(dask_visitor.bags.values()))
return DaskRunnerResult(client, futures)
94 changes: 94 additions & 0 deletions sdks/python/apache_beam/runners/dask/dask_runner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import unittest

import apache_beam as beam
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing import test_pipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to

try:
from apache_beam.runners.dask.dask_runner import DaskOptions
from apache_beam.runners.dask.dask_runner import DaskRunner
import dask
import dask.distributed as ddist
except (ImportError, ModuleNotFoundError):
raise unittest.SkipTest('Dask must be installed to run tests.')


class DaskOptionsTest(unittest.TestCase):
def test_parses_connection_timeout__defaults_to_none(self):
default_options = PipelineOptions([])
default_dask_options = default_options.view_as(DaskOptions)
self.assertEqual(None, default_dask_options.timeout)

def test_parses_connection_timeout__parses_int(self):
conn_options = PipelineOptions('--dask_connection_timeout 12'.split())
dask_conn_options = conn_options.view_as(DaskOptions)
self.assertEqual(12, dask_conn_options.timeout)

def test_parses_connection_timeout__handles_bad_input(self):
err_options = PipelineOptions('--dask_connection_timeout foo'.split())
dask_err_options = err_options.view_as(DaskOptions)
self.assertEqual(dask.config.no_default, dask_err_options.timeout)

def test_parser_destinations__agree_with_dask_client(self):
options = PipelineOptions(
'--dask_client_address localhost:8080 --dask_connection_timeout 600 '
'--dask_scheduler_file foobar.cfg --dask_client_name charlie '
'--dask_connection_limit 1024'.split())
dask_options = options.view_as(DaskOptions)

# Get the argument names for the constructor.
client_args = list(inspect.signature(ddist.Client).parameters)

for opt_name in dask_options.get_all_options(drop_default=True).keys():
with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
self.assertIn(opt_name, client_args)


class DaskRunnerRunPipelineTest(unittest.TestCase):
"""Test class used to introspect the dask runner via a debugger."""
def setUp(self) -> None:
self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner())

def test_create(self):
with self.pipeline as p:
pcoll = p | beam.Create([1])
assert_that(pcoll, equal_to([1]))

def test_create_and_map(self):
def double(x):
return x * 2

with self.pipeline as p:
pcoll = p | beam.Create([1]) | beam.Map(double)
assert_that(pcoll, equal_to([2]))

def test_create_map_and_groupby(self):
def double(x):
return x * 2, x

with self.pipeline as p:
pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
assert_that(pcoll, equal_to([(2, [1])]))


if __name__ == '__main__':
unittest.main()
Loading