Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1687,7 +1687,8 @@ def with_exception_handling(
error_handler,
on_failure_callback,
allow_unsafe_userstate_in_process,
self.get_resource_hints())
self.get_resource_hints(),
self.get_type_hints())

def with_error_handler(self, error_handler, **exception_handling_kwargs):
"""An alias for `with_exception_handling(error_handler=error_handler, ...)`
Expand Down Expand Up @@ -1979,6 +1980,15 @@ def expand(self, pcoll):
self._main_tag,
self._allow_unknown_tags)

def with_exception_handling(self, main_tag=None, **kwargs):
if main_tag is None:
main_tag = self._main_tag or 'good'
named = self._do_transform.with_exception_handling(
main_tag=main_tag, **kwargs)
# named is _NamedPTransform wrapping _ExceptionHandlingWrapper
named.transform._extra_tags = self._tags
return named


class DoFnInfo(object):
"""This class represents the state in the ParDoPayload's function spec,
Expand Down Expand Up @@ -2320,7 +2330,8 @@ def __init__(
error_handler,
on_failure_callback,
allow_unsafe_userstate_in_process,
resource_hints):
resource_hints,
pardo_type_hints=None):
if partial and use_subprocess:
raise ValueError('partial and use_subprocess are mutually incompatible.')
self._fn = fn
Expand All @@ -2338,8 +2349,17 @@ def __init__(
self._on_failure_callback = on_failure_callback
self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process
self._resource_hints = resource_hints
self._pardo_type_hints = pardo_type_hints
self._extra_tags = None

def expand(self, pcoll):
def with_outputs(self, *tags, main=None):
self._extra_tags = tags
if main is not None:
self._main_tag = main
return self

def _build_pardo(self, pcoll):
"""Build the inner ParDo with the exception-handling wrapper DoFn."""
if self._allow_unsafe_userstate_in_process:
if self._use_subprocess or self._timeout:
# TODO(https://github.com/apache/beam/issues/35976): Implement this
Expand All @@ -2366,15 +2386,11 @@ def expand(self, pcoll):
*self._args,
**self._kwargs,
)
# This is the fix: propagate hints.
pardo.get_resource_hints().update(self._resource_hints)
return pardo

result = pcoll | pardo.with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)

def _post_process_result(self, pcoll, result):
"""Apply threshold checking and error handler logic to the result."""
if self._threshold < 1.0:

class MaybeWindow(ptransform.PTransform):
Expand Down Expand Up @@ -2408,10 +2424,54 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam):

if self._error_handler:
self._error_handler.add_error_pcollection(result[self._dead_letter_tag])
if self._extra_tags:
return result
return result[self._main_tag]
else:
return result

def expand_2_71_0(self, pcoll):
"""Pre-2.72.0 behavior: manual element_type override, no with_output_types.
"""
pardo = self._build_pardo(pcoll)
result = pcoll | pardo.with_outputs(
self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True)
#TODO(BEAM-18957): Fix when type inference supports tagged outputs.
result[self._main_tag].element_type = self._fn.infer_output_type(
pcoll.element_type)

return self._post_process_result(pcoll, result)

def expand(self, pcoll):
if pcoll.pipeline.options.is_compat_version_prior_to("2.72.0"):
return self.expand_2_71_0(pcoll)

pardo = self._build_pardo(pcoll)

if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()):
main_output_type = self._pardo_type_hints.simple_output_type(self.label)
tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types())
else:
main_output_type = self._fn.infer_output_type(pcoll.element_type)
tagged_type_hints = dict(self._fn.get_type_hints().tagged_output_types())

# Dead letter format: Tuple[element, Tuple[exception_type, repr, traceback]]
dead_letter_type = typehints.Tuple[pcoll.element_type,
typehints.Tuple[type,
str,
typehints.List[str]]]

tagged_type_hints[self._dead_letter_tag] = dead_letter_type
pardo = pardo.with_output_types(main_output_type, **tagged_type_hints)

all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag})
result = pcoll | pardo.with_outputs(
*all_tags,
main=self._main_tag,
allow_unknown_tags=True if self._extra_tags is None else None)

return self._post_process_result(pcoll, result)


class _ExceptionHandlingWrapperDoFn(DoFn):
def __init__(
Expand Down
219 changes: 219 additions & 0 deletions sdks/python/apache_beam/transforms/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
import os
import tempfile
import unittest
from typing import Iterable
from typing import Literal
from typing import TypeVar

import pytest

import apache_beam as beam
from apache_beam.coders import coders
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.resources import ResourceHint
Expand Down Expand Up @@ -515,6 +518,222 @@ class TagHint(ResourceHint):
)


class ExceptionHandlingWithOutputsTest(unittest.TestCase):
"""Tests for combining with_exception_handling() and with_outputs()."""
def _create_dofn_with_tagged_outputs(self):
"""A DoFn that yields tagged outputs and can raise on even numbers."""
class DoWithFailures(beam.DoFn):
def process(
self, element: int
) -> Iterable[int
| beam.pvalue.TaggedOutput[Literal['threes'], int]
| beam.pvalue.TaggedOutput[Literal['fives'], str]]:
if element % 2 == 0:
raise ValueError(f'Even numbers not allowed {element}')
if element % 3 == 0:
yield beam.pvalue.TaggedOutput('threes', element) # type: ignore[misc]
elif element % 5 == 0:
yield beam.pvalue.TaggedOutput('fives', str(element)) # type: ignore[misc]
else:
yield element

return DoWithFailures()

