Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pulsar/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def value(self):
"""
Returns object with the de-serialized version of the message content
"""
return self._schema.decode(self._message.data())
return self._schema.decode_message(self._message)

def properties(self):
"""
Expand Down Expand Up @@ -812,6 +812,7 @@ def my_listener(consumer, message):

c._client = self
c._schema = schema
c._schema.attach_client(self._client)
self._consumers.append(c)
return c

Expand Down Expand Up @@ -913,6 +914,7 @@ def my_listener(reader, message):
c._reader = self._client.create_reader(topic, start_message_id, conf)
c._client = self
c._schema = schema
c._schema.attach_client(self._client)
self._consumers.append(c)
return c

Expand Down
6 changes: 6 additions & 0 deletions pulsar/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,15 @@ def encode(self, obj):
def decode(self, data):
pass

def decode_message(self, msg: _pulsar.Message):
return self.decode(msg.data())

def schema_info(self):
return self._schema_info

def attach_client(self, client: _pulsar.Client):
self._client = client

def _validate_object_type(self, obj):
if not isinstance(obj, self._record_cls):
raise TypeError('Invalid record obj of type ' + str(type(obj))
Expand Down
45 changes: 44 additions & 1 deletion pulsar/schema/schema_avro.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import _pulsar
import io
import json
import logging
import enum

from . import Record
Expand All @@ -40,6 +42,8 @@ def __init__(self, record_cls, schema_definition=None):
self._schema = record_cls.schema()
else:
self._schema = schema_definition
self._writer_schemas = dict()
self._logger = logging.getLogger()
super(AvroSchema, self).__init__(record_cls, _pulsar.SchemaType.AVRO, self._schema, 'AVRO')

def _get_serialized_value(self, x):
Expand Down Expand Up @@ -76,8 +80,47 @@ def encode_dict(self, d):
return obj

def decode(self, data):
return self._decode_bytes(data, self._schema)

def decode_message(self, msg: _pulsar.Message):
if self._client is None:
return self.decode(msg.data())
topic = msg.topic_name()
version = msg.int_schema_version()
try:
writer_schema = self._get_writer_schema(topic, version)
return self._decode_bytes(msg.data(), writer_schema)
except Exception as e:
self._logger.error('Failed to get schema info of {topic} version {version}: {e}')
return self._decode_bytes(msg.data(), self._schema)

def _get_writer_schema(self, topic: str, version: int) -> 'dict':
if self._writer_schemas.get(topic) is None:
self._writer_schemas[topic] = dict()
writer_schema = self._writer_schemas[topic].get(version)
if writer_schema is not None:
return writer_schema
if self._client is None:
return self._schema

self._logger.info('Downloading schema of %s version %d...', topic, version)
info = self._client.get_schema_info(topic, version)
self._logger.info('Downloaded schema of %s version %d', topic, version)
if info.schema_type() != _pulsar.SchemaType.AVRO:
raise RuntimeError(f'The schema type of topic "{topic}" and version {version}'
f' is {info.schema_type()}')
writer_schema = json.loads(info.schema())
self._writer_schemas[topic][version] = writer_schema
return writer_schema

def _decode_bytes(self, data: bytes, writer_schema: dict):
buffer = io.BytesIO(data)
d = fastavro.schemaless_reader(buffer, self._schema)
# If the record names are different between the writer schema and the reader schema,
# schemaless_reader will fail with fastavro._read_common.SchemaResolutionError.
# So we make the record name fields consistent here.
reader_schema: dict = self._schema
writer_schema['name'] = reader_schema['name']
d = fastavro.schemaless_reader(buffer, writer_schema, reader_schema)
if self._record_cls is not None:
return self._record_cls(**d)
else:
Expand Down
7 changes: 7 additions & 0 deletions src/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ std::vector<std::string> Client_getTopicPartitions(Client& client, const std::st
[&](GetPartitionsCallback callback) { client.getPartitionsForTopicAsync(topic, callback); });
}

SchemaInfo Client_getSchemaInfo(Client& client, const std::string& topic, int64_t version) {
return waitForAsyncValue<SchemaInfo>([&](std::function<void(Result, const SchemaInfo&)> callback) {
client.getSchemaInfoAsync(topic, version, callback);
});
}

void Client_close(Client& client) {
waitForAsyncResult([&](ResultCallback callback) { client.closeAsync(callback); });
}
Expand All @@ -71,6 +77,7 @@ void export_client(py::module_& m) {
.def("subscribe_pattern", &Client_subscribe_pattern)
.def("create_reader", &Client_createReader)
.def("get_topic_partitions", &Client_getTopicPartitions)
.def("get_schema_info", &Client_getSchemaInfo)
.def("close", &Client_close)
.def("shutdown", &Client::shutdown);
}
1 change: 1 addition & 0 deletions src/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ void export_message(py::module_& m) {
})
.def("topic_name", &Message::getTopicName, return_value_policy::copy)
.def("redelivery_count", &Message::getRedeliveryCount)
.def("int_schema_version", &Message::getLongSchemaVersion)
.def("schema_version", &Message::getSchemaVersion, return_value_policy::copy);

MessageBatch& (MessageBatch::*MessageBatchParseFromString)(const std::string& payload,
Expand Down
59 changes: 59 additions & 0 deletions tests/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
# under the License.
#

import math
import logging
import requests
from typing import List
from unittest import TestCase, main

import fastavro
Expand All @@ -27,6 +31,9 @@
import json
from fastavro.schema import load_schema

logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)-5s %(message)s')


class SchemaTest(TestCase):

Expand Down Expand Up @@ -1287,5 +1294,57 @@ class SomeSchema(Record):
with self.assertRaises(TypeError) as e:
SomeSchema(some_field=["not", "integer"])
self.assertEqual(str(e.exception), "Array field some_field items should all be of type int")

def test_schema_evolve(self):
class User1(Record):
name = String()
age = Integer()

class User2(Record):
_sorted_fields = True
name = String()
age = Integer(required=True)

response = requests.put('http://localhost:8080/admin/v2/namespaces/'
'public/default/schemaCompatibilityStrategy',
data='"FORWARD"'.encode(),
headers={'Content-Type': 'application/json'})
self.assertEqual(response.status_code, 204)

topic = 'schema-test-schema-evolve-2'
client = pulsar.Client(self.serviceUrl)
producer1 = client.create_producer(topic, schema=AvroSchema(User1))
consumer = client.subscribe(topic, 'sub', schema=AvroSchema(User1))
reader = client.create_reader(topic,
schema=AvroSchema(User1),
start_message_id=pulsar.MessageId.earliest)
producer2 = client.create_producer(topic, schema=AvroSchema(User2))

num_messages = 10 * 2
for i in range(int(num_messages / 2)):
producer1.send(User1(age=i+100, name=f'User1 {i}'))
producer2.send(User2(age=i+200, name=f'User2 {i}'))

def verify_messages(msgs: List[pulsar.Message]):
for i, msg in enumerate(msgs):
value = msg.value()
index = math.floor(i / 2)
if i % 2 == 0:
self.assertEqual(value.age, index + 100)
self.assertEqual(value.name, f'User1 {index}')
else:
self.assertEqual(value.age, index + 200)
self.assertEqual(value.name, f'User2 {index}')

msgs1 = []
msgs2 = []
for i in range(num_messages):
msgs1.append(consumer.receive())
msgs2.append(reader.read_next(1000))
verify_messages(msgs1)
verify_messages(msgs2)

client.close()

if __name__ == '__main__':
main()