From 092540a7f90d7667bfff367506abd17b9556b569 Mon Sep 17 00:00:00 2001 From: Yunze Xu Date: Mon, 22 May 2023 17:47:52 +0800 Subject: [PATCH] Fetch writer schema to decode Avro messages Fixes https://github.com/apache/pulsar-client-python/issues/108 ### Motivation Currently the Python client uses the reader schema, which is the schema of the consumer, to decode Avro messages. However, when the writer schema is different from the reader schema, the decode will fail. ### Modifications Add `attach_client` method to `Schema` and call it when creating consumers and readers. This method stores a reference to a `_pulsar.Client` instance, which leverages the C++ APIs added in https://github.com/apache/pulsar-client-cpp/pull/257 to fetch schema info. The `AvroSchema` class fetches and caches the writer schema if it is not cached, then use both the writer schema and reader schema to decode messages. Add `test_schema_evolve` to test consumers or readers can decode any message whose writer schema is different with the reader schema. --- pulsar/__init__.py | 4 ++- pulsar/schema/schema.py | 6 ++++ pulsar/schema/schema_avro.py | 45 ++++++++++++++++++++++++++- src/client.cc | 7 +++++ src/message.cc | 1 + tests/schema_test.py | 59 ++++++++++++++++++++++++++++++++++++ 6 files changed, 120 insertions(+), 2 deletions(-) diff --git a/pulsar/__init__.py b/pulsar/__init__.py index b1b8a9d..2a2b32c 100644 --- a/pulsar/__init__.py +++ b/pulsar/__init__.py @@ -127,7 +127,7 @@ def value(self): """ Returns object with the de-serialized version of the message content """ - return self._schema.decode(self._message.data()) + return self._schema.decode_message(self._message) def properties(self): """ @@ -812,6 +812,7 @@ def my_listener(consumer, message): c._client = self c._schema = schema + c._schema.attach_client(self._client) self._consumers.append(c) return c @@ -913,6 +914,7 @@ def my_listener(reader, message): c._reader = self._client.create_reader(topic, start_message_id, conf) c._client = self c._schema = schema + c._schema.attach_client(self._client) self._consumers.append(c) return c diff --git a/pulsar/schema/schema.py b/pulsar/schema/schema.py index f062c2e..b50a1fe 100644 --- a/pulsar/schema/schema.py +++ b/pulsar/schema/schema.py @@ -38,9 +38,15 @@ def encode(self, obj): def decode(self, data): pass + def decode_message(self, msg: _pulsar.Message): + return self.decode(msg.data()) + def schema_info(self): return self._schema_info + def attach_client(self, client: _pulsar.Client): + self._client = client + def _validate_object_type(self, obj): if not isinstance(obj, self._record_cls): raise TypeError('Invalid record obj of type ' + str(type(obj)) diff --git a/pulsar/schema/schema_avro.py b/pulsar/schema/schema_avro.py index 3e629fb..70fda98 100644 --- a/pulsar/schema/schema_avro.py +++ b/pulsar/schema/schema_avro.py @@ -19,6 +19,8 @@ import _pulsar import io +import json +import logging import enum from . import Record @@ -40,6 +42,8 @@ def __init__(self, record_cls, schema_definition=None): self._schema = record_cls.schema() else: self._schema = schema_definition + self._writer_schemas = dict() + self._logger = logging.getLogger() super(AvroSchema, self).__init__(record_cls, _pulsar.SchemaType.AVRO, self._schema, 'AVRO') def _get_serialized_value(self, x): @@ -76,8 +80,47 @@ def encode_dict(self, d): return obj def decode(self, data): + return self._decode_bytes(data, self._schema) + + def decode_message(self, msg: _pulsar.Message): + if self._client is None: + return self.decode(msg.data()) + topic = msg.topic_name() + version = msg.int_schema_version() + try: + writer_schema = self._get_writer_schema(topic, version) + return self._decode_bytes(msg.data(), writer_schema) + except Exception as e: + self._logger.error('Failed to get schema info of {topic} version {version}: {e}') + return self._decode_bytes(msg.data(), self._schema) + + def _get_writer_schema(self, topic: str, version: int) -> 'dict': + if self._writer_schemas.get(topic) is None: + self._writer_schemas[topic] = dict() + writer_schema = self._writer_schemas[topic].get(version) + if writer_schema is not None: + return writer_schema + if self._client is None: + return self._schema + + self._logger.info('Downloading schema of %s version %d...', topic, version) + info = self._client.get_schema_info(topic, version) + self._logger.info('Downloaded schema of %s version %d', topic, version) + if info.schema_type() != _pulsar.SchemaType.AVRO: + raise RuntimeError(f'The schema type of topic "{topic}" and version {version}' + f' is {info.schema_type()}') + writer_schema = json.loads(info.schema()) + self._writer_schemas[topic][version] = writer_schema + return writer_schema + + def _decode_bytes(self, data: bytes, writer_schema: dict): buffer = io.BytesIO(data) - d = fastavro.schemaless_reader(buffer, self._schema) + # If the record names are different between the writer schema and the reader schema, + # schemaless_reader will fail with fastavro._read_common.SchemaResolutionError. + # So we make the record name fields consistent here. + reader_schema: dict = self._schema + writer_schema['name'] = reader_schema['name'] + d = fastavro.schemaless_reader(buffer, writer_schema, reader_schema) if self._record_cls is not None: return self._record_cls(**d) else: diff --git a/src/client.cc b/src/client.cc index 0103309..626ff9f 100644 --- a/src/client.cc +++ b/src/client.cc @@ -58,6 +58,12 @@ std::vector Client_getTopicPartitions(Client& client, const std::st [&](GetPartitionsCallback callback) { client.getPartitionsForTopicAsync(topic, callback); }); } +SchemaInfo Client_getSchemaInfo(Client& client, const std::string& topic, int64_t version) { + return waitForAsyncValue([&](std::function callback) { + client.getSchemaInfoAsync(topic, version, callback); + }); +} + void Client_close(Client& client) { waitForAsyncResult([&](ResultCallback callback) { client.closeAsync(callback); }); } @@ -71,6 +77,7 @@ void export_client(py::module_& m) { .def("subscribe_pattern", &Client_subscribe_pattern) .def("create_reader", &Client_createReader) .def("get_topic_partitions", &Client_getTopicPartitions) + .def("get_schema_info", &Client_getSchemaInfo) .def("close", &Client_close) .def("shutdown", &Client::shutdown); } diff --git a/src/message.cc b/src/message.cc index 6e8dd3f..895209f 100644 --- a/src/message.cc +++ b/src/message.cc @@ -98,6 +98,7 @@ void export_message(py::module_& m) { }) .def("topic_name", &Message::getTopicName, return_value_policy::copy) .def("redelivery_count", &Message::getRedeliveryCount) + .def("int_schema_version", &Message::getLongSchemaVersion) .def("schema_version", &Message::getSchemaVersion, return_value_policy::copy); MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const std::string& payload, diff --git a/tests/schema_test.py b/tests/schema_test.py index 47acc30..3e6e9c6 100755 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -18,6 +18,10 @@ # under the License. # +import math +import logging +import requests +from typing import List from unittest import TestCase, main import fastavro @@ -27,6 +31,9 @@ import json from fastavro.schema import load_schema +logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)-5s %(message)s') + class SchemaTest(TestCase): @@ -1287,5 +1294,57 @@ class SomeSchema(Record): with self.assertRaises(TypeError) as e: SomeSchema(some_field=["not", "integer"]) self.assertEqual(str(e.exception), "Array field some_field items should all be of type int") + + def test_schema_evolve(self): + class User1(Record): + name = String() + age = Integer() + + class User2(Record): + _sorted_fields = True + name = String() + age = Integer(required=True) + + response = requests.put('http://localhost:8080/admin/v2/namespaces/' + 'public/default/schemaCompatibilityStrategy', + data='"FORWARD"'.encode(), + headers={'Content-Type': 'application/json'}) + self.assertEqual(response.status_code, 204) + + topic = 'schema-test-schema-evolve-2' + client = pulsar.Client(self.serviceUrl) + producer1 = client.create_producer(topic, schema=AvroSchema(User1)) + consumer = client.subscribe(topic, 'sub', schema=AvroSchema(User1)) + reader = client.create_reader(topic, + schema=AvroSchema(User1), + start_message_id=pulsar.MessageId.earliest) + producer2 = client.create_producer(topic, schema=AvroSchema(User2)) + + num_messages = 10 * 2 + for i in range(int(num_messages / 2)): + producer1.send(User1(age=i+100, name=f'User1 {i}')) + producer2.send(User2(age=i+200, name=f'User2 {i}')) + + def verify_messages(msgs: List[pulsar.Message]): + for i, msg in enumerate(msgs): + value = msg.value() + index = math.floor(i / 2) + if i % 2 == 0: + self.assertEqual(value.age, index + 100) + self.assertEqual(value.name, f'User1 {index}') + else: + self.assertEqual(value.age, index + 200) + self.assertEqual(value.name, f'User2 {index}') + + msgs1 = [] + msgs2 = [] + for i in range(num_messages): + msgs1.append(consumer.receive()) + msgs2.append(reader.read_next(1000)) + verify_messages(msgs1) + verify_messages(msgs2) + + client.close() + if __name__ == '__main__': main()