-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementing Python Bounded Source Reader DoFn #13154
Changes from 2 commits
0ab8ea2
c82f2d1
7297e0a
db7aafd
d0a1509
c16e3b3
4934da1
ac09ae0
050cd4c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,7 @@ | |
from apache_beam.portability.api import beam_runner_api_pb2 | ||
from apache_beam.pvalue import AsIter | ||
from apache_beam.pvalue import AsSingleton | ||
from apache_beam.transforms import PTransform | ||
from apache_beam.transforms import core | ||
from apache_beam.transforms import ptransform | ||
from apache_beam.transforms import window | ||
|
@@ -1427,145 +1428,157 @@ def with_completed(self, completed): | |
fraction=self._fraction, remaining=self._remaining, completed=completed) | ||
|
||
|
||
class _SDFBoundedSourceRestriction(object): | ||
""" A restriction wraps SourceBundle and RangeTracker. """ | ||
def __init__(self, source_bundle, range_tracker=None): | ||
self._source_bundle = source_bundle | ||
self._range_tracker = range_tracker | ||
|
||
def __reduce__(self): | ||
# The instance of RangeTracker shouldn't be serialized. | ||
return (self.__class__, (self._source_bundle, )) | ||
|
||
def range_tracker(self): | ||
if not self._range_tracker: | ||
self._range_tracker = self._source_bundle.source.get_range_tracker( | ||
self._source_bundle.start_position, self._source_bundle.stop_position) | ||
return self._range_tracker | ||
|
||
def weight(self): | ||
return self._source_bundle.weight | ||
|
||
def source(self): | ||
return self._source_bundle.source | ||
|
||
def try_split(self, fraction_of_remainder): | ||
consumed_fraction = self.range_tracker().fraction_consumed() | ||
fraction = ( | ||
consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder) | ||
position = self.range_tracker().position_at_fraction(fraction) | ||
# Need to stash current stop_pos before splitting since | ||
# range_tracker.split will update its stop_pos if splits | ||
# successfully. | ||
stop_pos = self._source_bundle.stop_position | ||
split_result = self.range_tracker().try_split(position) | ||
if split_result: | ||
split_pos, split_fraction = split_result | ||
primary_weight = self._source_bundle.weight * split_fraction | ||
residual_weight = self._source_bundle.weight - primary_weight | ||
# Update self to primary weight and end position. | ||
self._source_bundle = SourceBundle( | ||
primary_weight, | ||
self._source_bundle.source, | ||
self._source_bundle.start_position, | ||
split_pos) | ||
return ( | ||
self, | ||
_SDFBoundedSourceRestriction( | ||
SourceBundle( | ||
residual_weight, | ||
self._source_bundle.source, | ||
split_pos, | ||
stop_pos))) | ||
|
||
|
||
class _SDFBoundedSourceRestrictionTracker(RestrictionTracker): | ||
"""An `iobase.RestrictionTracker` implementations for wrapping BoundedSource | ||
with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which | ||
wraps SourceBundle and RangeTracker. | ||
|
||
Delegated RangeTracker guarantees synchronization safety. | ||
""" | ||
def __init__(self, restriction): | ||
if not isinstance(restriction, _SDFBoundedSourceRestriction): | ||
raise ValueError( | ||
'Initializing SDFBoundedSourceRestrictionTracker' | ||
' requires a _SDFBoundedSourceRestriction') | ||
self.restriction = restriction | ||
|
||
def current_progress(self): | ||
# type: () -> RestrictionProgress | ||
return RestrictionProgress( | ||
fraction=self.restriction.range_tracker().fraction_consumed()) | ||
|
||
def current_restriction(self): | ||
self.restriction.range_tracker() | ||
return self.restriction | ||
|
||
def start_pos(self): | ||
return self.restriction.range_tracker().start_position() | ||
|
||
def stop_pos(self): | ||
return self.restriction.range_tracker().stop_position() | ||
|
||
def try_claim(self, position): | ||
return self.restriction.range_tracker().try_claim(position) | ||
|
||
def try_split(self, fraction_of_remainder): | ||
return self.restriction.try_split(fraction_of_remainder) | ||
|
||
def check_done(self): | ||
return self.restriction.range_tracker().fraction_consumed() >= 1.0 | ||
|
||
def is_bounded(self): | ||
return True | ||
|
||
|
||
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider): | ||
""" | ||
A `RestrictionProvider` that is used by SDF for `BoundedSource`. | ||
|
||
If source is provided, uses it for initializing restriction. Otherwise | ||
initializes restriction based on input element that is expected to be of | ||
BoundedSource type. | ||
""" | ||
def __init__(self, source: BoundedSource = None, desired_chunk_size=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should be able to remote source here? |
||
self._check_source(source) | ||
self._source = source | ||
self._desired_chunk_size = desired_chunk_size | ||
|
||
def _check_source(self, src): | ||
if src is not None and not isinstance(src, BoundedSource): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
raise RuntimeError( | ||
'SDFBoundedSourceRestrictionProvider can only utilize BoundedSource') | ||
|
||
def initial_restriction(self, element_source: BoundedSource): | ||
src = element_source if self._source is None else self._source | ||
self._check_source(src) | ||
range_tracker = src.get_range_tracker(None, None) | ||
return _SDFBoundedSourceRestriction( | ||
SourceBundle( | ||
None, | ||
src, | ||
range_tracker.start_position(), | ||
range_tracker.stop_position())) | ||
|
||
def create_tracker(self, restriction): | ||
return _SDFBoundedSourceRestrictionTracker(restriction) | ||
|
||
def split(self, element, restriction): | ||
if self._desired_chunk_size is None: | ||
try: | ||
estimated_size = restriction.source().estimate_size() | ||
except NotImplementedError: | ||
estimated_size = None | ||
self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size) | ||
|
||
# Invoke source.split to get initial splitting results. | ||
source_bundles = restriction.source().split(self._desired_chunk_size) | ||
for source_bundle in source_bundles: | ||
yield _SDFBoundedSourceRestriction(source_bundle) | ||
|
||
def restriction_size(self, element, restriction): | ||
return restriction.weight() | ||
|
||
def restriction_coder(self): | ||
return coders.DillCoder() | ||
|
||
|
||
class _SDFBoundedSourceWrapper(ptransform.PTransform): | ||
"""A ``PTransform`` that uses SDF to read from a ``BoundedSource``. | ||
|
||
NOTE: This transform can only be used with beam_fn_api enabled. | ||
""" | ||
class _SDFBoundedSourceRestriction(object): | ||
""" A restriction wraps SourceBundle and RangeTracker. """ | ||
def __init__(self, source_bundle, range_tracker=None): | ||
self._source_bundle = source_bundle | ||
self._range_tracker = range_tracker | ||
|
||
def __reduce__(self): | ||
# The instance of RangeTracker shouldn't be serialized. | ||
return (self.__class__, (self._source_bundle, )) | ||
|
||
def range_tracker(self): | ||
if not self._range_tracker: | ||
self._range_tracker = self._source_bundle.source.get_range_tracker( | ||
self._source_bundle.start_position, | ||
self._source_bundle.stop_position) | ||
return self._range_tracker | ||
|
||
def weight(self): | ||
return self._source_bundle.weight | ||
|
||
def source(self): | ||
return self._source_bundle.source | ||
|
||
def try_split(self, fraction_of_remainder): | ||
consumed_fraction = self.range_tracker().fraction_consumed() | ||
fraction = ( | ||
consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder) | ||
position = self.range_tracker().position_at_fraction(fraction) | ||
# Need to stash current stop_pos before splitting since | ||
# range_tracker.split will update its stop_pos if splits | ||
# successfully. | ||
stop_pos = self._source_bundle.stop_position | ||
split_result = self.range_tracker().try_split(position) | ||
if split_result: | ||
split_pos, split_fraction = split_result | ||
primary_weight = self._source_bundle.weight * split_fraction | ||
residual_weight = self._source_bundle.weight - primary_weight | ||
# Update self to primary weight and end position. | ||
self._source_bundle = SourceBundle( | ||
primary_weight, | ||
self._source_bundle.source, | ||
self._source_bundle.start_position, | ||
split_pos) | ||
return ( | ||
self, | ||
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction( | ||
SourceBundle( | ||
residual_weight, | ||
self._source_bundle.source, | ||
split_pos, | ||
stop_pos))) | ||
|
||
class _SDFBoundedSourceRestrictionTracker(RestrictionTracker): | ||
"""An `iobase.RestrictionTracker` implementations for wrapping BoundedSource | ||
with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which | ||
wraps SourceBundle and RangeTracker. | ||
|
||
Delegated RangeTracker guarantees synchronization safety. | ||
""" | ||
def __init__(self, restriction): | ||
if not isinstance(restriction, | ||
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction): | ||
raise ValueError( | ||
'Initializing SDFBoundedSourceRestrictionTracker' | ||
' requires a _SDFBoundedSourceRestriction') | ||
self.restriction = restriction | ||
|
||
def current_progress(self): | ||
# type: () -> RestrictionProgress | ||
return RestrictionProgress( | ||
fraction=self.restriction.range_tracker().fraction_consumed()) | ||
|
||
def current_restriction(self): | ||
self.restriction.range_tracker() | ||
return self.restriction | ||
|
||
def start_pos(self): | ||
return self.restriction.range_tracker().start_position() | ||
|
||
def stop_pos(self): | ||
return self.restriction.range_tracker().stop_position() | ||
|
||
def try_claim(self, position): | ||
return self.restriction.range_tracker().try_claim(position) | ||
|
||
def try_split(self, fraction_of_remainder): | ||
return self.restriction.try_split(fraction_of_remainder) | ||
|
||
def check_done(self): | ||
return self.restriction.range_tracker().fraction_consumed() >= 1.0 | ||
|
||
def is_bounded(self): | ||
return True | ||
|
||
class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider): | ||
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`.""" | ||
def __init__(self, source, desired_chunk_size=None): | ||
self._source = source | ||
self._desired_chunk_size = desired_chunk_size | ||
|
||
def initial_restriction(self, element): | ||
# Get initial range_tracker from source | ||
range_tracker = self._source.get_range_tracker(None, None) | ||
return _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction( | ||
SourceBundle( | ||
None, | ||
self._source, | ||
range_tracker.start_position(), | ||
range_tracker.stop_position())) | ||
|
||
def create_tracker(self, restriction): | ||
return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker( | ||
restriction) | ||
|
||
def split(self, element, restriction): | ||
if self._desired_chunk_size is None: | ||
try: | ||
estimated_size = self._source.estimate_size() | ||
except NotImplementedError: | ||
estimated_size = None | ||
self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size) | ||
|
||
# Invoke source.split to get initial splitting results. | ||
source_bundles = self._source.split(self._desired_chunk_size) | ||
for source_bundle in source_bundles: | ||
yield _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction( | ||
source_bundle) | ||
|
||
def restriction_size(self, element, restriction): | ||
return restriction.weight() | ||
|
||
def restriction_coder(self): | ||
return coders.DillCoder() | ||
|
||
def __init__(self, source): | ||
if not isinstance(source, BoundedSource): | ||
raise RuntimeError('SDFBoundedSourceWrapper can only wrap BoundedSource') | ||
|
@@ -1590,12 +1603,9 @@ def process( | |
self, | ||
element, | ||
restriction_tracker=core.DoFn.RestrictionParam( | ||
_SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider( | ||
source))): | ||
_SDFBoundedSourceRestrictionProvider(source=source))): | ||
current_restriction = restriction_tracker.current_restriction() | ||
assert isinstance( | ||
current_restriction, | ||
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction) | ||
assert isinstance(current_restriction, _SDFBoundedSourceRestriction) | ||
return current_restriction.source().read( | ||
current_restriction.range_tracker()) | ||
|
||
|
@@ -1618,3 +1628,48 @@ def display_data(self): | |
'source': DisplayDataItem(self.source.__class__, label='Read Source'), | ||
'source_dd': self.source | ||
} | ||
|
||
|
||
class SDFBoundedSourceReader(PTransform): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the major difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've done this - but I've still allowed the source to come in via the constructor as well as as an input. The intention of doing this is to keep the display data for simple Read transforms where the source is known at construction time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I thought we still keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm thinking whether it would be better for |
||
"""A ``PTransform`` that uses SDF to read from each ``BoundedSource`` in a | ||
PCollection. | ||
|
||
NOTE: This transform can only be used with beam_fn_api enabled. | ||
""" | ||
def __init__(self): | ||
super(SDFBoundedSourceReader, self).__init__() | ||
|
||
def _create_sdf_bounded_source_dofn(self): | ||
class SDFBoundedSourceDoFn(core.DoFn): | ||
def __init__(self): | ||
pass | ||
|
||
def process( | ||
self, | ||
unused_element, | ||
restriction_tracker=core.DoFn.RestrictionParam( | ||
_SDFBoundedSourceRestrictionProvider())): | ||
current_restriction = restriction_tracker.current_restriction() | ||
assert isinstance(current_restriction, _SDFBoundedSourceRestriction) | ||
|
||
result = current_restriction.source().read( | ||
current_restriction.range_tracker()) | ||
return result | ||
|
||
return SDFBoundedSourceDoFn() | ||
|
||
def expand(self, pvalue): | ||
return pvalue | core.ParDo(self._create_sdf_bounded_source_dofn()) | ||
|
||
def get_windowing(self, unused_inputs): | ||
return core.Windowing(window.GlobalWindows()) | ||
|
||
def _infer_output_coder(self, input_type=None, input_coder=None): | ||
return self.source.default_output_coder() | ||
|
||
def display_data(self): | ||
return { | ||
'source': DisplayDataItem( | ||
self.source.__class__, label='Read Bounded Sources'), | ||
'source_dd': self.source | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we also need to update pydoc here as well.