Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow one to bound the size of output shards when writing to files. #22130

Merged
merged 2 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion sdks/python/apache_beam/io/filebasedsink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
shard_name_template=None,
mime_type='application/octet-stream',
compression_type=CompressionTypes.AUTO,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have code style guide about the usage of asterisk in function parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have any guidance here (other than that which is generic to Python, which would indicate most of these arguments should be passed by keyword).

max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""
Raises:
Expand Down Expand Up @@ -108,6 +111,8 @@ def __init__(
shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
self.max_records_per_shard = max_records_per_shard
self.max_bytes_per_shard = max_bytes_per_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the implementation of write below, only one of them will take effect. Do we need to raise a warning (or info) to remind possible misuse when neither is None? Also need to document this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Fixed so that both take effect.

self.skip_if_empty = skip_if_empty

def display_data(self):
Expand All @@ -130,7 +135,13 @@ def open(self, temp_path):
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
return FileSystems.create(temp_path, self.mime_type, self.compression_type)
writer = FileSystems.create(
temp_path, self.mime_type, self.compression_type)
if self.max_bytes_per_shard:
self.byte_counter = _ByteCountingWriter(writer)
return self.byte_counter
else:
return writer

def write_record(self, file_handle, value):
"""Writes a single record go the file handle returned by ``open()``.
Expand Down Expand Up @@ -406,10 +417,33 @@ def __init__(self, sink, temp_shard_path):
self.sink = sink
self.temp_shard_path = temp_shard_path
self.temp_handle = self.sink.open(temp_shard_path)
self.num_records_written = 0

def write(self, value):
self.sink.write_record(self.temp_handle, value)
if self.sink.max_records_per_shard:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If write still does not return, could create another method like "at_capacity" returns true if the writer has reached capacity. Also in this way max_bytes_per_shard and max_records_per_shard can have effect at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

self.num_records_written += 1
return self.num_records_written >= self.sink.max_records_per_shard
if self.sink.max_bytes_per_shard:
return (
self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)

def close(self):
self.sink.close(self.temp_handle)
return self.temp_shard_path


class _ByteCountingWriter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can bytes_written be handled also in write function as num_records_written thus no need for the wrapped class? FileBasedSink.open used to return an instance of BufferedWriter always but if use this wrapped class it now it may not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, if the writer is compressed, record sends to FileBasedWriter may have different length to the record actually written and that's why a wrapped class is needed here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found that io.BufferedWriter.write returns the number of bytes written (https://docs.python.org/3.7/library/io.html#io.BufferedWriter.write) so bytes_written can be traced directly in FileBasedSinkWriter.write

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately io.BufferedWriter.write returns the number of bytes written for that call, not a running total.

def __init__(self, writer):
self.writer = writer
self.bytes_written = 0

def write(self, bs):
self.bytes_written += len(bs)
self.writer.write(bs)

def flush(self):
self.writer.flush()

def close(self):
self.writer.close()
12 changes: 9 additions & 3 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,12 @@ class Writer(object):
See ``iobase.Sink`` for more detailed documentation about the process of
writing to a sink.
"""
def write(self, value):
"""Writes a value to the sink using the current writer."""
def write(self, value) -> Optional[bool]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better not change the signature of base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's backwards compatible, which is why I made it Optional. But I've moved to using at_capacity as suggested instead.

"""Writes a value to the sink using the current writer.

Returns True if this writer should be considered at capacity and a new one
should be created.
"""
raise NotImplementedError

def close(self):
Expand Down Expand Up @@ -1184,7 +1188,9 @@ def process(self, element, init_result):
if self.writer is None:
# We ignore UUID collisions here since they are extremely rare.
self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
self.writer.write(element)
if self.writer.write(element):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always call self.writer.write and then test if self.writer.at_capacity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

yield self.writer.close()
self.writer = None

def finish_bundle(self):
if self.writer is not None:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/io/textio.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ def __init__(self,
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""Initialize a _TextSink.
Expand Down Expand Up @@ -469,6 +472,14 @@ def __init__(self,
append_trailing_newlines is set, '\n' will be added.
footer: String to write at the end of file as a footer. If not None and
append_trailing_newlines is set, '\n' will be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
Returns:
Expand All @@ -482,6 +493,8 @@ def __init__(self,
coder=coder,
mime_type='text/plain',
compression_type=compression_type,
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)
self._append_trailing_newlines = append_trailing_newlines
self._header = header
Expand Down Expand Up @@ -791,6 +804,9 @@ def __init__(
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
r"""Initialize a :class:`WriteToText` transform.
Expand Down Expand Up @@ -830,6 +846,14 @@ def __init__(
footer (str): String to write at the end of file as a footer.
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
"""

Expand All @@ -843,7 +867,9 @@ def __init__(
compression_type,
header,
footer,
skip_if_empty)
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)

def expand(self, pcoll):
return pcoll | Write(self._sink)
42 changes: 42 additions & 0 deletions sdks/python/apache_beam/io/textio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,48 @@ def test_write_empty_skipped(self):
outputs = list(glob.glob(self.path + '*'))
self.assertEqual(outputs, [])

def test_write_max_records_per_shard(self):
records_per_shard = 13
lines = [str(i).encode('utf-8') for i in range(100)]
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path, max_records_per_shard=records_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
shard_lines = list(f.read().splitlines())
self.assertLessEqual(len(shard_lines), records_per_shard)
read_result.extend(shard_lines)
self.assertEqual(sorted(read_result), sorted(lines))

def test_write_max_bytes_per_shard(self):
bytes_per_shard = 300
max_len = 100
lines = [b'x' * i for i in range(max_len)]
header = b'a' * 20
footer = b'b' * 30
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path,
header=header,
footer=footer,
max_bytes_per_shard=bytes_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
contents = f.read()
self.assertLessEqual(
len(contents), bytes_per_shard + max_len + len(footer) + 2)
shard_lines = list(contents.splitlines())
self.assertEqual(shard_lines[0], header)
self.assertEqual(shard_lines[-1], footer)
read_result.extend(shard_lines[1:-1])
self.assertEqual(sorted(read_result), sorted(lines))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down