Skip to content

Commit

Permalink
Fix tests to use TestPipeline() for assertions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681462683
  • Loading branch information
zcharles8 authored and copybara-github committed Oct 2, 2024
1 parent b656a7a commit e3ff114
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions dataset_grouper/count_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for count_utils.py."""

from absl.testing import absltest
import apache_beam as beam
# pylint: disable=g-importing-member
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
# pylint: enable=g-importing-member
from dataset_grouper import count_utils


Expand Down Expand Up @@ -46,52 +48,49 @@ def test_merge_client_records(self):
],
),
]
with beam.Pipeline() as root:
keyed_counts = root | beam.Create(test_counts)
merged_counts = keyed_counts | beam.ParDo(count_utils.MergeGroupCounts())

expected_counts = [
(b'a', count_utils.GroupCount(6, 12, 17)),
(b'b', count_utils.GroupCount(5, 3, 1)),
(b'c', count_utils.GroupCount(7, 1, 2)),
]
assert_that(merged_counts, equal_to(expected_counts))
with TestPipeline() as p:
keyed_counts = p | beam.Create(test_counts)
merged_counts = keyed_counts | beam.ParDo(count_utils.MergeGroupCounts())
assert_that(merged_counts, equal_to(expected_counts))

def test_format_group_counts(self):
test_records = [
(b'a', count_utils.GroupCount(1, 2, 3)),
(b'b', count_utils.GroupCount(2, 3, 5)),
(b'c', count_utils.GroupCount(3, 7, 10)),
]
with beam.Pipeline() as root:
keyed_records = root | beam.Create(test_records)
actual_result = keyed_records | beam.ParDo(count_utils.FormatGroupCount())

expected_result = [
'a,1,2,3',
'b,2,3,5',
'c,3,7,10',
]
assert_that(actual_result, equal_to(expected_result))
with TestPipeline() as p:
keyed_records = p | beam.Create(test_records)
actual_result = keyed_records | beam.ParDo(count_utils.FormatGroupCount())
assert_that(actual_result, equal_to(expected_result))

def test_format_group_counts_with_delimiter(self):
test_records = [
(b'a', count_utils.GroupCount(1, 2, 3)),
(b'b', count_utils.GroupCount(2, 3, 5)),
(b'c', count_utils.GroupCount(3, 7, 10)),
]
with beam.Pipeline() as root:
keyed_records = root | beam.Create(test_records)
actual_result = keyed_records | beam.ParDo(
count_utils.FormatGroupCount(), delimiter='+'
)

expected_result = [
'a+1+2+3',
'b+2+3+5',
'c+3+7+10',
]
assert_that(actual_result, equal_to(expected_result))
with TestPipeline() as p:
keyed_records = p | beam.Create(test_records)
actual_result = keyed_records | beam.ParDo(
count_utils.FormatGroupCount(), delimiter='+'
)
assert_that(actual_result, equal_to(expected_result))


if __name__ == '__main__':
Expand Down

0 comments on commit e3ff114

Please sign in to comment.