Skip to content

Commit

Permalink
Allow one to bound the size of output shards when writing to files. (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
robertwb authored and lostluck committed Aug 25, 2022
1 parent 5d765ea commit fd3adb1
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 3 deletions.
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/io/filebasedsink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
shard_name_template=None,
mime_type='application/octet-stream',
compression_type=CompressionTypes.AUTO,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""
Raises:
Expand Down Expand Up @@ -108,6 +111,8 @@ def __init__(
shard_name_template)
self.compression_type = compression_type
self.mime_type = mime_type
self.max_records_per_shard = max_records_per_shard
self.max_bytes_per_shard = max_bytes_per_shard
self.skip_if_empty = skip_if_empty

def display_data(self):
Expand All @@ -130,7 +135,13 @@ def open(self, temp_path):
The returned file handle is passed to ``write_[encoded_]record`` and
``close``.
"""
return FileSystems.create(temp_path, self.mime_type, self.compression_type)
writer = FileSystems.create(
temp_path, self.mime_type, self.compression_type)
if self.max_bytes_per_shard:
self.byte_counter = _ByteCountingWriter(writer)
return self.byte_counter
else:
return writer

def write_record(self, file_handle, value):
"""Writes a single record go the file handle returned by ``open()``.
Expand Down Expand Up @@ -406,10 +417,36 @@ def __init__(self, sink, temp_shard_path):
self.sink = sink
self.temp_shard_path = temp_shard_path
self.temp_handle = self.sink.open(temp_shard_path)
self.num_records_written = 0

def write(self, value):
self.num_records_written += 1
self.sink.write_record(self.temp_handle, value)

def at_capacity(self):
return (
self.sink.max_records_per_shard and
self.num_records_written >= self.sink.max_records_per_shard
) or (
self.sink.max_bytes_per_shard and
self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)

def close(self):
self.sink.close(self.temp_handle)
return self.temp_shard_path


class _ByteCountingWriter:
def __init__(self, writer):
self.writer = writer
self.bytes_written = 0

def write(self, bs):
self.bytes_written += len(bs)
self.writer.write(bs)

def flush(self):
self.writer.flush()

def close(self):
self.writer.close()
12 changes: 11 additions & 1 deletion sdks/python/apache_beam/io/iobase.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,8 @@ class Writer(object):
writing to a sink.
"""
def write(self, value):
"""Writes a value to the sink using the current writer."""
"""Writes a value to the sink using the current writer.
"""
raise NotImplementedError

def close(self):
Expand All @@ -863,6 +864,12 @@ def close(self):
"""
raise NotImplementedError

def at_capacity(self) -> bool:
"""Returns whether this writer should be considered at capacity
and a new one should be created.
"""
return False


class Read(ptransform.PTransform):
"""A transform that reads a PCollection."""
Expand Down Expand Up @@ -1185,6 +1192,9 @@ def process(self, element, init_result):
# We ignore UUID collisions here since they are extremely rare.
self.writer = self.sink.open_writer(init_result, str(uuid.uuid4()))
self.writer.write(element)
if self.writer.at_capacity():
yield self.writer.close()
self.writer = None

def finish_bundle(self):
if self.writer is not None:
Expand Down
28 changes: 27 additions & 1 deletion sdks/python/apache_beam/io/textio.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,9 @@ def __init__(self,
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
"""Initialize a _TextSink.
Expand Down Expand Up @@ -469,6 +472,14 @@ def __init__(self,
append_trailing_newlines is set, '\n' will be added.
footer: String to write at the end of file as a footer. If not None and
append_trailing_newlines is set, '\n' will be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
Returns:
Expand All @@ -482,6 +493,8 @@ def __init__(self,
coder=coder,
mime_type='text/plain',
compression_type=compression_type,
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)
self._append_trailing_newlines = append_trailing_newlines
self._header = header
Expand Down Expand Up @@ -791,6 +804,9 @@ def __init__(
compression_type=CompressionTypes.AUTO,
header=None,
footer=None,
*,
max_records_per_shard=None,
max_bytes_per_shard=None,
skip_if_empty=False):
r"""Initialize a :class:`WriteToText` transform.
Expand Down Expand Up @@ -830,6 +846,14 @@ def __init__(
footer (str): String to write at the end of file as a footer.
If not :data:`None` and **append_trailing_newlines** is set, ``\n`` will
be added.
max_records_per_shard: Maximum number of records to write to any
individual shard.
max_bytes_per_shard: Target maximum number of bytes to write to any
individual shard. This may be exceeded slightly, as a new shard is
created once this limit is hit, but the remainder of a given record, a
subsequent newline, and a footer may cause the actual shard size
to exceed this value. This also tracks the uncompressed,
not compressed, size of the shard.
skip_if_empty: Don't write any shards if the PCollection is empty.
"""

Expand All @@ -843,7 +867,9 @@ def __init__(
compression_type,
header,
footer,
skip_if_empty)
max_records_per_shard=max_records_per_shard,
max_bytes_per_shard=max_bytes_per_shard,
skip_if_empty=skip_if_empty)

def expand(self, pcoll):
return pcoll | Write(self._sink)
42 changes: 42 additions & 0 deletions sdks/python/apache_beam/io/textio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,6 +1668,48 @@ def test_write_empty_skipped(self):
outputs = list(glob.glob(self.path + '*'))
self.assertEqual(outputs, [])

def test_write_max_records_per_shard(self):
records_per_shard = 13
lines = [str(i).encode('utf-8') for i in range(100)]
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path, max_records_per_shard=records_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
shard_lines = list(f.read().splitlines())
self.assertLessEqual(len(shard_lines), records_per_shard)
read_result.extend(shard_lines)
self.assertEqual(sorted(read_result), sorted(lines))

def test_write_max_bytes_per_shard(self):
bytes_per_shard = 300
max_len = 100
lines = [b'x' * i for i in range(max_len)]
header = b'a' * 20
footer = b'b' * 30
with TestPipeline() as p:
# pylint: disable=expression-not-assigned
p | beam.core.Create(lines) | WriteToText(
self.path,
header=header,
footer=footer,
max_bytes_per_shard=bytes_per_shard)

read_result = []
for file_name in glob.glob(self.path + '*'):
with open(file_name, 'rb') as f:
contents = f.read()
self.assertLessEqual(
len(contents), bytes_per_shard + max_len + len(footer) + 2)
shard_lines = list(contents.splitlines())
self.assertEqual(shard_lines[0], header)
self.assertEqual(shard_lines[-1], footer)
read_result.extend(shard_lines[1:-1])
self.assertEqual(sorted(read_result), sorted(lines))


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down

0 comments on commit fd3adb1

Please sign in to comment.