diff --git a/pulsar/__init__.py b/pulsar/__init__.py index b1b8a9d..2a2b32c 100644 --- a/pulsar/__init__.py +++ b/pulsar/__init__.py @@ -127,7 +127,7 @@ def value(self): """ Returns object with the de-serialized version of the message content """ - return self._schema.decode(self._message.data()) + return self._schema.decode_message(self._message) def properties(self): """ @@ -812,6 +812,7 @@ def my_listener(consumer, message): c._client = self c._schema = schema + c._schema.attach_client(self._client) self._consumers.append(c) return c @@ -913,6 +914,7 @@ def my_listener(reader, message): c._reader = self._client.create_reader(topic, start_message_id, conf) c._client = self c._schema = schema + c._schema.attach_client(self._client) self._consumers.append(c) return c diff --git a/pulsar/schema/schema.py b/pulsar/schema/schema.py index f062c2e..b50a1fe 100644 --- a/pulsar/schema/schema.py +++ b/pulsar/schema/schema.py @@ -38,9 +38,15 @@ def encode(self, obj): def decode(self, data): pass + def decode_message(self, msg: _pulsar.Message): + return self.decode(msg.data()) + def schema_info(self): return self._schema_info + def attach_client(self, client: _pulsar.Client): + self._client = client + def _validate_object_type(self, obj): if not isinstance(obj, self._record_cls): raise TypeError('Invalid record obj of type ' + str(type(obj)) diff --git a/pulsar/schema/schema_avro.py b/pulsar/schema/schema_avro.py index 3e629fb..70fda98 100644 --- a/pulsar/schema/schema_avro.py +++ b/pulsar/schema/schema_avro.py @@ -19,6 +19,8 @@ import _pulsar import io +import json +import logging import enum from . import Record @@ -40,6 +42,8 @@ def __init__(self, record_cls, schema_definition=None): self._schema = record_cls.schema() else: self._schema = schema_definition + self._writer_schemas = dict() + self._logger = logging.getLogger() super(AvroSchema, self).__init__(record_cls, _pulsar.SchemaType.AVRO, self._schema, 'AVRO') def _get_serialized_value(self, x): @@ -76,8 +80,47 @@ def encode_dict(self, d): return obj def decode(self, data): + return self._decode_bytes(data, self._schema) + + def decode_message(self, msg: _pulsar.Message): + if self._client is None: + return self.decode(msg.data()) + topic = msg.topic_name() + version = msg.int_schema_version() + try: + writer_schema = self._get_writer_schema(topic, version) + return self._decode_bytes(msg.data(), writer_schema) + except Exception as e: + self._logger.error('Failed to get schema info of {topic} version {version}: {e}') + return self._decode_bytes(msg.data(), self._schema) + + def _get_writer_schema(self, topic: str, version: int) -> 'dict': + if self._writer_schemas.get(topic) is None: + self._writer_schemas[topic] = dict() + writer_schema = self._writer_schemas[topic].get(version) + if writer_schema is not None: + return writer_schema + if self._client is None: + return self._schema + + self._logger.info('Downloading schema of %s version %d...', topic, version) + info = self._client.get_schema_info(topic, version) + self._logger.info('Downloaded schema of %s version %d', topic, version) + if info.schema_type() != _pulsar.SchemaType.AVRO: + raise RuntimeError(f'The schema type of topic "{topic}" and version {version}' + f' is {info.schema_type()}') + writer_schema = json.loads(info.schema()) + self._writer_schemas[topic][version] = writer_schema + return writer_schema + + def _decode_bytes(self, data: bytes, writer_schema: dict): buffer = io.BytesIO(data) - d = fastavro.schemaless_reader(buffer, self._schema) + # If the record names are different between the writer schema and the reader schema, + # schemaless_reader will fail with fastavro._read_common.SchemaResolutionError. + # So we make the record name fields consistent here. + reader_schema: dict = self._schema + writer_schema['name'] = reader_schema['name'] + d = fastavro.schemaless_reader(buffer, writer_schema, reader_schema) if self._record_cls is not None: return self._record_cls(**d) else: diff --git a/src/client.cc b/src/client.cc index 0103309..626ff9f 100644 --- a/src/client.cc +++ b/src/client.cc @@ -58,6 +58,12 @@ std::vector Client_getTopicPartitions(Client& client, const std::st [&](GetPartitionsCallback callback) { client.getPartitionsForTopicAsync(topic, callback); }); } +SchemaInfo Client_getSchemaInfo(Client& client, const std::string& topic, int64_t version) { + return waitForAsyncValue([&](std::function callback) { + client.getSchemaInfoAsync(topic, version, callback); + }); +} + void Client_close(Client& client) { waitForAsyncResult([&](ResultCallback callback) { client.closeAsync(callback); }); } @@ -71,6 +77,7 @@ void export_client(py::module_& m) { .def("subscribe_pattern", &Client_subscribe_pattern) .def("create_reader", &Client_createReader) .def("get_topic_partitions", &Client_getTopicPartitions) + .def("get_schema_info", &Client_getSchemaInfo) .def("close", &Client_close) .def("shutdown", &Client::shutdown); } diff --git a/src/message.cc b/src/message.cc index 6e8dd3f..895209f 100644 --- a/src/message.cc +++ b/src/message.cc @@ -98,6 +98,7 @@ void export_message(py::module_& m) { }) .def("topic_name", &Message::getTopicName, return_value_policy::copy) .def("redelivery_count", &Message::getRedeliveryCount) + .def("int_schema_version", &Message::getLongSchemaVersion) .def("schema_version", &Message::getSchemaVersion, return_value_policy::copy); MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const std::string& payload, diff --git a/tests/schema_test.py b/tests/schema_test.py index 47acc30..3e6e9c6 100755 --- a/tests/schema_test.py +++ b/tests/schema_test.py @@ -18,6 +18,10 @@ # under the License. # +import math +import logging +import requests +from typing import List from unittest import TestCase, main import fastavro @@ -27,6 +31,9 @@ import json from fastavro.schema import load_schema +logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)-5s %(message)s') + class SchemaTest(TestCase): @@ -1287,5 +1294,57 @@ class SomeSchema(Record): with self.assertRaises(TypeError) as e: SomeSchema(some_field=["not", "integer"]) self.assertEqual(str(e.exception), "Array field some_field items should all be of type int") + + def test_schema_evolve(self): + class User1(Record): + name = String() + age = Integer() + + class User2(Record): + _sorted_fields = True + name = String() + age = Integer(required=True) + + response = requests.put('http://localhost:8080/admin/v2/namespaces/' + 'public/default/schemaCompatibilityStrategy', + data='"FORWARD"'.encode(), + headers={'Content-Type': 'application/json'}) + self.assertEqual(response.status_code, 204) + + topic = 'schema-test-schema-evolve-2' + client = pulsar.Client(self.serviceUrl) + producer1 = client.create_producer(topic, schema=AvroSchema(User1)) + consumer = client.subscribe(topic, 'sub', schema=AvroSchema(User1)) + reader = client.create_reader(topic, + schema=AvroSchema(User1), + start_message_id=pulsar.MessageId.earliest) + producer2 = client.create_producer(topic, schema=AvroSchema(User2)) + + num_messages = 10 * 2 + for i in range(int(num_messages / 2)): + producer1.send(User1(age=i+100, name=f'User1 {i}')) + producer2.send(User2(age=i+200, name=f'User2 {i}')) + + def verify_messages(msgs: List[pulsar.Message]): + for i, msg in enumerate(msgs): + value = msg.value() + index = math.floor(i / 2) + if i % 2 == 0: + self.assertEqual(value.age, index + 100) + self.assertEqual(value.name, f'User1 {index}') + else: + self.assertEqual(value.age, index + 200) + self.assertEqual(value.name, f'User2 {index}') + + msgs1 = [] + msgs2 = [] + for i in range(num_messages): + msgs1.append(consumer.receive()) + msgs2.append(reader.read_next(1000)) + verify_messages(msgs1) + verify_messages(msgs2) + + client.close() + if __name__ == '__main__': main()