diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index b35bbaf404cc5..06e159172ab51 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -254,6 +254,16 @@ def __init__(self, topic, partition): def _jTopicAndPartition(self, helper): return helper.createTopicAndPartition(self._topic, self._partition) + def __eq__(self, other): + if isinstance(other, self.__class__): + return (self._topic == other._topic + and self._partition == other._partition) + else: + return False + + def __ne__(self, other): + return not self.__eq__(other) + class Broker(object): """ diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 2c908daa8b214..f7fa481d50235 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -898,6 +898,16 @@ def transformWithOffsetRanges(rdd): self.assertEqual(offsetRanges, [OffsetRange(topic, 0, long(0), long(6))]) + def test_topic_and_partition_equality(self): + topic_and_partition_a = TopicAndPartition("foo", 0) + topic_and_partition_b = TopicAndPartition("foo", 0) + topic_and_partition_c = TopicAndPartition("bar", 0) + topic_and_partition_d = TopicAndPartition("foo", 1) + + self.assertEqual(topic_and_partition_a, topic_and_partition_b) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_c) + self.assertNotEqual(topic_and_partition_a, topic_and_partition_d) + class FlumeStreamTests(PySparkStreamingTestCase): timeout = 20 # seconds