Skip to content

Commit

Permalink
Cleanup handling of KAFKA_VERSION env var in tests (#1887)
Browse files Browse the repository at this point in the history
Now that we are using `pytest`, there is no need for a custom decorator
because we can use `pytest.mark.skipif()`.

This makes the code significantly simpler. In particular, dropping the
custom `@kafka_versions()` decorator is necessary because it uses
`func.wraps()` which doesn't play nice with `pytest` fixtures:
- pytest-dev/pytest#677
- https://stackoverflow.com/a/19614807/770425

So this is a pre-requisite to migrating some of those tests to using
pytest fixtures.
  • Loading branch information
jeffwidman authored Aug 22, 2019
1 parent e49caeb commit 98c0058
Show file tree
Hide file tree
Showing 10 changed files with 65 additions and 142 deletions.
14 changes: 4 additions & 10 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@

import pytest

from test.fixtures import KafkaFixture, ZookeeperFixture, random_string, version as kafka_version


@pytest.fixture(scope="module")
def version():
"""Return the Kafka version set in the OS environment"""
return kafka_version()

from test.testutil import env_kafka_version, random_string
from test.fixtures import KafkaFixture, ZookeeperFixture

@pytest.fixture(scope="module")
def zookeeper():
Expand All @@ -26,9 +20,9 @@ def kafka_broker(kafka_broker_factory):


@pytest.fixture(scope="module")
def kafka_broker_factory(version, zookeeper):
def kafka_broker_factory(zookeeper):
"""Return a Kafka broker fixture factory"""
assert version, 'KAFKA_VERSION must be specified to run integration tests'
assert env_kafka_version(), 'KAFKA_VERSION must be specified to run integration tests'

_brokers = []
def factory(**broker_params):
Expand Down
25 changes: 3 additions & 22 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import logging
import os
import os.path
import random
import socket
import string
import subprocess
import time
import uuid
Expand All @@ -19,29 +17,12 @@
from kafka.client_async import KafkaClient
from kafka.protocol.admin import CreateTopicsRequest
from kafka.protocol.metadata import MetadataRequest
from test.testutil import env_kafka_version, random_string
from test.service import ExternalService, SpawnedService

log = logging.getLogger(__name__)


def random_string(length):
return "".join(random.choice(string.ascii_letters) for i in range(length))


def version_str_to_tuple(version_str):
"""Transform a version string into a tuple.
Example: '0.8.1.1' --> (0, 8, 1, 1)
"""
return tuple(map(int, version_str.split('.')))


def version():
if 'KAFKA_VERSION' not in os.environ:
return ()
return version_str_to_tuple(os.environ['KAFKA_VERSION'])


def get_open_port():
sock = socket.socket()
sock.bind(("", 0))
Expand Down Expand Up @@ -477,7 +458,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
num_partitions == self.partitions and \
replication_factor == self.replicas:
self._send_request(MetadataRequest[0]([topic_name]))
elif version() >= (0, 10, 1, 0):
elif env_kafka_version() >= (0, 10, 1, 0):
request = CreateTopicsRequest[0]([(topic_name, num_partitions,
replication_factor, [], [])], timeout_ms)
result = self._send_request(request, timeout=timeout_ms)
Expand All @@ -497,7 +478,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
'--replication-factor', self.replicas \
if replication_factor is None \
else replication_factor)
if version() >= (0, 10):
if env_kafka_version() >= (0, 10):
args.append('--if-not-exists')
env = self.kafka_run_class_env()
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand Down
6 changes: 4 additions & 2 deletions test/test_client_integration.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os

import pytest

from kafka.errors import KafkaTimeoutError
from kafka.protocol import create_message
from kafka.structs import (
FetchRequestPayload, OffsetCommitRequestPayload, OffsetFetchRequestPayload,
ProduceRequestPayload)

from test.fixtures import ZookeeperFixture, KafkaFixture
from test.testutil import KafkaIntegrationTestCase, kafka_versions
from test.testutil import KafkaIntegrationTestCase, env_kafka_version


class TestKafkaClientIntegration(KafkaIntegrationTestCase):
Expand Down Expand Up @@ -80,7 +82,7 @@ def test_send_produce_request_maintains_request_response_order(self):
# Offset Tests #
####################

