Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DLQ support in RunInference #26261

Merged
merged 5 commits into from
Apr 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

## New Features / Improvements

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
* Dead letter queue support added to RunInference in Python ([#24209](https://github.com/apache/beam/issues/24209)).

## Breaking Changes

Expand Down
76 changes: 64 additions & 12 deletions sdks/python/apache_beam/ml/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def __init__(
self._metrics_namespace = metrics_namespace
self._model_metadata_pcoll = model_metadata_pcoll
self._enable_side_input_loading = self._model_metadata_pcoll is not None
self._with_exception_handling = False

# TODO(BEAM-14046): Add and link to help documentation.
@classmethod
Expand Down Expand Up @@ -368,20 +369,71 @@ def expand(
# batching DoFn APIs.
| beam.BatchElements(**self._model_handler.batch_elements_kwargs()))

run_inference_pardo = beam.ParDo(
_RunInferenceDoFn(
self._model_handler,
self._clock,
self._metrics_namespace,
self._enable_side_input_loading),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
) if self._enable_side_input_loading else None).with_resource_hints(
**resource_hints)

if self._with_exception_handling:
run_inference_pardo = run_inference_pardo.with_exception_handling(
exc_class=self._exc_class,
use_subprocess=self._use_subprocess,
threshold=self._threshold)

return (
batched_elements_pcoll
| 'BeamML_RunInference' >> (
beam.ParDo(
_RunInferenceDoFn(
self._model_handler,
self._clock,
self._metrics_namespace,
self._enable_side_input_loading),
self._inference_args,
beam.pvalue.AsSingleton(
self._model_metadata_pcoll,
) if self._enable_side_input_loading else
None).with_resource_hints(**resource_hints)))
| 'BeamML_RunInference' >> run_inference_pardo)

def with_exception_handling(
self, *, exc_class=Exception, use_subprocess=False, threshold=1):
Copy link
Contributor

@AnandInguva AnandInguva Apr 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we want to provide main_tag and dead_letter_tag? since we already mentioned by default good tag and bad tag so not sure how useful it will be.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had it, but intentionally omitted it since those are primarily useful in the context of having >2 outputs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sg

"""Automatically provides a dead letter output for skipping bad records.
This can allow a pipeline to continue successfully rather than fail or
continuously throw errors on retry when bad elements are encountered.

This returns a tagged output with two PCollections, the first being the
results of successfully processing the input PCollection, and the second
being the set of bad batches of records (those which threw exceptions
during processing) along with information about the errors raised.

For example, one would write::

good, bad = RunInference(
maybe_error_raising_model_handler
).with_exception_handling()

and `good` will be a PCollection of PredictionResults and `bad` will
contain a tuple of all batches that raised exceptions, along with their
corresponding exception.


Args:
exc_class: An exception class, or tuple of exception classes, to catch.
Optional, defaults to 'Exception'.
use_subprocess: Whether to execute the DoFn logic in a subprocess. This
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JFYI, use_subprocess currently has some known issues / limitations, which may surface here as well.

allows one to recover from errors that can crash the calling process
(e.g. from an underlying library causing a segfault), but is
slower as elements and results must cross a process boundary. Note
that this starts up a long-running process that is used to handle
all the elements (until hard failure, which should be rare) rather
than a new process per element, so the overhead should be minimal
(and can be amortized if there's any per-process or per-bundle
initialization that needs to be done). Optional, defaults to False.
threshold: An upper bound on the ratio of records that can be bad before
aborting the entire pipeline. Optional, defaults to 1.0 (meaning
up to 100% of records can be bad and the pipeline will still succeed).
"""
self._with_exception_handling = True
self._exc_class = exc_class
self._use_subprocess = use_subprocess
self._threshold = threshold
return self


class _MetricsCollector:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/ml/inference/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ def predict(self, example: int) -> int:


class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
def __init__(self, clock=None):
def __init__(self, clock=None, min_batch_size=1, max_batch_size=9999):
self._fake_clock = clock
self._min_batch_size = min_batch_size
self._max_batch_size = max_batch_size

def load_model(self):
if self._fake_clock:
Expand All @@ -69,6 +71,12 @@ def run_inference(
def update_model_path(self, model_path: Optional[str] = None):
pass

def batch_elements_kwargs(self):
return {
'min_batch_size': self._min_batch_size,
'max_batch_size': self._max_batch_size
}


class FakeModelHandlerReturnsPredictionResult(
base.ModelHandler[int, base.PredictionResult, FakeModel]):
Expand Down Expand Up @@ -171,6 +179,24 @@ def test_run_inference_impl_with_maybe_keyed_examples(self):
model_handler)
assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')

def test_run_inference_impl_dlq(self):
with TestPipeline() as pipeline:
examples = [1, 'TEST', 3, 10, 'TEST2']
expected_good = [2, 4, 11]
expected_bad = ['TEST', 'TEST2']
pcoll = pipeline | 'start' >> beam.Create(examples)
good, bad = pcoll | base.RunInference(
FakeModelHandler(
min_batch_size=1,
max_batch_size=1
)).with_exception_handling()
assert_that(good, equal_to(expected_good), label='assert:inferences')

# bad will be in form [batch[elements], error]. Just pull out bad element.
bad_without_error = bad | beam.Map(lambda x: x[0][0])
assert_that(
bad_without_error, equal_to(expected_bad), label='assert:failures')

def test_run_inference_impl_inference_args(self):
with TestPipeline() as pipeline:
examples = [1, 5, 3, 10]
Expand Down