From b1f2420c15749e3ebeeb27460a74147a99b785c8 Mon Sep 17 00:00:00 2001 From: Zachary Charles Date: Wed, 2 Oct 2024 08:58:39 -0700 Subject: [PATCH] Fix tests to use TestPipeline() for assertions. PiperOrigin-RevId: 681472544 --- dataset_grouper/count_utils_test.py | 35 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/dataset_grouper/count_utils_test.py b/dataset_grouper/count_utils_test.py index 592d793..f44f87f 100644 --- a/dataset_grouper/count_utils_test.py +++ b/dataset_grouper/count_utils_test.py @@ -11,12 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for count_utils.py.""" from absl.testing import absltest import apache_beam as beam +# pylint: disable=g-importing-member +from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +# pylint: enable=g-importing-member from dataset_grouper import count_utils @@ -46,16 +48,15 @@ def test_merge_client_records(self): ], ), ] - with beam.Pipeline() as root: - keyed_counts = root | beam.Create(test_counts) - merged_counts = keyed_counts | beam.ParDo(count_utils.MergeGroupCounts()) - expected_counts = [ (b'a', count_utils.GroupCount(6, 12, 17)), (b'b', count_utils.GroupCount(5, 3, 1)), (b'c', count_utils.GroupCount(7, 1, 2)), ] - assert_that(merged_counts, equal_to(expected_counts)) + with TestPipeline() as p: + keyed_counts = p | beam.Create(test_counts) + merged_counts = keyed_counts | beam.ParDo(count_utils.MergeGroupCounts()) + assert_that(merged_counts, equal_to(expected_counts)) def test_format_group_counts(self): test_records = [ @@ -63,16 +64,15 @@ def test_format_group_counts(self): (b'b', count_utils.GroupCount(2, 3, 5)), (b'c', count_utils.GroupCount(3, 7, 10)), ] - with beam.Pipeline() as root: - keyed_records = root | beam.Create(test_records) - actual_result = keyed_records | beam.ParDo(count_utils.FormatGroupCount()) - expected_result = [ 'a,1,2,3', 'b,2,3,5', 'c,3,7,10', ] - assert_that(actual_result, equal_to(expected_result)) + with TestPipeline() as p: + keyed_records = p | beam.Create(test_records) + actual_result = keyed_records | beam.ParDo(count_utils.FormatGroupCount()) + assert_that(actual_result, equal_to(expected_result)) def test_format_group_counts_with_delimiter(self): test_records = [ @@ -80,18 +80,17 @@ def test_format_group_counts_with_delimiter(self): (b'b', count_utils.GroupCount(2, 3, 5)), (b'c', count_utils.GroupCount(3, 7, 10)), ] - with beam.Pipeline() as root: - keyed_records = root | beam.Create(test_records) - actual_result = keyed_records | beam.ParDo( - count_utils.FormatGroupCount(), delimiter='+' - ) - expected_result = [ 'a+1+2+3', 'b+2+3+5', 'c+3+7+10', ] - assert_that(actual_result, equal_to(expected_result)) + with TestPipeline() as p: + keyed_records = p | beam.Create(test_records) + actual_result = keyed_records | beam.ParDo( + count_utils.FormatGroupCount(), delimiter='+' + ) + assert_that(actual_result, equal_to(expected_result)) if __name__ == '__main__':