diff --git a/CHANGES.md b/CHANGES.md index 206ac3ba11ad..8538432a2ccc 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -107,8 +107,12 @@ * Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). * Debezium IO upgraded to 3.1.1 requires Java 17 (Java) ([#34747](https://github.com/apache/beam/issues/34747)). * Add support for streaming writes in IOBase (Python) +* Add IT test for streaming writes for IOBase (Python) * Implement support for streaming writes in FileBasedSink (Python) +* Expose support for streaming writes in AvroIO (Python) +* Expose support for streaming writes in ParquetIO (Python) * Expose support for streaming writes in TextIO (Python) +* Expose support for streaming writes in TFRecordsIO (Python) ## New Features / Improvements diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 553b6c741f3d..da904bf6fb55 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -354,8 +354,7 @@ def split_points_unclaimed(stop_position): while range_tracker.try_claim(next_block_start): block = next(blocks) next_block_start = block.offset + block.size - for record in block: - yield record + yield from block _create_avro_source = _FastAvroSource @@ -375,7 +374,8 @@ def __init__( num_shards=0, shard_name_template=None, mime_type='application/x-avro', - use_fastavro=True): + use_fastavro=True, + triggering_frequency=None): """Initialize a WriteToAvro transform. Args: @@ -393,17 +393,30 @@ def __init__( Constraining the number of shards is likely to reduce the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. + In streaming if not set, the service will write a file per bundle. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. use_fastavro (bool): This flag is left for API backwards compatibility and no longer has an effect. Do not use. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToAvro transform usable for writing. @@ -411,7 +424,7 @@ def __init__( self._schema = schema self._sink_provider = lambda avro_schema: _create_avro_sink( file_path_prefix, avro_schema, codec, file_name_suffix, num_shards, - shard_name_template, mime_type) + shard_name_template, mime_type, triggering_frequency) def expand(self, pcoll): if self._schema: @@ -428,6 +441,15 @@ def expand(self, pcoll): records = pcoll | beam.Map( beam_row_to_avro_dict(avro_schema, beam_schema)) self._sink = self._sink_provider(avro_schema) + if (not pcoll.is_bounded and self._sink.shard_name_template + == filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + return records | beam.io.iobase.Write(self._sink) def display_data(self): @@ -441,7 +463,8 @@ def _create_avro_sink( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency=60): if "class 'avro.schema" in str(type(schema)): raise ValueError( 'You are using Avro IO with fastavro (default with Beam on ' @@ -454,7 +477,8 @@ def _create_avro_sink( file_name_suffix, num_shards, shard_name_template, - mime_type) + mime_type, + triggering_frequency) class _BaseAvroSink(filebasedsink.FileBasedSink): @@ -467,7 +491,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, file_name_suffix=file_name_suffix, @@ -477,7 +502,8 @@ def __init__( mime_type=mime_type, # Compression happens at the block level using the supplied codec, and # not at the file level. - compression_type=CompressionTypes.UNCOMPRESSED) + compression_type=CompressionTypes.UNCOMPRESSED, + triggering_frequency=triggering_frequency) self._schema = schema self._codec = codec @@ -498,7 +524,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, schema, @@ -506,7 +533,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type) + mime_type, + triggering_frequency) self.file_handle = None def open(self, temp_path): diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 6dd9e620c665..6669b6fb8abf 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -16,11 +16,15 @@ # # pytype: skip-file +import glob import json import logging import math import os +import pytz import pytest +import re +import shutil import tempfile import unittest from typing import List, Any @@ -47,14 +51,17 @@ from apache_beam.io.filesystems import FileSystems from apache_beam.options.pipeline_options import StandardOptions from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher from apache_beam.transforms.sql import SqlTransform from apache_beam.transforms.userstate import CombiningValueStateSpec +from apache_beam.transforms.util import LogElements from apache_beam.utils.timestamp import Timestamp from apache_beam.typehints import schemas +from datetime import datetime # Import snappy optionally; some tests will be skipped when import fails. try: @@ -673,6 +680,273 @@ def _write_data( return f.name +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(60), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation, there are more (see + # above docs) + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".avro", + num_shards=num_shards, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[1614556800.0, 1614556805.0)-00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.avro$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".avro", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.avro$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #AvroIO + avroschema = { + 'name': 'dummy', # your supposed to be file name with .avro extension + 'type': 'record', # type of avro serilazation + 'fields': [ # this defines actual keys & their types + {'name': 'age', 'type': 'int'}, + ], + } + output2 = output | 'WriteToAvro' >> beam.io.WriteToAvro( + file_path_prefix=self.tempdir + "/ouput_WriteToAvro", + file_name_suffix=".txt", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + schema=avroschema) + _ = output2 | 'LogElements after WriteToAvro' >> LogElements( + prefix='after WriteToAvro ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToAvro-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.avro + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.txt$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToAvro*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/iobase_it_test.py b/sdks/python/apache_beam/io/iobase_it_test.py new file mode 100644 index 000000000000..acb44f4085bc --- /dev/null +++ b/sdks/python/apache_beam/io/iobase_it_test.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +import logging +import unittest +import uuid + +import apache_beam as beam +from apache_beam.io.textio import WriteToText +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.transforms.window import FixedWindows + +# End-to-End tests for iobase +# Usage: +# cd sdks/python +# pip install build && python -m build --sdist +# DataflowRunner: +# python -m pytest -o log_cli=True -o log_level=Info \ +# apache_beam/io/iobase_it_test.py::IOBaseITTest \ +# --test-pipeline-options="--runner=TestDataflowRunner \ +# --project=apache-beam-testing --region=us-central1 \ +# --temp_location=gs://apache-beam-testing-temp/temp \ +# --sdk_location=dist/apache_beam-2.65.0.dev0.tar.gz" + + +class IOBaseITTest(unittest.TestCase): + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.runner_name = type(self.test_pipeline.runner).__name__ + + def test_unbounded_pcoll_without_global_window(self): + # https://github.com/apache/beam/issues/25598 + + args = self.test_pipeline.get_full_options_as_args(streaming=True) + + topic = 'projects/pubsub-public-data/topics/taxirides-realtime' + unique_id = str(uuid.uuid4()) + output_file = f'gs://apache-beam-testing-integration-testing/iobase/test-{unique_id}' # pylint: disable=line-too-long + + p = beam.Pipeline(argv=args) + # Read from Pub/Sub with fixed windowing + lines = ( + p + | "ReadFromPubSub" >> beam.io.ReadFromPubSub(topic=topic) + | "WindowInto" >> beam.WindowInto(FixedWindows(10))) + + # Write to text file + _ = lines | 'WriteToText' >> WriteToText(output_file) + + result = p.run() + result.wait_until_finish(duration=60 * 1000) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/io/parquetio.py b/sdks/python/apache_beam/io/parquetio.py index 48c51428c17d..fa8b56f916dc 100644 --- a/sdks/python/apache_beam/io/parquetio.py +++ b/sdks/python/apache_beam/io/parquetio.py @@ -48,6 +48,7 @@ from apache_beam.transforms import PTransform from apache_beam.transforms import window from apache_beam.typehints import schemas +from apache_beam.utils.windowed_value import WindowedValue try: import pyarrow as pa @@ -105,8 +106,10 @@ def __init__( self._buffer_size = record_batch_size self._record_batches = [] self._record_batches_byte_size = 0 + self._window = None - def process(self, row): + def process(self, row, w=DoFn.WindowParam, pane=DoFn.PaneInfoParam): + self._window = w if len(self._buffer[0]) >= self._buffer_size: self._flush_buffer() @@ -123,7 +126,17 @@ def finish_bundle(self): self._flush_buffer() if self._record_batches_byte_size > 0: table = self._create_table() - yield window.GlobalWindows.windowed_value_at_end_of_window(table) + if self._window is None or isinstance(self._window, window.GlobalWindow): + # bounded input + yield window.GlobalWindows.windowed_value_at_end_of_window(table) + else: + # unbounded input + yield WindowedValue( + table, + timestamp=self._window. + end, #or it could be max of timestamp of the rows processed + windows=[self._window] # TODO(pabloem) HOW DO WE GET THE PANE + ) def display_data(self): res = super().display_data() @@ -476,7 +489,9 @@ def __init__( file_name_suffix='', num_shards=0, shard_name_template=None, - mime_type='application/x-parquet'): + mime_type='application/x-parquet', + triggering_frequency=None, + ): """Initialize a WriteToParquet transform. Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of @@ -540,14 +555,26 @@ def __init__( the performance of a pipeline. Setting this value is not recommended unless you require a specific number of output files. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToParquet transform usable for writing. @@ -567,10 +594,20 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template + == filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + if self._schema is None: try: beam_schema = schemas.schema_from_element_type(pcoll.element_type) @@ -583,7 +620,11 @@ def expand(self, pcoll): else: convert_fn = _RowDictionariesToArrowTable( self._schema, self._row_group_buffer_size, self._record_batch_size) - return pcoll | ParDo(convert_fn) | Write(self._sink) + if pcoll.is_bounded: + return pcoll | ParDo(convert_fn) | Write(self._sink) + else: + self._sink.convert_fn = convert_fn + return pcoll | Write(self._sink) def display_data(self): return { @@ -610,7 +651,7 @@ def __init__( num_shards=0, shard_name_template=None, mime_type='application/x-parquet', - ): + triggering_frequency=None): """Initialize a WriteToParquetBatched transform. Writes parquet files from a :class:`~apache_beam.pvalue.PCollection` of @@ -668,11 +709,21 @@ def __init__( the shard number and shard count. When constructing a filename for a particular shard number, the upper-case letters 'S' and 'N' are replaced with the 0-padded shard number and shard count respectively. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().isoformat(), + window.end.to_utc_datetime().isoformat()`` This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + set to 1 and only one file will be generated. + The default pattern used is '-SSSSS-of-NNNNN' if None is passed as the + shard_name_template and the PCollection is bounded. + The default pattern used is '-W-SSSSS-of-NNNNN' if None is passed as the + shard_name_template and the PCollection is unbounded. mime_type: The MIME type to use for the produced files, if the filesystem supports specifying MIME types. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. Returns: A WriteToParquetBatched transform usable for writing. @@ -688,10 +739,19 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template + == filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) return pcoll | Write(self._sink) def display_data(self): @@ -707,7 +767,8 @@ def _create_parquet_sink( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency=60): return \ _ParquetSink( file_path_prefix, @@ -718,7 +779,8 @@ def _create_parquet_sink( file_name_suffix, num_shards, shard_name_template, - mime_type + mime_type, + triggering_frequency ) @@ -734,7 +796,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - mime_type): + mime_type, + triggering_frequency): super().__init__( file_path_prefix, file_name_suffix=file_name_suffix, @@ -744,7 +807,8 @@ def __init__( mime_type=mime_type, # Compression happens at the block level using the supplied codec, and # not at the file level. - compression_type=CompressionTypes.UNCOMPRESSED) + compression_type=CompressionTypes.UNCOMPRESSED, + triggering_frequency=triggering_frequency) self._schema = schema self._codec = codec if ARROW_MAJOR_VERSION == 1 and self._codec.lower() == "lz4": diff --git a/sdks/python/apache_beam/io/parquetio_it_test.py b/sdks/python/apache_beam/io/parquetio_it_test.py index 052b54f3ebfb..b06e7268fec4 100644 --- a/sdks/python/apache_beam/io/parquetio_it_test.py +++ b/sdks/python/apache_beam/io/parquetio_it_test.py @@ -19,10 +19,14 @@ import logging import string import unittest +import uuid from collections import Counter +from datetime import datetime import pytest +import pytz +import apache_beam as beam from apache_beam import Create from apache_beam import DoFn from apache_beam import FlatMap @@ -37,6 +41,7 @@ from apache_beam.testing.util import BeamAssertException from apache_beam.transforms import CombineGlobally from apache_beam.transforms.combiners import Count +from apache_beam.transforms.periodicsequence import PeriodicImpulse try: import pyarrow as pa @@ -142,6 +147,42 @@ def get_int(self): return i +@unittest.skipIf(pa is None, "PyArrow is not installed.") +class WriteStreamingIT(unittest.TestCase): + def setUp(self): + self.test_pipeline = TestPipeline(is_integration_test=True) + self.runner_name = type(self.test_pipeline.runner).__name__ + super().setUp() + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + + args = self.test_pipeline.get_full_options_as_args(streaming=True) + + unique_id = str(uuid.uuid4()) + output_file = f'gs://apache-beam-testing-integration-testing/iobase/test-{unique_id}' # pylint: disable=line-too-long + p = beam.Pipeline(argv=args) + pyschema = pa.schema([('age', pa.int64())]) + + _ = ( + p + | "generate impulse" >> PeriodicImpulse( + start_timestamp=datetime(2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp(), + stop_timestamp=datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp(), + fire_interval=1) + | "generate data" >> beam.Map(lambda t: {'age': t * 10}) + | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=output_file, + file_name_suffix=".parquet", + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema)) + result = p.run() + result.wait_until_finish(duration=600 * 1000) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/parquetio_test.py b/sdks/python/apache_beam/io/parquetio_test.py index c602f4cc801b..9371705a1fa3 100644 --- a/sdks/python/apache_beam/io/parquetio_test.py +++ b/sdks/python/apache_beam/io/parquetio_test.py @@ -16,17 +16,21 @@ # # pytype: skip-file +import glob import json import logging import os +import re import shutil import tempfile import unittest +from datetime import datetime from tempfile import TemporaryDirectory import hamcrest as hc import pandas import pytest +import pytz from parameterized import param from parameterized import parameterized @@ -45,10 +49,12 @@ from apache_beam.io.parquetio import _create_parquet_sink from apache_beam.io.parquetio import _create_parquet_source from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display_test import DisplayDataItemMatcher +from apache_beam.transforms.util import LogElements try: import pyarrow as pa @@ -655,6 +661,290 @@ def test_read_all_from_parquet_with_filename(self): equal_to(result)) +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = (p | GenerateEvent.sample_data()) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_undef_shards_default_shard_name_template_windowed_pcoll( # pylint: disable=line-too-long + self): + with TestPipeline() as p: + output = ( + p | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(10), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0))) + #ParquetIO + pyschema = pa.schema([('age', pa.int64())]) + output2 = output | 'WriteToParquet' >> beam.io.WriteToParquet( + file_path_prefix=self.tempdir + "/ouput_WriteToParquet", + file_name_suffix=".parquet", + num_shards=0, + schema=pyschema) + _ = output2 | 'LogElements after WriteToParquet' >> LogElements( + prefix='after WriteToParquet ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToParquet-[1614556800.0, 1614556805.0)-00000-of-00002.parquet + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.parquet$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToParquet*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertGreaterEqual( + len(file_names), + 1 * 3, #25s of data covered by 3 10s windows + "expected %d files, but got: %d" % (1 * 3, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/io/tfrecordio.py b/sdks/python/apache_beam/io/tfrecordio.py index b911c64a1348..e27ea5070b06 100644 --- a/sdks/python/apache_beam/io/tfrecordio.py +++ b/sdks/python/apache_beam/io/tfrecordio.py @@ -290,7 +290,8 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - compression_type): + compression_type, + triggering_frequency=60): """Initialize a TFRecordSink. See WriteToTFRecord for details.""" super().__init__( @@ -300,7 +301,8 @@ def __init__( num_shards=num_shards, shard_name_template=shard_name_template, mime_type='application/octet-stream', - compression_type=compression_type) + compression_type=compression_type, + triggering_frequency=triggering_frequency) def write_encoded_record(self, file_handle, value): _TFRecordUtil.write_record(file_handle, value) @@ -315,7 +317,8 @@ def __init__( file_name_suffix='', num_shards=0, shard_name_template=None, - compression_type=CompressionTypes.AUTO): + compression_type=CompressionTypes.AUTO, + triggering_frequency=None): """Initialize WriteToTFRecord transform. Args: @@ -326,16 +329,29 @@ def __init__( file_name_suffix: Suffix for the files written. num_shards: The number of files (shards) used for output. If not set, the default value will be used. + In streaming if not set, the service will write a file per bundle. shard_name_template: A template string containing placeholders for - the shard number and shard count. When constructing a filename for a - particular shard number, the upper-case letters 'S' and 'N' are - replaced with the 0-padded shard number and shard count respectively. - This argument can be '' in which case it behaves as if num_shards was - set to 1 and only one file will be generated. The default pattern used - is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. + the shard number and shard count. Currently only ``''``, + ``'-SSSSS-of-NNNNN'``, ``'-W-SSSSS-of-NNNNN'`` and + ``'-V-SSSSS-of-NNNNN'`` are patterns accepted by the service. + When constructing a filename for a particular shard number, the + upper-case letters ``S`` and ``N`` are replaced with the ``0``-padded + shard number and shard count respectively. This argument can be ``''`` + in which case it behaves as if num_shards was set to 1 and only one file + will be generated. The default pattern used is ``'-SSSSS-of-NNNNN'`` for + bounded PCollections and for ``'-W-SSSSS-of-NNNNN'`` unbounded + PCollections. + W is used for windowed shard naming and is replaced with + ``[window.start, window.end)`` + V is used for windowed shard naming and is replaced with + ``[window.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S"), + window.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S")`` compression_type: Used to handle compressed output files. Typical value is CompressionTypes.AUTO, in which case the file_path's extension will be used to detect the compression. + triggering_frequency: (int) Every triggering_frequency duration, a window + will be triggered and all bundles in the window will be written. + If set it overrides user windowing. Mandatory for GlobalWindow. Returns: A WriteToTFRecord transform object. @@ -347,7 +363,17 @@ def __init__( file_name_suffix, num_shards, shard_name_template, - compression_type) + compression_type, + triggering_frequency) def expand(self, pcoll): + if (not pcoll.is_bounded and self._sink.shard_name_template + == filebasedsink.DEFAULT_SHARD_NAME_TEMPLATE): + self._sink.shard_name_template = ( + filebasedsink.DEFAULT_WINDOW_SHARD_NAME_TEMPLATE) + self._sink.shard_name_format = self._sink._template_to_format( + self._sink.shard_name_template) + self._sink.shard_name_glob_format = self._sink._template_to_glob_format( + self._sink.shard_name_template) + return pcoll | Write(self._sink) diff --git a/sdks/python/apache_beam/io/tfrecordio_test.py b/sdks/python/apache_beam/io/tfrecordio_test.py index a867c0212ad3..6522ade36d80 100644 --- a/sdks/python/apache_beam/io/tfrecordio_test.py +++ b/sdks/python/apache_beam/io/tfrecordio_test.py @@ -21,15 +21,20 @@ import glob import gzip import io +import json import logging import os import pickle import random import re +import shutil +import tempfile import unittest import zlib +from datetime import datetime import crcmod +import pytz import apache_beam as beam from apache_beam import Create @@ -41,9 +46,11 @@ from apache_beam.io.tfrecordio import _TFRecordSink from apache_beam.io.tfrecordio import _TFRecordUtil from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.test_stream import TestStream from apache_beam.testing.test_utils import TempDir from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.transforms.util import LogElements try: import tensorflow.compat.v1 as tf # pylint: disable=import-error @@ -558,6 +565,258 @@ def test_end2end_read_write_read(self): assert_that(actual_data, equal_to(expected_data)) +class GenerateEvent(beam.PTransform): + @staticmethod + def sample_data(): + return GenerateEvent() + + def expand(self, input): + elemlist = [{'age': 10}, {'age': 20}, {'age': 30}] + elem = elemlist + return ( + input + | TestStream().add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 1, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 2, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 3, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 4, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 5, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 6, + 0, tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 7, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 8, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 9, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 10, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 11, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 12, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 13, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 14, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 15, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 16, 0, + tzinfo=pytz.UTC).timestamp()). + add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 17, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 18, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 19, 0, + tzinfo=pytz.UTC).timestamp()). + advance_watermark_to( + datetime(2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).add_elements( + elements=elem, + event_timestamp=datetime( + 2021, 3, 1, 0, 0, 20, 0, + tzinfo=pytz.UTC).timestamp()).advance_watermark_to( + datetime( + 2021, 3, 1, 0, 0, 25, 0, tzinfo=pytz.UTC). + timestamp()).advance_watermark_to_infinity()) + + +class WriteStreamingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tempdir = tempfile.mkdtemp() + + def tearDown(self): + if os.path.exists(self.tempdir): + shutil.rmtree(self.tempdir) + + def test_write_streaming_2_shards_default_shard_name_template( + self, num_shards=2): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | 'User windowing' >> beam.transforms.core.WindowInto( + beam.transforms.window.FixedWindows(60), + trigger=beam.transforms.trigger.AfterWatermark(), + accumulation_mode=beam.transforms.trigger.AccumulationMode. + DISCARDING, + allowed_lateness=beam.utils.timestamp.Duration(seconds=0)) + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + num_shards=num_shards, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[1614556800.0, 1614556805.0)-00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P[\d\.]+), ' + r'(?P[\d\.]+|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template( + self, num_shards=2, shard_name_template='-V-SSSSS-of-NNNNN'): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=60, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + self.assertEqual( + len(file_names), + num_shards, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + def test_write_streaming_2_shards_custom_shard_name_template_5s_window( + self, + num_shards=2, + shard_name_template='-V-SSSSS-of-NNNNN', + triggering_frequency=5): + with TestPipeline() as p: + output = ( + p + | GenerateEvent.sample_data() + | "encode" >> beam.Map(lambda s: json.dumps(s).encode('utf-8'))) + #TFrecordIO + output2 = output | 'WriteToTFRecord' >> beam.io.WriteToTFRecord( + file_path_prefix=self.tempdir + "/ouput_WriteToTFRecord", + file_name_suffix=".tfrecord", + shard_name_template=shard_name_template, + num_shards=num_shards, + triggering_frequency=triggering_frequency, + ) + _ = output2 | 'LogElements after WriteToTFRecord' >> LogElements( + prefix='after WriteToTFRecord ', with_window=True, level=logging.INFO) + + # Regex to match the expected windowed file pattern + # Example: + # ouput_WriteToTFRecord-[2021-03-01T00-00-00, 2021-03-01T00-01-00)- + # 00000-of-00002.tfrecord + # It captures: window_interval, shard_num, total_shards + pattern_string = ( + r'.*-\[(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}), ' + r'(?P\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}|Infinity)\)-' + r'(?P\d{5})-of-(?P\d{5})\.tfrecord$') + pattern = re.compile(pattern_string) + file_names = [] + for file_name in glob.glob(self.tempdir + '/ouput_WriteToTFRecord*'): + match = pattern.match(file_name) + self.assertIsNotNone( + match, f"File name {file_name} did not match expected pattern.") + if match: + file_names.append(file_name) + print("Found files matching expected pattern:", file_names) + # for 5s window size, the input should be processed by 5 windows with + # 2 shards per window + self.assertEqual( + len(file_names), + 10, + "expected %d files, but got: %d" % (num_shards, len(file_names))) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()