Skip to content

Commit

Permalink
Cleanup handling of KAFKA_VERSION env var in tests
Browse files Browse the repository at this point in the history
Now that we are using `pytest`, there is no need for a custom decorator
because we can use `pytest.mark.skipif()`.

This makes the code significantly simpler.
  • Loading branch information
jeffwidman committed Aug 22, 2019
1 parent 928cdda commit 9756d51
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 119 deletions.
13 changes: 4 additions & 9 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,8 @@

import pytest

from test.fixtures import KafkaFixture, ZookeeperFixture, random_string, version as kafka_version


@pytest.fixture(scope="module")
def version():
"""Return the Kafka version set in the OS environment"""
return kafka_version()
from test.fixtures import KafkaFixture, ZookeeperFixture, random_string
from test.testutil import env_kafka_version


@pytest.fixture(scope="module")
Expand All @@ -28,9 +23,9 @@ def kafka_broker(kafka_broker_factory):


@pytest.fixture(scope="module")
def kafka_broker_factory(version, zookeeper):
def kafka_broker_factory(zookeeper):
"""Return a Kafka broker fixture factory"""
assert version, 'KAFKA_VERSION must be specified to run integration tests'
assert env_kafka_version(), 'KAFKA_VERSION must be specified to run integration tests'

_brokers = []
def factory(**broker_params):
Expand Down
19 changes: 3 additions & 16 deletions test/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from kafka.protocol.admin import CreateTopicsRequest
from kafka.protocol.metadata import MetadataRequest
from test.service import ExternalService, SpawnedService
from test.testutil import env_kafka_version

log = logging.getLogger(__name__)

Expand All @@ -28,20 +29,6 @@ def random_string(length):
return "".join(random.choice(string.ascii_letters) for i in range(length))


def version_str_to_tuple(version_str):
"""Transform a version string into a tuple.
Example: '0.8.1.1' --> (0, 8, 1, 1)
"""
return tuple(map(int, version_str.split('.')))


def version():
if 'KAFKA_VERSION' not in os.environ:
return ()
return version_str_to_tuple(os.environ['KAFKA_VERSION'])


def get_open_port():
sock = socket.socket()
sock.bind(("", 0))
Expand Down Expand Up @@ -477,7 +464,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
num_partitions == self.partitions and \
replication_factor == self.replicas:
self._send_request(MetadataRequest[0]([topic_name]))
elif version() >= (0, 10, 1, 0):
elif env_kafka_version() >= (0, 10, 1, 0):
request = CreateTopicsRequest[0]([(topic_name, num_partitions,
replication_factor, [], [])], timeout_ms)
result = self._send_request(request, timeout=timeout_ms)
Expand All @@ -497,7 +484,7 @@ def _create_topic(self, topic_name, num_partitions, replication_factor, timeout_
'--replication-factor', self.replicas \
if replication_factor is None \
else replication_factor)
if version() >= (0, 10):
if env_kafka_version() >= (0, 10):
args.append('--if-not-exists')
env = self.kafka_run_class_env()
proc = subprocess.Popen(args, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
Expand Down
19 changes: 9 additions & 10 deletions test/test_consumer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
from kafka.coordinator.base import MemberState
from kafka.structs import TopicPartition

from test.fixtures import random_string, version
from test.fixtures import random_string
from test.testutil import env_kafka_version


def get_connect_str(kafka_broker):
return kafka_broker.host + ':' + str(kafka_broker.port)


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, topic, version):
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_consumer(kafka_broker, topic):
# The `topic` fixture is included because
# 0.8.2 brokers need a topic to function well
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
Expand All @@ -29,17 +30,16 @@ def test_consumer(kafka_broker, topic, version):
assert consumer._client._conns[node_id].state is ConnectionStates.CONNECTED
consumer.close()

@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
def test_consumer_topics(kafka_broker, topic, version):
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_consumer_topics(kafka_broker, topic):
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
# Necessary to drive the IO
consumer.poll(500)
assert topic in consumer.topics()
assert len(consumer.partitions_for_topic(topic)) > 0
consumer.close()

