Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions sdks/python/apache_beam/io/gcp/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,27 +69,43 @@ class PubsubMessage(object):
attributes: (dict) Key-value map of str to str, containing both user-defined
and service generated attributes (such as id_label and
timestamp_attribute). May be None.

message_id: (string) Message_Id as generated by PubSub
publish_time: (timestamp) Published time of message as generated by PubSub
as per google.protobuf.timestamp_pb2.Timestamp

TODO(BEAM-7819): message_id and publish_time are not populated
on Dataflow runner
"""
def __init__(self, data, attributes):
def __init__(self, data, attributes, message_id=None, publish_time=None):
if data is None and not attributes:
raise ValueError(
'Either data (%r) or attributes (%r) must be set.', data, attributes)
self.data = data
self.attributes = attributes
self.message_id = message_id
self.publish_time = publish_time

def __hash__(self):
if self.publish_time is not None:
return hash((self.data, frozenset(self.attributes.items()),
self.message_id, self.publish_time.seconds,
self.publish_time.nanos))

return hash((self.data, frozenset(self.attributes.items())))

def __eq__(self, other):
return isinstance(other, PubsubMessage) and (
self.data == other.data and self.attributes == other.attributes)
self.attributes == other.attributes and
self.message_id == other.message_id and
self.publish_time == other.publish_time)

def __ne__(self, other):
# TODO(BEAM-5949): Needed for Python 2 compatibility.
return not self == other

def __repr__(self):
return 'PubsubMessage(%s, %s)' % (self.data, self.attributes)
return 'PubsubMessage(%s, %s, %s, %s)' % (self.data, self.attributes, self.message_id, self.publish_time)

@staticmethod
def _from_proto_str(proto_msg):
Expand All @@ -108,7 +124,7 @@ def _from_proto_str(proto_msg):
msg.ParseFromString(proto_msg)
# Convert ScalarMapContainer to dict.
attributes = dict((key, msg.attributes[key]) for key in msg.attributes)
return PubsubMessage(msg.data, attributes)
return PubsubMessage(msg.data, attributes, msg.message_id, msg.publish_time)

def _to_proto_str(self):
"""Get serialized form of ``PubsubMessage``.
Expand All @@ -125,6 +141,8 @@ def _to_proto_str(self):
msg.data = self.data
for key, value in iteritems(self.attributes):
msg.attributes[key] = value
msg.publish_time = self.publish_time
msg.message_id = self.message_id
return msg.SerializeToString()

