Skip to content

Commit

Permalink
Allow one to bound the size of output shards when writing to files.
Browse files Browse the repository at this point in the history
This fixes #22129.
  • Loading branch information
robertwb committed Jul 1, 2022
1 parent 52e1b3f commit 4f909c9
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 5 deletions.
36 changes: 35 additions & 1 deletion sdks/python/apache_beam/io/filebasedsink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
shard_name_template=None,
mime_type='application/octet-stream',
compression_type=CompressionTypes.AUTO,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""
Raises:
Expand Down Expand Up @@ -108,6 +111,8 @@ def __init__(
shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
self.max_records_per_shard = max_records_per_shard
self.max_bytes_per_shard = max_bytes_per_shard
self.skip_if_empty = skip_if_empty

def display_data(self):
Expand All @@ -130,7 +135,13 @@ def open(self, temp_path):
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
return FileSystems.create(temp_path, self.mime_type, self.compression_type)
writer = FileSystems.create(
temp_path, self.mime_type, self.compression_type)
if self.max_bytes_per_shard:
self.byte_counter = _ByteCountingWriter(writer)
return self.byte_counter
else:
return writer

def write_record(self, file_handle, value):
"""Writes a single record go the file handle returned by ``open()``.
Expand Down Expand Up @@ -406,10 +417,33 @@ def __init__(self, sink, temp_shard_path):
self.sink = sink
self.temp_shard_path = temp_shard_path
self.temp_handle = self.sink.open(temp_shard_path)
self.num_records_written = 0

def write(self, value):
self.sink.write_record(self.temp_handle, value)
if self.sink.max_records_per_shard:
self.num_records_written += 1
return self.num_records_written >= self.sink.max_records_per_shard
if self.sink.max_bytes_per_shard:
return (
self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)

def close(self):
self.sink.close(self.temp_handle)
return self.temp_shard_path


class _ByteCountingWriter:
def __init__(self, writer):
self.writer = writer
self.bytes_written = 0

def write(self, bs):
self.bytes_written += len(bs)
self.writer.write(bs)

def flush(self):
self.writer.flush()

def close(self):
self.writer.close()
12 changes: 9 additions & 3 deletions sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,8 +848,12 @@ class Writer(object):
See ``iobase.Sink`` for more detailed documentation about the process of
writing to a sink.
"""
def write(self, value):
"""Writes a value to the sink using the current writer."""
def write(self, value) -> Optional[bool]:
"""Writes a value to the sink using the current writer.
Returns True if this writer should be considered at capacity and a new one
should be created.
"""
raise NotImplementedError

def close(self):
Expand Down Expand Up @@ -1184,7 +1188,9 @@ def process(self, element, init_result):
if self.writer is None:
# We ignore UUID collisions here since they are extremely rare.
self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
self.writer.write(element)
if self.writer.write(element):
yield self.writer.close()
self.writer = None

def finish_bundle(self):
if self.writer is not None:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/io/textio.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ def __init__(self,
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""Initialize a _TextSink.
Expand Down Expand Up @@ -469,6 +472,14 @@ def __init__(self,
append_trailing_newlines is set, '\n' will be added.
footer: String to write at the end of file as a footer. If not None and
append_trailing_newlines is set, '\n' will be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
Returns:
Expand All @@ -482,6 +493,8 @@ def __init__(self,
coder=coder,
mime_type='text/plain',
compression_type=compression_type,
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)
self._append_trailing_newlines = append_trailing_newlines
self._header = header
Expand Down Expand Up @@ -791,6 +804,9 @@ def __init__(
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
r"""Initialize a :class:`WriteToText` transform.
Expand Down Expand Up @@ -830,6 +846,14 @@ def __init__(
footer (str): String to write at the end of file as a footer.
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
"""

Expand All @@ -843,7 +867,9 @@ def __init__(
compression_type,
header,
footer,
skip_if_empty)
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)

def expand(self, pcoll):
return pcoll | Write(self._sink)
36 changes: 36 additions & 0 deletions sdks/python/apache_beam/io/textio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,42 @@ def test_write_empty_skipped(self):
outputs = list(glob.glob(self.path + '*'))
self.assertEqual(outputs, [])

def test_write_max_records_per_shard(self):
records_per_shard = 13
lines = [str(i).encode('utf-8') for i in range(100)]
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(self.path, max_records_per_shard=records_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
shard_lines = list(f.read().splitlines())
self.assertLessEqual(len(shard_lines), records_per_shard)
read_result.extend(shard_lines)
self.assertEqual(sorted(read_result), sorted(lines))

def test_write_max_bytes_per_shard(self):
bytes_per_shard = 300
max_len = 100
lines = [b'x' * i for i in range(max_len)]
header = b'a' * 20
footer = b'b' * 30
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(self.path, header=header, footer=footer, max_bytes_per_shard=bytes_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
contents = f.read()
self.assertLessEqual(len(contents), bytes_per_shard + max_len + len(footer) + 2)
shard_lines = list(contents.splitlines())
self.assertEqual(shard_lines[0], header)
self.assertEqual(shard_lines[-1], footer)
read_result.extend(shard_lines[1:-1])
self.assertEqual(sorted(read_result), sorted(lines))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit 4f909c9

Please sign in to comment.