@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
def test_group(kafka_broker, topic):
num_partitions = 4
connect_str = get_connect_str(kafka_broker)
Expand Down Expand Up @@ -129,7 +129,7 @@ def consumer_thread(i):
threads[c] = None


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_paused(kafka_broker, topic):
consumer = KafkaConsumer(bootstrap_servers=get_connect_str(kafka_broker))
topics = [TopicPartition(topic, 1)]
Expand All @@ -148,8 +148,7 @@ def test_paused(kafka_broker, topic):
consumer.close()


@pytest.mark.skipif(version() < (0, 9), reason='Unsupported Kafka Version')
@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(env_kafka_version() < (0, 9), reason='Unsupported Kafka Version')
def test_heartbeat_thread(kafka_broker, topic):
group_id = 'test-group-' + random_string(6)
consumer = KafkaConsumer(topic,
Expand Down
24 changes: 12 additions & 12 deletions test/test_consumer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
)
from kafka.structs import TopicPartition, OffsetAndTimestamp

from test.fixtures import random_string, version
from test.testutil import kafka_versions, Timer, assert_message_count
from test.fixtures import random_string
from test.testutil import env_kafka_version, Timer, assert_message_count


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer(kafka_consumer_factory, send_messages):
"""Test KafkaConsumer"""
consumer = kafka_consumer_factory(auto_offset_reset='earliest')
Expand All @@ -35,7 +35,7 @@ def test_kafka_consumer(kafka_consumer_factory, send_messages):
assert len(messages[1]) == 100


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer_unsupported_encoding(
topic, kafka_producer_factory, kafka_consumer_factory):
# Send a compressed message
Expand All @@ -53,7 +53,7 @@ def test_kafka_consumer_unsupported_encoding(
consumer.poll(timeout_ms=2000)


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
TIMEOUT_MS = 500
consumer = kafka_consumer_factory(auto_offset_reset='earliest',
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_kafka_consumer__blocking(kafka_consumer_factory, topic, send_messages):
assert t.interval >= (TIMEOUT_MS / 1000.0)


@kafka_versions('>=0.8.1')
@pytest.mark.skipif(env_kafka_version() < (0, 8, 1), reason="Requires KAFKA_VERSION >= 0.8.1")
def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messages):
GROUP_ID = random_string(10)

Expand Down Expand Up @@ -131,7 +131,7 @@ def test_kafka_consumer__offset_commit_resume(kafka_consumer_factory, send_messa
assert len(set(output_msgs1) | set(output_msgs2)) == 200


@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_messages):
send_messages(range(100, 200), partition=0)
send_messages(range(200, 300), partition=1)
Expand All @@ -150,7 +150,7 @@ def test_kafka_consumer_max_bytes_simple(kafka_consumer_factory, topic, send_mes
assert seen_partitions == {TopicPartition(topic, 0), TopicPartition(topic, 1)}


@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages):
# We send to only 1 partition so we don't have parallel requests to 2
# nodes for data.
Expand All @@ -176,7 +176,7 @@ def test_kafka_consumer_max_bytes_one_msg(kafka_consumer_factory, send_messages)
assert len(fetched_msgs) == 10


@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer):
late_time = int(time.time()) * 1000
middle_time = late_time - 1000
Expand Down Expand Up @@ -225,7 +225,7 @@ def test_kafka_consumer_offsets_for_time(topic, kafka_consumer, kafka_producer):
assert offsets == {tp: late_msg.offset + 1}


@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_producer, topic):
tp0 = TopicPartition(topic, 0)
tp1 = TopicPartition(topic, 1)
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_kafka_consumer_offsets_search_many_partitions(kafka_consumer, kafka_pro
}


@kafka_versions('<0.10.1')
@pytest.mark.skipif(env_kafka_version() >= (0, 10, 1), reason="Requires KAFKA_VERSION < 0.10.1")
def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic):
consumer = kafka_consumer
tp = TopicPartition(topic, 0)
Expand All @@ -272,7 +272,7 @@ def test_kafka_consumer_offsets_for_time_old(kafka_consumer, topic):
consumer.offsets_for_times({tp: int(time.time())})