@kafka_versions('>=0.8.1')
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_commit_fetch_offsets(self):
req = OffsetCommitRequestPayload(self.topic, 0, 42, 'metadata')
(resp,) = self.client.send_offset_commit_request('group', [req])
Expand Down
2 changes: 1 addition & 1 deletion test/test_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
lz4_encode_old_kafka, lz4_decode_old_kafka,
)

from test.fixtures import random_string
from test.testutil import random_string


def test_gzip():
Expand Down
18 changes: 8 additions & 10 deletions test/test_consumer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
from kafka.coordinator.base import MemberState
from kafka.structs import TopicPartition

from test.fixtures import random_string, version
from test.testutil import env_kafka_version, random_string


def get_connect_str(kafka_broker):
return kafka_broker.host + ':' + str(kafka_broker.port)


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, topic, version):
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, topic):
# The `topic` fixture is included because
# 0.8.2 brokers need a topic to function well
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
Expand All @@ -29,17 +29,16 @@ def test_consumer(kafka_broker, topic, version):
assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED
consumer.close()

@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer_topics(kafka_broker, topic, version):
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_consumer_topics(kafka_broker, topic):
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
# Necessary to drive the IO
consumer.poll(500)
assert topic in consumer.topics()
assert len(consumer.partitions_for_topic(topic)) > 0
consumer.close()

@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
def test_group(kafka_broker, topic):
num_partitions = 4
connect_str = get_connect_str(kafka_broker)
Expand Down Expand Up @@ -129,7 +128,7 @@ def consumer_thread(i):
threads[c] = None


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_paused(kafka_broker, topic):
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
topics = [TopicPartition(topic, 1)]
Expand All @@ -148,8 +147,7 @@ def test_paused(kafka_broker, topic):
consumer.close()


@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
def test_heartbeat_thread(kafka_broker, topic):
group_id = 'test-group-' + random_string(6)
consumer = KafkaConsumer(topic,
Expand Down
42 changes: 21 additions & 21 deletions test/test_consumer_integration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import logging
import os
import time
from mock import patch
import pytest
import kafka.codec

from mock import patch
import pytest
from kafka.vendor.six.moves import range
from kafka.vendor import six
from kafka.vendor.six.moves import range

from . import unittest
from kafka import (
KafkaConsumer, MultiProcessConsumer, SimpleConsumer, create_message,
create_gzip_message, KafkaProducer
)
import kafka.codec
from kafka.consumer.base import MAX_FETCH_BUFFER_SIZE_BYTES
from kafka.errors import (
ConsumerFetchSizeTooSmall, OffsetOutOfRangeError, UnsupportedVersionError,
Expand All @@ -23,11 +22,11 @@
ProduceRequestPayload, TopicPartition, OffsetAndTimestamp
)

from test.fixtures import ZookeeperFixture, KafkaFixture, random_string, version
from test.testutil import KafkaIntegrationTestCase, kafka_versions, Timer
from test.fixtures import ZookeeperFixture, KafkaFixture
from test.testutil import KafkaIntegrationTestCase, Timer, env_kafka_version, random_string


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer(kafka_producer, topic, kafka_consumer_factory):
"""Test KafkaConsumer"""
kafka_consumer = kafka_consumer_factory(auto_offset_reset='earliest')
Expand All @@ -54,7 +53,7 @@ def test_kafka_consumer(kafka_producer, topic, kafka_consumer_factory):
kafka_consumer.close()


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer_unsupported_encoding(
topic, kafka_producer_factory, kafka_consumer_factory):
# Send a compressed message
Expand Down Expand Up @@ -211,7 +210,7 @@ def test_simple_consumer_no_reset(self):
with self.assertRaises(OffsetOutOfRangeError):
consumer.get_message()

@kafka_versions('>=0.8.1')
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_simple_consumer_load_initial_offsets(self):
self.send_messages(0, range(0, 100))
self.send_messages(1, range(100, 200))
Expand Down Expand Up @@ -388,7 +387,7 @@ def test_multi_proc_pending(self):
consumer.stop()

@unittest.skip('MultiProcessConsumer deprecated and these tests are flaky')
@kafka_versions('>=0.8.1')
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_multi_process_consumer_load_initial_offsets(self):
self.send_messages(0, range(0, 10))
self.send_messages(1, range(10, 20))
Expand Down Expand Up @@ -459,7 +458,7 @@ def test_huge_messages(self):

big_consumer.stop()

@kafka_versions('>=0.8.1')
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_offset_behavior__resuming_behavior(self):
self.send_messages(0, range(0, 100))
self.send_messages(1, range(100, 200))
Expand Down Expand Up @@ -491,7 +490,7 @@ def test_offset_behavior__resuming_behavior(self):
consumer2.stop()

@unittest.skip('MultiProcessConsumer deprecated and these tests are flaky')
@kafka_versions('>=0.8.1')
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_multi_process_offset_behavior__resuming_behavior(self):
self.send_messages(0, range(0, 100))
self.send_messages(1, range(100, 200))
Expand Down Expand Up @@ -548,6 +547,7 @@ def test_fetch_buffer_size(self):
messages = [ message for message in consumer ]
self.assertEqual(len(messages), 2)

@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer__blocking(self):
TIMEOUT_MS = 500
consumer = self.kafka_consumer(auto_offset_reset='earliest',
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_kafka_consumer__blocking(self):
self.assertGreaterEqual(t.interval, TIMEOUT_MS / 1000.0 )
consumer.close()

@kafka_versions('>=0.8.1')
@pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1")
def test_kafka_consumer__offset_commit_resume(self):
GROUP_ID = random_string(10)

Expand All @@ -605,7 +605,7 @@ def test_kafka_consumer__offset_commit_resume(self):
output_msgs1 = []
for _ in range(180):
m = next(consumer1)
output_msgs1.append(m)
output_msgs1.append((m.key, m.value))
self.assert_message_count(output_msgs1, 180)
consumer1.close()

Expand All @@ -621,12 +621,12 @@ def test_kafka_consumer__offset_commit_resume(self):
output_msgs2 = []
for _ in range(20):
m = next(consumer2)
output_msgs2.append(m)
output_msgs2.append((m.key, m.value))
self.assert_message_count(output_msgs2, 20)
self.assertEqual(len(set(output_msgs1) | set(output_msgs2)), 200)
consumer2.close()

@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_max_bytes_simple(self):
self.send_messages(0, range(100, 200))
self.send_messages(1, range(200, 300))
Expand All @@ -647,7 +647,7 @@ def test_kafka_consumer_max_bytes_simple(self):
TopicPartition(self.topic, 0), TopicPartition(self.topic, 1)]))
consumer.close()

