From 0b25489fe6d70179fde8ec0c6073c6be1ec6e47f Mon Sep 17 00:00:00 2001 From: Riccardo Magliocchetti Date: Mon, 1 Jul 2024 11:01:38 +0200 Subject: [PATCH] instrumentation/kafka: fix handling consumer iteration if transaction not sampled Handle the case where if the transaction is not sampled capture_span will return None instead of span. While at it fix handling of checking for KAFKA_HOST in tests. Fix #2073 --- elasticapm/instrumentation/packages/kafka.py | 3 ++- tests/instrumentation/kafka_tests.py | 24 +++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/elasticapm/instrumentation/packages/kafka.py b/elasticapm/instrumentation/packages/kafka.py index c3bc2d64d..ab9ebd1a4 100644 --- a/elasticapm/instrumentation/packages/kafka.py +++ b/elasticapm/instrumentation/packages/kafka.py @@ -143,7 +143,8 @@ def call(self, module, method, wrapped, instance, args, kwargs): try: result = wrapped(*args, **kwargs) except StopIteration: - span.cancel() + if span: + span.cancel() raise if span and not isinstance(span, DroppedSpan): topic = result[0] diff --git a/tests/instrumentation/kafka_tests.py b/tests/instrumentation/kafka_tests.py index 71416c130..0bfc5c496 100644 --- a/tests/instrumentation/kafka_tests.py +++ b/tests/instrumentation/kafka_tests.py @@ -45,11 +45,10 @@ pytestmark = [pytest.mark.kafka] -if "KAFKA_HOST" not in os.environ: +KAFKA_HOST = os.environ.get("KAFKA_HOST") +if not KAFKA_HOST: pytestmark.append(pytest.mark.skip("Skipping kafka tests, no KAFKA_HOST environment variable set")) -KAFKA_HOST = os.environ["KAFKA_HOST"] - @pytest.fixture(scope="function") def topics(): @@ -233,3 +232,22 @@ def test_kafka_poll_unsampled_transaction(instrument, elasticapm_client, consume elasticapm_client.end_transaction("foo") spans = elasticapm_client.events[SPAN] assert len(spans) == 0 + + +def test_kafka_consumer_unsampled_transaction_handles_stop_iteration( + instrument, elasticapm_client, producer, consumer, topics +): + def delayed_send(): + time.sleep(0.2) + producer.send("test", key=b"foo", value=b"bar") + + thread = threading.Thread(target=delayed_send) + thread.start() + transaction = elasticapm_client.begin_transaction("foo") + transaction.is_sampled = False + for item in consumer: + pass + thread.join() + elasticapm_client.end_transaction("foo") + spans = elasticapm_client.events[SPAN] + assert len(spans) == 0