Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow one to bound the size of output shards when writing to files. #22130

Merged
merged 2 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/io/filebasedsink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
shard_name_template=None,
mime_type='application/octet-stream',
compression_type=CompressionTypes.AUTO,
*,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have code style guide about the usage of asterisk in function parameters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have any guidance here (other than that which is generic to Python, which would indicate most of these arguments should be passed by keyword).

max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""
Raises:
Expand Down Expand Up @@ -108,6 +111,8 @@ def __init__(
shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
self.max_records_per_shard = max_records_per_shard
self.max_bytes_per_shard = max_bytes_per_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the implementation of write below, only one of them will take effect. Do we need to raise a warning (or info) to remind possible misuse when neither is None? Also need to document this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch. Fixed so that both take effect.

self.skip_if_empty = skip_if_empty

def display_data(self):
Expand All @@ -130,7 +135,13 @@ def open(self, temp_path):
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
return FileSystems.create(temp_path, self.mime_type, self.compression_type)
writer = FileSystems.create(
temp_path, self.mime_type, self.compression_type)
if self.max_bytes_per_shard:
self.byte_counter = _ByteCountingWriter(writer)
return self.byte_counter
else:
return writer

def write_record(self, file_handle, value):
"""Writes a single record go the file handle returned by ``open()``.
Expand Down Expand Up @@ -406,10 +417,36 @@ def __init__(self, sink, temp_shard_path):
self.sink = sink
self.temp_shard_path = temp_shard_path
self.temp_handle = self.sink.open(temp_shard_path)
self.num_records_written = 0

def write(self, value):
self.num_records_written += 1
self.sink.write_record(self.temp_handle, value)

def at_capacity(self):
return (
self.sink.max_records_per_shard and
self.num_records_written >= self.sink.max_records_per_shard
) or (
self.sink.max_bytes_per_shard and
self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)

def close(self):
self.sink.close(self.temp_handle)
return self.temp_shard_path


class _ByteCountingWriter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can bytes_written be handled also in write function as num_records_written thus no need for the wrapped class? FileBasedSink.open used to return an instance of BufferedWriter always but if use this wrapped class it now it may not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, if the writer is compressed, record sends to FileBasedWriter may have different length to the record actually written and that's why a wrapped class is needed here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found that io.BufferedWriter.write returns the number of bytes written (https://docs.python.org/3.7/library/io.html#io.BufferedWriter.write) so bytes_written can be traced directly in FileBasedSinkWriter.write

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately io.BufferedWriter.write returns the number of bytes written for that call, not a running total.

def __init__(self, writer):
self.writer = writer
self.bytes_written = 0

def write(self, bs):
self.bytes_written += len(bs)
self.writer.write(bs)

def flush(self):
self.writer.flush()

def close(self):
self.writer.close()
12 changes: 11 additions & 1 deletion sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ class Writer(object):
writing to a sink.
"""
def write(self, value):
"""Writes a value to the sink using the current writer."""
"""Writes a value to the sink using the current writer.
"""
raise NotImplementedError

def close(self):
Expand All @@ -863,6 +864,12 @@ def close(self):
"""
raise NotImplementedError

def at_capacity(self) -> bool:
"""Returns whether this writer should be considered at capacity
and a new one should be created.
"""
return False


class Read(ptransform.PTransform):
"""A transform that reads a PCollection."""
Expand Down Expand Up @@ -1185,6 +1192,9 @@ def process(self, element, init_result):
# We ignore UUID collisions here since they are extremely rare.
self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
self.writer.write(element)
if self.writer.at_capacity():
yield self.writer.close()
self.writer = None

def finish_bundle(self):
if self.writer is not None:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/io/textio.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ def __init__(self,
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""Initialize a _TextSink.

Expand Down Expand Up @@ -469,6 +472,14 @@ def __init__(self,
append_trailing_newlines is set, '\n' will be added.
footer: String to write at the end of file as a footer. If not None and
append_trailing_newlines is set, '\n' will be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.

Returns:
Expand All @@ -482,6 +493,8 @@ def __init__(self,
coder=coder,
mime_type='text/plain',
compression_type=compression_type,
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)
self._append_trailing_newlines = append_trailing_newlines
self._header = header
Expand Down Expand Up @@ -791,6 +804,9 @@ def __init__(
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
r"""Initialize a :class:`WriteToText` transform.

Expand Down Expand Up @@ -830,6 +846,14 @@ def __init__(
footer (str): String to write at the end of file as a footer.
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
"""

Expand All @@ -843,7 +867,9 @@ def __init__(
compression_type,
header,
footer,
skip_if_empty)
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)

def expand(self, pcoll):
return pcoll | Write(self._sink)
42 changes: 42 additions & 0 deletions sdks/python/apache_beam/io/textio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,48 @@ def test_write_empty_skipped(self):
outputs = list(glob.glob(self.path + '*'))
self.assertEqual(outputs, [])

def test_write_max_records_per_shard(self):
records_per_shard = 13
lines = [str(i).encode('utf-8') for i in range(100)]
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path, max_records_per_shard=records_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
shard_lines = list(f.read().splitlines())
self.assertLessEqual(len(shard_lines), records_per_shard)
read_result.extend(shard_lines)
self.assertEqual(sorted(read_result), sorted(lines))

def test_write_max_bytes_per_shard(self):
bytes_per_shard = 300
max_len = 100
lines = [b'x' * i for i in range(max_len)]
header = b'a' * 20
footer = b'b' * 30
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path,
header=header,
footer=footer,
max_bytes_per_shard=bytes_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
contents = f.read()
self.assertLessEqual(
len(contents), bytes_per_shard + max_len + len(footer) + 2)
shard_lines = list(contents.splitlines())
self.assertEqual(shard_lines[0], header)
self.assertEqual(shard_lines[-1], footer)
read_result.extend(shard_lines[1:-1])
self.assertEqual(sorted(read_result), sorted(lines))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down