Skip to content

Commit

Permalink
Fix counter metrics for ParDo#with_exception_handling(timeout). (#32571)
Browse files Browse the repository at this point in the history
Co-authored-by: Claude <cvandermerwe@google.com>
  • Loading branch information
claudevdm and Claude authored Oct 3, 2024
1 parent 001ac59 commit 0a71499
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
18 changes: 15 additions & 3 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2611,11 +2611,23 @@ def __getattribute__(self, name):
def process(self, *args, **kwargs):
if self._pool is None:
self._pool = concurrent.futures.ThreadPoolExecutor(10)

# Import here to avoid circular dependency
from apache_beam.runners.worker.statesampler import get_current_tracker, set_current_tracker

# State sampler/tracker is stored as a thread local variable, and is used
# when incrementing counter metrics.
dispatching_thread_state_sampler = get_current_tracker()

def wrapped_process():
"""Makes the dispatching thread local state sampler available to child
thread"""
set_current_tracker(dispatching_thread_state_sampler)
return list(self._fn.process(*args, **kwargs))

# Ensure we iterate over the entire output list in the given amount of time.
try:
return self._pool.submit(
lambda: list(self._fn.process(*args, **kwargs))).result(
self._timeout)
return self._pool.submit(wrapped_process).result(self._timeout)
except TimeoutError:
self._pool.shutdown(wait=False)
self._pool = None
Expand Down
26 changes: 26 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,6 +2780,32 @@ def test_timeout(self):
('slow', 'TimeoutError()')]),
label='CheckBad')

def test_increment_counter(self):
# Counters are not currently supported for
# ParDo#with_exception_handling(use_subprocess=True).
if (self.use_subprocess):
return

class CounterDoFn(beam.DoFn):
def __init__(self):
self.records_counter = Metrics.counter(self.__class__, 'recordsCounter')

def process(self, element):
self.records_counter.inc()

with TestPipeline() as p:
_, _ = (
(p | beam.Create([1,2,3])) | beam.ParDo(CounterDoFn())
.with_exception_handling(
use_subprocess=self.use_subprocess, timeout=1))
results = p.result
metric_results = results.metrics().query(
MetricsFilter().with_name("recordsCounter"))
records_counter = metric_results['counters'][0]

self.assertEqual(records_counter.key.metric.name, 'recordsCounter')
self.assertEqual(records_counter.result, 3)

def test_lifecycle(self):
die = type(self).die

Expand Down

0 comments on commit 0a71499

Please sign in to comment.