diff --git a/kafka/consumer/group.py b/kafka/consumer/group.py index b3e182c5d..3195b1b4a 100644 --- a/kafka/consumer/group.py +++ b/kafka/consumer/group.py @@ -552,11 +552,9 @@ def committed(self, partition): committed = None return committed - def topics(self): - """Get all topics the user is authorized to view. - - Returns: - set: topics + def _fetch_all_topic_metadata(self): + """A blocking call that fetches topic metadata for all topics in the + cluster that the user is authorized to view. """ cluster = self._client.cluster if self._client._metadata_refresh_in_progress and self._client._topics: @@ -567,10 +565,24 @@ def topics(self): future = cluster.request_update() self._client.poll(future=future) cluster.need_all_topic_metadata = stash - return cluster.topics() + + def topics(self): + """Get all topics the user is authorized to view. + This will always issue a remote call to the cluster to fetch the latest + information. + + Returns: + set: topics + """ + self._fetch_all_topic_metadata() + return self._client.cluster.topics() def partitions_for_topic(self, topic): - """Get metadata about the partitions for a given topic. + """This method first checks the local metadata cache for information + about the topic. If the topic is not found (either because the topic + does not exist, the user is not authorized to view the topic, or the + metadata cache is not populated), then it will issue a metadata update + call to the cluster. Arguments: topic (str): Topic to check. @@ -578,7 +590,12 @@ def partitions_for_topic(self, topic): Returns: set: Partition ids """ - return self._client.cluster.partitions_for_topic(topic) + cluster = self._client.cluster + partitions = cluster.partitions_for_topic(topic) + if partitions is None: + self._fetch_all_topic_metadata() + partitions = cluster.partitions_for_topic(topic) + return partitions def poll(self, timeout_ms=0, max_records=None): """Fetch data from assigned topics / partitions. diff --git a/test/test_consumer_group.py b/test/test_consumer_group.py index d7aaa8896..ec2685765 100644 --- a/test/test_consumer_group.py +++ b/test/test_consumer_group.py @@ -29,6 +29,15 @@ def test_consumer(kafka_broker, topic, version): assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED consumer.close() +@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set") +def test_consumer_topics(kafka_broker, topic, version): + consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker)) + # Necessary to drive the IO + consumer.poll(500) + consumer_topics = consumer.topics() + assert topic in consumer_topics + assert len(consumer.partitions_for_topic(topic)) > 0 + consumer.close() @pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version') @pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")