@staticmethod
Expand All @@ -137,7 +155,7 @@ def _from_message(msg):
"""
# Convert ScalarMapContainer to dict.
attributes = dict((key, msg.attributes[key]) for key in msg.attributes)
return PubsubMessage(msg.data, attributes)
return PubsubMessage(msg.data, attributes, msg.message_id, msg.publish_time)


class ReadFromPubSub(PTransform):
Expand Down
2 changes: 1 addition & 1 deletion sdks/python/apache_beam/io/gcp/pubsub_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,4 +674,4 @@ def test_write_messages_unsupported_features(self, mock_pubsub):

if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
unittest.main()
146 changes: 57 additions & 89 deletions sdks/python/apache_beam/io/gcp/tests/pubsub_matcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@

"""Unit test for PubSub verifier."""

# pytype: skip-file

from __future__ import absolute_import

import logging
import sys
import unittest

# patches unittest.TestCase to be python3 compatible
import future.tests.base # pylint: disable=unused-import
import mock
from hamcrest import assert_that as hc_assert_that

Expand All @@ -35,16 +31,22 @@
from apache_beam.testing.test_utils import PullResponseMessage
from apache_beam.testing.test_utils import create_pull_response

# Protect against environments where pubsub library is not available.
try:
from google.cloud import pubsub
from google.protobuf.timestamp_pb2 import Timestamp
except ImportError:
pubsub = None
Timestamp = None


@unittest.skipIf(pubsub is None, 'PubSub dependencies are not installed.')
@unittest.skipIf(Timestamp is None,
'Google Protobuf dependencies are not installed.')
@mock.patch('time.sleep', return_value=None)
@mock.patch('google.cloud.pubsub.SubscriberClient')
class PubSubMatcherTest(unittest.TestCase):

@classmethod
def setUpClass(cls):
# Method has been renamed in Python 3
Expand All @@ -54,21 +56,14 @@ def setUpClass(cls):
def setUp(self):
self.mock_presult = mock.MagicMock()

def init_matcher(
self, expected_msg=None, with_attributes=False, strip_attributes=None):
self.pubsub_matcher = PubSubMessageMatcher(
'mock_project',
'mock_sub_name',
expected_msg,
with_attributes=with_attributes,
strip_attributes=strip_attributes)

def init_counter_matcher(self, expected_msg_len=1):
def init_matcher(self, with_attributes=False, strip_attributes=None):
self.pubsub_matcher = PubSubMessageMatcher(
'mock_project', 'mock_sub_name', expected_msg_len=expected_msg_len)
'mock_project', 'mock_sub_name', ['mock_expected_msg'],
with_attributes=with_attributes, strip_attributes=strip_attributes)

def test_message_matcher_success(self, mock_get_sub, unsued_mock):
self.init_matcher(expected_msg=[b'a', b'b'])
self.init_matcher()
self.pubsub_matcher.expected_msg = [b'a', b'b']
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response([PullResponseMessage(b'a', {})]),
Expand All @@ -79,127 +74,100 @@ def test_message_matcher_success(self, mock_get_sub, unsued_mock):
self.assertEqual(mock_sub.acknowledge.call_count, 2)

def test_message_matcher_attributes_success(self, mock_get_sub, unsued_mock):
self.init_matcher(
expected_msg=[PubsubMessage(b'a', {'k': 'v'})], with_attributes=True)
self.init_matcher(with_attributes=True)
self.pubsub_matcher.expected_msg = [PubsubMessage(b'a', {'k': 'v'},
'0123456789',
Timestamp())]
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response([PullResponseMessage(b'a', {'k': 'v'})])
create_pull_response([PullResponseMessage(b'a', {'k': 'v'},
'0123456789',
Timestamp())])
]
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)

def test_message_matcher_attributes_fail(self, mock_get_sub, unsued_mock):
self.init_matcher(
expected_msg=[PubsubMessage(b'a', {})], with_attributes=True)
self.init_matcher(with_attributes=True)
self.pubsub_matcher.expected_msg = [PubsubMessage(b'a', {},
'0123456789',
Timestamp())]
mock_sub = mock_get_sub.return_value
# Unexpected attribute 'k'.
mock_sub.pull.side_effect = [
create_pull_response([PullResponseMessage(b'a', {'k': 'v'})])
create_pull_response([PullResponseMessage(b'a', {'k': 'v'},
'0123456789',
Timestamp())])
]
with self.assertRaisesRegex(AssertionError, r'Unexpected'):
with self.assertRaisesRegexp(AssertionError, r'Unexpected'):
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)

def test_message_matcher_strip_success(self, mock_get_sub, unsued_mock):
self.init_matcher(
expected_msg=[PubsubMessage(b'a', {'k': 'v'})],
with_attributes=True,
strip_attributes=['id', 'timestamp'])
self.init_matcher(with_attributes=True,
strip_attributes=['id', 'timestamp'])
self.pubsub_matcher.expected_msg = [PubsubMessage(b'a', {'k': 'v'},
'0123456789',
Timestamp())]
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response([
PullResponseMessage(
b'a', {
'id': 'foo', 'timestamp': 'bar', 'k': 'v'
})
])
]
mock_sub.pull.side_effect = [create_pull_response([
PullResponseMessage(b'a', {'id': 'foo', 'timestamp': 'bar', 'k': 'v'},
'0123456789', Timestamp())
])]
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)

def test_message_matcher_strip_fail(self, mock_get_sub, unsued_mock):
self.init_matcher(
expected_msg=[PubsubMessage(b'a', {'k': 'v'})],
with_attributes=True,
strip_attributes=['id', 'timestamp'])
self.init_matcher(with_attributes=True,
strip_attributes=['id', 'timestamp'])
self.pubsub_matcher.expected_msg = [PubsubMessage(b'a', {'k': 'v'},
'0123456789',
Timestamp())]
mock_sub = mock_get_sub.return_value
# Message is missing attribute 'timestamp'.
mock_sub.pull.side_effect = [
create_pull_response(
[PullResponseMessage(b'a', {
'id': 'foo', 'k': 'v'
})])
]
with self.assertRaisesRegex(AssertionError, r'Stripped attributes'):
mock_sub.pull.side_effect = [create_pull_response([
PullResponseMessage(b'a', {'id': 'foo', 'k': 'v'},
'0123456789', Timestamp())
])]
with self.assertRaisesRegexp(AssertionError, r'Stripped attributes'):
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)

def test_message_matcher_mismatch(self, mock_get_sub, unused_mock):
self.init_matcher(expected_msg=[b'a'])
self.init_matcher()
self.pubsub_matcher.expected_msg = [b'a']
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response(
[PullResponseMessage(b'c', {}), PullResponseMessage(b'd', {})]),
create_pull_response([PullResponseMessage(b'c', {}, '01',
Timestamp()),
PullResponseMessage(b'd', {}, '02',
Timestamp())]),
]
with self.assertRaises(AssertionError) as error:
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertCountEqual([b'c', b'd'], self.pubsub_matcher.messages)
self.assertIn(
'\nExpected: Expected 1 messages.\n but: Got 2 messages.',
str(error.exception.args[0]))
self.assertTrue(
'\nExpected: Expected 1 messages.\n but: Got 2 messages.'
in str(error.exception.args[0]))
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)

def test_message_matcher_timeout(self, mock_get_sub, unused_mock):
self.init_matcher(expected_msg=[b'a'])
self.init_matcher()
mock_sub = mock_get_sub.return_value
mock_sub.return_value.full_name.return_value = 'mock_sub'
self.pubsub_matcher.timeout = 0.1
with self.assertRaisesRegex(AssertionError, r'Expected 1.*\n.*Got 0'):
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertTrue(mock_sub.pull.called)
self.assertEqual(mock_sub.acknowledge.call_count, 0)

def test_message_count_matcher_below_fail(self, mock_get_sub, unused_mock):
self.init_counter_matcher(expected_msg_len=1)
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response(
[PullResponseMessage(b'c', {}), PullResponseMessage(b'd', {})]),
]
with self.assertRaises(AssertionError) as error:
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertIn(
'\nExpected: Expected 1 messages.\n but: Got 2 messages.',
str(error.exception.args[0]))

def test_message_count_matcher_above_fail(self, mock_get_sub, unused_mock):
self.init_counter_matcher(expected_msg_len=1)
mock_sub = mock_get_sub.return_value
self.pubsub_matcher.timeout = 0.1
with self.assertRaisesRegex(AssertionError, r'Expected 1.*\n.*Got 0'):
with self.assertRaisesRegexp(AssertionError, r'Expected 1.*\n.*Got 0'):
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertTrue(mock_sub.pull.called)
self.assertEqual(mock_sub.acknowledge.call_count, 0)

def test_message_count_matcher_success(self, mock_get_sub, unused_mock):
self.init_counter_matcher(expected_msg_len=15)
mock_sub = mock_get_sub.return_value
mock_sub.pull.side_effect = [
create_pull_response(
[PullResponseMessage(b'a', {'foo': 'bar'}) for _ in range(15)])
]
hc_assert_that(self.mock_presult, self.pubsub_matcher)
self.assertEqual(mock_sub.pull.call_count, 1)
self.assertEqual(mock_sub.acknowledge.call_count, 1)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
unittest.main()
Loading