@kafka_versions('>=0.10.1')
@pytest.mark.skipif(env_kafka_version() < (0, 10, 1), reason="Requires KAFKA_VERSION >= 0.10.1")
def test_kafka_consumer_offsets_for_times_errors(kafka_consumer_factory, topic):
consumer = kafka_consumer_factory(fetch_max_wait_ms=200,
request_timeout_ms=500)
Expand Down
11 changes: 6 additions & 5 deletions test/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from kafka import KafkaConsumer, KafkaProducer, TopicPartition
from kafka.producer.buffer import SimpleBufferPool
from test.fixtures import random_string, version
from test.fixtures import random_string
from test.testutil import env_kafka_version


def test_buffer_pool():
Expand All @@ -22,13 +23,13 @@ def test_buffer_pool():
assert buf2.read() == b''


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
def test_end_to_end(kafka_broker, compression):

if compression == 'lz4':
# LZ4 requires 0.8.2
if version() < (0, 8, 2):
if env_kafka_version() < (0, 8, 2):
return
# python-lz4 crashes on older versions of pypy
elif platform.python_implementation() == 'PyPy':
Expand Down Expand Up @@ -80,7 +81,7 @@ def test_kafka_producer_gc_cleanup():
assert threading.active_count() == threads


@pytest.mark.skipif(not version(), reason="No KAFKA_VERSION set")
@pytest.mark.skipif(not env_kafka_version(), reason="No KAFKA_VERSION set")
@pytest.mark.parametrize("compression", [None, 'gzip', 'snappy', 'lz4'])
def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
connect_str = ':'.join([kafka_broker.host, str(kafka_broker.port)])
Expand All @@ -91,7 +92,7 @@ def test_kafka_producer_proper_record_metadata(kafka_broker, compression):
magic = producer._max_usable_produce_magic()

# record headers are supported in 0.11.0
if version() < (0, 11, 0):
if env_kafka_version() < (0, 11, 0):
headers = None
else:
headers = [("Header Key", b"Header Value")]
Expand Down
74 changes: 7 additions & 67 deletions test/testutil.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,17 @@
from __future__ import absolute_import

import functools
import operator
import os
import time

import pytest

from test.fixtures import version_str_to_tuple, version as kafka_version
def env_kafka_version():
"""Return the Kafka version set in the OS environment as a tuple.

def kafka_versions(*versions):
Example: '0.8.1.1' --> (0, 8, 1, 1)
"""
Describe the Kafka versions this test is relevant to.
The versions are passed in as strings, for example:
'0.11.0'
'>=0.10.1.0'
'>0.9', '<1.0' # since this accepts multiple versions args
The current KAFKA_VERSION will be evaluated against this version. If the
result is False, then the test is skipped. Similarly, if KAFKA_VERSION is
not set the test is skipped.
Note: For simplicity, this decorator accepts Kafka versions as strings even
though the similarly functioning `api_version` only accepts tuples. Trying
to convert it to tuples quickly gets ugly due to mixing operator strings
alongside version tuples. While doable when one version is passed in, it
isn't pretty when multiple versions are passed in.
"""

def construct_lambda(s):
if s[0].isdigit():
op_str = '='
v_str = s
elif s[1].isdigit():
op_str = s[0] # ! < > =
v_str = s[1:]
elif s[2].isdigit():
op_str = s[0:2] # >= <=
v_str = s[2:]
else:
raise ValueError('Unrecognized kafka version / operator: %s' % (s,))

op_map = {
'=': operator.eq,
'!': operator.ne,
'>': operator.gt,
'<': operator.lt,
'>=': operator.ge,
'<=': operator.le
}
op = op_map[op_str]
version = version_str_to_tuple(v_str)
return lambda a: op(a, version)

validators = map(construct_lambda, versions)

def real_kafka_versions(func):
@functools.wraps(func)
def wrapper(func, *args, **kwargs):
version = kafka_version()

if not version:
pytest.skip("no kafka version set in KAFKA_VERSION env var")

for f in validators:
if not f(version):
pytest.skip("unsupported kafka version")

return func(*args, **kwargs)
return wrapper

return real_kafka_versions
if 'KAFKA_VERSION' not in os.environ:
return ()
return tuple(map(int, os.environ['KAFKA_VERSION'].split('.')))


def assert_message_count(messages, num_messages):
Expand Down

0 comments on commit 9756d51

Please sign in to comment.