Skip to content

Commit

Permalink
DLQ support in RunInference (#26261)
Browse files Browse the repository at this point in the history
* DLQ support in RunInference

* Doc example

* Comment

* CHANGES.md
  • Loading branch information
damccorm authored Apr 14, 2023
1 parent 82699be commit b9f27f9
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 14 deletions.
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Dead letter queue support added to RunInference in Python ([#24209](https://github.com/apache/beam/issues/24209)).

## Breaking Changes

Expand Down
76 changes: 64 additions & 12 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(
self._metrics_namespace = metrics_namespace
self._model_metadata_pcoll = model_metadata_pcoll
self._enable_side_input_loading = self._model_metadata_pcoll is not None
self._with_exception_handling = False

# TODO(BEAM-14046): Add and link to help documentation.
@classmethod
Expand Down Expand Up @@ -368,20 +369,71 @@ def expand(
# batching DoFn APIs.
| beam.BatchElements(**self._model_handler.batch_elements_kwargs()))

run_inference_pardo = beam.ParDo(
_RunInferenceDoFn(
self._model_handler,
self._clock,
self._metrics_namespace,
self._enable_side_input_loading),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
) if self._enable_side_input_loading else None).with_resource_hints(
**resource_hints)

if self._with_exception_handling:
run_inference_pardo = run_inference_pardo.with_exception_handling(
exc_class=self._exc_class,
use_subprocess=self._use_subprocess,
threshold=self._threshold)

return (
batched_elements_pcoll
| 'BeamML_RunInference' >> (
beam.ParDo(
_RunInferenceDoFn(
self._model_handler,
self._clock,
self._metrics_namespace,
self._enable_side_input_loading),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
) if self._enable_side_input_loading else
None).with_resource_hints(**resource_hints)))
| 'BeamML_RunInference' >> run_inference_pardo)

def with_exception_handling(
self, *, exc_class=Exception, use_subprocess=False, threshold=1):
"""Automatically provides a dead letter output for skipping bad records.
This can allow a pipeline to continue successfully rather than fail or
continuously throw errors on retry when bad elements are encountered.
This returns a tagged output with two PCollections, the first being the
results of successfully processing the input PCollection, and the second
being the set of bad batches of records (those which threw exceptions
during processing) along with information about the errors raised.
For example, one would write::
good, bad = RunInference(
maybe_error_raising_model_handler
).with_exception_handling()
and `good` will be a PCollection of PredictionResults and `bad` will
contain a tuple of all batches that raised exceptions, along with their
corresponding exception.
Args:
exc_class: An exception class, or tuple of exception classes, to catch.
Optional, defaults to 'Exception'.
use_subprocess: Whether to execute the DoFn logic in a subprocess. This
allows one to recover from errors that can crash the calling process
(e.g. from an underlying library causing a segfault), but is
slower as elements and results must cross a process boundary. Note
that this starts up a long-running process that is used to handle
all the elements (until hard failure, which should be rare) rather
than a new process per element, so the overhead should be minimal
(and can be amortized if there's any per-process or per-bundle
initialization that needs to be done). Optional, defaults to False.
threshold: An upper bound on the ratio of records that can be bad before
aborting the entire pipeline. Optional, defaults to 1.0 (meaning
up to 100% of records can be bad and the pipeline will still succeed).
"""
self._with_exception_handling = True
self._exc_class = exc_class
self._use_subprocess = use_subprocess
self._threshold = threshold
return self


class _MetricsCollector:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def predict(self, example: int) -> int:


class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(self, clock=None):
def __init__(self, clock=None, min_batch_size=1, max_batch_size=9999):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size

def load_model(self):
if self._fake_clock:
Expand All @@ -69,6 +71,12 @@ def run_inference(
def update_model_path(self, model_path: Optional[str] = None):
pass

def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
Expand Down Expand Up @@ -171,6 +179,24 @@ def test_run_inference_impl_with_maybe_keyed_examples(self):
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')

def test_run_inference_impl_dlq(self):
with TestPipeline() as pipeline:
examples = [1, 'TEST', 3, 10, 'TEST2']
expected_good = [2, 4, 11]
expected_bad = ['TEST', 'TEST2']
pcoll = pipeline | 'start' >> beam.Create(examples)
good, bad = pcoll | base.RunInference(
FakeModelHandler(
min_batch_size=1,
max_batch_size=1
)).with_exception_handling()
assert_that(good, equal_to(expected_good), label='assert:inferences')

# bad will be in form [batch[elements], error]. Just pull out bad element.
bad_without_error = bad | beam.Map(lambda x: x[0][0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')

def test_run_inference_impl_inference_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down

0 comments on commit b9f27f9

Please sign in to comment.