Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Python Bounded Source Reader DoFn #13154

Merged
merged 9 commits into from
Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
333 changes: 194 additions & 139 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsIter
from apache_beam.pvalue import AsSingleton
from apache_beam.transforms import PTransform
from apache_beam.transforms import core
from apache_beam.transforms import ptransform
from apache_beam.transforms import window
Expand Down Expand Up @@ -1427,145 +1428,157 @@ def with_completed(self, completed):
fraction=self._fraction, remaining=self._remaining, completed=completed)


class _SDFBoundedSourceRestriction(object):
""" A restriction wraps SourceBundle and RangeTracker. """
def __init__(self, source_bundle, range_tracker=None):
self._source_bundle = source_bundle
self._range_tracker = range_tracker

def __reduce__(self):
# The instance of RangeTracker shouldn't be serialized.
return (self.__class__, (self._source_bundle, ))

def range_tracker(self):
if not self._range_tracker:
self._range_tracker = self._source_bundle.source.get_range_tracker(
self._source_bundle.start_position, self._source_bundle.stop_position)
return self._range_tracker

def weight(self):
return self._source_bundle.weight

def source(self):
return self._source_bundle.source

def try_split(self, fraction_of_remainder):
consumed_fraction = self.range_tracker().fraction_consumed()
fraction = (
consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
position = self.range_tracker().position_at_fraction(fraction)
# Need to stash current stop_pos before splitting since
# range_tracker.split will update its stop_pos if splits
# successfully.
stop_pos = self._source_bundle.stop_position
split_result = self.range_tracker().try_split(position)
if split_result:
split_pos, split_fraction = split_result
primary_weight = self._source_bundle.weight * split_fraction
residual_weight = self._source_bundle.weight - primary_weight
# Update self to primary weight and end position.
self._source_bundle = SourceBundle(
primary_weight,
self._source_bundle.source,
self._source_bundle.start_position,
split_pos)
return (
self,
_SDFBoundedSourceRestriction(
SourceBundle(
residual_weight,
self._source_bundle.source,
split_pos,
stop_pos)))


class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
"""An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
wraps SourceBundle and RangeTracker.

Delegated RangeTracker guarantees synchronization safety.
"""
def __init__(self, restriction):
if not isinstance(restriction, _SDFBoundedSourceRestriction):
raise ValueError(
'Initializing SDFBoundedSourceRestrictionTracker'
' requires a _SDFBoundedSourceRestriction')
self.restriction = restriction

def current_progress(self):
# type: () -> RestrictionProgress
return RestrictionProgress(
fraction=self.restriction.range_tracker().fraction_consumed())

def current_restriction(self):
self.restriction.range_tracker()
return self.restriction

def start_pos(self):
return self.restriction.range_tracker().start_position()

def stop_pos(self):
return self.restriction.range_tracker().stop_position()

def try_claim(self, position):
return self.restriction.range_tracker().try_claim(position)

def try_split(self, fraction_of_remainder):
return self.restriction.try_split(fraction_of_remainder)

def check_done(self):
return self.restriction.range_tracker().fraction_consumed() >= 1.0

def is_bounded(self):
return True


class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""
A `RestrictionProvider` that is used by SDF for `BoundedSource`.

If source is provided, uses it for initializing restriction. Otherwise
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we also need to update pydoc here as well.

initializes restriction based on input element that is expected to be of
BoundedSource type.
"""
def __init__(self, source: BoundedSource = None, desired_chunk_size=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to remote source here?

self._check_source(source)
self._source = source
self._desired_chunk_size = desired_chunk_size

def _check_source(self, src):
if src is not None and not isinstance(src, BoundedSource):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The src cannot be None, right?

raise RuntimeError(
'SDFBoundedSourceRestrictionProvider can only utilize BoundedSource')

def initial_restriction(self, element_source: BoundedSource):
src = element_source if self._source is None else self._source
self._check_source(src)
range_tracker = src.get_range_tracker(None, None)
return _SDFBoundedSourceRestriction(
SourceBundle(
None,
src,
range_tracker.start_position(),
range_tracker.stop_position()))

def create_tracker(self, restriction):
return _SDFBoundedSourceRestrictionTracker(restriction)

def split(self, element, restriction):
if self._desired_chunk_size is None:
try:
estimated_size = restriction.source().estimate_size()
except NotImplementedError:
estimated_size = None
self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size)

# Invoke source.split to get initial splitting results.
source_bundles = restriction.source().split(self._desired_chunk_size)
for source_bundle in source_bundles:
yield _SDFBoundedSourceRestriction(source_bundle)

def restriction_size(self, element, restriction):
return restriction.weight()

def restriction_coder(self):
return coders.DillCoder()


class _SDFBoundedSourceWrapper(ptransform.PTransform):
"""A ``PTransform`` that uses SDF to read from a ``BoundedSource``.

NOTE: This transform can only be used with beam_fn_api enabled.
"""
class _SDFBoundedSourceRestriction(object):
""" A restriction wraps SourceBundle and RangeTracker. """
def __init__(self, source_bundle, range_tracker=None):
self._source_bundle = source_bundle
self._range_tracker = range_tracker

def __reduce__(self):
# The instance of RangeTracker shouldn't be serialized.
return (self.__class__, (self._source_bundle, ))

def range_tracker(self):
if not self._range_tracker:
self._range_tracker = self._source_bundle.source.get_range_tracker(
self._source_bundle.start_position,
self._source_bundle.stop_position)
return self._range_tracker

def weight(self):
return self._source_bundle.weight

def source(self):
return self._source_bundle.source

def try_split(self, fraction_of_remainder):
consumed_fraction = self.range_tracker().fraction_consumed()
fraction = (
consumed_fraction + (1 - consumed_fraction) * fraction_of_remainder)
position = self.range_tracker().position_at_fraction(fraction)
# Need to stash current stop_pos before splitting since
# range_tracker.split will update its stop_pos if splits
# successfully.
stop_pos = self._source_bundle.stop_position
split_result = self.range_tracker().try_split(position)
if split_result:
split_pos, split_fraction = split_result
primary_weight = self._source_bundle.weight * split_fraction
residual_weight = self._source_bundle.weight - primary_weight
# Update self to primary weight and end position.
self._source_bundle = SourceBundle(
primary_weight,
self._source_bundle.source,
self._source_bundle.start_position,
split_pos)
return (
self,
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
SourceBundle(
residual_weight,
self._source_bundle.source,
split_pos,
stop_pos)))

class _SDFBoundedSourceRestrictionTracker(RestrictionTracker):
"""An `iobase.RestrictionTracker` implementations for wrapping BoundedSource
with SDF. The tracked restriction is a _SDFBoundedSourceRestriction, which
wraps SourceBundle and RangeTracker.

Delegated RangeTracker guarantees synchronization safety.
"""
def __init__(self, restriction):
if not isinstance(restriction,
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction):
raise ValueError(
'Initializing SDFBoundedSourceRestrictionTracker'
' requires a _SDFBoundedSourceRestriction')
self.restriction = restriction

def current_progress(self):
# type: () -> RestrictionProgress
return RestrictionProgress(
fraction=self.restriction.range_tracker().fraction_consumed())

def current_restriction(self):
self.restriction.range_tracker()
return self.restriction

def start_pos(self):
return self.restriction.range_tracker().start_position()

def stop_pos(self):
return self.restriction.range_tracker().stop_position()

def try_claim(self, position):
return self.restriction.range_tracker().try_claim(position)

def try_split(self, fraction_of_remainder):
return self.restriction.try_split(fraction_of_remainder)

def check_done(self):
return self.restriction.range_tracker().fraction_consumed() >= 1.0

def is_bounded(self):
return True

class _SDFBoundedSourceRestrictionProvider(core.RestrictionProvider):
"""A `RestrictionProvider` that is used by SDF for `BoundedSource`."""
def __init__(self, source, desired_chunk_size=None):
self._source = source
self._desired_chunk_size = desired_chunk_size

def initial_restriction(self, element):
# Get initial range_tracker from source
range_tracker = self._source.get_range_tracker(None, None)
return _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
SourceBundle(
None,
self._source,
range_tracker.start_position(),
range_tracker.stop_position()))

def create_tracker(self, restriction):
return _SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionTracker(
restriction)

def split(self, element, restriction):
if self._desired_chunk_size is None:
try:
estimated_size = self._source.estimate_size()
except NotImplementedError:
estimated_size = None
self._desired_chunk_size = Read.get_desired_chunk_size(estimated_size)

# Invoke source.split to get initial splitting results.
source_bundles = self._source.split(self._desired_chunk_size)
for source_bundle in source_bundles:
yield _SDFBoundedSourceWrapper._SDFBoundedSourceRestriction(
source_bundle)

def restriction_size(self, element, restriction):
return restriction.weight()

def restriction_coder(self):
return coders.DillCoder()

def __init__(self, source):
if not isinstance(source, BoundedSource):
raise RuntimeError('SDFBoundedSourceWrapper can only wrap BoundedSource')
Expand All @@ -1590,12 +1603,9 @@ def process(
self,
element,
restriction_tracker=core.DoFn.RestrictionParam(
_SDFBoundedSourceWrapper._SDFBoundedSourceRestrictionProvider(
source))):
_SDFBoundedSourceRestrictionProvider(source=source))):
current_restriction = restriction_tracker.current_restriction()
assert isinstance(
current_restriction,
_SDFBoundedSourceWrapper._SDFBoundedSourceRestriction)
assert isinstance(current_restriction, _SDFBoundedSourceRestriction)
return current_restriction.source().read(
current_restriction.range_tracker())

Expand All @@ -1618,3 +1628,48 @@ def display_data(self):
'source': DisplayDataItem(self.source.__class__, label='Read Source'),
'source_dd': self.source
}


class SDFBoundedSourceReader(PTransform):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the major difference between SDFBoundedSourceWrapper and SDFBoundedSourceReader is that SDFBoundedSourceWrapper takes the source as construction param where SDFBoundedSourceReader takes the source as input element. We could change the implementation of SDFBoundedSourceWrapper as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've done this - but I've still allowed the source to come in via the constructor as well as as an input. The intention of doing this is to keep the display data for simple Read transforms where the source is known at construction time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I thought we still keep _SDFBoundedSourceWrapper . Thanks for the clarification!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking whether it would be better for SDFBoundedSourceReader to take data_to_display as constructor instead of source directly if any. What do you think?

"""A ``PTransform`` that uses SDF to read from each ``BoundedSource`` in a
PCollection.

NOTE: This transform can only be used with beam_fn_api enabled.
"""
def __init__(self):
super(SDFBoundedSourceReader, self).__init__()

def _create_sdf_bounded_source_dofn(self):
class SDFBoundedSourceDoFn(core.DoFn):
def __init__(self):
pass

def process(
self,
unused_element,
restriction_tracker=core.DoFn.RestrictionParam(
_SDFBoundedSourceRestrictionProvider())):
current_restriction = restriction_tracker.current_restriction()
assert isinstance(current_restriction, _SDFBoundedSourceRestriction)

result = current_restriction.source().read(
current_restriction.range_tracker())
return result

return SDFBoundedSourceDoFn()

def expand(self, pvalue):
return pvalue | core.ParDo(self._create_sdf_bounded_source_dofn())

def get_windowing(self, unused_inputs):
return core.Windowing(window.GlobalWindows())

def _infer_output_coder(self, input_type=None, input_coder=None):
return self.source.default_output_coder()

def display_data(self):
return {
'source': DisplayDataItem(
self.source.__class__, label='Read Bounded Sources'),
'source_dd': self.source
}
Loading