@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_max_bytes_one_msg(self):
# We send to only 1 partition so we don't have parallel requests to 2
# nodes for data.
Expand All @@ -673,7 +673,7 @@ def test_kafka_consumer_max_bytes_one_msg(self):
self.assertEqual(len(fetched_msgs), 10)
consumer.close()

@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_for_time(self):
late_time = int(time.time()) * 1000
middle_time = late_time - 1000
Expand Down Expand Up @@ -727,7 +727,7 @@ def test_kafka_consumer_offsets_for_time(self):
})
consumer.close()

@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_search_many_partitions(self):
tp0 = TopicPartition(self.topic, 0)
tp1 = TopicPartition(self.topic, 1)
Expand Down Expand Up @@ -766,15 +766,15 @@ def test_kafka_consumer_offsets_search_many_partitions(self):
})
consumer.close()

@kafka_versions('<0.10.1')
@pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1")
def test_kafka_consumer_offsets_for_time_old(self):
consumer = self.kafka_consumer()
tp = TopicPartition(self.topic, 0)

with self.assertRaises(UnsupportedVersionError):
consumer.offsets_for_times({tp: int(time.time())})

@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_for_times_errors(self):
consumer = self.kafka_consumer(fetch_max_wait_ms=200,
request_timeout_ms=500)
Expand Down
4 changes: 2 additions & 2 deletions test/test_failover_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from kafka.producer.base import Producer
from kafka.structs import TopicPartition

from test.fixtures import ZookeeperFixture, KafkaFixture, random_string
from test.testutil import KafkaIntegrationTestCase
from test.fixtures import ZookeeperFixture, KafkaFixture
from test.testutil import KafkaIntegrationTestCase, random_string


log = logging.getLogger(__name__)
Expand Down
Loading

0 comments on commit 98c0058

Please sign in to comment.