diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 6d2552a2a6a1..35cefdb10175 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1687,7 +1687,8 @@ def with_exception_handling( error_handler, on_failure_callback, allow_unsafe_userstate_in_process, - self.get_resource_hints()) + self.get_resource_hints(), + self.get_type_hints()) def with_error_handler(self, error_handler, **exception_handling_kwargs): """An alias for `with_exception_handling(error_handler=error_handler, ...)` @@ -1979,6 +1980,15 @@ def expand(self, pcoll): self._main_tag, self._allow_unknown_tags) + def with_exception_handling(self, main_tag=None, **kwargs): + if main_tag is None: + main_tag = self._main_tag or 'good' + named = self._do_transform.with_exception_handling( + main_tag=main_tag, **kwargs) + # named is _NamedPTransform wrapping _ExceptionHandlingWrapper + named.transform._extra_tags = self._tags + return named + class DoFnInfo(object): """This class represents the state in the ParDoPayload's function spec, @@ -2320,7 +2330,8 @@ def __init__( error_handler, on_failure_callback, allow_unsafe_userstate_in_process, - resource_hints): + resource_hints, + pardo_type_hints=None): if partial and use_subprocess: raise ValueError('partial and use_subprocess are mutually incompatible.') self._fn = fn @@ -2338,8 +2349,17 @@ def __init__( self._on_failure_callback = on_failure_callback self._allow_unsafe_userstate_in_process = allow_unsafe_userstate_in_process self._resource_hints = resource_hints + self._pardo_type_hints = pardo_type_hints + self._extra_tags = None - def expand(self, pcoll): + def with_outputs(self, *tags, main=None): + self._extra_tags = tags + if main is not None: + self._main_tag = main + return self + + def _build_pardo(self, pcoll): + """Build the inner ParDo with the exception-handling wrapper DoFn.""" if self._allow_unsafe_userstate_in_process: if self._use_subprocess or self._timeout: # TODO(https://github.com/apache/beam/issues/35976): Implement this @@ -2366,15 +2386,11 @@ def expand(self, pcoll): *self._args, **self._kwargs, ) - # This is the fix: propagate hints. pardo.get_resource_hints().update(self._resource_hints) + return pardo - result = pcoll | pardo.with_outputs( - self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) - #TODO(BEAM-18957): Fix when type inference supports tagged outputs. - result[self._main_tag].element_type = self._fn.infer_output_type( - pcoll.element_type) - + def _post_process_result(self, pcoll, result): + """Apply threshold checking and error handler logic to the result.""" if self._threshold < 1.0: class MaybeWindow(ptransform.PTransform): @@ -2408,10 +2424,54 @@ def check_threshold(bad, total, threshold, window=DoFn.WindowParam): if self._error_handler: self._error_handler.add_error_pcollection(result[self._dead_letter_tag]) + if self._extra_tags: + return result return result[self._main_tag] else: return result + def expand_2_71_0(self, pcoll): + """Pre-2.72.0 behavior: manual element_type override, no with_output_types. + """ + pardo = self._build_pardo(pcoll) + result = pcoll | pardo.with_outputs( + self._dead_letter_tag, main=self._main_tag, allow_unknown_tags=True) + #TODO(BEAM-18957): Fix when type inference supports tagged outputs. + result[self._main_tag].element_type = self._fn.infer_output_type( + pcoll.element_type) + + return self._post_process_result(pcoll, result) + + def expand(self, pcoll): + if pcoll.pipeline.options.is_compat_version_prior_to("2.72.0"): + return self.expand_2_71_0(pcoll) + + pardo = self._build_pardo(pcoll) + + if (self._pardo_type_hints and self._pardo_type_hints._has_output_types()): + main_output_type = self._pardo_type_hints.simple_output_type(self.label) + tagged_type_hints = dict(self._pardo_type_hints.tagged_output_types()) + else: + main_output_type = self._fn.infer_output_type(pcoll.element_type) + tagged_type_hints = dict(self._fn.get_type_hints().tagged_output_types()) + + # Dead letter format: Tuple[element, Tuple[exception_type, repr, traceback]] + dead_letter_type = typehints.Tuple[pcoll.element_type, + typehints.Tuple[type, + str, + typehints.List[str]]] + + tagged_type_hints[self._dead_letter_tag] = dead_letter_type + pardo = pardo.with_output_types(main_output_type, **tagged_type_hints) + + all_tags = tuple(set(self._extra_tags or ()) | {self._dead_letter_tag}) + result = pcoll | pardo.with_outputs( + *all_tags, + main=self._main_tag, + allow_unknown_tags=True if self._extra_tags is None else None) + + return self._post_process_result(pcoll, result) + class _ExceptionHandlingWrapperDoFn(DoFn): def __init__( diff --git a/sdks/python/apache_beam/transforms/core_test.py b/sdks/python/apache_beam/transforms/core_test.py index 73f004c130c2..a38fc469b152 100644 --- a/sdks/python/apache_beam/transforms/core_test.py +++ b/sdks/python/apache_beam/transforms/core_test.py @@ -22,12 +22,15 @@ import os import tempfile import unittest +from typing import Iterable +from typing import Literal from typing import TypeVar import pytest import apache_beam as beam from apache_beam.coders import coders +from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.resources import ResourceHint @@ -515,6 +518,222 @@ class TagHint(ResourceHint): ) +class ExceptionHandlingWithOutputsTest(unittest.TestCase): + """Tests for combining with_exception_handling() and with_outputs().""" + def _create_dofn_with_tagged_outputs(self): + """A DoFn that yields tagged outputs and can raise on even numbers.""" + class DoWithFailures(beam.DoFn): + def process( + self, element: int + ) -> Iterable[int + | beam.pvalue.TaggedOutput[Literal['threes'], int] + | beam.pvalue.TaggedOutput[Literal['fives'], str]]: + if element % 2 == 0: + raise ValueError(f'Even numbers not allowed {element}') + if element % 3 == 0: + yield beam.pvalue.TaggedOutput('threes', element) # type: ignore[misc] + elif element % 5 == 0: + yield beam.pvalue.TaggedOutput('fives', str(element)) # type: ignore[misc] + else: + yield element + + return DoWithFailures() + + def test_with_exception_handling_then_with_outputs(self): + """Direction 1: .with_exception_handling().with_outputs()""" + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3, 4, 5, 6, 7]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()). + with_exception_handling().with_outputs( + 'threes', 'fives', main='main')) + + assert_that(results.main, equal_to([1, 7]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + assert_that(results.fives, equal_to(['5']), 'fives') + bad_elements = results.bad | beam.Keys() + assert_that(bad_elements, equal_to([2, 4, 6]), 'bad') + # Verify type hints from annotations are propagated + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.threes.element_type, int) + self.assertEqual(results.fives.element_type, str) + self.assertEqual( + results.bad.element_type, + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + + def test_with_outputs_then_with_exception_handling(self): + """Direction 2: .with_outputs().with_exception_handling()""" + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3, 4, 5, 6, 7]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', 'fives', main='main').with_exception_handling()) + + assert_that(results.main, equal_to([1, 7]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + assert_that(results.fives, equal_to(['5']), 'fives') + bad_elements = results.bad | beam.Keys() + assert_that(bad_elements, equal_to([2, 4, 6]), 'bad') + # Verify type hints from annotations are propagated + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.threes.element_type, int) + self.assertEqual(results.fives.element_type, str) + self.assertEqual( + results.bad.element_type, + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + + def test_with_outputs_then_with_exception_handling_custom_dead_letter_tag( + self): + """Direction 2 with custom dead_letter_tag.""" + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', + main='main').with_exception_handling(dead_letter_tag='errors')) + + assert_that(results.main, equal_to([1]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + bad_elements = results.errors | beam.Keys() + assert_that(bad_elements, equal_to([2]), 'errors') + self.assertEqual(results.threes.element_type, int) + self.assertEqual( + results.errors.element_type, + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + + def test_with_exception_handling_then_with_outputs_custom_dead_letter_tag( + self): + """Direction 1 with custom dead_letter_tag.""" + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3]) + | beam.ParDo( + self._create_dofn_with_tagged_outputs()).with_exception_handling( + dead_letter_tag='errors').with_outputs('threes', main='main')) + + assert_that(results.main, equal_to([1]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + bad_elements = results.errors | beam.Keys() + assert_that(bad_elements, equal_to([2]), 'errors') + self.assertEqual(results.threes.element_type, int) + self.assertEqual( + results.errors.element_type, + typehints.Tuple[int, typehints.Tuple[type, str, typehints.List[str]]]) + + def test_exception_handling_no_with_outputs_backward_compat(self): + """Without with_outputs(), behavior is unchanged.""" + + with beam.Pipeline() as p: + good, bad = ( + p + | beam.Create([1, 2, 7]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling()) + + assert_that(good, equal_to([1, 7]), 'good') + bad_elements = bad | beam.Keys() + assert_that(bad_elements, equal_to([2]), 'bad') + + def test_exception_handling_compat_version_uses_old_behavior(self): + """With compat version < 2.72.0, old expand path is used.""" + options = PipelineOptions(update_compatibility_version="2.71.0") + with beam.Pipeline(options=options) as p: + good, bad = ( + p + | beam.Create([1, 2, 7]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()) + .with_exception_handling()) + + assert_that(good, equal_to([1, 7]), 'good') + bad_elements = bad | beam.Keys() + assert_that(bad_elements, equal_to([2]), 'bad') + + def test_exception_handling_compat_version_element_type_set_manually(self): + """With compat version < 2.72.0, element_type is set via manual override + (the old behavior) rather than via with_output_types.""" + + options = PipelineOptions(update_compatibility_version="2.71.0") + with beam.Pipeline(options=options) as p: + results = ( + p + | beam.Create([1, 2, 3]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()). + with_exception_handling().with_outputs('threes', main='main')) + + # In old path, dead letter type is Any (no with_output_types call) + self.assertEqual(results.bad.element_type, typehints.Any) + # Tagged outputs still get types from DoFn Literal annotations + # (via DoOutputsTuple.__getitem__ reading tagged_output_types) + self.assertEqual(results.threes.element_type, int) + # Main output type should still be inferred via manual override + assert_that(results.main, equal_to([1]), 'main') + + def test_with_outputs_then_exception_handling_with_map(self): + """with_outputs().with_exception_handling() also works on Map.""" + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3, 4, 5]) + | beam.Map(lambda x: x if x % 2 != 0 else 1 / 0).with_outputs( + main='main').with_exception_handling()) + assert_that(results.main, equal_to([1, 3, 5]), 'main') + bad_elements = results.bad | beam.Keys() + assert_that(bad_elements, equal_to([2, 4]), 'bad') + + def test_with_output_types_chained_on_pardo(self): + """When type hints are chained on the ParDo (not annotations on the DoFn), + tagged output types should still be propagated through + with_exception_handling().with_outputs().""" + class DoWithFailuresNoAnnotations(beam.DoFn): + def process(self, element): + if element % 2 == 0: + raise ValueError(f'Even numbers not allowed {element}') + if element % 3 == 0: + yield beam.pvalue.TaggedOutput('threes', element) + else: + yield element + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([1, 2, 3, 7]) + | beam.ParDo(DoWithFailuresNoAnnotations()).with_output_types( + int, threes=int).with_exception_handling().with_outputs( + 'threes', main='main')) + + assert_that(results.main, equal_to([1, 7]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + bad_elements = results.bad | beam.Keys() + assert_that(bad_elements, equal_to([2]), 'bad') + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.threes.element_type, int) + + def test_with_outputs_and_error_handler(self): + """with_outputs() + error_handler should return DoOutputsTuple, not a + bare PCollection.""" + from apache_beam.transforms.error_handling import ErrorHandler + with beam.Pipeline() as p: + with ErrorHandler(beam.Map(lambda x: x)) as handler: + results = ( + p + | beam.Create([1, 2, 3, 4, 5, 6, 7]) + | beam.ParDo(self._create_dofn_with_tagged_outputs()).with_outputs( + 'threes', 'fives', + main='main').with_exception_handling(error_handler=handler)) + + assert_that(results.main, equal_to([1, 7]), 'main') + assert_that(results.threes, equal_to([3]), 'threes') + assert_that(results.fives, equal_to(['5']), 'fives') + + def test_callablewrapper_typehint(): T = TypeVar("T")