def test_with_exception_handling_then_with_outputs(self):
"""Direction 1: .with_exception_handling().with_outputs()"""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
with_exception_handling().with_outputs(
'threes', 'fives', main='main'))

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
assert_that(results.fives, equal_to(['5']), 'fives')
bad_elements = results.bad | beam.Keys()
assert_that(bad_elements, equal_to([2, 4, 6]), 'bad')
# Verify type hints from annotations are propagated
self.assertEqual(results.main.element_type, int)
self.assertEqual(results.threes.element_type, int)
self.assertEqual(results.fives.element_type, str)
self.assertEqual(
results.bad.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])

def test_with_outputs_then_with_exception_handling(self):
"""Direction 2: .with_outputs().with_exception_handling()"""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes', 'fives', main='main').with_exception_handling())

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
assert_that(results.fives, equal_to(['5']), 'fives')
bad_elements = results.bad | beam.Keys()
assert_that(bad_elements, equal_to([2, 4, 6]), 'bad')
# Verify type hints from annotations are propagated
self.assertEqual(results.main.element_type, int)
self.assertEqual(results.threes.element_type, int)
self.assertEqual(results.fives.element_type, str)
self.assertEqual(
results.bad.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])

def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag(
self):
"""Direction 2 with custom dead_letter_tag."""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes',
main='main').with_exception_handling(dead_letter_tag='errors'))

assert_that(results.main, equal_to([1]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
bad_elements = results.errors | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'errors')
self.assertEqual(results.threes.element_type, int)
self.assertEqual(
results.errors.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])

def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag(
self):
"""Direction 1 with custom dead_letter_tag."""

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(
self._create_dofn_with_tagged_outputs()).with_exception_handling(
dead_letter_tag='errors').with_outputs('threes', main='main'))

assert_that(results.main, equal_to([1]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
bad_elements = results.errors | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'errors')
self.assertEqual(results.threes.element_type, int)
self.assertEqual(
results.errors.element_type,
typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]])

def test_exception_handling_no_with_outputs_backward_compat(self):
"""Without with_outputs(), behavior is unchanged."""

with beam.Pipeline() as p:
good, bad = (
p
| beam.Create([1, 2, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling())

assert_that(good, equal_to([1, 7]), 'good')
bad_elements = bad | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'bad')

def test_exception_handling_compat_version_uses_old_behavior(self):
"""With compat version < 2.72.0, old expand path is used."""
options = PipelineOptions(update_compatibility_version="2.71.0")
with beam.Pipeline(options=options) as p:
good, bad = (
p
| beam.Create([1, 2, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs())
.with_exception_handling())

assert_that(good, equal_to([1, 7]), 'good')
bad_elements = bad | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'bad')

def test_exception_handling_compat_version_element_type_set_manually(self):
"""With compat version < 2.72.0, element_type is set via manual override
(the old behavior) rather than via with_output_types."""

options = PipelineOptions(update_compatibility_version="2.71.0")
with beam.Pipeline(options=options) as p:
results = (
p
| beam.Create([1, 2, 3])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).
with_exception_handling().with_outputs('threes', main='main'))

# In old path, dead letter type is Any (no with_output_types call)
self.assertEqual(results.bad.element_type, typehints.Any)
# Tagged outputs still get types from DoFn Literal annotations
# (via DoOutputsTuple.__getitem__ reading tagged_output_types)
self.assertEqual(results.threes.element_type, int)
# Main output type should still be inferred via manual override
assert_that(results.main, equal_to([1]), 'main')

def test_with_outputs_then_exception_handling_with_map(self):
"""with_outputs().with_exception_handling() also works on Map."""
with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3, 4, 5])
| beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs(
main='main').with_exception_handling())
assert_that(results.main, equal_to([1, 3, 5]), 'main')
bad_elements = results.bad | beam.Keys()
assert_that(bad_elements, equal_to([2, 4]), 'bad')

def test_with_output_types_chained_on_pardo(self):
"""When type hints are chained on the ParDo (not annotations on the DoFn),
tagged output types should still be propagated through
with_exception_handling().with_outputs()."""
class DoWithFailuresNoAnnotations(beam.DoFn):
def process(self, element):
if element % 2 == 0:
raise ValueError(f'Even numbers not allowed {element}')
if element % 3 == 0:
yield beam.pvalue.TaggedOutput('threes', element)
else:
yield element

with beam.Pipeline() as p:
results = (
p
| beam.Create([1, 2, 3, 7])
| beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types(
int, threes=int).with_exception_handling().with_outputs(
'threes', main='main'))

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
bad_elements = results.bad | beam.Keys()
assert_that(bad_elements, equal_to([2]), 'bad')
self.assertEqual(results.main.element_type, int)
self.assertEqual(results.threes.element_type, int)

def test_with_outputs_and_error_handler(self):
"""with_outputs() + error_handler should return DoOutputsTuple, not a
bare PCollection."""
from apache_beam.transforms.error_handling import ErrorHandler
with beam.Pipeline() as p:
with ErrorHandler(beam.Map(lambda x: x)) as handler:
results = (
p
| beam.Create([1, 2, 3, 4, 5, 6, 7])
| beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs(
'threes', 'fives',
main='main').with_exception_handling(error_handler=handler))

assert_that(results.main, equal_to([1, 7]), 'main')
assert_that(results.threes, equal_to([3]), 'threes')
assert_that(results.fives, equal_to(['5']), 'fives')


def test_callablewrapper_typehint():
T = TypeVar("T")

Expand Down
Loading