Skip to content

Commit

Permalink
Try deepcopy combine_fn and fallback to pickling if TypeError. (#32645)
Browse files Browse the repository at this point in the history
* Try deepcopy combine_fn and fallback to pickling if TypeError.

* Remove logging, add unit test

* Linter fixes

---------

Co-authored-by: Claude <cvandermerwe@google.com>
  • Loading branch information
claudevdm and Claude authored Oct 9, 2024
1 parent 2ee6100 commit 3516645
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 22 deletions.
33 changes: 33 additions & 0 deletions sdks/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# pytype: skip-file

import math
from typing import Set
from typing import Tuple

Expand Down Expand Up @@ -124,6 +125,38 @@ def run_combine(pipeline, input_elements=5, lift_combiners=True):
assert_that(pcoll, equal_to([(expected_result, expected_result)]))


def run_combine_uncopyable_attr(
pipeline, input_elements=5, lift_combiners=True):
# Calculate the expected result, which is the sum of an arithmetic sequence.
# By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10
expected_result = input_elements * (input_elements - 1) / 2

# Enable runtime type checking in order to cover TypeCheckCombineFn by
# the test.
pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True
pipeline.get_pipeline_options().view_as(
TypeOptions).allow_unsafe_triggers = True

with pipeline as p:
pcoll = p | 'Start' >> beam.Create(range(input_elements))

# Certain triggers, such as AfterCount, are incompatible with combiner
# lifting. We can use that fact to prevent combiners from being lifted.
if not lift_combiners:
pcoll |= beam.WindowInto(
window.GlobalWindows(),
trigger=trigger.AfterCount(input_elements),
accumulation_mode=trigger.AccumulationMode.DISCARDING)

combine_fn = CallSequenceEnforcingCombineFn()
# Modules are not deep copyable. Ensure fanout falls back to pickling for
# copying combine_fn.
combine_fn.module_attribute = math
pcoll |= 'Do' >> beam.CombineGlobally(combine_fn).with_fanout(fanout=1)

assert_that(pcoll, equal_to([expected_result]))


def run_pardo(pipeline, input_elements=10):
with pipeline as p:
_ = (
Expand Down
18 changes: 14 additions & 4 deletions sdks/python/apache_beam/transforms/combinefn_lifecycle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.transforms.combinefn_lifecycle_pipeline import CallSequenceEnforcingCombineFn
from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine
from apache_beam.transforms.combinefn_lifecycle_pipeline import run_combine_uncopyable_attr
from apache_beam.transforms.combinefn_lifecycle_pipeline import run_pardo


Expand All @@ -53,15 +54,24 @@ def test_combining_value_state(self):


@parameterized_class([
{'runner': direct_runner.BundleBasedDirectRunner},
{'runner': fn_api_runner.FnApiRunner},
]) # yapf: disable
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'dill'},
{'runner': direct_runner.BundleBasedDirectRunner, 'pickler': 'cloudpickle'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'dill'},
{'runner': fn_api_runner.FnApiRunner, 'pickler': 'cloudpickle'},
]) # yapf: disable
class LocalCombineFnLifecycleTest(unittest.TestCase):
def tearDown(self):
CallSequenceEnforcingCombineFn.instances.clear()

def test_combine(self):
run_combine(TestPipeline(runner=self.runner()))
test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"])
run_combine(TestPipeline(runner=self.runner(), options=test_options))
self._assert_teardown_called()

def test_combine_deepcopy_fails(self):
test_options = PipelineOptions(flags=[f"--pickle_library={self.pickler}"])
run_combine_uncopyable_attr(
TestPipeline(runner=self.runner(), options=test_options))
self._assert_teardown_called()

def test_non_liftable_combine(self):
Expand Down
51 changes: 33 additions & 18 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3170,33 +3170,48 @@ def process(self, element):
yield pvalue.TaggedOutput('hot', ((self._nonce % fanout, key), value))

class PreCombineFn(CombineFn):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
try:
self._combine_fn_copy = copy.deepcopy(combine_fn)
except Exception:
self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn))

self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.add_input = self._combine_fn_copy.add_input
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.teardown = self._combine_fn_copy.teardown

@staticmethod
def extract_output(accumulator):
# Boolean indicates this is an accumulator.
return (True, accumulator)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
add_input = combine_fn.add_input
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
teardown = combine_fn.teardown

class PostCombineFn(CombineFn):
@staticmethod
def add_input(accumulator, element):
def __init__(self):
# Deepcopy of the combine_fn to avoid sharing state between lifted
# stages when using cloudpickle.
try:
self._combine_fn_copy = copy.deepcopy(combine_fn)
except Exception:
self._combine_fn_copy = pickler.loads(pickler.dumps(combine_fn))

self.setup = self._combine_fn_copy.setup
self.create_accumulator = self._combine_fn_copy.create_accumulator
self.merge_accumulators = self._combine_fn_copy.merge_accumulators
self.compact = self._combine_fn_copy.compact
self.extract_output = self._combine_fn_copy.extract_output
self.teardown = self._combine_fn_copy.teardown

def add_input(self, accumulator, element):
is_accumulator, value = element
if is_accumulator:
return combine_fn.merge_accumulators([accumulator, value])
return self._combine_fn_copy.merge_accumulators([accumulator, value])
else:
return combine_fn.add_input(accumulator, value)

setup = combine_fn.setup
create_accumulator = combine_fn.create_accumulator
merge_accumulators = combine_fn.merge_accumulators
compact = combine_fn.compact
extract_output = combine_fn.extract_output
teardown = combine_fn.teardown
return self._combine_fn_copy.add_input(accumulator, value)

def StripNonce(nonce_key_value):
(_, key), value = nonce_key_value
Expand Down

0 comments on commit 3516645

Please sign in to comment.