diff --git a/src/confluent_kafka/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/__init__.py index e4ad4be17..d6b9fd197 100644 --- a/src/confluent_kafka/schema_registry/__init__.py +++ b/src/confluent_kafka/schema_registry/__init__.py @@ -18,44 +18,44 @@ from typing import Optional from .schema_registry_client import ( - ConfigCompatibilityLevel, - Metadata, - MetadataProperties, - MetadataTags, - RegisteredSchema, - Rule, - RuleKind, - RuleMode, - RuleParams, - RuleSet, - Schema, - SchemaRegistryClient, - SchemaRegistryError, - SchemaReference, - ServerConfig + ConfigCompatibilityLevel, + Metadata, + MetadataProperties, + MetadataTags, + RegisteredSchema, + Rule, + RuleKind, + RuleMode, + RuleParams, + RuleSet, + Schema, + SchemaRegistryClient, + SchemaRegistryError, + SchemaReference, + ServerConfig ) _MAGIC_BYTE = 0 __all__ = [ - "ConfigCompatibilityLevel", - "Metadata", - "MetadataProperties", - "MetadataTags", - "RegisteredSchema", - "Rule", - "RuleKind", - "RuleMode", - "RuleParams", - "RuleSet", - "Schema", - "SchemaRegistryClient", - "SchemaRegistryError", - "SchemaReference", - "ServerConfig", - "topic_subject_name_strategy", - "topic_record_subject_name_strategy", - "record_subject_name_strategy" + "ConfigCompatibilityLevel", + "Metadata", + "MetadataProperties", + "MetadataTags", + "RegisteredSchema", + "Rule", + "RuleKind", + "RuleMode", + "RuleParams", + "RuleSet", + "Schema", + "SchemaRegistryClient", + "SchemaRegistryError", + "SchemaReference", + "ServerConfig", + "topic_subject_name_strategy", + "topic_record_subject_name_strategy", + "record_subject_name_strategy" ] diff --git a/tests/integration/schema_registry/__init__.py b/src/confluent_kafka/schema_registry/_sync/__init__.py similarity index 100% rename from tests/integration/schema_registry/__init__.py rename to src/confluent_kafka/schema_registry/_sync/__init__.py diff --git a/src/confluent_kafka/schema_registry/_sync/avro.py b/src/confluent_kafka/schema_registry/_sync/avro.py new file mode 100644 index 000000000..57c792a18 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/avro.py @@ -0,0 +1,582 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from json import loads +from struct import pack, unpack +from typing import Dict, Union, Optional, Callable + +from fastavro import schemaless_reader, schemaless_writer + +from confluent_kafka.schema_registry.common.avro import AvroSchema, _schema_loads, \ + get_inline_tags, parse_schema_with_repo, transform, _ContextStringIO + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, + SchemaRegistryClient) +from confluent_kafka.serialization import (SerializationError, + SerializationContext) +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache + +__all__ = [ + '_resolve_named_schema', + 'AvroSerializer', + 'AvroDeserializer', +] + + +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient +) -> Dict[str, AvroSchema]: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :return: named_schemas dict. + """ + named_schemas = {} + if schema.references is not None: + for ref in schema.references: + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) + parsed_schema = parse_schema_with_repo( + referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) + named_schemas.update(ref_named_schemas) + named_schemas[ref.name] = parsed_schema + return named_schemas + + +class AvroSerializer(BaseSerializer): + """ + Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+==================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Note: + Prior to serialization, all values must first be converted to + a dict instance. This may handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with AvroSerializer. + + See ``avro_producer.py`` in the examples directory for example usage. + + Note: + Tuple notation can be used to determine which branch of an ambiguous union to take. + + See `fastavro notation `_ + + Args: + schema_registry_client (SchemaRegistryClient): Schema Registry client instance. + + schema_str (str or Schema): + Avro `Schema Declaration. `_ + Accepts either a string or a :py:class:`Schema` instance. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. + + conf (dict): AvroSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_schema', + '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + def __init__( + self, + schema_registry_client: SchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + schema = None + + self._registry = schema_registry_client + self._schema_id = None + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + parsed_schema = self._get_parsed_schema(schema) + + if isinstance(parsed_schema, list): + # if parsed_schema is a list, we have an Avro union and there + # is no valid schema name. This is fine because the only use of + # schema_name is for supplying the subject name to the registry + # and union types should use topic_subject_name_strategy, which + # just discards the schema name anyway + schema_name = None + else: + # The Avro spec states primitives have a name equal to their type + # i.e. {"type": "string"} has a name of string. + # This function does not comply. + # https://github.com/fastavro/fastavro/issues/415 + schema_dict = loads(schema.schema_str) + schema_name = parsed_schema.get("name", schema_dict.get("type")) + else: + schema_name = None + parsed_schema = None + + self._schema = schema + self._schema_name = schema_name + self._parsed_schema = parsed_schema + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to Avro binary format, prepending it with Confluent + Schema Registry framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata pertaining to the serialization operation. + + Raises: + SerializerError: If any error occurs serializing obj. + SchemaRegistryError: If there was an error registering the schema with + Schema Registry, or auto.register.schemas is + false and the schema was not registered. + + Returns: + bytes: Confluent Schema Registry encoded Avro bytes + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = self._registry.register_schema( + subject, self._schema, self._normalize_schemas) + else: + registered_schema = self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + parsed_schema = self._get_parsed_schema(latest_schema.schema) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, get_inline_tags(parsed_schema), + field_transformer) + else: + parsed_schema = self._parsed_schema + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the record to the rest of the buffer + schemaless_writer(fo, parsed_schema, value) + + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema + + +class AvroDeserializer(BaseDeserializer): + """ + Deserializer for Avro binary encoded data with Confluent Schema Registry + framing. + + +-----------------------------+----------+--------------------------------------------------+ + | Property Name | Type | Description | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+--------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+--------------------------------------------------+ + + Note: + By default, Avro complex types are returned as dicts. This behavior can + be overridden by registering a callable ``from_dict`` with the deserializer to + convert the dicts to the desired type. + + See ``avro_consumer.py`` in the examples directory in the examples + directory for example usage. + + Args: + schema_registry_client (SchemaRegistryClient): Confluent Schema Registry + client instance. + + schema_str (str, Schema, optional): Avro reader schema declaration Accepts + either a string or a :py:class:`Schema` instance. If not provided, the + writer schema will be used as the reader schema. Note that string + definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to an instance of some object. + + return_record_name (bool): If True, when reading a union of records, the result will + be a tuple where the first value is the name of the record and the second value is + the record itself. Defaults to False. + + See Also: + `Apache Avro Schema Declaration `_ + + `Apache Avro Schema Resolution `_ + """ + + __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', + '_schema', '_parsed_schemas'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy} + + def __init__( + self, + schema_registry_client: SchemaRegistryClient, + schema_str: Union[str, Schema, None] = None, + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + return_record_name: bool = False, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + schema = None + if schema_str is not None: + if isinstance(schema_str, str): + schema = _schema_loads(schema_str) + elif isinstance(schema_str, Schema): + schema = schema_str + else: + raise TypeError('You must pass either schema string or schema object') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema = self._get_parsed_schema(self._schema) + else: + self._reader_schema = None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature " + "from_dict(SerializationContext, dict) -> object") + self._from_dict = from_dict + + self._return_record_name = return_record_name + if not isinstance(self._return_record_name, bool): + raise ValueError("return_record_name must be a boolean value") + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + return self.__deserialize(data, ctx) + + def __deserialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: + """ + Deserialize Avro binary encoded data with Confluent Schema Registry framing to + a dict, or object instance according to from_dict, if specified. + + Arguments: + data (bytes): bytes + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError: if an error occurs parsing data. + + Returns: + object: If data is None, then None. Else, a dict, or object instance according + to from_dict, if specified. + """ # noqa: E501 + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) if ctx else None + latest_schema = None + if subject is not None: + latest_schema = self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + writer_schema_raw = self._registry.get_schema(schema_id) + writer_schema = self._get_parsed_schema(writer_schema_raw) + + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None + if subject is not None: + latest_schema = self._get_reader_schema(subject) + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema = self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema = self._reader_schema + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if migrations: + obj_dict = schemaless_reader(payload, + writer_schema, + None, + self._return_record_name) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + else: + obj_dict = schemaless_reader(payload, + writer_schema, + reader_schema, + self._return_record_name) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, get_inline_tags(reader_schema), + field_transformer) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + def _get_parsed_schema(self, schema: Schema) -> AvroSchema: + parsed_schema = self._parsed_schemas.get_parsed_schema(schema) + if parsed_schema is not None: + return parsed_schema + + named_schemas = _resolve_named_schema(schema, self._registry) + prepared_schema = _schema_loads(schema.schema_str) + parsed_schema = parse_schema_with_repo( + prepared_schema.schema_str, named_schemas=named_schemas) + + self._parsed_schemas.set(schema, parsed_schema) + return parsed_schema diff --git a/src/confluent_kafka/schema_registry/_sync/json_schema.py b/src/confluent_kafka/schema_registry/_sync/json_schema.py new file mode 100644 index 000000000..b0c8815c2 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/json_schema.py @@ -0,0 +1,648 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import struct +from typing import Union, Optional, Tuple, Callable + +from cachetools import LRUCache +from jsonschema import ValidationError +from jsonschema.protocols import Validator +from jsonschema.validators import validator_for +from referencing import Registry, Resource + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + Schema, + topic_subject_name_strategy, + RuleMode, SchemaRegistryClient) + +from confluent_kafka.schema_registry.common.json_schema import ( + DEFAULT_SPEC, JsonSchema, _retrieve_via_httpx, transform, _ContextStringIO +) +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, \ + ParsedSchemaCache +from confluent_kafka.serialization import (SerializationError, + SerializationContext) + +__all__ = [ + '_resolve_named_schema', + 'JSONSerializer', + 'JSONDeserializer' +] + + +def _resolve_named_schema( + schema: Schema, schema_registry_client: SchemaRegistryClient, + ref_registry: Optional[Registry] = None +) -> Registry: + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param ref_registry: Registry of named schemas resolved recursively. + :return: Registry + """ + if ref_registry is None: + # Retrieve external schemas for backward compatibility + ref_registry = Registry(retrieve=_retrieve_via_httpx) + if schema.references is not None: + for ref in schema.references: + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) + ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) + referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) + resource = Resource.from_contents( + referenced_schema_dict, default_specification=DEFAULT_SPEC) + ref_registry = ref_registry.with_resource(ref.name, resource) + return ref_registry + + +class JSONSerializer(BaseSerializer): + """ + Serializer that outputs JSON encoded data with Confluent Schema Registry framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies: + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Notes: + The ``title`` annotation, referred to elsewhere as a record name + is not strictly required by the JSON Schema specification. It is + however required by this serializer in order to register the schema + with Confluent Schema Registry. + + Prior to serialization, all objects must first be converted to + a dict instance. This may be handled manually prior to calling + :py:func:`Producer.produce()` or by registering a `to_dict` + callable with JSONSerializer. + + Args: + schema_str (str, Schema): + `JSON Schema definition. `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For + referencing other schemas, use a :py:class:`Schema` instance. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + to_dict (callable, optional): Callable(object, SerializationContext) -> dict. + Converts object to a dict. + + conf (dict): JsonSerializer configuration. + """ # noqa: E501 + __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', + '_schema', '_schema_id', '_schema_name', '_to_dict', + '_parsed_schemas', '_validators', '_validate', '_json_encode'] + + _default_conf = {'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + def __init__( + self, + schema_str: Union[str, Schema, None], + schema_registry_client: SchemaRegistryClient, + to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_encode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + self._schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + self._schema = schema_str + else: + self._schema = None + + self._json_encode = json_encode or json.dumps + self._registry = schema_registry_client + self._rule_registry = ( + rule_registry if rule_registry else RuleRegistry.get_global_instance() + ) + self._schema_id = None + self._known_subjects = set() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + + if to_dict is not None and not callable(to_dict): + raise ValueError("to_dict must be callable with the signature " + "to_dict(object, SerializationContext)->dict") + + self._to_dict = to_dict + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + schema_dict, ref_registry = self._get_parsed_schema(self._schema) + if schema_dict: + schema_name = schema_dict.get('title', None) + else: + schema_name = None + + self._schema_name = schema_name + self._parsed_schema = schema_dict + self._ref_registry = ref_registry + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(obj, ctx) + + def __serialize(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an object to JSON, prepending it with Confluent Schema Registry + framing. + + Args: + obj (object): The object instance to serialize. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Raises: + SerializerError if any error occurs serializing obj. + + Returns: + bytes: None if obj is None, else a byte array containing the JSON + serialized data with Confluent Schema Registry framing. + """ + + if obj is None: + return None + + subject = self._subject_name_func(ctx, self._schema_name) + latest_schema = self._get_reader_schema(subject) + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects: + # Check to ensure this schema has been registered under subject_name. + if self._auto_register: + # The schema name will always be the same. We can't however register + # a schema without a subject so we set the schema_id here to handle + # the initial registration. + self._schema_id = self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + registered_schema = self._registry.lookup_schema(subject, + self._schema, + self._normalize_schemas) + self._schema_id = registered_schema.schema_id + + self._known_subjects.add(subject) + + if self._to_dict is not None: + value = self._to_dict(obj, ctx) + else: + value = obj + + if latest_schema is not None: + schema = latest_schema.schema + parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) + root_resource = Resource.from_contents( + parsed_schema, default_specification=DEFAULT_SPEC) + ref_resolver = ref_registry.resolver_with_root(root_resource) + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) + value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, value, None, + field_transformer) + else: + schema = self._schema + parsed_schema, ref_registry = self._parsed_schema, self._ref_registry + + if self._validate: + try: + validator = self._get_validator(schema, parsed_schema, ref_registry) + validator.validate(value) + except ValidationError as ve: + raise SerializationError(ve.message) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order (big endian) + fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) + # JSON dump always writes a str never bytes + # https://docs.python.org/3/library/json.html + encoded_value = self._json_encode(value) + if isinstance(encoded_value, str): + encoded_value = encoded_value.encode("utf8") + fo.write(encoded_value) + + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator + + +class JSONDeserializer(BaseDeserializer): + """ + Deserializer for JSON encoded data with Confluent Schema Registry + framing. + + Configuration properties: + + +-----------------------------+----------+----------------------------------------------------+ + | Property Name | Type | Description | + +=============================+==========+====================================================+ + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata``| dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-----------------------------+----------+----------------------------------------------------+ + | | | Whether to validate the payload against the | + | ``validate`` | bool | the given schema. | + | | | | + +-----------------------------+----------+----------------------------------------------------+ + + Args: + schema_str (str, Schema, optional): + `JSON schema definition `_ + Accepts schema as either a string or a :py:class:`Schema` instance. + Note that string definitions cannot reference other schemas. For referencing other schemas, + use a :py:class:`Schema` instance. If not provided, schemas will be + retrieved from schema_registry_client based on the schema ID in the + wire header of each message. + + from_dict (callable, optional): Callable(dict, SerializationContext) -> object. + Converts a dict to a Python object instance. + + schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. + """ # noqa: E501 + + __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', + '_parsed_schemas', '_validators', '_validate', '_json_decode'] + + _default_conf = {'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'validate': True} + + def __init__( + self, + schema_str: Union[str, Schema, None], + from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, + schema_registry_client: Optional[SchemaRegistryClient] = None, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None, + json_decode: Optional[Callable] = None, + ): + super().__init__() + if isinstance(schema_str, str): + schema = Schema(schema_str, schema_type="JSON") + elif isinstance(schema_str, Schema): + schema = schema_str + if bool(schema.references) and schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") + elif schema_str is None: + if schema_registry_client is None: + raise ValueError( + """schema_registry_client must be provided if "schema_str" is not provided""" + ) + schema = schema_str + else: + raise TypeError('You must pass either str or Schema') + + self._schema = schema + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._validators = LRUCache(1000) + self._json_decode = json_decode or json.loads + self._use_schema_id = None + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._validate = conf_copy.pop('validate') + if not isinstance(self._validate, bool): + raise ValueError("validate must be a boolean value") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + if schema: + self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) + else: + self._reader_schema, self._ref_registry = None, None + + if from_dict is not None and not callable(from_dict): + raise ValueError("from_dict must be callable with the signature" + " from_dict(dict, SerializationContext) -> object") + + self._from_dict = from_dict + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a JSON encoded record with Confluent Schema Registry framing to + a dict, or object instance according to from_dict if from_dict is specified. + + Args: + data (bytes): A JSON serialized record with Confluent Schema Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization operation. + + Returns: + A dict, or object instance according to from_dict if from_dict is specified. + + Raises: + SerializerError: If there was an error reading the Confluent framing data, or + if ``data`` was not successfully validated with the configured schema. + """ + + if data is None: + return None + + if len(data) <= 5: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = self._get_reader_schema(subject) + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unexpected magic byte {}. This message " + "was not produced with a Confluent " + "Schema Registry serializer".format(magic)) + + # JSON documents are self-describing; no need to query schema + obj_dict = self._json_decode(payload.read()) + + if self._registry is not None: + writer_schema_raw = self._registry.get_schema(schema_id) + writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) + if subject is None: + subject = self._subject_name_func(ctx, writer_schema.get("title")) + if subject is not None: + latest_schema = self._get_reader_schema(subject) + else: + writer_schema_raw = None + writer_schema, writer_ref_registry = None, None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) + elif self._schema is not None: + migrations = None + reader_schema_raw = self._schema + reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema, reader_ref_registry = writer_schema, writer_ref_registry + + if migrations: + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + + reader_root_resource = Resource.from_contents( + reader_schema, default_specification=DEFAULT_SPEC) + reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, + "$", message, field_transform)) + obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, obj_dict, None, + field_transformer) + + if self._validate: + try: + validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) + validator.validate(obj_dict) + except ValidationError as ve: + raise SerializationError(ve.message) + + if self._from_dict is not None: + return self._from_dict(obj_dict, ctx) + + return obj_dict + + def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: + if schema is None: + return None, None + + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + ref_registry = _resolve_named_schema(schema, self._registry) + parsed_schema = json.loads(schema.schema_str) + + self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) + return parsed_schema, ref_registry + + def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: + validator = self._validators.get(schema, None) + if validator is not None: + return validator + + cls = validator_for(parsed_schema) + cls.check_schema(parsed_schema) + validator = cls(parsed_schema, registry=registry) + + self._validators[schema] = validator + return validator diff --git a/src/confluent_kafka/schema_registry/_sync/protobuf.py b/src/confluent_kafka/schema_registry/_sync/protobuf.py new file mode 100644 index 000000000..83e4324aa --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/protobuf.py @@ -0,0 +1,807 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020-2022 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import struct +import warnings +from typing import Set, List, Union, Optional, Tuple + +from google.protobuf import json_format, descriptor_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.protobuf.descriptor import Descriptor, FileDescriptor +from google.protobuf.message import DecodeError, Message +from google.protobuf.message_factory import GetMessageClass + +from confluent_kafka.schema_registry import (_MAGIC_BYTE, + reference_subject_name_strategy, + topic_subject_name_strategy) +from confluent_kafka.schema_registry.schema_registry_client import SchemaRegistryClient +from confluent_kafka.schema_registry.common.protobuf import _bytes, _create_index_array, \ + _init_pool, _is_builtin, _schema_to_str, _str_to_proto, transform, _ContextStringIO +from confluent_kafka.schema_registry.rule_registry import RuleRegistry +from confluent_kafka.schema_registry import (Schema, + SchemaReference, + RuleMode) +from confluent_kafka.serialization import SerializationError, \ + SerializationContext + +from confluent_kafka.schema_registry.serde import BaseSerializer, BaseDeserializer, ParsedSchemaCache + +__all__ = [ + '_resolve_named_schema', + 'ProtobufSerializer', + 'ProtobufDeserializer', +] + + +def _resolve_named_schema( + schema: Schema, + schema_registry_client: SchemaRegistryClient, + pool: DescriptorPool, + visited: Optional[Set[str]] = None +): + """ + Resolves named schemas referenced by the provided schema recursively. + :param schema: Schema to resolve named schemas for. + :param schema_registry_client: SchemaRegistryClient to use for retrieval. + :param pool: DescriptorPool to add resolved schemas to. + :return: DescriptorPool + """ + if visited is None: + visited = set() + if schema.references is not None: + for ref in schema.references: + if _is_builtin(ref.name) or ref.name in visited: + continue + visited.add(ref.name) + referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') + _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) + file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) + pool.Add(file_descriptor_proto) + + +class ProtobufSerializer(BaseSerializer): + """ + Serializer for Protobuf Message derived classes. Serialization format is Protobuf, + with Confluent Schema Registry framing. + + Configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +=====================================+==========+======================================================+ + | | | If True, automatically register the configured | + | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | + | | | not previously been associated with the relevant | + | | | subject (determined via subject.name.strategy). | + | | | | + | | | Defaults to True. | + | | | | + | | | Raises SchemaRegistryError if the schema was not | + | | | registered against the subject, or could not be | + | | | successfully registered. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to normalize schemas, which will | + | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | + | | | including ordering properties and references. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the given schema ID for | + | ``use.schema.id`` | int | serialization. | + | | | | + +-----------------------------------------+----------+--------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | serialization. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | WARNING: There is no check that the latest | + | | | schema is backwards compatible with the object | + | | | being serialized. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether or not to skip known types when resolving | + | ``skip.known.types`` | bool | schema dependencies. | + | | | | + | | | Defaults to True. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka.schema_registry | + | | | namespace. | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | + | | | references are constructed. | + | | | | + | | | Defaults to reference_subject_name_strategy | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | + | | | serialize message indexes without zig-zag encoding. | + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If the consumers of the topic being produced to are | + | | | using confluent-kafka-python <1.8 then this property | + | | | must be set to True until all old consumers have | + | | | have been upgraded. | + | | | | + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + Schemas are registered against subject names in Confluent Schema Registry that + define a scope in which the schemas can be evolved. By default, the subject name + is formed by concatenating the topic name with the message field (key or value) + separated by a hyphen. + + i.e. {topic name}-{message field} + + Alternative naming strategies may be configured with the property + ``subject.name.strategy``. + + Supported subject name strategies + + +--------------------------------------+------------------------------+ + | Subject Name Strategy | Output Format | + +======================================+==============================+ + | topic_subject_name_strategy(default) | {topic name}-{message field} | + +--------------------------------------+------------------------------+ + | topic_record_subject_name_strategy | {topic name}-{record name} | + +--------------------------------------+------------------------------+ + | record_subject_name_strategy | {record name} | + +--------------------------------------+------------------------------+ + + See `Subject name strategy `_ for additional details. + + Args: + msg_type (Message): Protobuf Message type. + + schema_registry_client (SchemaRegistryClient): Schema Registry + client instance. + + conf (dict): ProtobufSerializer configuration. + + See Also: + `Protobuf API reference `_ + """ # noqa: E501 + __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', + '_schema', '_schema_id', '_ref_reference_subject_func', + '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'auto.register.schemas': True, + 'normalize.schemas': False, + 'use.schema.id': None, + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'skip.known.types': True, + 'subject.name.strategy': topic_subject_name_strategy, + 'reference.subject.name.strategy': reference_subject_name_strategy, + 'use.deprecated.format': False, + } + + def __init__( + self, + msg_type: Message, + schema_registry_client: SchemaRegistryClient, + conf: Optional[dict] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufSerializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._auto_register = conf_copy.pop('auto.register.schemas') + if not isinstance(self._auto_register, bool): + raise ValueError("auto.register.schemas must be a boolean value") + + self._normalize_schemas = conf_copy.pop('normalize.schemas') + if not isinstance(self._normalize_schemas, bool): + raise ValueError("normalize.schemas must be a boolean value") + + self._use_schema_id = conf_copy.pop('use.schema.id') + if (self._use_schema_id is not None and + not isinstance(self._use_schema_id, int)): + raise ValueError("use.schema.id must be an int value") + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + if self._use_latest_version and self._auto_register: + raise ValueError("cannot enable both use.latest.version and auto.register.schemas") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._skip_known_types = conf_copy.pop('skip.known.types') + if not isinstance(self._skip_known_types, bool): + raise ValueError("skip.known.types must be a boolean value") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2021 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._ref_reference_subject_func = conf_copy.pop( + 'reference.subject.name.strategy') + if not callable(self._ref_reference_subject_func): + raise ValueError("subject.name.strategy must be callable") + + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._schema_id = None + self._known_subjects = set() + self._msg_class = msg_type + self._parsed_schemas = ParsedSchemaCache() + + descriptor = msg_type.DESCRIPTOR + self._index_array = _create_index_array(descriptor) + self._schema = Schema(_schema_to_str(descriptor.file), + schema_type='PROTOBUF') + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): + """ + Writes val to buf, either using zigzag or uvarint encoding. + + Args: + buf (BytesIO): buffer to write to. + val (int): integer to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + if zigzag: + val = (val << 1) ^ (val >> 63) + + while (val & ~0x7f) != 0: + buf.write(_bytes((val & 0x7f) | 0x80)) + val >>= 7 + buf.write(_bytes(val)) + + @staticmethod + def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): + """ + Encodes each int as a uvarint onto buf + + Args: + buf (BytesIO): buffer to write to. + ints ([int]): ints to be encoded. + zigzag (bool): whether to encode in zigzag or uvarint encoding + """ + + assert len(ints) > 0 + # The root element at the 0 position does not need a length prefix. + if ints == [0]: + buf.write(_bytes(0x00)) + return + + ProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) + + for value in ints: + ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) + + def _resolve_dependencies( + self, ctx: SerializationContext, + file_desc: FileDescriptor + ) -> List[SchemaReference]: + """ + Resolves and optionally registers schema references recursively. + + Args: + ctx (SerializationContext): Serialization context. + + file_desc (FileDescriptor): file descriptor to traverse. + """ + + schema_refs = [] + for dep in file_desc.dependencies: + if self._skip_known_types and _is_builtin(dep.name): + continue + dep_refs = self._resolve_dependencies(ctx, dep) + subject = self._ref_reference_subject_func(ctx, dep) + schema = Schema(_schema_to_str(dep), + references=dep_refs, + schema_type='PROTOBUF') + if self._auto_register: + self._registry.register_schema(subject, schema) + + reference = self._registry.lookup_schema(subject, schema) + # schema_refs are per file descriptor + schema_refs.append(SchemaReference(dep.name, + subject, + reference.version)) + return schema_refs + + def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(message, ctx) + + def __serialize(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Serializes an instance of a class derived from Protobuf Message, and prepends + it with Confluent Schema Registry framing. + + Args: + message (Message): An instance of a class derived from Protobuf Message. + + ctx (SerializationContext): Metadata relevant to the serialization. + operation. + + Raises: + SerializerError if any error occurs during serialization. + + Returns: + None if messages is None, else a byte array containing the Protobuf + serialized message with Confluent Schema Registry framing. + """ + + if message is None: + return None + + if not isinstance(message, self._msg_class): + raise ValueError("message must be of type {} not {}" + .format(self._msg_class, type(message))) + + subject = self._subject_name_func(ctx, message.DESCRIPTOR.full_name) if ctx else None + latest_schema = None + if subject is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + + if latest_schema is not None: + self._schema_id = latest_schema.schema_id + elif subject not in self._known_subjects and ctx is not None: + references = self._resolve_dependencies(ctx, message.DESCRIPTOR.file) + self._schema = Schema( + self._schema.schema_str, + self._schema.schema_type, + references + ) + + if self._auto_register: + self._schema_id = self._registry.register_schema(subject, + self._schema, + self._normalize_schemas) + else: + self._schema_id = self._registry.lookup_schema( + subject, self._schema, self._normalize_schemas).schema_id + + self._known_subjects.add(subject) + + if latest_schema is not None: + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + fd = pool.FindFileByName(fd_proto.name) + desc = fd.message_types_by_name[message.DESCRIPTOR.name] + def field_transformer(rule_ctx, field_transform, msg): return ( # noqa: E731 + transform(rule_ctx, desc, msg, field_transform)) + message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, + latest_schema.schema, message, None, + field_transformer) + + with _ContextStringIO() as fo: + # Write the magic byte and schema ID in network byte order + # (big endian) + fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) + # write the index array that specifies the message descriptor + # of the serialized data. + self._encode_varints(fo, self._index_array, + zigzag=not self._use_deprecated_format) + # write the serialized data itself + fo.write(message.SerializeToString()) + return fo.getvalue() + + def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + +class ProtobufDeserializer(BaseDeserializer): + """ + Deserializer for Protobuf serialized data with Confluent Schema Registry framing. + + Args: + message_type (Message derived type): Protobuf Message type. + conf (dict): Configuration dictionary. + + ProtobufDeserializer configuration properties: + + +-------------------------------------+----------+------------------------------------------------------+ + | Property Name | Type | Description | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version for | + | ``use.latest.version`` | bool | deserialization. | + | | | | + | | | Defaults to False. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Whether to use the latest subject version with | + | ``use.latest.with.metadata`` | dict | the given metadata. | + | | | | + | | | Defaults to None. | + +-------------------------------------+----------+------------------------------------------------------+ + | | | Callable(SerializationContext, str) -> str | + | | | | + | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | + | | | constructed. Standard naming strategies are | + | | | defined in the confluent_kafka. schema_registry | + | | | namespace . | + | | | | + | | | Defaults to topic_subject_name_strategy. | + +-------------------------------------+----------+------------------------------------------------------+ + | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | + | | | deserialize message indexes without zig-zag encoding.| + | | | This option must be explicitly configured as older | + | | | and newer Protobuf producers are incompatible. | + | | | If Protobuf messages in the topic to consume were | + | | | produced with confluent-kafka-python <1.8 then this | + | | | property must be set to True until all old messages | + | | | have been processed and producers have been upgraded.| + | | | Warning: This configuration property will be removed | + | | | in a future version of the client. | + +-------------------------------------+----------+------------------------------------------------------+ + + + See Also: + `Protobuf API reference `_ + """ + + __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] + + _default_conf = { + 'use.latest.version': False, + 'use.latest.with.metadata': None, + 'subject.name.strategy': topic_subject_name_strategy, + 'use.deprecated.format': False, + } + + def __init__( + self, + message_type: Message, + conf: Optional[dict] = None, + schema_registry_client: Optional[SchemaRegistryClient] = None, + rule_conf: Optional[dict] = None, + rule_registry: Optional[RuleRegistry] = None + ): + super().__init__() + + self._registry = schema_registry_client + self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() + self._parsed_schemas = ParsedSchemaCache() + self._use_schema_id = None + + # Require use.deprecated.format to be explicitly configured + # during a transitionary period since old/new format are + # incompatible. + if conf is None or 'use.deprecated.format' not in conf: + raise RuntimeError( + "ProtobufDeserializer: the 'use.deprecated.format' configuration " + "property must be explicitly set due to backward incompatibility " + "with older confluent-kafka-python Protobuf producers and consumers. " + "See the release notes for more details") + + conf_copy = self._default_conf.copy() + if conf is not None: + conf_copy.update(conf) + + self._use_latest_version = conf_copy.pop('use.latest.version') + if not isinstance(self._use_latest_version, bool): + raise ValueError("use.latest.version must be a boolean value") + + self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') + if (self._use_latest_with_metadata is not None and + not isinstance(self._use_latest_with_metadata, dict)): + raise ValueError("use.latest.with.metadata must be a dict value") + + self._subject_name_func = conf_copy.pop('subject.name.strategy') + if not callable(self._subject_name_func): + raise ValueError("subject.name.strategy must be callable") + + self._use_deprecated_format = conf_copy.pop('use.deprecated.format') + if not isinstance(self._use_deprecated_format, bool): + raise ValueError("use.deprecated.format must be a boolean value") + if self._use_deprecated_format: + warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " + "configuration property, and the ability to use the " + "old incorrect Protobuf serializer heading format " + "introduced in confluent-kafka-python v1.4.0, " + "will be removed in an upcoming release in 2022 Q2. " + "Please migrate your Python Protobuf producers and " + "consumers to 'use.deprecated.format':False as " + "soon as possible") + + descriptor = message_type.DESCRIPTOR + self._msg_class = GetMessageClass(descriptor) + + for rule in self._rule_registry.get_executors(): + rule.configure(self._registry.config() if self._registry else {}, + rule_conf if rule_conf else {}) + + @staticmethod + def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: + """ + Decodes a single varint from a buffer. + + Args: + buf (BytesIO): buffer to read from + zigzag (bool): decode as zigzag or uvarint + + Returns: + int: decoded varint + + Raises: + EOFError: if buffer is empty + """ + + value = 0 + shift = 0 + try: + while True: + i = ProtobufDeserializer._read_byte(buf) + + value |= (i & 0x7f) << shift + shift += 7 + if not (i & 0x80): + break + + if zigzag: + value = (value >> 1) ^ -(value & 1) + + return value + + except EOFError: + raise EOFError("Unexpected EOF while reading index") + + @staticmethod + def _read_byte(buf: io.BytesIO) -> int: + """ + Read one byte from buf as an int. + + Args: + buf (BytesIO): The buffer to read from. + + .. _ord: + https://docs.python.org/2/library/functions.html#ord + """ + + i = buf.read(1) + if i == b'': + raise EOFError("Unexpected EOF encountered") + return ord(i) + + @staticmethod + def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: + """ + Read an index array from buf that specifies the message + descriptor of interest in the file descriptor. + + Args: + buf (BytesIO): The buffer to read from. + + Returns: + list of int: The index array. + """ + + size = ProtobufDeserializer._decode_varint(buf, zigzag=zigzag) + if size < 0 or size > 100000: + raise DecodeError("Invalid Protobuf msgidx array length") + + if size == 0: + return [0] + + msg_index = [] + for _ in range(size): + msg_index.append(ProtobufDeserializer._decode_varint(buf, + zigzag=zigzag)) + + return msg_index + + def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + return self.__serialize(data, ctx) + + def __serialize(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: + """ + Deserialize a serialized protobuf message with Confluent Schema Registry + framing. + + Args: + data (bytes): Serialized protobuf message with Confluent Schema + Registry framing. + + ctx (SerializationContext): Metadata relevant to the serialization + operation. + + Returns: + Message: Protobuf Message instance. + + Raises: + SerializerError: If there was an error reading the Confluent framing + data, or parsing the protobuf serialized message. + """ + + if data is None: + return None + + # SR wire protocol + msg_index length + if len(data) < 6: + raise SerializationError("Expecting data framing of length 6 bytes or " + "more but total data size is {} bytes. This " + "message was not produced with a Confluent " + "Schema Registry serializer".format(len(data))) + + subject = self._subject_name_func(ctx, None) + latest_schema = None + if subject is not None and self._registry is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + + with _ContextStringIO(data) as payload: + magic, schema_id = struct.unpack('>bI', payload.read(5)) + if magic != _MAGIC_BYTE: + raise SerializationError("Unknown magic byte. This message was " + "not produced with a Confluent " + "Schema Registry serializer") + + msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) + + if self._registry is not None: + writer_schema_raw = self._registry.get_schema(schema_id, fmt='serialized') + fd_proto, pool = self._get_parsed_schema(writer_schema_raw) + writer_schema = pool.FindFileByName(fd_proto.name) + writer_desc = self._get_message_desc(pool, writer_schema, msg_index) + if subject is None: + subject = self._subject_name_func(ctx, writer_desc.full_name) + if subject is not None: + latest_schema = self._get_reader_schema(subject, fmt='serialized') + else: + writer_schema_raw = None + writer_schema = None + + if latest_schema is not None: + migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) + reader_schema_raw = latest_schema.schema + fd_proto, pool = self._get_parsed_schema(latest_schema.schema) + reader_schema = pool.FindFileByName(fd_proto.name) + else: + migrations = None + reader_schema_raw = writer_schema_raw + reader_schema = writer_schema + + if reader_schema is not None: + # Initialize reader desc to first message in file + reader_desc = self._get_message_desc(pool, reader_schema, [0]) + # Attempt to find a reader desc with the same name as the writer + reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) + + if migrations: + msg = GetMessageClass(writer_desc)() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + obj_dict = json_format.MessageToDict(msg, True) + obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) + msg = GetMessageClass(reader_desc)() + msg = json_format.ParseDict(obj_dict, msg) + else: + # Protobuf Messages are self-describing; no need to query schema + msg = self._msg_class() + try: + msg.ParseFromString(payload.read()) + except DecodeError as e: + raise SerializationError(str(e)) + + def field_transformer(rule_ctx, field_transform, message): return ( # noqa: E731 + transform(rule_ctx, reader_desc, message, field_transform)) + msg = self._execute_rules(ctx, subject, RuleMode.READ, None, + reader_schema_raw, msg, None, + field_transformer) + + return msg + + def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: + result = self._parsed_schemas.get_parsed_schema(schema) + if result is not None: + return result + + pool = DescriptorPool() + _init_pool(pool) + _resolve_named_schema(schema, self._registry, pool) + fd_proto = _str_to_proto("default", schema.schema_str) + pool.Add(fd_proto) + self._parsed_schemas.set(schema, (fd_proto, pool)) + return fd_proto, pool + + def _get_message_desc( + self, pool: DescriptorPool, fd: FileDescriptor, + msg_index: List[int] + ) -> Descriptor: + file_desc_proto = descriptor_pb2.FileDescriptorProto() + fd.CopyToProto(file_desc_proto) + (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) + package = file_desc_proto.package + qualified_name = package + "." + full_name if package else full_name + return pool.FindMessageTypeByName(qualified_name) + + def _get_message_desc_proto( + self, + path: str, + desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], + msg_index: List[int] + ) -> Tuple[str, descriptor_pb2.DescriptorProto]: + index = msg_index[0] + if isinstance(desc, descriptor_pb2.FileDescriptorProto): + msg = desc.message_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) + else: + msg = desc.nested_type[index] + path = path + "." + msg.name if path else msg.name + if len(msg_index) == 1: + return path, msg + return self._get_message_desc_proto(path, msg, msg_index[1:]) diff --git a/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py new file mode 100644 index 000000000..8d259873e --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/schema_registry_client.py @@ -0,0 +1,1114 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import time +import urllib +from urllib.parse import unquote, urlparse + +import httpx +from typing import List, Dict, Optional, Union, Any, Tuple, Callable + +from cachetools import TTLCache, LRUCache +from httpx import Response + +from authlib.integrations.httpx_client import OAuth2Client + +from confluent_kafka.schema_registry.error import SchemaRegistryError, OAuthTokenError +from confluent_kafka.schema_registry.common.schema_registry_client import ( + RegisteredSchema, + ServerConfig, + is_success, + is_retriable, + _BearerFieldProvider, + full_jitter, + _SchemaCache, + Schema, + _StaticFieldProvider, +) + +__all__ = [ + '_urlencode', + '_CustomOAuthClient', + '_OAuthClient', + '_BaseRestClient', + '_RestClient', + 'SchemaRegistryClient', +] + +# TODO: consider adding `six` dependency or employing a compat file +# Python 2.7 is officially EOL so compatibility issue will be come more the norm. +# We need a better way to handle these issues. +# Six is one possibility but the compat file pattern used by requests +# is also quite nice. +# +# six: https://pypi.org/project/six/ +# compat file : https://github.com/psf/requests/blob/master/requests/compat.py +try: + string_type = basestring # noqa + + def _urlencode(value: str) -> str: + return urllib.quote(value, safe='') +except NameError: + string_type = str + + def _urlencode(value: str) -> str: + return urllib.parse.quote(value, safe='') + +log = logging.getLogger(__name__) + + +class _CustomOAuthClient(_BearerFieldProvider): + def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): + self.custom_function = custom_function + self.custom_config = custom_config + + def get_bearer_fields(self) -> dict: + return self.custom_function(self.custom_config) + + +class _OAuthClient(_BearerFieldProvider): + def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, + identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): + self.token = None + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) + self.token_endpoint = token_endpoint + self.max_retries = max_retries + self.retries_wait_ms = retries_wait_ms + self.retries_max_wait_ms = retries_max_wait_ms + self.token_expiry_threshold = 0.8 + + def get_bearer_fields(self) -> dict: + return { + 'bearer.auth.token': self.get_access_token(), + 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool + } + + def token_expired(self) -> bool: + expiry_window = self.token['expires_in'] * self.token_expiry_threshold + + return self.token['expires_at'] < time.time() + expiry_window + + def get_access_token(self) -> str: + if not self.token or self.token_expired(): + self.generate_access_token() + + return self.token['access_token'] + + def generate_access_token(self) -> None: + for i in range(self.max_retries + 1): + try: + self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') + return + except Exception as e: + if i >= self.max_retries: + raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " + f"attempts due to error: {str(e)}") + time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + + +class _BaseRestClient(object): + + def __init__(self, conf: dict): + # copy dict to avoid mutating the original + conf_copy = conf.copy() + + base_url = conf_copy.pop('url', None) + if base_url is None: + raise ValueError("Missing required configuration property url") + if not isinstance(base_url, string_type): + raise TypeError("url must be a str, not " + str(type(base_url))) + base_urls = [] + for url in base_url.split(','): + url = url.strip().rstrip('/') + if not url.startswith('http') and not url.startswith('mock'): + raise ValueError("Invalid url {}".format(url)) + base_urls.append(url) + if not base_urls: + raise ValueError("Missing required configuration property url") + self.base_urls = base_urls + + self.verify = True + ca = conf_copy.pop('ssl.ca.location', None) + if ca is not None: + self.verify = ca + + key: Optional[str] = conf_copy.pop('ssl.key.location', None) + client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) + self.cert: Union[str, Tuple[str, str], None] = None + + if client_cert is not None and key is not None: + self.cert = (client_cert, key) + + if client_cert is not None and key is None: + self.cert = client_cert + + if key is not None and client_cert is None: + raise ValueError("ssl.certificate.location required when" + " configuring ssl.key.location") + + parsed = urlparse(self.base_urls[0]) + try: + userinfo = (unquote(parsed.username), unquote(parsed.password)) + except (AttributeError, TypeError): + userinfo = ("", "") + if 'basic.auth.user.info' in conf_copy: + if userinfo != ('', ''): + raise ValueError("basic.auth.user.info configured with" + " userinfo credentials in the URL." + " Remove userinfo credentials from the url or" + " remove basic.auth.user.info from the" + " configuration") + + userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) + + if len(userinfo) != 2: + raise ValueError("basic.auth.user.info must be in the form" + " of {username}:{password}") + + self.auth = userinfo if userinfo != ('', '') else None + + # The following adds support for proxy config + # If specified: it uses the specified proxy details when making requests + self.proxy = None + proxy = conf_copy.pop('proxy', None) + if proxy is not None: + self.proxy = proxy + + self.timeout = None + timeout = conf_copy.pop('timeout', None) + if timeout is not None: + self.timeout = timeout + + self.cache_capacity = 1000 + cache_capacity = conf_copy.pop('cache.capacity', None) + if cache_capacity is not None: + if not isinstance(cache_capacity, (int, float)): + raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) + self.cache_capacity = cache_capacity + + self.cache_latest_ttl_sec = None + cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) + if cache_latest_ttl_sec is not None: + if not isinstance(cache_latest_ttl_sec, (int, float)): + raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) + self.cache_latest_ttl_sec = cache_latest_ttl_sec + + self.max_retries = 3 + max_retries = conf_copy.pop('max.retries', None) + if max_retries is not None: + if not isinstance(max_retries, (int, float)): + raise TypeError("max.retries must be a number, not " + str(type(max_retries))) + self.max_retries = max_retries + + self.retries_wait_ms = 1000 + retries_wait_ms = conf_copy.pop('retries.wait.ms', None) + if retries_wait_ms is not None: + if not isinstance(retries_wait_ms, (int, float)): + raise TypeError("retries.wait.ms must be a number, not " + + str(type(retries_wait_ms))) + self.retries_wait_ms = retries_wait_ms + + self.retries_max_wait_ms = 20000 + retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) + if retries_max_wait_ms is not None: + if not isinstance(retries_max_wait_ms, (int, float)): + raise TypeError("retries.max.wait.ms must be a number, not " + + str(type(retries_max_wait_ms))) + self.retries_max_wait_ms = retries_max_wait_ms + + self.bearer_field_provider = None + logical_cluster = None + identity_pool = None + self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) + if self.bearer_auth_credentials_source is not None: + self.auth = None + + if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: + headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] + missing_headers = [header for header in headers if header not in conf_copy] + if missing_headers: + raise ValueError("Missing required bearer configuration properties: {}" + .format(", ".join(missing_headers))) + + logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') + if not isinstance(logical_cluster, str): + raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) + + identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') + if not isinstance(identity_pool, str): + raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) + + if self.bearer_auth_credentials_source == 'OAUTHBEARER': + properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', + 'bearer.auth.issuer.endpoint.url'] + missing_properties = [prop for prop in properties_list if prop not in conf_copy] + if missing_properties: + raise ValueError("Missing required OAuth configuration properties: {}". + format(", ".join(missing_properties))) + + self.client_id = conf_copy.pop('bearer.auth.client.id') + if not isinstance(self.client_id, string_type): + raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) + + self.client_secret = conf_copy.pop('bearer.auth.client.secret') + if not isinstance(self.client_secret, string_type): + raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) + + self.scope = conf_copy.pop('bearer.auth.scope') + if not isinstance(self.scope, string_type): + raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) + + self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') + if not isinstance(self.token_endpoint, string_type): + raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " + + str(type(self.token_endpoint))) + + self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, + self.token_endpoint, logical_cluster, identity_pool, + self.max_retries, self.retries_wait_ms, + self.retries_max_wait_ms) + elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': + if 'bearer.auth.token' not in conf_copy: + raise ValueError("Missing bearer.auth.token") + static_token = conf_copy.pop('bearer.auth.token') + self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) + if not isinstance(static_token, string_type): + raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) + elif self.bearer_auth_credentials_source == 'CUSTOM': + custom_bearer_properties = ['bearer.auth.custom.provider.function', + 'bearer.auth.custom.provider.config'] + missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] + if missing_custom_properties: + raise ValueError("Missing required custom OAuth configuration properties: {}". + format(", ".join(missing_custom_properties))) + + custom_function = conf_copy.pop('bearer.auth.custom.provider.function') + if not callable(custom_function): + raise TypeError("bearer.auth.custom.provider.function must be a callable, not " + + str(type(custom_function))) + + custom_config = conf_copy.pop('bearer.auth.custom.provider.config') + if not isinstance(custom_config, dict): + raise TypeError("bearer.auth.custom.provider.config must be a dict, not " + + str(type(custom_config))) + + self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) + else: + raise ValueError('Unrecognized bearer.auth.credentials.source') + + # Any leftover keys are unknown to _RestClient + if len(conf_copy) > 0: + raise ValueError("Unrecognized properties: {}" + .format(", ".join(conf_copy.keys()))) + + def get(self, url: str, query: Optional[dict] = None) -> Any: + raise NotImplementedError() + + def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + raise NotImplementedError() + + def delete(self, url: str) -> Any: + raise NotImplementedError() + + def put(self, url: str, body: Optional[dict] = None) -> Any: + raise NotImplementedError() + + +class _RestClient(_BaseRestClient): + """ + HTTP client for Confluent Schema Registry. + + See SchemaRegistryClient for configuration details. + + Args: + conf (dict): Dictionary containing _RestClient configuration + """ + + def __init__(self, conf: dict): + super().__init__(conf) + + self.session = httpx.Client( + verify=self.verify, + cert=self.cert, + auth=self.auth, + proxy=self.proxy, + timeout=self.timeout + ) + + def handle_bearer_auth(self, headers: dict) -> None: + bearer_fields = self.bearer_field_provider.get_bearer_fields() + required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] + + missing_fields = [] + for field in required_fields: + if field not in bearer_fields: + missing_fields.append(field) + + if missing_fields: + raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" + .format(", ".join(missing_fields))) + + headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) + headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] + headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] + + def get(self, url: str, query: Optional[dict] = None) -> Any: + return self.send_request(url, method='GET', query=query) + + def post(self, url: str, body: Optional[dict], **kwargs) -> Any: + return self.send_request(url, method='POST', body=body) + + def delete(self, url: str) -> Any: + return self.send_request(url, method='DELETE') + + def put(self, url: str, body: Optional[dict] = None) -> Any: + return self.send_request(url, method='PUT', body=body) + + def send_request( + self, url: str, method: str, body: Optional[dict] = None, + query: Optional[dict] = None + ) -> Any: + """ + Sends HTTP request to the SchemaRegistry, trying each base URL in turn. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + url (str): Request path + + method (str): HTTP method + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + dict: Schema Registry response content. + """ + + headers = {'Accept': "application/vnd.schemaregistry.v1+json," + " application/vnd.schemaregistry+json," + " application/json"} + + if body is not None: + body = json.dumps(body) + headers = {'Content-Length': str(len(body)), + 'Content-Type': "application/vnd.schemaregistry.v1+json"} + + if self.bearer_auth_credentials_source: + self.handle_bearer_auth(headers) + + response = None + for i, base_url in enumerate(self.base_urls): + try: + response = self.send_http_request( + base_url, url, method, headers, body, query) + + if is_success(response.status_code): + return response.json() + + if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: + break + except Exception as e: + if i == len(self.base_urls) - 1: + # Raise the exception since we have no more urls to try + raise e + + try: + raise SchemaRegistryError(response.status_code, + response.json().get('error_code'), + response.json().get('message')) + # Schema Registry may return malformed output when it hits unexpected errors + except (ValueError, KeyError, AttributeError): + raise SchemaRegistryError(response.status_code, + -1, + "Unknown Schema Registry Error: " + + str(response.content)) + + def send_http_request( + self, base_url: str, url: str, method: str, headers: Optional[dict], + body: Optional[str] = None, query: Optional[dict] = None + ) -> Response: + """ + Sends HTTP request to the SchemaRegistry. + + All unsuccessful attempts will raise a SchemaRegistryError with the + response contents. In most cases this will be accompanied by a + Schema Registry supplied error code. + + In the event the response is malformed an error_code of -1 will be used. + + Args: + base_url (str): Schema Registry base URL + + url (str): Request path + + method (str): HTTP method + + headers (dict): Headers + + body (str): Request content + + query (dict): Query params to attach to the URL + + Returns: + Response: Schema Registry response content. + """ + response = None + for i in range(self.max_retries + 1): + response = self.session.request( + method, url="/".join([base_url, url]), + headers=headers, content=body, params=query) + + if is_success(response.status_code): + return response + + if not is_retriable(response.status_code) or i >= self.max_retries: + return response + + time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) + return response + + +class SchemaRegistryClient(object): + """ + A Confluent Schema Registry client. + + Configuration properties (* indicates a required field): + + +------------------------------+------+-------------------------------------------------+ + | Property name | type | Description | + +==============================+======+=================================================+ + | ``url`` * | str | Comma-separated list of Schema Registry URLs. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to CA certificate file used | + | ``ssl.ca.location`` | str | to verify the Schema Registry's | + | | | private key. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's private key | + | | | (PEM) used for authentication. | + | ``ssl.key.location`` | str | | + | | | ``ssl.certificate.location`` must also be set. | + +------------------------------+------+-------------------------------------------------+ + | | | Path to client's public key (PEM) used for | + | | | authentication. | + | ``ssl.certificate.location`` | str | | + | | | May be set without ssl.key.location if the | + | | | private key is stored within the PEM as well. | + +------------------------------+------+-------------------------------------------------+ + | | | Client HTTP credentials in the form of | + | | | ``username:password``. | + | ``basic.auth.user.info`` | str | | + | | | By default userinfo is extracted from | + | | | the URL if present. | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``proxy`` | str | Proxy such as http://localhost:8030. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``timeout`` | int | Request timeout. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | | + | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | + | | | | + +------------------------------+------+-------------------------------------------------+ + | | | Maximum time to wait for the first retry. | + | | | When jitter is applied, the actual wait may | + | ``retries.wait.ms`` | int | be less. | + | | | | + | | | Defaults to 1000. | + +------------------------------+------+-------------------------------------------------+ + + Args: + conf (dict): Schema Registry client configuration. + + See Also: + `Confluent Schema Registry documentation `_ + """ # noqa: E501 + + def __init__(self, conf: dict): + self._conf = conf + self._rest_client = _RestClient(conf) + self._cache = _SchemaCache() + cache_capacity = self._rest_client.cache_capacity + cache_ttl = self._rest_client.cache_latest_ttl_sec + if cache_ttl is not None: + self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) + self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) + else: + self._latest_version_cache = LRUCache(cache_capacity) + self._latest_with_metadata_cache = LRUCache(cache_capacity) + + def __enter__(self): + return self + + def __exit__(self, *args): + if self._rest_client is not None: + self._rest_client.session.close() + + def config(self): + return self._conf + + def register_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> int: + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self.register_schema_full_response(subject_name, schema, normalize_schemas) + return registered_schema.schema_id + + def register_schema_full_response( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False + ) -> 'RegisteredSchema': + """ + Registers a schema under ``subject_name``. + + Args: + subject_name (str): subject to register a schema under + schema (Schema): Schema instance to register + normalize_schemas (bool): Normalize schema before registering + + Returns: + int: Schema id + + Raises: + SchemaRegistryError: if Schema violates this subject's + Compatibility policy or is otherwise invalid. + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + schema_id = self._cache.get_id_by_schema(subject_name, schema) + if schema_id is not None: + return RegisteredSchema(schema_id, schema, subject_name, None) + + request = schema.to_dict() + + response = self._rest_client.post( + 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), + body=request) + + registered_schema = RegisteredSchema.from_dict(response) + + # The registered schema may not be fully populated + self._cache.set_schema(subject_name, registered_schema.schema_id, schema) + + return registered_schema + + def get_schema( + self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None + ) -> 'Schema': + """ + Fetches the schema associated with ``schema_id`` from the + Schema Registry. The result is cached so subsequent attempts will not + require an additional round-trip to the Schema Registry. + + Args: + schema_id (int): Schema id + subject_name (str): Subject name the schema is registered under + fmt (str): Format of the schema + + Returns: + Schema: Schema instance identified by the ``schema_id`` + + Raises: + SchemaRegistryError: If schema can't be found. + + See Also: + `GET Schema API Reference `_ + """ # noqa: E501 + + schema = self._cache.get_schema_by_id(subject_name, schema_id) + if schema is not None: + return schema + + query = {'subject': subject_name} if subject_name is not None else None + if fmt is not None: + if query is not None: + query['format'] = fmt + else: + query = {'format': fmt} + response = self._rest_client.get('schemas/ids/{}'.format(schema_id), query) + + schema = Schema.from_dict(response) + + self._cache.set_schema(subject_name, schema_id, schema) + + return schema + + def lookup_schema( + self, subject_name: str, schema: 'Schema', + normalize_schemas: bool = False, deleted: bool = False + ) -> 'RegisteredSchema': + """ + Returns ``schema`` registration information for ``subject``. + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + normalize_schemas (bool): Normalize schema before registering + deleted (bool): Whether to include deleted schemas. + + Returns: + RegisteredSchema: Subject registration information for this schema. + + Raises: + SchemaRegistryError: If schema or subject can't be found + + See Also: + `POST Subject API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) + if registered_schema is not None: + return registered_schema + + request = schema.to_dict() + + response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' + .format(_urlencode(subject_name), normalize_schemas, deleted), + body=request) + + result = RegisteredSchema.from_dict(response) + + # Ensure the schema matches the input + registered_schema = RegisteredSchema( + schema_id=result.schema_id, + subject=result.subject, + version=result.version, + schema=schema, + ) + + self._cache.set_registered_schema(schema, registered_schema) + + return registered_schema + + def get_subjects(self) -> List[str]: + """ + List all subjects registered with the Schema Registry + + Returns: + list(str): Registered subject names + + Raises: + SchemaRegistryError: if subjects can't be found + + See Also: + `GET subjects API Reference `_ + """ # noqa: E501 + + return self._rest_client.get('subjects') + + def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: + """ + Deletes the specified subject and its associated compatibility level if + registered. It is recommended to use this API only when a topic needs + to be recycled or in development environments. + + Args: + subject_name (str): subject name + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + list(int): Versions deleted under this subject + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `DELETE Subject API Reference `_ + """ # noqa: E501 + + if permanent: + versions = self._rest_client.delete('subjects/{}?permanent=true' + .format(_urlencode(subject_name))) + self._cache.remove_by_subject(subject_name) + else: + versions = self._rest_client.delete('subjects/{}' + .format(_urlencode(subject_name))) + + return versions + + def get_latest_version( + self, subject_name: str, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject + + Args: + subject_name (str): Subject name. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._latest_version_cache.get(subject_name, None) + if registered_schema is not None: + return registered_schema + + query = {'format': fmt} if fmt is not None else None + response = self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + 'latest'), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_version_cache[subject_name] = registered_schema + + return registered_schema + + def get_latest_with_metadata( + self, subject_name: str, metadata: Dict[str, str], + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves latest registered version for subject with the given metadata + + Args: + subject_name (str): Subject name. + metadata (dict): The key-value pairs for the metadata. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + """ # noqa: E501 + + cache_key = (subject_name, frozenset(metadata.items()), deleted) + registered_schema = self._latest_with_metadata_cache.get(cache_key, None) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + keys = metadata.keys() + if keys: + query['key'] = [_urlencode(key) for key in keys] + query['value'] = [_urlencode(metadata[key]) for key in keys] + + response = self._rest_client.get('subjects/{}/metadata' + .format(_urlencode(subject_name)), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._latest_with_metadata_cache[cache_key] = registered_schema + + return registered_schema + + def get_version( + self, subject_name: str, version: int, + deleted: bool = False, fmt: Optional[str] = None + ) -> 'RegisteredSchema': + """ + Retrieves a specific schema registered under ``subject_name``. + + Args: + subject_name (str): Subject name. + version (int): version number. Defaults to latest version. + deleted (bool): Whether to include deleted schemas. + fmt (str): Format of the schema + + Returns: + RegisteredSchema: Registration information for this version. + + Raises: + SchemaRegistryError: if the version can't be found or is invalid. + + See Also: + `GET Subject Version API Reference `_ + """ # noqa: E501 + + registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) + if registered_schema is not None: + return registered_schema + + query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} + response = self._rest_client.get('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version), query) + + registered_schema = RegisteredSchema.from_dict(response) + + self._cache.set_registered_schema(registered_schema.schema, registered_schema) + + return registered_schema + + def get_versions(self, subject_name: str) -> List[int]: + """ + Get a list of all versions registered with this subject. + + Args: + subject_name (str): Subject name. + + Returns: + list(int): Registered versions + + Raises: + SchemaRegistryError: If subject can't be found + + See Also: + `GET Subject Versions API Reference `_ + """ # noqa: E501 + + return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) + + def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: + """ + Deletes a specific version registered to ``subject_name``. + + Args: + subject_name (str) Subject name + + version (int): Version number + + permanent (bool): True for a hard delete, False (default) for a soft delete + + Returns: + int: Version number which was deleted + + Raises: + SchemaRegistryError: if the subject or version cannot be found. + + See Also: + `Delete Subject Version API Reference `_ + """ # noqa: E501 + + if permanent: + response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' + .format(_urlencode(subject_name), + version)) + self._cache.remove_by_subject_version(subject_name, version) + else: + response = self._rest_client.delete('subjects/{}/versions/{}' + .format(_urlencode(subject_name), + version)) + + return response + + def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: + """ + Update global or subject level compatibility level. + + Args: + level (str): Compatibility level. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global compatibility + level policy if not set. + + Returns: + str: The newly configured compatibility level. + + Raises: + SchemaRegistryError: If the compatibility level is invalid. + + See Also: + `PUT Subject Compatibility API Reference `_ + """ # noqa: E501 + + if level is None: + raise ValueError("level must be set") + + if subject_name is None: + return self._rest_client.put('config', + body={'compatibility': level.upper()}) + + return self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body={'compatibility': level.upper()}) + + def get_compatibility(self, subject_name: Optional[str] = None) -> str: + """ + Get the current compatibility level. + + Args: + subject_name (str, optional): Subject name. Returns global policy + if left unset. + + Returns: + str: Compatibility level for the subject if set, otherwise the global compatibility level. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Compatibility API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = self._rest_client.get(url) + return result['compatibilityLevel'] + + def test_compatibility( + self, subject_name: str, schema: 'Schema', + version: Union[int, str] = "latest" + ) -> bool: + """Test the compatibility of a candidate schema for a given subject and version + + Args: + subject_name (str): Subject name the schema is registered under + schema (Schema): Schema instance. + version (int or str, optional): Version number, or the string "latest". Defaults to "latest". + + Returns: + bool: True if the schema is compatible with the specified version + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `POST Test Compatibility API Reference `_ + """ # noqa: E501 + + request = schema.to_dict() + + response = self._rest_client.post( + 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request + ) + + return response['is_compatible'] + + def set_config( + self, subject_name: Optional[str] = None, + config: Optional['ServerConfig'] = None + ) -> 'ServerConfig': + """ + Update global or subject config. + + Args: + config (ServerConfig): Config. See API reference for a list of + valid values. + + subject_name (str, optional): Subject to update. Sets global config + if not set. + + Returns: + str: The newly configured config. + + Raises: + SchemaRegistryError: If the config is invalid. + + See Also: + `PUT Subject Config API Reference `_ + """ # noqa: E501 + + if config is None: + raise ValueError("config must be set") + + if subject_name is None: + return self._rest_client.put('config', + body=config.to_dict()) + + return self._rest_client.put('config/{}' + .format(_urlencode(subject_name)), + body=config.to_dict()) + + def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': + """ + Get the current config. + + Args: + subject_name (str, optional): Subject name. Returns global config + if left unset. + + Returns: + ServerConfig: Config for the subject if set, otherwise the global config. + + Raises: + SchemaRegistryError: if the request was unsuccessful. + + See Also: + `GET Subject Config API Reference `_ + """ # noqa: E501 + + if subject_name is not None: + url = 'config/{}'.format(_urlencode(subject_name)) + else: + url = 'config' + + result = self._rest_client.get(url) + return ServerConfig.from_dict(result) + + def clear_latest_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + + def clear_caches(self): + self._latest_version_cache.clear() + self._latest_with_metadata_cache.clear() + self._cache.clear() + + @staticmethod + def new_client(conf: dict) -> 'SchemaRegistryClient': + from confluent_kafka.schema_registry.mock_schema_registry_client import MockSchemaRegistryClient + url = conf.get("url") + if url.startswith("mock://"): + return MockSchemaRegistryClient(conf) + return SchemaRegistryClient(conf) diff --git a/src/confluent_kafka/schema_registry/_sync/serde.py b/src/confluent_kafka/schema_registry/_sync/serde.py new file mode 100644 index 000000000..a0481f9b0 --- /dev/null +++ b/src/confluent_kafka/schema_registry/_sync/serde.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +from typing import List, Optional, Set, Dict, Any + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.common.serde import ErrorAction, \ + FieldTransformer, Migration, NoneAction, RuleAction, \ + RuleConditionError, RuleContext, RuleError +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema, RuleSet +from confluent_kafka.serialization import Serializer, Deserializer, \ + SerializationContext, SerializationError + +__all__ = [ + 'BaseSerde', + 'BaseSerializer', + 'BaseDeserializer', +] + +log = logging.getLogger(__name__) + + +class BaseSerde(object): + __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', + '_registry', '_rule_registry', '_subject_name_func', + '_field_transformer'] + + def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: + if self._use_schema_id is not None: + schema = self._registry.get_schema(self._use_schema_id, subject, fmt) + return self._registry.lookup_schema(subject, schema, False, True) + if self._use_latest_with_metadata is not None: + return self._registry.get_latest_with_metadata( + subject, self._use_latest_with_metadata, True, fmt) + if self._use_latest_version: + return self._registry.get_latest_version(subject, fmt) + return None + + def _execute_rules( + self, ser_ctx: SerializationContext, subject: str, + rule_mode: RuleMode, + source: Optional[Schema], target: Optional[Schema], + message: Any, inline_tags: Optional[Dict[str, Set[str]]], + field_transformer: Optional[FieldTransformer] + ) -> Any: + if message is None or target is None: + return message + rules: Optional[List[Rule]] = None + if rule_mode == RuleMode.UPGRADE: + if target is not None and target.rule_set is not None: + rules = target.rule_set.migration_rules + elif rule_mode == RuleMode.DOWNGRADE: + if source is not None and source.rule_set is not None: + rules = source.rule_set.migration_rules + rules = rules[:] if rules else [] + rules.reverse() + else: + if target is not None and target.rule_set is not None: + rules = target.rule_set.domain_rules + if rule_mode == RuleMode.READ: + # Execute read rules in reverse order for symmetry + rules = rules[:] if rules else [] + rules.reverse() + + if not rules: + return message + + for index in range(len(rules)): + rule = rules[index] + if self._is_disabled(rule): + continue + if rule.mode == RuleMode.WRITEREAD: + if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: + continue + elif rule.mode == RuleMode.UPDOWN: + if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: + continue + elif rule.mode != rule_mode: + continue + + ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, + index, rules, inline_tags, field_transformer) + rule_executor = self._rule_registry.get_executor(rule.type.upper()) + if rule_executor is None: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, + RuleError(f"Could not find rule executor of type {rule.type}"), + 'ERROR') + return message + try: + result = rule_executor.transform(ctx, message) + if rule.kind == RuleKind.CONDITION: + if not result: + raise RuleConditionError(rule) + elif rule.kind == RuleKind.TRANSFORM: + message = result + self._run_action( + ctx, rule_mode, rule, + self._get_on_failure(rule) if message is None else self._get_on_success(rule), + message, None, + 'ERROR' if message is None else 'NONE') + except SerializationError: + raise + except Exception as e: + self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), + message, e, 'ERROR') + return message + + def _get_on_success(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_success is not None: + return override.on_success + return rule.on_success + + def _get_on_failure(self, rule: Rule) -> Optional[str]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.on_failure is not None: + return override.on_failure + return rule.on_failure + + def _is_disabled(self, rule: Rule) -> Optional[bool]: + override = self._rule_registry.get_override(rule.type) + if override is not None and override.disabled is not None: + return override.disabled + return rule.disabled + + def _run_action( + self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, + action: Optional[str], message: Any, + ex: Optional[Exception], default_action: str + ): + action_name = self._get_rule_action_name(rule, rule_mode, action) + if action_name is None: + action_name = default_action + rule_action = self._get_rule_action(ctx, action_name) + if rule_action is None: + log.error("Could not find rule action of type %s", action_name) + raise RuleError(f"Could not find rule action of type {action_name}") + try: + rule_action.run(ctx, message, ex) + except SerializationError: + raise + except Exception as e: + log.warning("Could not run post-rule action %s: %s", action_name, e) + + def _get_rule_action_name( + self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] + ) -> Optional[str]: + if action_name is None or action_name == "": + return None + if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: + parts = action_name.split(',') + if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + return parts[0] + elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): + return parts[1] + return action_name + + def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: + if action_name == 'ERROR': + return ErrorAction() + elif action_name == 'NONE': + return NoneAction() + return self._rule_registry.get_action(action_name) + + +class BaseSerializer(BaseSerde, Serializer): + __slots__ = ['_auto_register', '_normalize_schemas'] + + +class BaseDeserializer(BaseSerde, Deserializer): + __slots__ = [] + + def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: + if rule_set is None: + return False + if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): + return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN + for rule in rule_set.migration_rules or []) + elif mode == RuleMode.UPDOWN: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + elif mode in (RuleMode.WRITE, RuleMode.READ): + return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD + for rule in rule_set.domain_rules or []) + elif mode == RuleMode.WRITEREAD: + return any(rule.mode == mode for rule in rule_set.migration_rules or []) + return False + + def _get_migrations( + self, subject: str, source_info: Schema, + target: RegisteredSchema, fmt: Optional[str] + ) -> List[Migration]: + source = self._registry.lookup_schema(subject, source_info, False, True) + migrations = [] + if source.version < target.version: + migration_mode = RuleMode.UPGRADE + first = source + last = target + elif source.version > target.version: + migration_mode = RuleMode.DOWNGRADE + first = target + last = source + else: + return migrations + previous: Optional[RegisteredSchema] = None + versions = self._get_schemas_between(subject, first, last, fmt) + for i in range(len(versions)): + version = versions[i] + if i == 0: + previous = version + continue + if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): + if migration_mode == RuleMode.UPGRADE: + migration = Migration(migration_mode, previous, version) + else: + migration = Migration(migration_mode, version, previous) + migrations.append(migration) + previous = version + if migration_mode == RuleMode.DOWNGRADE: + migrations.reverse() + return migrations + + def _get_schemas_between( + self, subject: str, first: RegisteredSchema, + last: RegisteredSchema, fmt: Optional[str] = None + ) -> List[RegisteredSchema]: + if last.version - first.version <= 1: + return [first, last] + version1 = first.version + version2 = last.version + result = [first] + for i in range(version1 + 1, version2): + result.append(self._registry.get_version(subject, i, True, fmt)) + result.append(last) + return result + + def _execute_migrations( + self, ser_ctx: SerializationContext, subject: str, + migrations: List[Migration], message: Any + ) -> Any: + for migration in migrations: + message = self._execute_rules(ser_ctx, subject, migration.rule_mode, + migration.source.schema, migration.target.schema, + message, None, None) + return message diff --git a/src/confluent_kafka/schema_registry/avro.py b/src/confluent_kafka/schema_registry/avro.py index 368507752..0da66230e 100644 --- a/src/confluent_kafka/schema_registry/avro.py +++ b/src/confluent_kafka/schema_registry/avro.py @@ -15,796 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import decimal -import re -from collections import defaultdict -from copy import deepcopy -from io import BytesIO -from json import loads -from struct import pack, unpack -from typing import Dict, Union, Optional, Set, Callable - -from fastavro import (schemaless_reader, - schemaless_writer, - repository, - validate) -from fastavro.schema import load_schema - -from . import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleMode, - RuleKind, SchemaRegistryClient) -from confluent_kafka.serialization import (SerializationError, - SerializationContext) -from .rule_registry import RuleRegistry -from .serde import BaseSerializer, BaseDeserializer, RuleContext, FieldType, \ - FieldTransform, RuleConditionError, ParsedSchemaCache - - -AvroMessage = Union[ - None, # 'null' Avro type - str, # 'string' and 'enum' - float, # 'float' and 'double' - int, # 'int' and 'long' - decimal.Decimal, # 'fixed' - bool, # 'boolean' - bytes, # 'bytes' - list, # 'array' - dict, # 'map' and 'record' -] -AvroSchema = Union[str, list, dict] - - -class _ContextStringIO(BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _schema_loads(schema_str: str) -> Schema: - """ - Instantiate a Schema instance from a declaration string. - - Args: - schema_str (str): Avro Schema declaration. - - .. _Schema declaration: - https://avro.apache.org/docs/current/spec.html#schemas - - Returns: - Schema: A Schema instance. - """ - - schema_str = schema_str.strip() - - # canonical form primitive declarations are not supported - if schema_str[0] != "{" and schema_str[0] != "[": - schema_str = '{"type":' + schema_str + '}' - - return Schema(schema_str, schema_type='AVRO') - - -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient -) -> Dict[str, AvroSchema]: - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :return: named_schemas dict. - """ - named_schemas = {} - if schema.references is not None: - for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_named_schemas = _resolve_named_schema(referenced_schema.schema, schema_registry_client) - parsed_schema = parse_schema_with_repo( - referenced_schema.schema.schema_str, named_schemas=ref_named_schemas) - named_schemas.update(ref_named_schemas) - named_schemas[ref.name] = parsed_schema - return named_schemas - - -class AvroSerializer(BaseSerializer): - """ - Serializer that outputs Avro binary encoded data with Confluent Schema Registry framing. - - Configuration properties: - - +-----------------------------+----------+--------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+==================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+--------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies: - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Note: - Prior to serialization, all values must first be converted to - a dict instance. This may handled manually prior to calling - :py:func:`Producer.produce()` or by registering a `to_dict` - callable with AvroSerializer. - - See ``avro_producer.py`` in the examples directory for example usage. - - Note: - Tuple notation can be used to determine which branch of an ambiguous union to take. - - See `fastavro notation `_ - - Args: - schema_registry_client (SchemaRegistryClient): Schema Registry client instance. - - schema_str (str or Schema): - Avro `Schema Declaration. `_ - Accepts either a string or a :py:class:`Schema` instance. Note that string - definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. - - to_dict (callable, optional): Callable(object, SerializationContext) -> dict. Converts object to a dict. - - conf (dict): AvroSerializer configuration. - """ # noqa: E501 - __slots__ = ['_known_subjects', '_parsed_schema', '_schema', - '_schema_id', '_schema_name', '_to_dict', '_parsed_schemas'] - - _default_conf = {'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} - - def __init__( - self, - schema_registry_client: SchemaRegistryClient, - schema_str: Union[str, Schema, None] = None, - to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - if isinstance(schema_str, str): - schema = _schema_loads(schema_str) - elif isinstance(schema_str, Schema): - schema = schema_str - else: - schema = None - - self._registry = schema_registry_client - self._schema_id = None - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._known_subjects = set() - self._parsed_schemas = ParsedSchemaCache() - - if to_dict is not None and not callable(to_dict): - raise ValueError("to_dict must be callable with the signature " - "to_dict(object, SerializationContext)->dict") - - self._to_dict = to_dict - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - parsed_schema = self._get_parsed_schema(schema) - - if isinstance(parsed_schema, list): - # if parsed_schema is a list, we have an Avro union and there - # is no valid schema name. This is fine because the only use of - # schema_name is for supplying the subject name to the registry - # and union types should use topic_subject_name_strategy, which - # just discards the schema name anyway - schema_name = None - else: - # The Avro spec states primitives have a name equal to their type - # i.e. {"type": "string"} has a name of string. - # This function does not comply. - # https://github.com/fastavro/fastavro/issues/415 - schema_dict = loads(schema.schema_str) - schema_name = parsed_schema.get("name", schema_dict.get("type")) - else: - schema_name = None - parsed_schema = None - - self._schema = schema - self._schema_name = schema_name - self._parsed_schema = parsed_schema - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an object to Avro binary format, prepending it with Confluent - Schema Registry framing. - - Args: - obj (object): The object instance to serialize. - - ctx (SerializationContext): Metadata pertaining to the serialization operation. - - Raises: - SerializerError: If any error occurs serializing obj. - SchemaRegistryError: If there was an error registering the schema with - Schema Registry, or auto.register.schemas is - false and the schema was not registered. - - Returns: - bytes: Confluent Schema Registry encoded Avro bytes - """ - - if obj is None: - return None - - subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - # Check to ensure this schema has been registered under subject_name. - if self._auto_register: - # The schema name will always be the same. We can't however register - # a schema without a subject so we set the schema_id here to handle - # the initial registration. - self._schema_id = self._registry.register_schema( - subject, self._schema, self._normalize_schemas) - else: - registered_schema = self._registry.lookup_schema( - subject, self._schema, self._normalize_schemas) - self._schema_id = registered_schema.schema_id - - self._known_subjects.add(subject) - - if self._to_dict is not None: - value = self._to_dict(obj, ctx) - else: - value = obj - - if latest_schema is not None: - parsed_schema = self._get_parsed_schema(latest_schema.schema) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, parsed_schema, msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, get_inline_tags(parsed_schema), - field_transformer) - else: - parsed_schema = self._parsed_schema - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(pack('>bI', _MAGIC_BYTE, self._schema_id)) - # write the record to the rest of the buffer - schemaless_writer(fo, parsed_schema, value) - - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> AvroSchema: - parsed_schema = self._parsed_schemas.get_parsed_schema(schema) - if parsed_schema is not None: - return parsed_schema - - named_schemas = _resolve_named_schema(schema, self._registry) - prepared_schema = _schema_loads(schema.schema_str) - parsed_schema = parse_schema_with_repo( - prepared_schema.schema_str, named_schemas=named_schemas) - - self._parsed_schemas.set(schema, parsed_schema) - return parsed_schema - - -class AvroDeserializer(BaseDeserializer): - """ - Deserializer for Avro binary encoded data with Confluent Schema Registry - framing. - - +-----------------------------+----------+--------------------------------------------------+ - | Property Name | Type | Description | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+--------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+--------------------------------------------------+ - - Note: - By default, Avro complex types are returned as dicts. This behavior can - be overridden by registering a callable ``from_dict`` with the deserializer to - convert the dicts to the desired type. - - See ``avro_consumer.py`` in the examples directory in the examples - directory for example usage. - - Args: - schema_registry_client (SchemaRegistryClient): Confluent Schema Registry - client instance. - - schema_str (str, Schema, optional): Avro reader schema declaration Accepts - either a string or a :py:class:`Schema` instance. If not provided, the - writer schema will be used as the reader schema. Note that string - definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. - - from_dict (callable, optional): Callable(dict, SerializationContext) -> object. - Converts a dict to an instance of some object. - - return_record_name (bool): If True, when reading a union of records, the result will - be a tuple where the first value is the name of the record and the second value is - the record itself. Defaults to False. - - See Also: - `Apache Avro Schema Declaration `_ - - `Apache Avro Schema Resolution `_ - """ - - __slots__ = ['_reader_schema', '_from_dict', '_return_record_name', - '_schema', '_parsed_schemas'] - - _default_conf = {'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy} - - def __init__( - self, - schema_registry_client: SchemaRegistryClient, - schema_str: Union[str, Schema, None] = None, - from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, - return_record_name: bool = False, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - schema = None - if schema_str is not None: - if isinstance(schema_str, str): - schema = _schema_loads(schema_str) - elif isinstance(schema_str, Schema): - schema = schema_str - else: - raise TypeError('You must pass either schema string or schema object') - - self._schema = schema - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._use_schema_id = None - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - self._reader_schema = self._get_parsed_schema(self._schema) - else: - self._reader_schema = None - - if from_dict is not None and not callable(from_dict): - raise ValueError("from_dict must be callable with the signature " - "from_dict(SerializationContext, dict) -> object") - self._from_dict = from_dict - - self._return_record_name = return_record_name - if not isinstance(self._return_record_name, bool): - raise ValueError("return_record_name must be a boolean value") - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: - """ - Deserialize Avro binary encoded data with Confluent Schema Registry framing to - a dict, or object instance according to from_dict, if specified. - - Arguments: - data (bytes): bytes - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Raises: - SerializerError: if an error occurs parsing data. - - Returns: - object: If data is None, then None. Else, a dict, or object instance according - to from_dict, if specified. - """ # noqa: E501 - - if data is None: - return None - - if len(data) <= 5: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) if ctx else None - latest_schema = None - if subject is not None: - latest_schema = self._get_reader_schema(subject) - - with _ContextStringIO(data) as payload: - magic, schema_id = unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - writer_schema_raw = self._registry.get_schema(schema_id) - writer_schema = self._get_parsed_schema(writer_schema_raw) - - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("name")) if ctx else None - if subject is not None: - latest_schema = self._get_reader_schema(subject) - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema = self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema = self._reader_schema - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if migrations: - obj_dict = schemaless_reader(payload, - writer_schema, - None, - self._return_record_name) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - else: - obj_dict = schemaless_reader(payload, - writer_schema, - reader_schema, - self._return_record_name) - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, get_inline_tags(reader_schema), - field_transformer) - - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) - - return obj_dict - - def _get_parsed_schema(self, schema: Schema) -> AvroSchema: - parsed_schema = self._parsed_schemas.get_parsed_schema(schema) - if parsed_schema is not None: - return parsed_schema - - named_schemas = _resolve_named_schema(schema, self._registry) - prepared_schema = _schema_loads(schema.schema_str) - parsed_schema = parse_schema_with_repo( - prepared_schema.schema_str, named_schemas=named_schemas) - - self._parsed_schemas.set(schema, parsed_schema) - return parsed_schema - - -class LocalSchemaRepository(repository.AbstractSchemaRepository): - def __init__(self, schemas): - self.schemas = schemas - - def load(self, subject): - return self.schemas.get(subject) - - -def parse_schema_with_repo(schema_str: str, named_schemas: Dict[str, AvroSchema]) -> AvroSchema: - copy = deepcopy(named_schemas) - copy["$root"] = loads(schema_str) - repo = LocalSchemaRepository(copy) - return load_schema("$root", repo=repo) - - -def transform( - ctx: RuleContext, schema: AvroSchema, message: AvroMessage, - field_transform: FieldTransform -) -> AvroMessage: - if message is None or schema is None: - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - field_ctx.field_type = get_type(schema) - if isinstance(schema, list): - subschema = _resolve_union(schema, message) - if subschema is None: - return message - return transform(ctx, subschema, message, field_transform) - elif isinstance(schema, dict): - schema_type = schema.get("type") - if schema_type == 'array': - return [transform(ctx, schema["items"], item, field_transform) - for item in message] - elif schema_type == 'map': - return {key: transform(ctx, schema["values"], value, field_transform) - for key, value in message.items()} - elif schema_type == 'record': - fields = schema["fields"] - for field in fields: - _transform_field(ctx, schema, field, message, field_transform) - return message - - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, schema: AvroSchema, field: dict, - message: AvroMessage, field_transform: FieldTransform -): - field_type = field["type"] - name = field["name"] - full_name = schema["name"] + "." + name - try: - ctx.enter_field( - message, - full_name, - name, - get_type(field_type), - None - ) - value = message[name] - new_value = transform(ctx, field_type, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - message[name] = new_value - finally: - ctx.exit_field() - - -def get_type(schema: AvroSchema) -> FieldType: - if isinstance(schema, list): - return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema - - if schema_type == 'record': - return FieldType.RECORD - elif schema_type == 'enum': - return FieldType.ENUM - elif schema_type == 'array': - return FieldType.ARRAY - elif schema_type == 'map': - return FieldType.MAP - elif schema_type == 'union': - return FieldType.COMBINED - elif schema_type == 'fixed': - return FieldType.FIXED - elif schema_type == 'string': - return FieldType.STRING - elif schema_type == 'bytes': - return FieldType.BYTES - elif schema_type == 'int': - return FieldType.INT - elif schema_type == 'long': - return FieldType.LONG - elif schema_type == 'float': - return FieldType.FLOAT - elif schema_type == 'double': - return FieldType.DOUBLE - elif schema_type == 'boolean': - return FieldType.BOOLEAN - elif schema_type == 'null': - return FieldType.NULL - else: - return FieldType.NULL - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]: - for subschema in schema: - try: - validate(message, subschema) - except: # noqa: E722 - continue - return subschema - return None - - -def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: - inline_tags = defaultdict(set) - _get_inline_tags_recursively('', '', schema, inline_tags) - return inline_tags - - -def _get_inline_tags_recursively( - ns: str, name: str, schema: Optional[AvroSchema], - tags: Dict[str, Set[str]] -): - if schema is None: - return - if isinstance(schema, list): - for subschema in schema: - _get_inline_tags_recursively(ns, name, subschema, tags) - elif not isinstance(schema, dict): - # string schemas; this could be either a named schema or a primitive type - return - else: - schema_type = schema.get("type") - if schema_type == 'array': - _get_inline_tags_recursively(ns, name, schema.get("items"), tags) - elif schema_type == 'map': - _get_inline_tags_recursively(ns, name, schema.get("values"), tags) - elif schema_type == 'record': - record_ns = schema.get("namespace") - record_name = schema.get("name") - if record_ns is None: - record_ns = _implied_namespace(name) - if record_ns is None: - record_ns = ns - if record_ns != '' and not record_name.startswith(record_ns): - record_name = f"{record_ns}.{record_name}" - fields = schema["fields"] - for field in fields: - field_tags = field.get("confluent:tags") - field_name = field.get("name") - field_type = field.get("type") - if field_tags is not None and field_name is not None: - tags[record_name + '.' + field_name].update(field_tags) - if field_type is not None: - _get_inline_tags_recursively(record_ns, record_name, field_type, tags) - - -def _implied_namespace(name: str) -> Optional[str]: - match = re.match(r"^(.*)\.[^.]+$", name) - return match.group(1) if match else None +from .common.avro import * # noqa +from ._sync.avro import * # noqa diff --git a/src/confluent_kafka/schema_registry/common/__init__.py b/src/confluent_kafka/schema_registry/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/confluent_kafka/schema_registry/common/avro.py b/src/confluent_kafka/schema_registry/common/avro.py new file mode 100644 index 000000000..7e038f6c7 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/avro.py @@ -0,0 +1,262 @@ +import decimal +import re +from collections import defaultdict +from copy import deepcopy +from io import BytesIO +from json import loads +from typing import Dict, Union, Optional, Set + +from fastavro import repository, validate +from fastavro.schema import load_schema + +from .schema_registry_client import Schema, RuleKind +from confluent_kafka.schema_registry.serde import RuleContext, FieldType, \ + FieldTransform, RuleConditionError + +__all__ = [ + 'AvroMessage', + 'AvroSchema', + '_schema_loads', + 'LocalSchemaRepository', + 'parse_schema_with_repo', + 'transform', + '_transform_field', + 'get_type', + '_disjoint', + '_resolve_union', + 'get_inline_tags', + '_get_inline_tags_recursively', + '_implied_namespace', +] + +AvroMessage = Union[ + None, # 'null' Avro type + str, # 'string' and 'enum' + float, # 'float' and 'double' + int, # 'int' and 'long' + decimal.Decimal, # 'fixed' + bool, # 'boolean' + bytes, # 'bytes' + list, # 'array' + dict, # 'map' and 'record' +] +AvroSchema = Union[str, list, dict] + + +class _ContextStringIO(BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + +def _schema_loads(schema_str: str) -> Schema: + """ + Instantiate a Schema instance from a declaration string. + + Args: + schema_str (str): Avro Schema declaration. + + .. _Schema declaration: + https://avro.apache.org/docs/current/spec.html#schemas + + Returns: + Schema: A Schema instance. + """ + + schema_str = schema_str.strip() + + # canonical form primitive declarations are not supported + if schema_str[0] != "{" and schema_str[0] != "[": + schema_str = '{"type":' + schema_str + '}' + + return Schema(schema_str, schema_type='AVRO') + + +class LocalSchemaRepository(repository.AbstractSchemaRepository): + def __init__(self, schemas): + self.schemas = schemas + + def load(self, subject): + return self.schemas.get(subject) + + +def parse_schema_with_repo(schema_str: str, named_schemas: Dict[str, AvroSchema]) -> AvroSchema: + copy = deepcopy(named_schemas) + copy["$root"] = loads(schema_str) + repo = LocalSchemaRepository(copy) + return load_schema("$root", repo=repo) + + +def transform( + ctx: RuleContext, schema: AvroSchema, message: AvroMessage, + field_transform: FieldTransform +) -> AvroMessage: + if message is None or schema is None: + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + field_ctx.field_type = get_type(schema) + if isinstance(schema, list): + subschema = _resolve_union(schema, message) + if subschema is None: + return message + return transform(ctx, subschema, message, field_transform) + elif isinstance(schema, dict): + schema_type = schema.get("type") + if schema_type == 'array': + return [transform(ctx, schema["items"], item, field_transform) + for item in message] + elif schema_type == 'map': + return {key: transform(ctx, schema["values"], value, field_transform) + for key, value in message.items()} + elif schema_type == 'record': + fields = schema["fields"] + for field in fields: + _transform_field(ctx, schema, field, message, field_transform) + return message + + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, schema: AvroSchema, field: dict, + message: AvroMessage, field_transform: FieldTransform +): + field_type = field["type"] + name = field["name"] + full_name = schema["name"] + "." + name + try: + ctx.enter_field( + message, + full_name, + name, + get_type(field_type), + None + ) + value = message[name] + new_value = transform(ctx, field_type, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + message[name] = new_value + finally: + ctx.exit_field() + + +def get_type(schema: AvroSchema) -> FieldType: + if isinstance(schema, list): + return FieldType.COMBINED + elif isinstance(schema, dict): + schema_type = schema.get("type") + else: + # string schemas; this could be either a named schema or a primitive type + schema_type = schema + + if schema_type == 'record': + return FieldType.RECORD + elif schema_type == 'enum': + return FieldType.ENUM + elif schema_type == 'array': + return FieldType.ARRAY + elif schema_type == 'map': + return FieldType.MAP + elif schema_type == 'union': + return FieldType.COMBINED + elif schema_type == 'fixed': + return FieldType.FIXED + elif schema_type == 'string': + return FieldType.STRING + elif schema_type == 'bytes': + return FieldType.BYTES + elif schema_type == 'int': + return FieldType.INT + elif schema_type == 'long': + return FieldType.LONG + elif schema_type == 'float': + return FieldType.FLOAT + elif schema_type == 'double': + return FieldType.DOUBLE + elif schema_type == 'boolean': + return FieldType.BOOLEAN + elif schema_type == 'null': + return FieldType.NULL + else: + return FieldType.NULL + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def _resolve_union(schema: AvroSchema, message: AvroMessage) -> Optional[AvroSchema]: + for subschema in schema: + try: + validate(message, subschema) + except: # noqa: E722 + continue + return subschema + return None + + +def get_inline_tags(schema: AvroSchema) -> Dict[str, Set[str]]: + inline_tags = defaultdict(set) + _get_inline_tags_recursively('', '', schema, inline_tags) + return inline_tags + + +def _get_inline_tags_recursively( + ns: str, name: str, schema: Optional[AvroSchema], + tags: Dict[str, Set[str]] +): + if schema is None: + return + if isinstance(schema, list): + for subschema in schema: + _get_inline_tags_recursively(ns, name, subschema, tags) + elif not isinstance(schema, dict): + # string schemas; this could be either a named schema or a primitive type + return + else: + schema_type = schema.get("type") + if schema_type == 'array': + _get_inline_tags_recursively(ns, name, schema.get("items"), tags) + elif schema_type == 'map': + _get_inline_tags_recursively(ns, name, schema.get("values"), tags) + elif schema_type == 'record': + record_ns = schema.get("namespace") + record_name = schema.get("name") + if record_ns is None: + record_ns = _implied_namespace(name) + if record_ns is None: + record_ns = ns + if record_ns != '' and not record_name.startswith(record_ns): + record_name = f"{record_ns}.{record_name}" + fields = schema["fields"] + for field in fields: + field_tags = field.get("confluent:tags") + field_name = field.get("name") + field_type = field.get("type") + if field_tags is not None and field_name is not None: + tags[record_name + '.' + field_name].update(field_tags) + if field_type is not None: + _get_inline_tags_recursively(record_ns, record_name, field_type, tags) + + +def _implied_namespace(name: str) -> Optional[str]: + match = re.match(r"^(.*)\.[^.]+$", name) + return match.group(1) if match else None diff --git a/src/confluent_kafka/schema_registry/common/json_schema.py b/src/confluent_kafka/schema_registry/common/json_schema.py new file mode 100644 index 000000000..ef789fe69 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/json_schema.py @@ -0,0 +1,194 @@ + +import decimal +from io import BytesIO + +from typing import Union, Optional, List, Set + +import httpx +import referencing +from jsonschema import validate, ValidationError +from referencing import Registry, Resource +from referencing._core import Resolver + +from confluent_kafka.schema_registry import RuleKind +from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, FieldType, \ + RuleConditionError + +__all__ = [ + 'JsonMessage', + 'JsonSchema', + 'DEFAULT_SPEC', + '_retrieve_via_httpx', + 'transform', + '_transform_field', + '_validate_subschemas', + 'get_type', + '_disjoint', + 'get_inline_tags', +] + +JsonMessage = Union[ + None, # 'null' Avro type + str, # 'string' and 'enum' + float, # 'float' and 'double' + int, # 'int' and 'long' + decimal.Decimal, # 'fixed' + bool, # 'boolean' + list, # 'array' + dict, # 'map' and 'record' +] + +JsonSchema = Union[bool, dict] + +DEFAULT_SPEC = referencing.jsonschema.DRAFT7 + + +class _ContextStringIO(BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + +def _retrieve_via_httpx(uri: str): + response = httpx.get(uri) + return Resource.from_contents( + response.json(), default_specification=DEFAULT_SPEC) + + +def transform( + ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, + path: str, message: JsonMessage, field_transform: FieldTransform +) -> Optional[JsonMessage]: + if message is None or schema is None or isinstance(schema, bool): + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + field_ctx.field_type = get_type(schema) + all_of = schema.get("allOf") + if all_of is not None: + subschema = _validate_subschemas(all_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + any_of = schema.get("anyOf") + if any_of is not None: + subschema = _validate_subschemas(any_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + one_of = schema.get("oneOf") + if one_of is not None: + subschema = _validate_subschemas(one_of, message, ref_registry) + if subschema is not None: + return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) + items = schema.get("items") + if items is not None: + if isinstance(message, list): + return [transform(ctx, items, ref_registry, ref_resolver, path, item, field_transform) for item in message] + ref = schema.get("$ref") + if ref is not None: + ref_schema = ref_resolver.lookup(ref) + return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) + schema_type = get_type(schema) + if schema_type == FieldType.RECORD: + props = schema.get("properties") + if props is not None: + for prop_name, prop_schema in props.items(): + _transform_field(ctx, path, prop_name, message, + prop_schema, ref_registry, ref_resolver, field_transform) + return message + if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, + prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform +): + full_name = path + "." + prop_name + try: + ctx.enter_field( + message, + full_name, + prop_name, + get_type(prop_schema), + get_inline_tags(prop_schema) + ) + value = message[prop_name] + new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + message[prop_name] = new_value + finally: + ctx.exit_field() + + +def _validate_subschemas( + subschemas: List[JsonSchema], + message: JsonMessage, + registry: Registry +) -> Optional[JsonSchema]: + for subschema in subschemas: + try: + validate(instance=message, schema=subschema, registry=registry) + return subschema + except ValidationError: + pass + return None + + +def get_type(schema: JsonSchema) -> FieldType: + if isinstance(schema, list): + return FieldType.COMBINED + elif isinstance(schema, dict): + schema_type = schema.get("type") + else: + # string schemas; this could be either a named schema or a primitive type + schema_type = schema + + if schema.get("const") is not None or schema.get("enum") is not None: + return FieldType.ENUM + if schema_type == "object": + props = schema.get("properties") + if not props: + return FieldType.MAP + return FieldType.RECORD + if schema_type == "array": + return FieldType.ARRAY + if schema_type == "string": + return FieldType.STRING + if schema_type == "integer": + return FieldType.INT + if schema_type == "number": + return FieldType.DOUBLE + if schema_type == "boolean": + return FieldType.BOOLEAN + if schema_type == "null": + return FieldType.NULL + return FieldType.NULL + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def get_inline_tags(schema: JsonSchema) -> Set[str]: + tags = schema.get("confluent:tags") + if tags is None: + return set() + else: + return set(tags) diff --git a/src/confluent_kafka/schema_registry/common/protobuf.py b/src/confluent_kafka/schema_registry/common/protobuf.py new file mode 100644 index 000000000..3b27940fc --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/protobuf.py @@ -0,0 +1,390 @@ +import io +import sys +import base64 +from collections import deque +from decimal import Context, Decimal, MAX_PREC +from typing import Set, List, Any + +from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ + duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ + type_pb2, wrappers_pb2 +from google.protobuf.descriptor_pool import DescriptorPool +from google.type import calendar_period_pb2, color_pb2, date_pb2, datetime_pb2, \ + dayofweek_pb2, expr_pb2, fraction_pb2, latlng_pb2, money_pb2, month_pb2, \ + postal_address_pb2, timeofday_pb2, quaternion_pb2 + +import confluent_kafka.schema_registry.confluent.meta_pb2 as meta_pb2 + +from google.protobuf.descriptor import Descriptor, FieldDescriptor, \ + FileDescriptor +from google.protobuf.message import DecodeError, Message + +from confluent_kafka.schema_registry.confluent.types import decimal_pb2 +from confluent_kafka.schema_registry import RuleKind +from confluent_kafka.serialization import SerializationError +from confluent_kafka.schema_registry.serde import RuleContext, FieldTransform, \ + FieldType, RuleConditionError + +__all__ = [ + '_bytes', + '_create_index_array', + '_schema_to_str', + '_proto_to_str', + '_str_to_proto', + '_init_pool', + 'transform', + '_transform_field', + '_set_field', + 'get_type', + 'is_map_field', + 'get_inline_tags', + '_disjoint', + '_is_builtin', + 'decimalToProtobuf', + 'protobufToDecimal' +] + +# Convert an int to bytes (inverse of ord()) +# Python3.chr() -> Unicode +# Python2.chr() -> str(alias for bytes) +if sys.version > '3': + def _bytes(v: int) -> bytes: + """ + Convert int to bytes + + Args: + v (int): The int to convert to bytes. + """ + return bytes((v,)) +else: + def _bytes(v: int) -> str: + """ + Convert int to bytes + + Args: + v (int): The int to convert to bytes. + """ + return chr(v) + + +class _ContextStringIO(io.BytesIO): + """ + Wrapper to allow use of StringIO via 'with' constructs. + """ + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + return False + + +def _create_index_array(msg_desc: Descriptor) -> List[int]: + """ + Creates an index array specifying the location of msg_desc in + the referenced FileDescriptor. + + Args: + msg_desc (MessageDescriptor): Protobuf MessageDescriptor + + Returns: + list of int: Protobuf MessageDescriptor index array. + + Raises: + ValueError: If the message descriptor is malformed. + """ + + msg_idx = deque() + + # Walk the nested MessageDescriptor tree up to the root. + current = msg_desc + found = False + while current.containing_type is not None: + previous = current + current = previous.containing_type + # find child's position + for idx, node in enumerate(current.nested_types): + if node == previous: + msg_idx.appendleft(idx) + found = True + break + if not found: + raise ValueError("Nested MessageDescriptor not found") + + # Add the index of the root MessageDescriptor in the FileDescriptor. + found = False + for idx, msg_type_name in enumerate(msg_desc.file.message_types_by_name): + if msg_type_name == current.name: + msg_idx.appendleft(idx) + found = True + break + if not found: + raise ValueError("MessageDescriptor not found in file") + + return list(msg_idx) + + +def _schema_to_str(file_descriptor: FileDescriptor) -> str: + """ + Base64 encode a FileDescriptor + + Args: + file_descriptor (FileDescriptor): FileDescriptor to encode. + + Returns: + str: Base64 encoded FileDescriptor + """ + + return base64.standard_b64encode(file_descriptor.serialized_pb).decode('ascii') + + +def _proto_to_str(file_descriptor_proto: descriptor_pb2.FileDescriptorProto) -> str: + """ + Base64 encode a FileDescriptorProto + + Args: + file_descriptor_proto (FileDescriptorProto): FileDescriptorProto to encode. + + Returns: + str: Base64 encoded FileDescriptorProto + """ + + return base64.standard_b64encode(file_descriptor_proto.SerializeToString()).decode('ascii') + + +def _str_to_proto(name: str, schema_str: str) -> descriptor_pb2.FileDescriptorProto: + """ + Base64 decode a FileDescriptor + + Args: + schema_str (str): Base64 encoded FileDescriptorProto + + Returns: + FileDescriptorProto: schema. + """ + + serialized_pb = base64.standard_b64decode(schema_str.encode('ascii')) + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + try: + file_descriptor_proto.ParseFromString(serialized_pb) + file_descriptor_proto.name = name + except DecodeError as e: + raise SerializationError(str(e)) + return file_descriptor_proto + + +def _init_pool(pool: DescriptorPool): + pool.AddSerializedFile(any_pb2.DESCRIPTOR.serialized_pb) + # source_context needed by api + pool.AddSerializedFile(source_context_pb2.DESCRIPTOR.serialized_pb) + # type needed by api + pool.AddSerializedFile(type_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(api_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(descriptor_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(empty_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(field_mask_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(wrappers_pb2.DESCRIPTOR.serialized_pb) + + pool.AddSerializedFile(calendar_period_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(color_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(date_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(datetime_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(dayofweek_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(expr_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(fraction_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(latlng_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(money_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(month_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(postal_address_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(quaternion_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(timeofday_pb2.DESCRIPTOR.serialized_pb) + + pool.AddSerializedFile(meta_pb2.DESCRIPTOR.serialized_pb) + pool.AddSerializedFile(decimal_pb2.DESCRIPTOR.serialized_pb) + + +def transform( + ctx: RuleContext, descriptor: Descriptor, message: Any, + field_transform: FieldTransform +) -> Any: + if message is None or descriptor is None: + return message + if isinstance(message, list): + return [transform(ctx, descriptor, item, field_transform) + for item in message] + if isinstance(message, dict): + return {key: transform(ctx, descriptor, value, field_transform) + for key, value in message.items()} + if isinstance(message, Message): + for fd in descriptor.fields: + _transform_field(ctx, fd, descriptor, message, field_transform) + return message + field_ctx = ctx.current_field() + if field_ctx is not None: + rule_tags = ctx.rule.tags + if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): + return field_transform(ctx, field_ctx, message) + return message + + +def _transform_field( + ctx: RuleContext, fd: FieldDescriptor, desc: Descriptor, + message: Message, field_transform: FieldTransform +): + try: + ctx.enter_field( + message, + fd.full_name, + fd.name, + get_type(fd), + get_inline_tags(fd) + ) + if fd.containing_oneof is not None and not message.HasField(fd.name): + return + value = getattr(message, fd.name) + if is_map_field(fd): + value = {key: value[key] for key in value} + elif fd.label == FieldDescriptor.LABEL_REPEATED: + value = [item for item in value] + new_value = transform(ctx, desc, value, field_transform) + if ctx.rule.kind == RuleKind.CONDITION: + if new_value is False: + raise RuleConditionError(ctx.rule) + else: + _set_field(fd, message, new_value) + finally: + ctx.exit_field() + + +def _set_field(fd: FieldDescriptor, message: Message, value: Any): + if isinstance(value, list): + message.ClearField(fd.name) + old_value = getattr(message, fd.name) + old_value.extend(value) + elif isinstance(value, dict): + message.ClearField(fd.name) + old_value = getattr(message, fd.name) + old_value.update(value) + else: + setattr(message, fd.name, value) + + +def get_type(fd: FieldDescriptor) -> FieldType: + if is_map_field(fd): + return FieldType.MAP + if fd.type == FieldDescriptor.TYPE_MESSAGE: + return FieldType.RECORD + if fd.type == FieldDescriptor.TYPE_ENUM: + return FieldType.ENUM + if fd.type == FieldDescriptor.TYPE_STRING: + return FieldType.STRING + if fd.type == FieldDescriptor.TYPE_BYTES: + return FieldType.BYTES + if fd.type in (FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_SINT32, + FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, + FieldDescriptor.TYPE_SFIXED32): + return FieldType.INT + if fd.type in (FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_SINT64, + FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, + FieldDescriptor.TYPE_SFIXED64): + return FieldType.LONG + if fd.type == FieldDescriptor.TYPE_FLOAT: + return FieldType.FLOAT + if fd.type == FieldDescriptor.TYPE_DOUBLE: + return FieldType.DOUBLE + if fd.type == FieldDescriptor.TYPE_BOOL: + return FieldType.BOOLEAN + return FieldType.NULL + + +def is_map_field(fd: FieldDescriptor): + return (fd.type == FieldDescriptor.TYPE_MESSAGE + and hasattr(fd.message_type, 'options') + and fd.message_type.options.map_entry) + + +def get_inline_tags(fd: FieldDescriptor) -> Set[str]: + meta = fd.GetOptions().Extensions[meta_pb2.field_meta] + if meta is None: + return set() + else: + return set(meta.tags) + + +def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: + for tag in tags1: + if tag in tags2: + return False + return True + + +def _is_builtin(name: str) -> bool: + return name.startswith('confluent/') or \ + name.startswith('google/protobuf/') or \ + name.startswith('google/type/') + + +def decimalToProtobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: + """ + Converts a Decimal to a Protobuf value. + + Args: + value (Decimal): The Decimal value to convert. + + Returns: + The Protobuf value. + """ + sign, digits, exp = value.as_tuple() + + delta = exp + scale + + if delta < 0: + raise ValueError( + "Scale provided does not match the decimal") + + unscaled_datum = 0 + for digit in digits: + unscaled_datum = (unscaled_datum * 10) + digit + + unscaled_datum = 10**delta * unscaled_datum + + bytes_req = (unscaled_datum.bit_length() + 8) // 8 + + if sign: + unscaled_datum = -unscaled_datum + + bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) + + result = decimal_pb2.Decimal() + result.value = bytes + result.precision = 0 + result.scale = scale + return result + + +decimal_context = Context() + + +def protobufToDecimal(value: decimal_pb2.Decimal) -> Decimal: + """ + Converts a Protobuf value to Decimal. + + Args: + value (decimal_pb2.Decimal): The Protobuf value to convert. + + Returns: + The Decimal value. + """ + unscaled_datum = int.from_bytes(value.value, byteorder="big", signed=True) + + if value.precision > 0: + decimal_context.prec = value.precision + else: + decimal_context.prec = MAX_PREC + return decimal_context.create_decimal(unscaled_datum).scaleb( + -value.scale, decimal_context + ) diff --git a/src/confluent_kafka/schema_registry/common/schema_registry_client.py b/src/confluent_kafka/schema_registry/common/schema_registry_client.py new file mode 100644 index 000000000..a9ab11756 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/schema_registry_client.py @@ -0,0 +1,933 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2020 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import random + +from attrs import define as _attrs_define +from attrs import field as _attrs_field +from collections import defaultdict +from enum import Enum +from threading import Lock +from typing import List, Dict, Type, TypeVar, \ + cast, Optional, Any + +__all__ = [ + 'VALID_AUTH_PROVIDERS', + '_BearerFieldProvider', + 'is_success', + 'is_retriable', + 'full_jitter', + '_StaticFieldProvider', + '_SchemaCache', + 'RuleKind', + 'RuleMode', + 'RuleParams', + 'Rule', + 'RuleSet', + 'MetadataTags', + 'MetadataProperties', + 'Metadata', + 'SchemaReference', + 'ConfigCompatibilityLevel', + 'ServerConfig', + 'Schema', + 'RegisteredSchema' +] + +VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] + + +class _BearerFieldProvider(metaclass=abc.ABCMeta): + @abc.abstractmethod + def get_bearer_fields(self) -> dict: + raise NotImplementedError + + +def is_success(status_code: int) -> bool: + return 200 <= status_code <= 299 + + +def is_retriable(status_code: int) -> bool: + return status_code in (408, 429, 500, 502, 503, 504) + + +def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: + no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) + return random.random() * min(no_jitter_delay, max_delay_ms) + + +class _StaticFieldProvider(_BearerFieldProvider): + def __init__(self, token: str, logical_cluster: str, identity_pool: str): + self.token = token + self.logical_cluster = logical_cluster + self.identity_pool = identity_pool + + def get_bearer_fields(self) -> dict: + return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, + 'bearer.auth.identity.pool.id': self.identity_pool} + + +class _SchemaCache(object): + """ + Thread-safe cache for use with the Schema Registry Client. + + This cache may be used to retrieve schema ids, schemas or to check + known subject membership. + """ + + def __init__(self): + self.lock = Lock() + self.schema_id_index = defaultdict(dict) + self.schema_index = defaultdict(dict) + self.rs_id_index = defaultdict(dict) + self.rs_version_index = defaultdict(dict) + self.rs_schema_index = defaultdict(dict) + + def set_schema(self, subject: str, schema_id: int, schema: 'Schema'): + """ + Add a Schema identified by schema_id to the cache. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): Schema's registration id + + schema (Schema): Schema instance + """ + + with self.lock: + self.schema_id_index[subject][schema_id] = schema + self.schema_index[subject][schema] = schema_id + + def set_registered_schema(self, schema: 'Schema', registered_schema: 'RegisteredSchema'): + """ + Add a RegisteredSchema to the cache. + + Args: + registered_schema (RegisteredSchema): RegisteredSchema instance + """ + + subject = registered_schema.subject + schema_id = registered_schema.schema_id + version = registered_schema.version + with self.lock: + self.schema_id_index[subject][schema_id] = schema + self.schema_index[subject][schema] = schema_id + self.rs_id_index[subject][schema_id] = registered_schema + self.rs_version_index[subject][version] = registered_schema + self.rs_schema_index[subject][schema] = registered_schema + + def get_schema_by_id(self, subject: str, schema_id: int) -> Optional['Schema']: + """ + Get the schema instance associated with schema id from the cache. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): Id used to identify a schema + + Returns: + Schema: The schema if known; else None + """ + + with self.lock: + return self.schema_id_index.get(subject, {}).get(schema_id, None) + + def get_id_by_schema(self, subject: str, schema: 'Schema') -> Optional[int]: + """ + Get the schema id associated with schema instance from the cache. + + Args: + subject (str): The subject this schema is associated with + + schema (Schema): The schema + + Returns: + int: The schema id if known; else None + """ + + with self.lock: + return self.schema_index.get(subject, {}).get(schema, None) + + def get_registered_by_subject_schema(self, subject: str, schema: 'Schema') -> Optional['RegisteredSchema']: + """ + Get the schema associated with this schema registered under subject. + + Args: + subject (str): The subject this schema is associated with + + schema (Schema): The schema associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_schema_index.get(subject, {}).get(schema, None) + + def get_registered_by_subject_id(self, subject: str, schema_id: int) -> Optional['RegisteredSchema']: + """ + Get the schema associated with this id registered under subject. + + Args: + subject (str): The subject this schema is associated with + + schema_id (int): The schema id associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_id_index.get(subject, {}).get(schema_id, None) + + def get_registered_by_subject_version(self, subject: str, version: int) -> Optional['RegisteredSchema']: + """ + Get the schema associated with this version registered under subject. + + Args: + subject (str): The subject this schema is associated with + + version (int): The version associated with this schema + + Returns: + RegisteredSchema: The registered schema if known; else None + """ + + with self.lock: + return self.rs_version_index.get(subject, {}).get(version, None) + + def remove_by_subject(self, subject: str): + """ + Remove schemas with the given subject. + + Args: + subject (str): The subject + """ + + with self.lock: + if subject in self.schema_id_index: + del self.schema_id_index[subject] + if subject in self.schema_index: + del self.schema_index[subject] + if subject in self.rs_id_index: + del self.rs_id_index[subject] + if subject in self.rs_version_index: + del self.rs_version_index[subject] + if subject in self.rs_schema_index: + del self.rs_schema_index[subject] + + def remove_by_subject_version(self, subject: str, version: int): + """ + Remove schemas with the given subject. + + Args: + subject (str): The subject + + version (int) The version + """ + + with self.lock: + if subject in self.rs_id_index: + for schema_id, registered_schema in self.rs_id_index[subject].items(): + if registered_schema.version == version: + del self.rs_schema_index[subject][schema_id] + if subject in self.rs_schema_index: + for schema, registered_schema in self.rs_schema_index[subject].items(): + if registered_schema.version == version: + del self.rs_schema_index[subject][schema] + rs = None + if subject in self.rs_version_index: + if version in self.rs_version_index[subject]: + rs = self.rs_version_index[subject][version] + del self.rs_version_index[subject][version] + if rs is not None: + if subject in self.schema_id_index: + if rs.schema_id in self.schema_id_index[subject]: + del self.schema_id_index[subject][rs.schema_id] + if rs.schema in self.schema_index[subject]: + del self.schema_index[subject][rs.schema] + + def clear(self): + """ + Clear the cache. + """ + + with self.lock: + self.schema_id_index.clear() + self.schema_index.clear() + self.rs_id_index.clear() + self.rs_version_index.clear() + self.rs_schema_index.clear() + + +T = TypeVar("T") + + +class RuleKind(str, Enum): + CONDITION = "CONDITION" + TRANSFORM = "TRANSFORM" + + def __str__(self) -> str: + return str(self.value) + + +class RuleMode(str, Enum): + UPGRADE = "UPGRADE" + DOWNGRADE = "DOWNGRADE" + UPDOWN = "UPDOWN" + READ = "READ" + WRITE = "WRITE" + WRITEREAD = "WRITEREAD" + + def __str__(self) -> str: + return str(self.value) + + +@_attrs_define +class RuleParams: + params: Dict[str, str] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + field_dict.update(self.params) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + rule_params = cls(params=d) + + return rule_params + + def __hash__(self): + return hash(frozenset(self.params.items())) + + +@_attrs_define(frozen=True) +class Rule: + name: Optional[str] + doc: Optional[str] + kind: Optional[RuleKind] + mode: Optional[RuleMode] + type: Optional[str] + tags: Optional[List[str]] = _attrs_field(hash=False) + params: Optional[RuleParams] + expr: Optional[str] + on_success: Optional[str] + on_failure: Optional[str] + disabled: Optional[bool] + + def to_dict(self) -> Dict[str, Any]: + name = self.name + + doc = self.doc + + kind_str: Optional[str] = None + if self.kind is not None: + kind_str = self.kind.value + + mode_str: Optional[str] = None + if self.mode is not None: + mode_str = self.mode.value + + rule_type = self.type + + tags = self.tags + + _params: Optional[Dict[str, Any]] = None + if self.params is not None: + _params = self.params.to_dict() + + expr = self.expr + + on_success = self.on_success + + on_failure = self.on_failure + + disabled = self.disabled + + field_dict: Dict[str, Any] = {} + field_dict.update({}) + if name is not None: + field_dict["name"] = name + if doc is not None: + field_dict["doc"] = doc + if kind_str is not None: + field_dict["kind"] = kind_str + if mode_str is not None: + field_dict["mode"] = mode_str + if type is not None: + field_dict["type"] = rule_type + if tags is not None: + field_dict["tags"] = tags + if _params is not None: + field_dict["params"] = _params + if expr is not None: + field_dict["expr"] = expr + if on_success is not None: + field_dict["onSuccess"] = on_success + if on_failure is not None: + field_dict["onFailure"] = on_failure + if disabled is not None: + field_dict["disabled"] = disabled + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + name = d.pop("name", None) + + doc = d.pop("doc", None) + + _kind = d.pop("kind", None) + kind: Optional[RuleKind] = None + if _kind is not None: + kind = RuleKind(_kind) + + _mode = d.pop("mode", None) + mode: Optional[RuleMode] = None + if _mode is not None: + mode = RuleMode(_mode) + + rule_type = d.pop("type", None) + + tags = cast(List[str], d.pop("tags", None)) + + _params: Optional[Dict[str, Any]] = d.pop("params", None) + params: Optional[RuleParams] = None + if _params is not None: + params = RuleParams.from_dict(_params) + + expr = d.pop("expr", None) + + on_success = d.pop("onSuccess", None) + + on_failure = d.pop("onFailure", None) + + disabled = d.pop("disabled", None) + + rule = cls( + name=name, + doc=doc, + kind=kind, + mode=mode, + type=rule_type, + tags=tags, + params=params, + expr=expr, + on_success=on_success, + on_failure=on_failure, + disabled=disabled, + ) + + return rule + + +@_attrs_define +class RuleSet: + migration_rules: Optional[List["Rule"]] = _attrs_field(hash=False) + domain_rules: Optional[List["Rule"]] = _attrs_field(hash=False) + + def to_dict(self) -> Dict[str, Any]: + _migration_rules: Optional[List[Dict[str, Any]]] = None + if self.migration_rules is not None: + _migration_rules = [] + for migration_rules_item_data in self.migration_rules: + migration_rules_item = migration_rules_item_data.to_dict() + _migration_rules.append(migration_rules_item) + + _domain_rules: Optional[List[Dict[str, Any]]] = None + if self.domain_rules is not None: + _domain_rules = [] + for domain_rules_item_data in self.domain_rules: + domain_rules_item = domain_rules_item_data.to_dict() + _domain_rules.append(domain_rules_item) + + field_dict: Dict[str, Any] = {} + field_dict.update({}) + if _migration_rules is not None: + field_dict["migrationRules"] = _migration_rules + if _domain_rules is not None: + field_dict["domainRules"] = _domain_rules + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + migration_rules = [] + _migration_rules = d.pop("migrationRules", None) + for migration_rules_item_data in _migration_rules or []: + migration_rules_item = Rule.from_dict(migration_rules_item_data) + migration_rules.append(migration_rules_item) + + domain_rules = [] + _domain_rules = d.pop("domainRules", None) + for domain_rules_item_data in _domain_rules or []: + domain_rules_item = Rule.from_dict(domain_rules_item_data) + domain_rules.append(domain_rules_item) + + rule_set = cls( + migration_rules=migration_rules, + domain_rules=domain_rules, + ) + + return rule_set + + def __hash__(self): + return hash(frozenset((self.migration_rules or []) + (self.domain_rules or []))) + + +@_attrs_define +class MetadataTags: + tags: Dict[str, List[str]] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + for prop_name, prop in self.tags.items(): + field_dict[prop_name] = prop + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + tags = {} + for prop_name, prop_dict in d.items(): + tag = cast(List[str], prop_dict) + + tags[prop_name] = tag + + metadata_tags = cls(tags=tags) + + return metadata_tags + + def __hash__(self): + return hash(frozenset(self.tags.items())) + + +@_attrs_define +class MetadataProperties: + properties: Dict[str, str] = _attrs_field(factory=dict, hash=False) + + def to_dict(self) -> Dict[str, Any]: + field_dict: Dict[str, Any] = {} + field_dict.update(self.properties) + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + metadata_properties = cls(properties=d) + + return metadata_properties + + def __hash__(self): + return hash(frozenset(self.properties.items())) + + +@_attrs_define(frozen=True) +class Metadata: + tags: Optional[MetadataTags] + properties: Optional[MetadataProperties] + sensitive: Optional[List[str]] = _attrs_field(hash=False) + + def to_dict(self) -> Dict[str, Any]: + _tags: Optional[Dict[str, Any]] = None + if self.tags is not None: + _tags = self.tags.to_dict() + + _properties: Optional[Dict[str, Any]] = None + if self.properties is not None: + _properties = self.properties.to_dict() + + sensitive: Optional[List[str]] = None + if self.sensitive is not None: + sensitive = [] + for sensitive_item in self.sensitive: + sensitive.append(sensitive_item) + + field_dict: Dict[str, Any] = {} + if _tags is not None: + field_dict["tags"] = _tags + if _properties is not None: + field_dict["properties"] = _properties + if sensitive is not None: + field_dict["sensitive"] = sensitive + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + _tags: Optional[Dict[str, Any]] = d.pop("tags", None) + tags: Optional[MetadataTags] = None + if _tags is not None: + tags = MetadataTags.from_dict(_tags) + + _properties: Optional[Dict[str, Any]] = d.pop("properties", None) + properties: Optional[MetadataProperties] = None + if _properties is not None: + properties = MetadataProperties.from_dict(_properties) + + sensitive = [] + _sensitive = d.pop("sensitive", None) + for sensitive_item in _sensitive or []: + sensitive.append(sensitive_item) + + metadata = cls( + tags=tags, + properties=properties, + sensitive=sensitive, + ) + + return metadata + + +@_attrs_define(frozen=True) +class SchemaReference: + name: Optional[str] + subject: Optional[str] + version: Optional[int] + + def to_dict(self) -> Dict[str, Any]: + name = self.name + + subject = self.subject + + version = self.version + + field_dict: Dict[str, Any] = {} + if name is not None: + field_dict["name"] = name + if subject is not None: + field_dict["subject"] = subject + if version is not None: + field_dict["version"] = version + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + name = d.pop("name", None) + + subject = d.pop("subject", None) + + version = d.pop("version", None) + + schema_reference = cls( + name=name, + subject=subject, + version=version, + ) + + return schema_reference + + +class ConfigCompatibilityLevel(str, Enum): + BACKWARD = "BACKWARD" + BACKWARD_TRANSITIVE = "BACKWARD_TRANSITIVE" + FORWARD = "FORWARD" + FORWARD_TRANSITIVE = "FORWARD_TRANSITIVE" + FULL = "FULL" + FULL_TRANSITIVE = "FULL_TRANSITIVE" + NONE = "NONE" + + def __str__(self) -> str: + return str(self.value) + + +@_attrs_define +class ServerConfig: + compatibility: Optional[ConfigCompatibilityLevel] = None + compatibility_level: Optional[ConfigCompatibilityLevel] = None + compatibility_group: Optional[str] = None + default_metadata: Optional[Metadata] = None + override_metadata: Optional[Metadata] = None + default_rule_set: Optional[RuleSet] = None + override_rule_set: Optional[RuleSet] = None + + def to_dict(self) -> Dict[str, Any]: + _compatibility: Optional[str] = None + if self.compatibility is not None: + _compatibility = self.compatibility.value + + _compatibility_level: Optional[str] = None + if self.compatibility_level is not None: + _compatibility_level = self.compatibility_level.value + + compatibility_group = self.compatibility_group + + _default_metadata: Optional[Dict[str, Any]] + if isinstance(self.default_metadata, Metadata): + _default_metadata = self.default_metadata.to_dict() + else: + _default_metadata = self.default_metadata + + _override_metadata: Optional[Dict[str, Any]] + if isinstance(self.override_metadata, Metadata): + _override_metadata = self.override_metadata.to_dict() + else: + _override_metadata = self.override_metadata + + _default_rule_set: Optional[Dict[str, Any]] + if isinstance(self.default_rule_set, RuleSet): + _default_rule_set = self.default_rule_set.to_dict() + else: + _default_rule_set = self.default_rule_set + + _override_rule_set: Optional[Dict[str, Any]] + if isinstance(self.override_rule_set, RuleSet): + _override_rule_set = self.override_rule_set.to_dict() + else: + _override_rule_set = self.override_rule_set + + field_dict: Dict[str, Any] = {} + if _compatibility is not None: + field_dict["compatibility"] = _compatibility + if _compatibility_level is not None: + field_dict["compatibilityLevel"] = _compatibility_level + if compatibility_group is not None: + field_dict["compatibilityGroup"] = compatibility_group + if _default_metadata is not None: + field_dict["defaultMetadata"] = _default_metadata + if _override_metadata is not None: + field_dict["overrideMetadata"] = _override_metadata + if _default_rule_set is not None: + field_dict["defaultRuleSet"] = _default_rule_set + if _override_rule_set is not None: + field_dict["overrideRuleSet"] = _override_rule_set + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + _compatibility = d.pop("compatibility", None) + compatibility: Optional[ConfigCompatibilityLevel] + if _compatibility is None: + compatibility = None + else: + compatibility = ConfigCompatibilityLevel(_compatibility) + + _compatibility_level = d.pop("compatibilityLevel", None) + compatibility_level: Optional[ConfigCompatibilityLevel] + if _compatibility_level is None: + compatibility_level = None + else: + compatibility_level = ConfigCompatibilityLevel(_compatibility_level) + + compatibility_group = d.pop("compatibilityGroup", None) + + def _parse_default_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + default_metadata = _parse_default_metadata(d.pop("defaultMetadata", None)) + + def _parse_override_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + override_metadata = _parse_override_metadata(d.pop("overrideMetadata", None)) + + def _parse_default_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + default_rule_set = _parse_default_rule_set(d.pop("defaultRuleSet", None)) + + def _parse_override_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + override_rule_set = _parse_override_rule_set(d.pop("overrideRuleSet", None)) + + config = cls( + compatibility=compatibility, + compatibility_level=compatibility_level, + compatibility_group=compatibility_group, + default_metadata=default_metadata, + override_metadata=override_metadata, + default_rule_set=default_rule_set, + override_rule_set=override_rule_set, + ) + + return config + + +@_attrs_define(frozen=True, cache_hash=True) +class Schema: + """ + An unregistered schema. + """ + + schema_str: Optional[str] + schema_type: Optional[str] = "AVRO" + references: Optional[List[SchemaReference]] = _attrs_field(factory=list, hash=False) + metadata: Optional[Metadata] = None + rule_set: Optional[RuleSet] = None + + def to_dict(self) -> Dict[str, Any]: + schema = self.schema_str + + schema_type = self.schema_type + + _references: Optional[List[Dict[str, Any]]] = [] + if self.references is not None: + for references_item_data in self.references: + references_item = references_item_data.to_dict() + _references.append(references_item) + + _metadata: Optional[Dict[str, Any]] = None + if isinstance(self.metadata, Metadata): + _metadata = self.metadata.to_dict() + + _rule_set: Optional[Dict[str, Any]] = None + if isinstance(self.rule_set, RuleSet): + _rule_set = self.rule_set.to_dict() + + field_dict: Dict[str, Any] = {} + if schema is not None: + field_dict["schema"] = schema + if schema_type is not None: + field_dict["schemaType"] = schema_type + if _references is not None: + field_dict["references"] = _references + if _metadata is not None: + field_dict["metadata"] = _metadata + if _rule_set is not None: + field_dict["ruleSet"] = _rule_set + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + schema = d.pop("schema", None) + + schema_type = d.pop("schemaType", "AVRO") + + references = [] + _references = d.pop("references", None) + for references_item_data in _references or []: + references_item = SchemaReference.from_dict(references_item_data) + + references.append(references_item) + + def _parse_metadata(data: object) -> Optional[Metadata]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return Metadata.from_dict(data) + + metadata = _parse_metadata(d.pop("metadata", None)) + + def _parse_rule_set(data: object) -> Optional[RuleSet]: + if data is None: + return data + if not isinstance(data, dict): + raise TypeError() + return RuleSet.from_dict(data) + + rule_set = _parse_rule_set(d.pop("ruleSet", None)) + + schema = cls( + schema_str=schema, + schema_type=schema_type, + references=references, + metadata=metadata, + rule_set=rule_set, + ) + + return schema + + +@_attrs_define(frozen=True, cache_hash=True) +class RegisteredSchema: + """ + An registered schema. + """ + + schema_id: Optional[int] + schema: Optional[Schema] + subject: Optional[str] + version: Optional[int] + + def to_dict(self) -> Dict[str, Any]: + schema = self.schema + + schema_id = self.schema_id + + subject = self.subject + + version = self.version + + field_dict: Dict[str, Any] = {} + if schema is not None: + field_dict = schema.to_dict() + if schema_id is not None: + field_dict["id"] = schema_id + if subject is not None: + field_dict["subject"] = subject + if version is not None: + field_dict["version"] = version + + return field_dict + + @classmethod + def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: + d = src_dict.copy() + + schema = Schema.from_dict(d) + + schema_id = d.pop("id", None) + + subject = d.pop("subject", None) + + version = d.pop("version", None) + + schema = cls( + schema_id=schema_id, + schema=schema, + subject=subject, + version=version, + ) + + return schema diff --git a/src/confluent_kafka/schema_registry/common/serde.py b/src/confluent_kafka/schema_registry/common/serde.py new file mode 100644 index 000000000..4244415c1 --- /dev/null +++ b/src/confluent_kafka/schema_registry/common/serde.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +import logging +from enum import Enum +from threading import Lock +from typing import Callable, List, Optional, Set, Dict, Any, TypeVar + +from confluent_kafka.schema_registry import RegisteredSchema +from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ + Rule, RuleKind, Schema +from confluent_kafka.schema_registry.wildcard_matcher import wildcard_match +from confluent_kafka.serialization import SerializationContext, SerializationError + + +__all__ = [ + 'FieldType', + 'FieldContext', + 'RuleContext', + 'FieldTransform', + 'FieldTransformer', + 'RuleBase', + 'RuleExecutor', + 'FieldRuleExecutor', + 'RuleAction', + 'ErrorAction', + 'NoneAction', + 'RuleError', + 'RuleConditionError', + 'Migration', + 'ParsedSchemaCache', +] + +log = logging.getLogger(__name__) + + +class FieldType(str, Enum): + RECORD = "RECORD" + ENUM = "ENUM" + ARRAY = "ARRAY" + MAP = "MAP" + COMBINED = "COMBINED" + FIXED = "FIXED" + STRING = "STRING" + BYTES = "BYTES" + INT = "INT" + LONG = "LONG" + FLOAT = "FLOAT" + DOUBLE = "DOUBLE" + BOOLEAN = "BOOLEAN" + NULL = "NULL" + + +class FieldContext(object): + __slots__ = ['containing_message', 'full_name', 'name', 'field_type', 'tags'] + + def __init__( + self, containing_message: Any, full_name: str, name: str, + field_type: FieldType, tags: Set[str] + ): + self.containing_message = containing_message + self.full_name = full_name + self.name = name + self.field_type = field_type + self.tags = tags + + def is_primitive(self) -> bool: + return self.field_type in (FieldType.INT, FieldType.LONG, FieldType.FLOAT, + FieldType.DOUBLE, FieldType.BOOLEAN, FieldType.NULL, + FieldType.STRING, FieldType.BYTES) + + def type_name(self) -> str: + return self.field_type.name + + +class RuleContext(object): + __slots__ = ['ser_ctx', 'source', 'target', 'subject', 'rule_mode', 'rule', + 'index', 'rules', 'inline_tags', 'field_transformer', '_field_contexts'] + + def __init__( + self, ser_ctx: SerializationContext, source: Optional[Schema], + target: Optional[Schema], subject: str, rule_mode: RuleMode, rule: Rule, + index: int, rules: List[Rule], inline_tags: Optional[Dict[str, Set[str]]], field_transformer + ): + self.ser_ctx = ser_ctx + self.source = source + self.target = target + self.subject = subject + self.rule_mode = rule_mode + self.rule = rule + self.index = index + self.rules = rules + self.inline_tags = inline_tags + self.field_transformer = field_transformer + self._field_contexts: List[FieldContext] = [] + + def get_parameter(self, name: str) -> Optional[str]: + params = self.rule.params + if params is not None: + value = params.params.get(name) + if value is not None: + return value + if (self.target is not None + and self.target.metadata is not None + and self.target.metadata.properties is not None): + value = self.target.metadata.properties.properties.get(name) + if value is not None: + return value + return None + + def _get_inline_tags(self, name: str) -> Set[str]: + if self.inline_tags is None: + return set() + return self.inline_tags.get(name, set()) + + def current_field(self) -> Optional[FieldContext]: + if not self._field_contexts: + return None + return self._field_contexts[-1] + + def enter_field( + self, containing_message: Any, full_name: str, name: str, + field_type: FieldType, tags: Optional[Set[str]] + ) -> FieldContext: + all_tags = set(tags if tags is not None else self._get_inline_tags(full_name)) + all_tags.update(self.get_tags(full_name)) + field_context = FieldContext(containing_message, full_name, name, field_type, all_tags) + self._field_contexts.append(field_context) + return field_context + + def get_tags(self, full_name: str) -> Set[str]: + result = set() + if (self.target is not None + and self.target.metadata is not None + and self.target.metadata.tags is not None): + tags = self.target.metadata.tags.tags + for k, v in tags.items(): + if wildcard_match(full_name, k): + result.update(v) + return result + + def exit_field(self): + if self._field_contexts: + self._field_contexts.pop() + + +FieldTransform = Callable[[RuleContext, FieldContext, Any], Any] + + +FieldTransformer = Callable[[RuleContext, FieldTransform, Any], Any] + + +class RuleBase(metaclass=abc.ABCMeta): + def configure(self, client_conf: dict, rule_conf: dict): + pass + + @abc.abstractmethod + def type(self) -> str: + raise NotImplementedError() + + def close(self): + pass + + +class RuleExecutor(RuleBase): + @abc.abstractmethod + def transform(self, ctx: RuleContext, message: Any) -> Any: + raise NotImplementedError() + + +class FieldRuleExecutor(RuleExecutor): + @abc.abstractmethod + def new_transform(self, ctx: RuleContext) -> FieldTransform: + raise NotImplementedError() + + def transform(self, ctx: RuleContext, message: Any) -> Any: + # TODO preserve source + if ctx.rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): + for i in range(ctx.index): + other_rule = ctx.rules[i] + if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): + # ignore this transform if an earlier one has the same tag + return message + elif ctx.rule_mode == RuleMode.READ or ctx.rule_mode == RuleMode.DOWNGRADE: + for i in range(ctx.index + 1, len(ctx.rules)): + other_rule = ctx.rules[i] + if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): + # ignore this transform if a later one has the same tag + return message + return ctx.field_transformer(ctx, self.new_transform(ctx), message) + + @staticmethod + def are_transforms_with_same_tag(rule1: Rule, rule2: Rule) -> bool: + return (bool(rule1.tags) + and rule1.kind == RuleKind.TRANSFORM + and rule1.kind == rule2.kind + and rule1.mode == rule2.mode + and rule1.type == rule2.type + and rule1.tags == rule2.tags) + + +class RuleAction(RuleBase): + @abc.abstractmethod + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + raise NotImplementedError() + + +class ErrorAction(RuleAction): + def type(self) -> str: + return 'ERROR' + + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + if ex is None: + raise SerializationError() + else: + raise SerializationError() from ex + + +class NoneAction(RuleAction): + def type(self) -> str: + return 'NONE' + + def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): + pass + + +class RuleError(Exception): + pass + + +class RuleConditionError(RuleError): + def __init__(self, rule: Rule): + super().__init__(RuleConditionError.error_message(rule)) + + @staticmethod + def error_message(rule: Rule) -> str: + if rule.doc: + return rule.doc + elif rule.expr: + return f"Rule expr failed: {rule.expr}" + else: + return f"Rule failed: {rule.name}" + + +class Migration(object): + __slots__ = ['rule_mode', 'source', 'target'] + + def __init__( + self, rule_mode: RuleMode, source: Optional[RegisteredSchema], + target: Optional[RegisteredSchema] + ): + self.rule_mode = rule_mode + self.source = source + self.target = target + + +T = TypeVar("T") + + +class ParsedSchemaCache(object): + """ + Thread-safe cache for parsed schemas + """ + + def __init__(self): + self.lock = Lock() + self.parsed_schemas = {} + + def set(self, schema: Schema, parsed_schema: T): + """ + Add a Schema identified by schema_id to the cache. + + Args: + schema (Schema): The schema + + parsed_schema (Any): The parsed schema + """ + + with self.lock: + self.parsed_schemas[schema] = parsed_schema + + def get_parsed_schema(self, schema: Schema) -> Optional[T]: + """ + Get the parsed schema associated with the schema + + Args: + schema (Schema): The schema + + Returns: + The parsed schema if known; else None + """ + + with self.lock: + return self.parsed_schemas.get(schema, None) + + def clear(self): + """ + Clear the cache. + """ + + with self.lock: + self.parsed_schemas.clear() diff --git a/src/confluent_kafka/schema_registry/json_schema.py b/src/confluent_kafka/schema_registry/json_schema.py index 157d0dd7f..2371c618e 100644 --- a/src/confluent_kafka/schema_registry/json_schema.py +++ b/src/confluent_kafka/schema_registry/json_schema.py @@ -15,792 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import decimal -from io import BytesIO - -import json -import struct -from typing import Union, Optional, List, Set, Tuple, Callable - -import httpx -import referencing -from cachetools import LRUCache -from jsonschema import validate, ValidationError -from jsonschema.protocols import Validator -from jsonschema.validators import validator_for -from referencing import Registry, Resource -from referencing._core import Resolver - -from confluent_kafka.schema_registry import (_MAGIC_BYTE, - Schema, - topic_subject_name_strategy, - RuleKind, - RuleMode, SchemaRegistryClient) -from confluent_kafka.schema_registry.rule_registry import RuleRegistry -from confluent_kafka.schema_registry.serde import BaseSerializer, \ - BaseDeserializer, RuleContext, FieldTransform, FieldType, \ - RuleConditionError, ParsedSchemaCache -from confluent_kafka.serialization import (SerializationError, - SerializationContext) - - -JsonMessage = Union[ - None, # 'null' Avro type - str, # 'string' and 'enum' - float, # 'float' and 'double' - int, # 'int' and 'long' - decimal.Decimal, # 'fixed' - bool, # 'boolean' - list, # 'array' - dict, # 'map' and 'record' -] - -JsonSchema = Union[bool, dict] - -DEFAULT_SPEC = referencing.jsonschema.DRAFT7 - - -class _ContextStringIO(BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _retrieve_via_httpx(uri: str): - response = httpx.get(uri) - return Resource.from_contents( - response.json(), default_specification=DEFAULT_SPEC) - - -def _resolve_named_schema( - schema: Schema, schema_registry_client: SchemaRegistryClient, - ref_registry: Optional[Registry] = None -) -> Registry: - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :param ref_registry: Registry of named schemas resolved recursively. - :return: Registry - """ - if ref_registry is None: - # Retrieve external schemas for backward compatibility - ref_registry = Registry(retrieve=_retrieve_via_httpx) - if schema.references is not None: - for ref in schema.references: - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True) - ref_registry = _resolve_named_schema(referenced_schema.schema, schema_registry_client, ref_registry) - referenced_schema_dict = json.loads(referenced_schema.schema.schema_str) - resource = Resource.from_contents( - referenced_schema_dict, default_specification=DEFAULT_SPEC) - ref_registry = ref_registry.with_resource(ref.name, resource) - return ref_registry - - -class JSONSerializer(BaseSerializer): - """ - Serializer that outputs JSON encoded data with Confluent Schema Registry framing. - - Configuration properties: - - +-----------------------------+----------+----------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+====================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - | | | | - | | | Raises SchemaRegistryError if the schema was not | - | | | registered against the subject, or could not be | - | | | successfully registered. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to validate the payload against the | - | ``validate`` | bool | the given schema. | - | | | | - +-----------------------------+----------+----------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies: - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Notes: - The ``title`` annotation, referred to elsewhere as a record name - is not strictly required by the JSON Schema specification. It is - however required by this serializer in order to register the schema - with Confluent Schema Registry. - - Prior to serialization, all objects must first be converted to - a dict instance. This may be handled manually prior to calling - :py:func:`Producer.produce()` or by registering a `to_dict` - callable with JSONSerializer. - - Args: - schema_str (str, Schema): - `JSON Schema definition. `_ - Accepts schema as either a string or a :py:class:`Schema` instance. - Note that string definitions cannot reference other schemas. For - referencing other schemas, use a :py:class:`Schema` instance. - - schema_registry_client (SchemaRegistryClient): Schema Registry - client instance. - - to_dict (callable, optional): Callable(object, SerializationContext) -> dict. - Converts object to a dict. - - conf (dict): JsonSerializer configuration. - """ # noqa: E501 - __slots__ = ['_known_subjects', '_parsed_schema', '_ref_registry', - '_schema', '_schema_id', '_schema_name', '_to_dict', - '_parsed_schemas', '_validators', '_validate', '_json_encode'] - - _default_conf = {'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'validate': True} - - def __init__( - self, - schema_str: Union[str, Schema, None], - schema_registry_client: SchemaRegistryClient, - to_dict: Optional[Callable[[object, SerializationContext], dict]] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None, - json_encode: Optional[Callable] = None, - ): - super().__init__() - if isinstance(schema_str, str): - self._schema = Schema(schema_str, schema_type="JSON") - elif isinstance(schema_str, Schema): - self._schema = schema_str - else: - self._schema = None - - self._json_encode = json_encode or json.dumps - self._registry = schema_registry_client - self._rule_registry = ( - rule_registry if rule_registry else RuleRegistry.get_global_instance() - ) - self._schema_id = None - self._known_subjects = set() - self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) - - if to_dict is not None and not callable(to_dict): - raise ValueError("to_dict must be callable with the signature " - "to_dict(object, SerializationContext)->dict") - - self._to_dict = to_dict - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._validate = conf_copy.pop('validate') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("validate must be a boolean value") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - schema_dict, ref_registry = self._get_parsed_schema(self._schema) - if schema_dict: - schema_name = schema_dict.get('title', None) - else: - schema_name = None - - self._schema_name = schema_name - self._parsed_schema = schema_dict - self._ref_registry = ref_registry - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, obj: object, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an object to JSON, prepending it with Confluent Schema Registry - framing. - - Args: - obj (object): The object instance to serialize. - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Raises: - SerializerError if any error occurs serializing obj. - - Returns: - bytes: None if obj is None, else a byte array containing the JSON - serialized data with Confluent Schema Registry framing. - """ - - if obj is None: - return None - - subject = self._subject_name_func(ctx, self._schema_name) - latest_schema = self._get_reader_schema(subject) - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects: - # Check to ensure this schema has been registered under subject_name. - if self._auto_register: - # The schema name will always be the same. We can't however register - # a schema without a subject so we set the schema_id here to handle - # the initial registration. - self._schema_id = self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) - else: - registered_schema = self._registry.lookup_schema(subject, - self._schema, - self._normalize_schemas) - self._schema_id = registered_schema.schema_id - - self._known_subjects.add(subject) - - if self._to_dict is not None: - value = self._to_dict(obj, ctx) - else: - value = obj - - if latest_schema is not None: - schema = latest_schema.schema - parsed_schema, ref_registry = self._get_parsed_schema(latest_schema.schema) - root_resource = Resource.from_contents( - parsed_schema, default_specification=DEFAULT_SPEC) - ref_resolver = ref_registry.resolver_with_root(root_resource) - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, parsed_schema, ref_registry, ref_resolver, "$", msg, field_transform)) - value = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, value, None, - field_transformer) - else: - schema = self._schema - parsed_schema, ref_registry = self._parsed_schema, self._ref_registry - - if self._validate: - try: - validator = self._get_validator(schema, parsed_schema, ref_registry) - validator.validate(value) - except ValidationError as ve: - raise SerializationError(ve.message) - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order (big endian) - fo.write(struct.pack(">bI", _MAGIC_BYTE, self._schema_id)) - # JSON dump always writes a str never bytes - # https://docs.python.org/3/library/json.html - encoded_value = self._json_encode(value) - if isinstance(encoded_value, str): - encoded_value = encoded_value.encode("utf8") - fo.write(encoded_value) - - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: - if schema is None: - return None, None - - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - ref_registry = _resolve_named_schema(schema, self._registry) - parsed_schema = json.loads(schema.schema_str) - - self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) - return parsed_schema, ref_registry - - def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: - validator = self._validators.get(schema, None) - if validator is not None: - return validator - - cls = validator_for(parsed_schema) - cls.check_schema(parsed_schema) - validator = cls(parsed_schema, registry=registry) - - self._validators[schema] = validator - return validator - - -class JSONDeserializer(BaseDeserializer): - """ - Deserializer for JSON encoded data with Confluent Schema Registry - framing. - - Configuration properties: - - +-----------------------------+----------+----------------------------------------------------+ - | Property Name | Type | Description | - +=============================+==========+====================================================+ - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata``| dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-----------------------------+----------+----------------------------------------------------+ - | | | Whether to validate the payload against the | - | ``validate`` | bool | the given schema. | - | | | | - +-----------------------------+----------+----------------------------------------------------+ - - Args: - schema_str (str, Schema, optional): - `JSON schema definition `_ - Accepts schema as either a string or a :py:class:`Schema` instance. - Note that string definitions cannot reference other schemas. For referencing other schemas, - use a :py:class:`Schema` instance. If not provided, schemas will be - retrieved from schema_registry_client based on the schema ID in the - wire header of each message. - - from_dict (callable, optional): Callable(dict, SerializationContext) -> object. - Converts a dict to a Python object instance. - - schema_registry_client (SchemaRegistryClient, optional): Schema Registry client instance. Needed if ``schema_str`` is a schema referencing other schemas or is not provided. - """ # noqa: E501 - - __slots__ = ['_reader_schema', '_ref_registry', '_from_dict', '_schema', - '_parsed_schemas', '_validators', '_validate', '_json_decode'] - - _default_conf = {'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'validate': True} - - def __init__( - self, - schema_str: Union[str, Schema, None], - from_dict: Optional[Callable[[dict, SerializationContext], object]] = None, - schema_registry_client: Optional[SchemaRegistryClient] = None, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None, - json_decode: Optional[Callable] = None, - ): - super().__init__() - if isinstance(schema_str, str): - schema = Schema(schema_str, schema_type="JSON") - elif isinstance(schema_str, Schema): - schema = schema_str - if bool(schema.references) and schema_registry_client is None: - raise ValueError( - """schema_registry_client must be provided if "schema_str" is a Schema instance with references""") - elif schema_str is None: - if schema_registry_client is None: - raise ValueError( - """schema_registry_client must be provided if "schema_str" is not provided""" - ) - schema = schema_str - else: - raise TypeError('You must pass either str or Schema') - - self._schema = schema - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._validators = LRUCache(1000) - self._json_decode = json_decode or json.loads - self._use_schema_id = None - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._validate = conf_copy.pop('validate') - if not isinstance(self._validate, bool): - raise ValueError("validate must be a boolean value") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - if schema: - self._reader_schema, self._ref_registry = self._get_parsed_schema(self._schema) - else: - self._reader_schema, self._ref_registry = None, None - - if from_dict is not None and not callable(from_dict): - raise ValueError("from_dict must be callable with the signature" - " from_dict(dict, SerializationContext) -> object") - - self._from_dict = from_dict - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Union[dict, object, None]: - """ - Deserialize a JSON encoded record with Confluent Schema Registry framing to - a dict, or object instance according to from_dict if from_dict is specified. - - Args: - data (bytes): A JSON serialized record with Confluent Schema Registry framing. - - ctx (SerializationContext): Metadata relevant to the serialization operation. - - Returns: - A dict, or object instance according to from_dict if from_dict is specified. - - Raises: - SerializerError: If there was an error reading the Confluent framing data, or - if ``data`` was not successfully validated with the configured schema. - """ - - if data is None: - return None - - if len(data) <= 5: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) - latest_schema = None - if subject is not None and self._registry is not None: - latest_schema = self._get_reader_schema(subject) - - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unexpected magic byte {}. This message " - "was not produced with a Confluent " - "Schema Registry serializer".format(magic)) - - # JSON documents are self-describing; no need to query schema - obj_dict = self._json_decode(payload.read()) - - if self._registry is not None: - writer_schema_raw = self._registry.get_schema(schema_id) - writer_schema, writer_ref_registry = self._get_parsed_schema(writer_schema_raw) - if subject is None: - subject = self._subject_name_func(ctx, writer_schema.get("title")) - if subject is not None: - latest_schema = self._get_reader_schema(subject) - else: - writer_schema_raw = None - writer_schema, writer_ref_registry = None, None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - reader_schema, reader_ref_registry = self._get_parsed_schema(latest_schema.schema) - elif self._schema is not None: - migrations = None - reader_schema_raw = self._schema - reader_schema, reader_ref_registry = self._reader_schema, self._ref_registry - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema, reader_ref_registry = writer_schema, writer_ref_registry - - if migrations: - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - - reader_root_resource = Resource.from_contents( - reader_schema, default_specification=DEFAULT_SPEC) - reader_ref_resolver = reader_ref_registry.resolver_with_root(reader_root_resource) - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_schema, reader_ref_registry, reader_ref_resolver, - "$", message, field_transform)) - obj_dict = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, obj_dict, None, - field_transformer) - - if self._validate: - try: - validator = self._get_validator(reader_schema_raw, reader_schema, reader_ref_registry) - validator.validate(obj_dict) - except ValidationError as ve: - raise SerializationError(ve.message) - - if self._from_dict is not None: - return self._from_dict(obj_dict, ctx) - - return obj_dict - - def _get_parsed_schema(self, schema: Schema) -> Tuple[Optional[JsonSchema], Optional[Registry]]: - if schema is None: - return None, None - - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - ref_registry = _resolve_named_schema(schema, self._registry) - parsed_schema = json.loads(schema.schema_str) - - self._parsed_schemas.set(schema, (parsed_schema, ref_registry)) - return parsed_schema, ref_registry - - def _get_validator(self, schema: Schema, parsed_schema: JsonSchema, registry: Registry) -> Validator: - validator = self._validators.get(schema, None) - if validator is not None: - return validator - - cls = validator_for(parsed_schema) - cls.check_schema(parsed_schema) - validator = cls(parsed_schema, registry=registry) - - self._validators[schema] = validator - return validator - - -def transform( - ctx: RuleContext, schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, - path: str, message: JsonMessage, field_transform: FieldTransform -) -> Optional[JsonMessage]: - if message is None or schema is None or isinstance(schema, bool): - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - field_ctx.field_type = get_type(schema) - all_of = schema.get("allOf") - if all_of is not None: - subschema = _validate_subschemas(all_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - any_of = schema.get("anyOf") - if any_of is not None: - subschema = _validate_subschemas(any_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - one_of = schema.get("oneOf") - if one_of is not None: - subschema = _validate_subschemas(one_of, message, ref_registry) - if subschema is not None: - return transform(ctx, subschema, ref_registry, ref_resolver, path, message, field_transform) - items = schema.get("items") - if items is not None: - if isinstance(message, list): - return [transform(ctx, items, ref_registry, ref_resolver, path, item, field_transform) for item in message] - ref = schema.get("$ref") - if ref is not None: - ref_schema = ref_resolver.lookup(ref) - return transform(ctx, ref_schema.contents, ref_registry, ref_resolver, path, message, field_transform) - schema_type = get_type(schema) - if schema_type == FieldType.RECORD: - props = schema.get("properties") - if props is not None: - for prop_name, prop_schema in props.items(): - _transform_field(ctx, path, prop_name, message, - prop_schema, ref_registry, ref_resolver, field_transform) - return message - if schema_type in (FieldType.ENUM, FieldType.STRING, FieldType.INT, FieldType.DOUBLE, FieldType.BOOLEAN): - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, path: str, prop_name: str, message: JsonMessage, - prop_schema: JsonSchema, ref_registry: Registry, ref_resolver: Resolver, field_transform: FieldTransform -): - full_name = path + "." + prop_name - try: - ctx.enter_field( - message, - full_name, - prop_name, - get_type(prop_schema), - get_inline_tags(prop_schema) - ) - value = message[prop_name] - new_value = transform(ctx, prop_schema, ref_registry, ref_resolver, full_name, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - message[prop_name] = new_value - finally: - ctx.exit_field() - - -def _validate_subschemas( - subschemas: List[JsonSchema], - message: JsonMessage, - registry: Registry -) -> Optional[JsonSchema]: - for subschema in subschemas: - try: - validate(instance=message, schema=subschema, registry=registry) - return subschema - except ValidationError: - pass - return None - - -def get_type(schema: JsonSchema) -> FieldType: - if isinstance(schema, list): - return FieldType.COMBINED - elif isinstance(schema, dict): - schema_type = schema.get("type") - else: - # string schemas; this could be either a named schema or a primitive type - schema_type = schema - - if schema.get("const") is not None or schema.get("enum") is not None: - return FieldType.ENUM - if schema_type == "object": - props = schema.get("properties") - if not props: - return FieldType.MAP - return FieldType.RECORD - if schema_type == "array": - return FieldType.ARRAY - if schema_type == "string": - return FieldType.STRING - if schema_type == "integer": - return FieldType.INT - if schema_type == "number": - return FieldType.DOUBLE - if schema_type == "boolean": - return FieldType.BOOLEAN - if schema_type == "null": - return FieldType.NULL - return FieldType.NULL - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def get_inline_tags(schema: JsonSchema) -> Set[str]: - tags = schema.get("confluent:tags") - if tags is None: - return set() - else: - return set(tags) +from .common.json_schema import * # noqa +from ._sync.json_schema import * # noqa diff --git a/src/confluent_kafka/schema_registry/protobuf.py b/src/confluent_kafka/schema_registry/protobuf.py index dfa5c1ffe..ba0f03a90 100644 --- a/src/confluent_kafka/schema_registry/protobuf.py +++ b/src/confluent_kafka/schema_registry/protobuf.py @@ -15,1139 +15,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io -import sys -import base64 -import struct -import warnings -from collections import deque -from decimal import Context, Decimal, MAX_PREC -from typing import Set, List, Union, Optional, Any, Tuple - -from google.protobuf import descriptor_pb2, any_pb2, api_pb2, empty_pb2, \ - duration_pb2, field_mask_pb2, source_context_pb2, struct_pb2, timestamp_pb2, \ - type_pb2, wrappers_pb2 -from google.protobuf import json_format -from google.protobuf.descriptor_pool import DescriptorPool -from google.type import calendar_period_pb2, color_pb2, date_pb2, datetime_pb2, \ - dayofweek_pb2, expr_pb2, fraction_pb2, latlng_pb2, money_pb2, month_pb2, \ - postal_address_pb2, timeofday_pb2, quaternion_pb2 - -import confluent_kafka.schema_registry.confluent.meta_pb2 as meta_pb2 - -from google.protobuf.descriptor import Descriptor, FieldDescriptor, \ - FileDescriptor -from google.protobuf.message import DecodeError, Message -from google.protobuf.message_factory import GetMessageClass - -from . import (_MAGIC_BYTE, - reference_subject_name_strategy, - topic_subject_name_strategy, SchemaRegistryClient) -from .confluent.types import decimal_pb2 -from .rule_registry import RuleRegistry -from .schema_registry_client import (Schema, - SchemaReference, - RuleKind, - RuleMode) -from confluent_kafka.serialization import SerializationError, \ - SerializationContext -from .serde import BaseSerializer, BaseDeserializer, RuleContext, \ - FieldTransform, FieldType, RuleConditionError, ParsedSchemaCache - -# Convert an int to bytes (inverse of ord()) -# Python3.chr() -> Unicode -# Python2.chr() -> str(alias for bytes) -if sys.version > '3': - def _bytes(v: int) -> bytes: - """ - Convert int to bytes - - Args: - v (int): The int to convert to bytes. - """ - return bytes((v,)) -else: - def _bytes(v: int) -> str: - """ - Convert int to bytes - - Args: - v (int): The int to convert to bytes. - """ - return chr(v) - - -class _ContextStringIO(io.BytesIO): - """ - Wrapper to allow use of StringIO via 'with' constructs. - """ - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - return False - - -def _create_index_array(msg_desc: Descriptor) -> List[int]: - """ - Creates an index array specifying the location of msg_desc in - the referenced FileDescriptor. - - Args: - msg_desc (MessageDescriptor): Protobuf MessageDescriptor - - Returns: - list of int: Protobuf MessageDescriptor index array. - - Raises: - ValueError: If the message descriptor is malformed. - """ - - msg_idx = deque() - - # Walk the nested MessageDescriptor tree up to the root. - current = msg_desc - found = False - while current.containing_type is not None: - previous = current - current = previous.containing_type - # find child's position - for idx, node in enumerate(current.nested_types): - if node == previous: - msg_idx.appendleft(idx) - found = True - break - if not found: - raise ValueError("Nested MessageDescriptor not found") - - # Add the index of the root MessageDescriptor in the FileDescriptor. - found = False - for idx, msg_type_name in enumerate(msg_desc.file.message_types_by_name): - if msg_type_name == current.name: - msg_idx.appendleft(idx) - found = True - break - if not found: - raise ValueError("MessageDescriptor not found in file") - - return list(msg_idx) - - -def _schema_to_str(file_descriptor: FileDescriptor) -> str: - """ - Base64 encode a FileDescriptor - - Args: - file_descriptor (FileDescriptor): FileDescriptor to encode. - - Returns: - str: Base64 encoded FileDescriptor - """ - - return base64.standard_b64encode(file_descriptor.serialized_pb).decode('ascii') - - -def _proto_to_str(file_descriptor_proto: descriptor_pb2.FileDescriptorProto) -> str: - """ - Base64 encode a FileDescriptorProto - - Args: - file_descriptor_proto (FileDescriptorProto): FileDescriptorProto to encode. - - Returns: - str: Base64 encoded FileDescriptorProto - """ - - return base64.standard_b64encode(file_descriptor_proto.SerializeToString()).decode('ascii') - - -def _str_to_proto(name: str, schema_str: str) -> descriptor_pb2.FileDescriptorProto: - """ - Base64 decode a FileDescriptor - - Args: - schema_str (str): Base64 encoded FileDescriptorProto - - Returns: - FileDescriptorProto: schema. - """ - - serialized_pb = base64.standard_b64decode(schema_str.encode('ascii')) - file_descriptor_proto = descriptor_pb2.FileDescriptorProto() - try: - file_descriptor_proto.ParseFromString(serialized_pb) - file_descriptor_proto.name = name - except DecodeError as e: - raise SerializationError(str(e)) - return file_descriptor_proto - - -def _resolve_named_schema( - schema: Schema, - schema_registry_client: SchemaRegistryClient, - pool: DescriptorPool, - visited: Optional[Set[str]] = None -): - """ - Resolves named schemas referenced by the provided schema recursively. - :param schema: Schema to resolve named schemas for. - :param schema_registry_client: SchemaRegistryClient to use for retrieval. - :param pool: DescriptorPool to add resolved schemas to. - :return: DescriptorPool - """ - if visited is None: - visited = set() - if schema.references is not None: - for ref in schema.references: - if _is_builtin(ref.name) or ref.name in visited: - continue - visited.add(ref.name) - referenced_schema = schema_registry_client.get_version(ref.subject, ref.version, True, 'serialized') - _resolve_named_schema(referenced_schema.schema, schema_registry_client, pool, visited) - file_descriptor_proto = _str_to_proto(ref.name, referenced_schema.schema.schema_str) - pool.Add(file_descriptor_proto) - - -def _init_pool(pool: DescriptorPool): - pool.AddSerializedFile(any_pb2.DESCRIPTOR.serialized_pb) - # source_context needed by api - pool.AddSerializedFile(source_context_pb2.DESCRIPTOR.serialized_pb) - # type needed by api - pool.AddSerializedFile(type_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(api_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(descriptor_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(duration_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(empty_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(field_mask_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(struct_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(timestamp_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(wrappers_pb2.DESCRIPTOR.serialized_pb) - - pool.AddSerializedFile(calendar_period_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(color_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(date_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(datetime_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(dayofweek_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(expr_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(fraction_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(latlng_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(money_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(month_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(postal_address_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(quaternion_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(timeofday_pb2.DESCRIPTOR.serialized_pb) - - pool.AddSerializedFile(meta_pb2.DESCRIPTOR.serialized_pb) - pool.AddSerializedFile(decimal_pb2.DESCRIPTOR.serialized_pb) - - -class ProtobufSerializer(BaseSerializer): - """ - Serializer for Protobuf Message derived classes. Serialization format is Protobuf, - with Confluent Schema Registry framing. - - Configuration properties: - - +-------------------------------------+----------+------------------------------------------------------+ - | Property Name | Type | Description | - +=====================================+==========+======================================================+ - | | | If True, automatically register the configured | - | ``auto.register.schemas`` | bool | schema with Confluent Schema Registry if it has | - | | | not previously been associated with the relevant | - | | | subject (determined via subject.name.strategy). | - | | | | - | | | Defaults to True. | - | | | | - | | | Raises SchemaRegistryError if the schema was not | - | | | registered against the subject, or could not be | - | | | successfully registered. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to normalize schemas, which will | - | ``normalize.schemas`` | bool | transform schemas to have a consistent format, | - | | | including ordering properties and references. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the given schema ID for | - | ``use.schema.id`` | int | serialization. | - | | | | - +-----------------------------------------+----------+--------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | serialization. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to False. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata`` | dict | the given metadata. | - | | | | - | | | WARNING: There is no check that the latest | - | | | schema is backwards compatible with the object | - | | | being serialized. | - | | | | - | | | Defaults to None. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether or not to skip known types when resolving | - | ``skip.known.types`` | bool | schema dependencies. | - | | | | - | | | Defaults to True. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka.schema_registry | - | | | namespace. | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``reference.subject.name.strategy`` | callable | Defines how Schema Registry subject names for schema | - | | | references are constructed. | - | | | | - | | | Defaults to reference_subject_name_strategy | - +-------------------------------------+----------+------------------------------------------------------+ - | ``use.deprecated.format`` | bool | Specifies whether the Protobuf serializer should | - | | | serialize message indexes without zig-zag encoding. | - | | | This option must be explicitly configured as older | - | | | and newer Protobuf producers are incompatible. | - | | | If the consumers of the topic being produced to are | - | | | using confluent-kafka-python <1.8 then this property | - | | | must be set to True until all old consumers have | - | | | have been upgraded. | - | | | | - | | | Warning: This configuration property will be removed | - | | | in a future version of the client. | - +-------------------------------------+----------+------------------------------------------------------+ - - Schemas are registered against subject names in Confluent Schema Registry that - define a scope in which the schemas can be evolved. By default, the subject name - is formed by concatenating the topic name with the message field (key or value) - separated by a hyphen. - - i.e. {topic name}-{message field} - - Alternative naming strategies may be configured with the property - ``subject.name.strategy``. - - Supported subject name strategies - - +--------------------------------------+------------------------------+ - | Subject Name Strategy | Output Format | - +======================================+==============================+ - | topic_subject_name_strategy(default) | {topic name}-{message field} | - +--------------------------------------+------------------------------+ - | topic_record_subject_name_strategy | {topic name}-{record name} | - +--------------------------------------+------------------------------+ - | record_subject_name_strategy | {record name} | - +--------------------------------------+------------------------------+ - - See `Subject name strategy `_ for additional details. - - Args: - msg_type (Message): Protobuf Message type. - - schema_registry_client (SchemaRegistryClient): Schema Registry - client instance. - - conf (dict): ProtobufSerializer configuration. - - See Also: - `Protobuf API reference `_ - """ # noqa: E501 - __slots__ = ['_skip_known_types', '_known_subjects', '_msg_class', '_index_array', - '_schema', '_schema_id', '_ref_reference_subject_func', - '_use_deprecated_format', '_parsed_schemas'] - - _default_conf = { - 'auto.register.schemas': True, - 'normalize.schemas': False, - 'use.schema.id': None, - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'skip.known.types': True, - 'subject.name.strategy': topic_subject_name_strategy, - 'reference.subject.name.strategy': reference_subject_name_strategy, - 'use.deprecated.format': False, - } - - def __init__( - self, - msg_type: Message, - schema_registry_client: SchemaRegistryClient, - conf: Optional[dict] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufSerializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._auto_register = conf_copy.pop('auto.register.schemas') - if not isinstance(self._auto_register, bool): - raise ValueError("auto.register.schemas must be a boolean value") - - self._normalize_schemas = conf_copy.pop('normalize.schemas') - if not isinstance(self._normalize_schemas, bool): - raise ValueError("normalize.schemas must be a boolean value") - - self._use_schema_id = conf_copy.pop('use.schema.id') - if (self._use_schema_id is not None and - not isinstance(self._use_schema_id, int)): - raise ValueError("use.schema.id must be an int value") - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - if self._use_latest_version and self._auto_register: - raise ValueError("cannot enable both use.latest.version and auto.register.schemas") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._skip_known_types = conf_copy.pop('skip.known.types') - if not isinstance(self._skip_known_types, bool): - raise ValueError("skip.known.types must be a boolean value") - - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') - if not isinstance(self._use_deprecated_format, bool): - raise ValueError("use.deprecated.format must be a boolean value") - if self._use_deprecated_format: - warnings.warn("ProtobufSerializer: the 'use.deprecated.format' " - "configuration property, and the ability to use the " - "old incorrect Protobuf serializer heading format " - "introduced in confluent-kafka-python v1.4.0, " - "will be removed in an upcoming release in 2021 Q2. " - "Please migrate your Python Protobuf producers and " - "consumers to 'use.deprecated.format':False as " - "soon as possible") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._ref_reference_subject_func = conf_copy.pop( - 'reference.subject.name.strategy') - if not callable(self._ref_reference_subject_func): - raise ValueError("subject.name.strategy must be callable") - - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._schema_id = None - self._known_subjects = set() - self._msg_class = msg_type - self._parsed_schemas = ParsedSchemaCache() - - descriptor = msg_type.DESCRIPTOR - self._index_array = _create_index_array(descriptor) - self._schema = Schema(_schema_to_str(descriptor.file), - schema_type='PROTOBUF') - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - @staticmethod - def _write_varint(buf: io.BytesIO, val: int, zigzag: bool = True): - """ - Writes val to buf, either using zigzag or uvarint encoding. - - Args: - buf (BytesIO): buffer to write to. - val (int): integer to be encoded. - zigzag (bool): whether to encode in zigzag or uvarint encoding - """ - - if zigzag: - val = (val << 1) ^ (val >> 63) - - while (val & ~0x7f) != 0: - buf.write(_bytes((val & 0x7f) | 0x80)) - val >>= 7 - buf.write(_bytes(val)) - - @staticmethod - def _encode_varints(buf: io.BytesIO, ints: List[int], zigzag: bool = True): - """ - Encodes each int as a uvarint onto buf - - Args: - buf (BytesIO): buffer to write to. - ints ([int]): ints to be encoded. - zigzag (bool): whether to encode in zigzag or uvarint encoding - """ - - assert len(ints) > 0 - # The root element at the 0 position does not need a length prefix. - if ints == [0]: - buf.write(_bytes(0x00)) - return - - ProtobufSerializer._write_varint(buf, len(ints), zigzag=zigzag) - - for value in ints: - ProtobufSerializer._write_varint(buf, value, zigzag=zigzag) - - def _resolve_dependencies( - self, ctx: SerializationContext, - file_desc: FileDescriptor - ) -> List[SchemaReference]: - """ - Resolves and optionally registers schema references recursively. - - Args: - ctx (SerializationContext): Serialization context. - - file_desc (FileDescriptor): file descriptor to traverse. - """ - - schema_refs = [] - for dep in file_desc.dependencies: - if self._skip_known_types and _is_builtin(dep.name): - continue - dep_refs = self._resolve_dependencies(ctx, dep) - subject = self._ref_reference_subject_func(ctx, dep) - schema = Schema(_schema_to_str(dep), - references=dep_refs, - schema_type='PROTOBUF') - if self._auto_register: - self._registry.register_schema(subject, schema) - - reference = self._registry.lookup_schema(subject, schema) - # schema_refs are per file descriptor - schema_refs.append(SchemaReference(dep.name, - subject, - reference.version)) - return schema_refs - - def __call__(self, message: Message, ctx: Optional[SerializationContext] = None) -> Optional[bytes]: - """ - Serializes an instance of a class derived from Protobuf Message, and prepends - it with Confluent Schema Registry framing. - - Args: - message (Message): An instance of a class derived from Protobuf Message. - - ctx (SerializationContext): Metadata relevant to the serialization. - operation. - - Raises: - SerializerError if any error occurs during serialization. - - Returns: - None if messages is None, else a byte array containing the Protobuf - serialized message with Confluent Schema Registry framing. - """ - - if message is None: - return None - - if not isinstance(message, self._msg_class): - raise ValueError("message must be of type {} not {}" - .format(self._msg_class, type(message))) - - subject = self._subject_name_func(ctx, message.DESCRIPTOR.full_name) if ctx else None - latest_schema = None - if subject is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') - - if latest_schema is not None: - self._schema_id = latest_schema.schema_id - elif subject not in self._known_subjects and ctx is not None: - references = self._resolve_dependencies(ctx, message.DESCRIPTOR.file) - self._schema = Schema( - self._schema.schema_str, - self._schema.schema_type, - references - ) - - if self._auto_register: - self._schema_id = self._registry.register_schema(subject, - self._schema, - self._normalize_schemas) - else: - self._schema_id = self._registry.lookup_schema( - subject, self._schema, self._normalize_schemas).schema_id - - self._known_subjects.add(subject) - - if latest_schema is not None: - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) - fd = pool.FindFileByName(fd_proto.name) - desc = fd.message_types_by_name[message.DESCRIPTOR.name] - field_transformer = lambda rule_ctx, field_transform, msg: ( # noqa: E731 - transform(rule_ctx, desc, msg, field_transform)) - message = self._execute_rules(ctx, subject, RuleMode.WRITE, None, - latest_schema.schema, message, None, - field_transformer) - - with _ContextStringIO() as fo: - # Write the magic byte and schema ID in network byte order - # (big endian) - fo.write(struct.pack('>bI', _MAGIC_BYTE, self._schema_id)) - # write the index array that specifies the message descriptor - # of the serialized data. - self._encode_varints(fo, self._index_array, - zigzag=not self._use_deprecated_format) - # write the serialized data itself - fo.write(message.SerializeToString()) - return fo.getvalue() - - def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - pool = DescriptorPool() - _init_pool(pool) - _resolve_named_schema(schema, self._registry, pool) - fd_proto = _str_to_proto("default", schema.schema_str) - pool.Add(fd_proto) - self._parsed_schemas.set(schema, (fd_proto, pool)) - return fd_proto, pool - - -class ProtobufDeserializer(BaseDeserializer): - """ - Deserializer for Protobuf serialized data with Confluent Schema Registry framing. - - Args: - message_type (Message derived type): Protobuf Message type. - conf (dict): Configuration dictionary. - - ProtobufDeserializer configuration properties: - - +-------------------------------------+----------+------------------------------------------------------+ - | Property Name | Type | Description | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version for | - | ``use.latest.version`` | bool | deserialization. | - | | | | - | | | Defaults to False. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Whether to use the latest subject version with | - | ``use.latest.with.metadata`` | dict | the given metadata. | - | | | | - | | | Defaults to None. | - +-------------------------------------+----------+------------------------------------------------------+ - | | | Callable(SerializationContext, str) -> str | - | | | | - | ``subject.name.strategy`` | callable | Defines how Schema Registry subject names are | - | | | constructed. Standard naming strategies are | - | | | defined in the confluent_kafka. schema_registry | - | | | namespace . | - | | | | - | | | Defaults to topic_subject_name_strategy. | - +-------------------------------------+----------+------------------------------------------------------+ - | ``use.deprecated.format`` | bool | Specifies whether the Protobuf deserializer should | - | | | deserialize message indexes without zig-zag encoding.| - | | | This option must be explicitly configured as older | - | | | and newer Protobuf producers are incompatible. | - | | | If Protobuf messages in the topic to consume were | - | | | produced with confluent-kafka-python <1.8 then this | - | | | property must be set to True until all old messages | - | | | have been processed and producers have been upgraded.| - | | | Warning: This configuration property will be removed | - | | | in a future version of the client. | - +-------------------------------------+----------+------------------------------------------------------+ - - - See Also: - `Protobuf API reference `_ - """ - - __slots__ = ['_msg_class', '_use_deprecated_format', '_parsed_schemas'] - - _default_conf = { - 'use.latest.version': False, - 'use.latest.with.metadata': None, - 'subject.name.strategy': topic_subject_name_strategy, - 'use.deprecated.format': False, - } - - def __init__( - self, - message_type: Message, - conf: Optional[dict] = None, - schema_registry_client: Optional[SchemaRegistryClient] = None, - rule_conf: Optional[dict] = None, - rule_registry: Optional[RuleRegistry] = None - ): - super().__init__() - - self._registry = schema_registry_client - self._rule_registry = rule_registry if rule_registry else RuleRegistry.get_global_instance() - self._parsed_schemas = ParsedSchemaCache() - self._use_schema_id = None - - # Require use.deprecated.format to be explicitly configured - # during a transitionary period since old/new format are - # incompatible. - if conf is None or 'use.deprecated.format' not in conf: - raise RuntimeError( - "ProtobufDeserializer: the 'use.deprecated.format' configuration " - "property must be explicitly set due to backward incompatibility " - "with older confluent-kafka-python Protobuf producers and consumers. " - "See the release notes for more details") - - conf_copy = self._default_conf.copy() - if conf is not None: - conf_copy.update(conf) - - self._use_latest_version = conf_copy.pop('use.latest.version') - if not isinstance(self._use_latest_version, bool): - raise ValueError("use.latest.version must be a boolean value") - - self._use_latest_with_metadata = conf_copy.pop('use.latest.with.metadata') - if (self._use_latest_with_metadata is not None and - not isinstance(self._use_latest_with_metadata, dict)): - raise ValueError("use.latest.with.metadata must be a dict value") - - self._subject_name_func = conf_copy.pop('subject.name.strategy') - if not callable(self._subject_name_func): - raise ValueError("subject.name.strategy must be callable") - - self._use_deprecated_format = conf_copy.pop('use.deprecated.format') - if not isinstance(self._use_deprecated_format, bool): - raise ValueError("use.deprecated.format must be a boolean value") - if self._use_deprecated_format: - warnings.warn("ProtobufDeserializer: the 'use.deprecated.format' " - "configuration property, and the ability to use the " - "old incorrect Protobuf serializer heading format " - "introduced in confluent-kafka-python v1.4.0, " - "will be removed in an upcoming release in 2022 Q2. " - "Please migrate your Python Protobuf producers and " - "consumers to 'use.deprecated.format':False as " - "soon as possible") - - descriptor = message_type.DESCRIPTOR - self._msg_class = GetMessageClass(descriptor) - - for rule in self._rule_registry.get_executors(): - rule.configure(self._registry.config() if self._registry else {}, - rule_conf if rule_conf else {}) - - @staticmethod - def _decode_varint(buf: io.BytesIO, zigzag: bool = True) -> int: - """ - Decodes a single varint from a buffer. - - Args: - buf (BytesIO): buffer to read from - zigzag (bool): decode as zigzag or uvarint - - Returns: - int: decoded varint - - Raises: - EOFError: if buffer is empty - """ - - value = 0 - shift = 0 - try: - while True: - i = ProtobufDeserializer._read_byte(buf) - - value |= (i & 0x7f) << shift - shift += 7 - if not (i & 0x80): - break - - if zigzag: - value = (value >> 1) ^ -(value & 1) - - return value - - except EOFError: - raise EOFError("Unexpected EOF while reading index") - - @staticmethod - def _read_byte(buf: io.BytesIO) -> int: - """ - Read one byte from buf as an int. - - Args: - buf (BytesIO): The buffer to read from. - - .. _ord: - https://docs.python.org/2/library/functions.html#ord - """ - - i = buf.read(1) - if i == b'': - raise EOFError("Unexpected EOF encountered") - return ord(i) - - @staticmethod - def _read_index_array(buf: io.BytesIO, zigzag: bool = True) -> List[int]: - """ - Read an index array from buf that specifies the message - descriptor of interest in the file descriptor. - - Args: - buf (BytesIO): The buffer to read from. - - Returns: - list of int: The index array. - """ - - size = ProtobufDeserializer._decode_varint(buf, zigzag=zigzag) - if size < 0 or size > 100000: - raise DecodeError("Invalid Protobuf msgidx array length") - - if size == 0: - return [0] - - msg_index = [] - for _ in range(size): - msg_index.append(ProtobufDeserializer._decode_varint(buf, - zigzag=zigzag)) - - return msg_index - - def __call__(self, data: bytes, ctx: Optional[SerializationContext] = None) -> Optional[Message]: - """ - Deserialize a serialized protobuf message with Confluent Schema Registry - framing. - - Args: - data (bytes): Serialized protobuf message with Confluent Schema - Registry framing. - - ctx (SerializationContext): Metadata relevant to the serialization - operation. - - Returns: - Message: Protobuf Message instance. - - Raises: - SerializerError: If there was an error reading the Confluent framing - data, or parsing the protobuf serialized message. - """ - - if data is None: - return None - - # SR wire protocol + msg_index length - if len(data) < 6: - raise SerializationError("Expecting data framing of length 6 bytes or " - "more but total data size is {} bytes. This " - "message was not produced with a Confluent " - "Schema Registry serializer".format(len(data))) - - subject = self._subject_name_func(ctx, None) - latest_schema = None - if subject is not None and self._registry is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') - - with _ContextStringIO(data) as payload: - magic, schema_id = struct.unpack('>bI', payload.read(5)) - if magic != _MAGIC_BYTE: - raise SerializationError("Unknown magic byte. This message was " - "not produced with a Confluent " - "Schema Registry serializer") - - msg_index = self._read_index_array(payload, zigzag=not self._use_deprecated_format) - - if self._registry is not None: - writer_schema_raw = self._registry.get_schema(schema_id, fmt='serialized') - fd_proto, pool = self._get_parsed_schema(writer_schema_raw) - writer_schema = pool.FindFileByName(fd_proto.name) - writer_desc = self._get_message_desc(pool, writer_schema, msg_index) - if subject is None: - subject = self._subject_name_func(ctx, writer_desc.full_name) - if subject is not None: - latest_schema = self._get_reader_schema(subject, fmt='serialized') - else: - writer_schema_raw = None - writer_schema = None - - if latest_schema is not None: - migrations = self._get_migrations(subject, writer_schema_raw, latest_schema, None) - reader_schema_raw = latest_schema.schema - fd_proto, pool = self._get_parsed_schema(latest_schema.schema) - reader_schema = pool.FindFileByName(fd_proto.name) - else: - migrations = None - reader_schema_raw = writer_schema_raw - reader_schema = writer_schema - - if reader_schema is not None: - # Initialize reader desc to first message in file - reader_desc = self._get_message_desc(pool, reader_schema, [0]) - # Attempt to find a reader desc with the same name as the writer - reader_desc = reader_schema.message_types_by_name.get(writer_desc.name, reader_desc) - - if migrations: - msg = GetMessageClass(writer_desc)() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - obj_dict = json_format.MessageToDict(msg, True) - obj_dict = self._execute_migrations(ctx, subject, migrations, obj_dict) - msg = GetMessageClass(reader_desc)() - msg = json_format.ParseDict(obj_dict, msg) - else: - # Protobuf Messages are self-describing; no need to query schema - msg = self._msg_class() - try: - msg.ParseFromString(payload.read()) - except DecodeError as e: - raise SerializationError(str(e)) - - field_transformer = lambda rule_ctx, field_transform, message: ( # noqa: E731 - transform(rule_ctx, reader_desc, message, field_transform)) - msg = self._execute_rules(ctx, subject, RuleMode.READ, None, - reader_schema_raw, msg, None, - field_transformer) - - return msg - - def _get_parsed_schema(self, schema: Schema) -> Tuple[descriptor_pb2.FileDescriptorProto, DescriptorPool]: - result = self._parsed_schemas.get_parsed_schema(schema) - if result is not None: - return result - - pool = DescriptorPool() - _init_pool(pool) - _resolve_named_schema(schema, self._registry, pool) - fd_proto = _str_to_proto("default", schema.schema_str) - pool.Add(fd_proto) - self._parsed_schemas.set(schema, (fd_proto, pool)) - return fd_proto, pool - - def _get_message_desc( - self, pool: DescriptorPool, fd: FileDescriptor, - msg_index: List[int] - ) -> Descriptor: - file_desc_proto = descriptor_pb2.FileDescriptorProto() - fd.CopyToProto(file_desc_proto) - (full_name, desc_proto) = self._get_message_desc_proto("", file_desc_proto, msg_index) - package = file_desc_proto.package - qualified_name = package + "." + full_name if package else full_name - return pool.FindMessageTypeByName(qualified_name) - - def _get_message_desc_proto( - self, - path: str, - desc: Union[descriptor_pb2.FileDescriptorProto, descriptor_pb2.DescriptorProto], - msg_index: List[int] - ) -> Tuple[str, descriptor_pb2.DescriptorProto]: - index = msg_index[0] - if isinstance(desc, descriptor_pb2.FileDescriptorProto): - msg = desc.message_type[index] - path = path + "." + msg.name if path else msg.name - if len(msg_index) == 1: - return path, msg - return self._get_message_desc_proto(path, msg, msg_index[1:]) - else: - msg = desc.nested_type[index] - path = path + "." + msg.name if path else msg.name - if len(msg_index) == 1: - return path, msg - return self._get_message_desc_proto(path, msg, msg_index[1:]) - - -def transform( - ctx: RuleContext, descriptor: Descriptor, message: Any, - field_transform: FieldTransform -) -> Any: - if message is None or descriptor is None: - return message - if isinstance(message, list): - return [transform(ctx, descriptor, item, field_transform) - for item in message] - if isinstance(message, dict): - return {key: transform(ctx, descriptor, value, field_transform) - for key, value in message.items()} - if isinstance(message, Message): - for fd in descriptor.fields: - _transform_field(ctx, fd, descriptor, message, field_transform) - return message - field_ctx = ctx.current_field() - if field_ctx is not None: - rule_tags = ctx.rule.tags - if not rule_tags or not _disjoint(set(rule_tags), field_ctx.tags): - return field_transform(ctx, field_ctx, message) - return message - - -def _transform_field( - ctx: RuleContext, fd: FieldDescriptor, desc: Descriptor, - message: Message, field_transform: FieldTransform -): - try: - ctx.enter_field( - message, - fd.full_name, - fd.name, - get_type(fd), - get_inline_tags(fd) - ) - if fd.containing_oneof is not None and not message.HasField(fd.name): - return - value = getattr(message, fd.name) - if is_map_field(fd): - value = {key: value[key] for key in value} - elif fd.label == FieldDescriptor.LABEL_REPEATED: - value = [item for item in value] - new_value = transform(ctx, desc, value, field_transform) - if ctx.rule.kind == RuleKind.CONDITION: - if new_value is False: - raise RuleConditionError(ctx.rule) - else: - _set_field(fd, message, new_value) - finally: - ctx.exit_field() - - -def _set_field(fd: FieldDescriptor, message: Message, value: Any): - if isinstance(value, list): - message.ClearField(fd.name) - old_value = getattr(message, fd.name) - old_value.extend(value) - elif isinstance(value, dict): - message.ClearField(fd.name) - old_value = getattr(message, fd.name) - old_value.update(value) - else: - setattr(message, fd.name, value) - - -def get_type(fd: FieldDescriptor) -> FieldType: - if is_map_field(fd): - return FieldType.MAP - if fd.type == FieldDescriptor.TYPE_MESSAGE: - return FieldType.RECORD - if fd.type == FieldDescriptor.TYPE_ENUM: - return FieldType.ENUM - if fd.type == FieldDescriptor.TYPE_STRING: - return FieldType.STRING - if fd.type == FieldDescriptor.TYPE_BYTES: - return FieldType.BYTES - if fd.type in (FieldDescriptor.TYPE_INT32, FieldDescriptor.TYPE_SINT32, - FieldDescriptor.TYPE_UINT32, FieldDescriptor.TYPE_FIXED32, - FieldDescriptor.TYPE_SFIXED32): - return FieldType.INT - if fd.type in (FieldDescriptor.TYPE_INT64, FieldDescriptor.TYPE_SINT64, - FieldDescriptor.TYPE_UINT64, FieldDescriptor.TYPE_FIXED64, - FieldDescriptor.TYPE_SFIXED64): - return FieldType.LONG - if fd.type == FieldDescriptor.TYPE_FLOAT: - return FieldType.FLOAT - if fd.type == FieldDescriptor.TYPE_DOUBLE: - return FieldType.DOUBLE - if fd.type == FieldDescriptor.TYPE_BOOL: - return FieldType.BOOLEAN - return FieldType.NULL - - -def is_map_field(fd: FieldDescriptor): - return (fd.type == FieldDescriptor.TYPE_MESSAGE - and hasattr(fd.message_type, 'options') - and fd.message_type.options.map_entry) - - -def get_inline_tags(fd: FieldDescriptor) -> Set[str]: - meta = fd.GetOptions().Extensions[meta_pb2.field_meta] - if meta is None: - return set() - else: - return set(meta.tags) - - -def _disjoint(tags1: Set[str], tags2: Set[str]) -> bool: - for tag in tags1: - if tag in tags2: - return False - return True - - -def _is_builtin(name: str) -> bool: - return name.startswith('confluent/') or \ - name.startswith('google/protobuf/') or \ - name.startswith('google/type/') - - -def decimalToProtobuf(value: Decimal, scale: int) -> decimal_pb2.Decimal: - """ - Converts a Decimal to a Protobuf value. - - Args: - value (Decimal): The Decimal value to convert. - - Returns: - The Protobuf value. - """ - sign, digits, exp = value.as_tuple() - - delta = exp + scale - - if delta < 0: - raise ValueError( - "Scale provided does not match the decimal") - - unscaled_datum = 0 - for digit in digits: - unscaled_datum = (unscaled_datum * 10) + digit - - unscaled_datum = 10**delta * unscaled_datum - - bytes_req = (unscaled_datum.bit_length() + 8) // 8 - - if sign: - unscaled_datum = -unscaled_datum - - bytes = unscaled_datum.to_bytes(bytes_req, byteorder="big", signed=True) - - result = decimal_pb2.Decimal() - result.value = bytes - result.precision = 0 - result.scale = scale - return result - - -decimal_context = Context() - - -def protobufToDecimal(value: decimal_pb2.Decimal) -> Decimal: - """ - Converts a Protobuf value to Decimal. - - Args: - value (decimal_pb2.Decimal): The Protobuf value to convert. - - Returns: - The Decimal value. - """ - unscaled_datum = int.from_bytes(value.value, byteorder="big", signed=True) - - if value.precision > 0: - decimal_context.prec = value.precision - else: - decimal_context.prec = MAX_PREC - return decimal_context.create_decimal(unscaled_datum).scaleb( - -value.scale, decimal_context - ) +from .common.protobuf import * # noqa +from ._sync.protobuf import * # noqa diff --git a/src/confluent_kafka/schema_registry/schema_registry_client.py b/src/confluent_kafka/schema_registry/schema_registry_client.py index 4cadf8bfd..e9a0eb3e2 100644 --- a/src/confluent_kafka/schema_registry/schema_registry_client.py +++ b/src/confluent_kafka/schema_registry/schema_registry_client.py @@ -14,1968 +14,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -import abc -import json -import logging -import random -import time -import urllib -from urllib.parse import unquote, urlparse - -import httpx -from attrs import define as _attrs_define -from attrs import field as _attrs_field -from collections import defaultdict -from enum import Enum -from threading import Lock -from typing import List, Dict, Type, TypeVar, \ - cast, Optional, Union, Any, Tuple, Callable - -from cachetools import TTLCache, LRUCache -from httpx import Response - -from authlib.integrations.httpx_client import OAuth2Client - -from .error import SchemaRegistryError, OAuthTokenError - -# TODO: consider adding `six` dependency or employing a compat file -# Python 2.7 is officially EOL so compatibility issue will be come more the norm. -# We need a better way to handle these issues. -# Six is one possibility but the compat file pattern used by requests -# is also quite nice. -# -# six: https://pypi.org/project/six/ -# compat file : https://github.com/psf/requests/blob/master/requests/compat.py -try: - string_type = basestring # noqa - - def _urlencode(value: str) -> str: - return urllib.quote(value, safe='') -except NameError: - string_type = str - - def _urlencode(value: str) -> str: - return urllib.parse.quote(value, safe='') - -log = logging.getLogger(__name__) -VALID_AUTH_PROVIDERS = ['URL', 'USER_INFO'] - - -class _BearerFieldProvider(metaclass=abc.ABCMeta): - @abc.abstractmethod - def get_bearer_fields(self) -> dict: - raise NotImplementedError - - -class _StaticFieldProvider(_BearerFieldProvider): - def __init__(self, token: str, logical_cluster: str, identity_pool: str): - self.token = token - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.token, 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - -class _CustomOAuthClient(_BearerFieldProvider): - def __init__(self, custom_function: Callable[[Dict], Dict], custom_config: dict): - self.custom_function = custom_function - self.custom_config = custom_config - - def get_bearer_fields(self) -> dict: - return self.custom_function(self.custom_config) - - -class _OAuthClient(_BearerFieldProvider): - def __init__(self, client_id: str, client_secret: str, scope: str, token_endpoint: str, logical_cluster: str, - identity_pool: str, max_retries: int, retries_wait_ms: int, retries_max_wait_ms: int): - self.token = None - self.logical_cluster = logical_cluster - self.identity_pool = identity_pool - self.client = OAuth2Client(client_id=client_id, client_secret=client_secret, scope=scope) - self.token_endpoint = token_endpoint - self.max_retries = max_retries - self.retries_wait_ms = retries_wait_ms - self.retries_max_wait_ms = retries_max_wait_ms - self.token_expiry_threshold = 0.8 - - def get_bearer_fields(self) -> dict: - return {'bearer.auth.token': self.get_access_token(), 'bearer.auth.logical.cluster': self.logical_cluster, - 'bearer.auth.identity.pool.id': self.identity_pool} - - def token_expired(self) -> bool: - expiry_window = self.token['expires_in'] * self.token_expiry_threshold - - return self.token['expires_at'] < time.time() + expiry_window - - def get_access_token(self) -> str: - if not self.token or self.token_expired(): - self.generate_access_token() - - return self.token['access_token'] - - def generate_access_token(self) -> None: - for i in range(self.max_retries + 1): - try: - self.token = self.client.fetch_token(url=self.token_endpoint, grant_type='client_credentials') - return - except Exception as e: - if i >= self.max_retries: - raise OAuthTokenError(f"Failed to retrieve token after {self.max_retries} " - f"attempts due to error: {str(e)}") - time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - - -class _BaseRestClient(object): - - def __init__(self, conf: dict): - # copy dict to avoid mutating the original - conf_copy = conf.copy() - - base_url = conf_copy.pop('url', None) - if base_url is None: - raise ValueError("Missing required configuration property url") - if not isinstance(base_url, string_type): - raise TypeError("url must be a str, not " + str(type(base_url))) - base_urls = [] - for url in base_url.split(','): - url = url.strip().rstrip('/') - if not url.startswith('http') and not url.startswith('mock'): - raise ValueError("Invalid url {}".format(url)) - base_urls.append(url) - if not base_urls: - raise ValueError("Missing required configuration property url") - self.base_urls = base_urls - - self.verify = True - ca = conf_copy.pop('ssl.ca.location', None) - if ca is not None: - self.verify = ca - - key: Optional[str] = conf_copy.pop('ssl.key.location', None) - client_cert: Optional[str] = conf_copy.pop('ssl.certificate.location', None) - self.cert: Union[str, Tuple[str, str], None] = None - - if client_cert is not None and key is not None: - self.cert = (client_cert, key) - - if client_cert is not None and key is None: - self.cert = client_cert - - if key is not None and client_cert is None: - raise ValueError("ssl.certificate.location required when" - " configuring ssl.key.location") - - parsed = urlparse(self.base_urls[0]) - try: - userinfo = (unquote(parsed.username), unquote(parsed.password)) - except (AttributeError, TypeError): - userinfo = ("", "") - if 'basic.auth.user.info' in conf_copy: - if userinfo != ('', ''): - raise ValueError("basic.auth.user.info configured with" - " userinfo credentials in the URL." - " Remove userinfo credentials from the url or" - " remove basic.auth.user.info from the" - " configuration") - - userinfo = tuple(conf_copy.pop('basic.auth.user.info', '').split(':', 1)) - - if len(userinfo) != 2: - raise ValueError("basic.auth.user.info must be in the form" - " of {username}:{password}") - - self.auth = userinfo if userinfo != ('', '') else None - - # The following adds support for proxy config - # If specified: it uses the specified proxy details when making requests - self.proxy = None - proxy = conf_copy.pop('proxy', None) - if proxy is not None: - self.proxy = proxy - - self.timeout = None - timeout = conf_copy.pop('timeout', None) - if timeout is not None: - self.timeout = timeout - - self.cache_capacity = 1000 - cache_capacity = conf_copy.pop('cache.capacity', None) - if cache_capacity is not None: - if not isinstance(cache_capacity, (int, float)): - raise TypeError("cache.capacity must be a number, not " + str(type(cache_capacity))) - self.cache_capacity = cache_capacity - - self.cache_latest_ttl_sec = None - cache_latest_ttl_sec = conf_copy.pop('cache.latest.ttl.sec', None) - if cache_latest_ttl_sec is not None: - if not isinstance(cache_latest_ttl_sec, (int, float)): - raise TypeError("cache.latest.ttl.sec must be a number, not " + str(type(cache_latest_ttl_sec))) - self.cache_latest_ttl_sec = cache_latest_ttl_sec - - self.max_retries = 3 - max_retries = conf_copy.pop('max.retries', None) - if max_retries is not None: - if not isinstance(max_retries, (int, float)): - raise TypeError("max.retries must be a number, not " + str(type(max_retries))) - self.max_retries = max_retries - - self.retries_wait_ms = 1000 - retries_wait_ms = conf_copy.pop('retries.wait.ms', None) - if retries_wait_ms is not None: - if not isinstance(retries_wait_ms, (int, float)): - raise TypeError("retries.wait.ms must be a number, not " - + str(type(retries_wait_ms))) - self.retries_wait_ms = retries_wait_ms - - self.retries_max_wait_ms = 20000 - retries_max_wait_ms = conf_copy.pop('retries.max.wait.ms', None) - if retries_max_wait_ms is not None: - if not isinstance(retries_max_wait_ms, (int, float)): - raise TypeError("retries.max.wait.ms must be a number, not " - + str(type(retries_max_wait_ms))) - self.retries_max_wait_ms = retries_max_wait_ms - - self.bearer_field_provider = None - logical_cluster = None - identity_pool = None - self.bearer_auth_credentials_source = conf_copy.pop('bearer.auth.credentials.source', None) - if self.bearer_auth_credentials_source is not None: - self.auth = None - - if self.bearer_auth_credentials_source in {'OAUTHBEARER', 'STATIC_TOKEN'}: - headers = ['bearer.auth.logical.cluster', 'bearer.auth.identity.pool.id'] - missing_headers = [header for header in headers if header not in conf_copy] - if missing_headers: - raise ValueError("Missing required bearer configuration properties: {}" - .format(", ".join(missing_headers))) - - logical_cluster = conf_copy.pop('bearer.auth.logical.cluster') - if not isinstance(logical_cluster, str): - raise TypeError("logical cluster must be a str, not " + str(type(logical_cluster))) - - identity_pool = conf_copy.pop('bearer.auth.identity.pool.id') - if not isinstance(identity_pool, str): - raise TypeError("identity pool id must be a str, not " + str(type(identity_pool))) - - if self.bearer_auth_credentials_source == 'OAUTHBEARER': - properties_list = ['bearer.auth.client.id', 'bearer.auth.client.secret', 'bearer.auth.scope', - 'bearer.auth.issuer.endpoint.url'] - missing_properties = [prop for prop in properties_list if prop not in conf_copy] - if missing_properties: - raise ValueError("Missing required OAuth configuration properties: {}". - format(", ".join(missing_properties))) - - self.client_id = conf_copy.pop('bearer.auth.client.id') - if not isinstance(self.client_id, string_type): - raise TypeError("bearer.auth.client.id must be a str, not " + str(type(self.client_id))) - - self.client_secret = conf_copy.pop('bearer.auth.client.secret') - if not isinstance(self.client_secret, string_type): - raise TypeError("bearer.auth.client.secret must be a str, not " + str(type(self.client_secret))) - - self.scope = conf_copy.pop('bearer.auth.scope') - if not isinstance(self.scope, string_type): - raise TypeError("bearer.auth.scope must be a str, not " + str(type(self.scope))) - - self.token_endpoint = conf_copy.pop('bearer.auth.issuer.endpoint.url') - if not isinstance(self.token_endpoint, string_type): - raise TypeError("bearer.auth.issuer.endpoint.url must be a str, not " - + str(type(self.token_endpoint))) - - self.bearer_field_provider = _OAuthClient(self.client_id, self.client_secret, self.scope, - self.token_endpoint, logical_cluster, identity_pool, - self.max_retries, self.retries_wait_ms, - self.retries_max_wait_ms) - elif self.bearer_auth_credentials_source == 'STATIC_TOKEN': - if 'bearer.auth.token' not in conf_copy: - raise ValueError("Missing bearer.auth.token") - static_token = conf_copy.pop('bearer.auth.token') - self.bearer_field_provider = _StaticFieldProvider(static_token, logical_cluster, identity_pool) - if not isinstance(static_token, string_type): - raise TypeError("bearer.auth.token must be a str, not " + str(type(static_token))) - elif self.bearer_auth_credentials_source == 'CUSTOM': - custom_bearer_properties = ['bearer.auth.custom.provider.function', - 'bearer.auth.custom.provider.config'] - missing_custom_properties = [prop for prop in custom_bearer_properties if prop not in conf_copy] - if missing_custom_properties: - raise ValueError("Missing required custom OAuth configuration properties: {}". - format(", ".join(missing_custom_properties))) - - custom_function = conf_copy.pop('bearer.auth.custom.provider.function') - if not callable(custom_function): - raise TypeError("bearer.auth.custom.provider.function must be a callable, not " - + str(type(custom_function))) - - custom_config = conf_copy.pop('bearer.auth.custom.provider.config') - if not isinstance(custom_config, dict): - raise TypeError("bearer.auth.custom.provider.config must be a dict, not " - + str(type(custom_config))) - - self.bearer_field_provider = _CustomOAuthClient(custom_function, custom_config) - else: - raise ValueError('Unrecognized bearer.auth.credentials.source') - - # Any leftover keys are unknown to _RestClient - if len(conf_copy) > 0: - raise ValueError("Unrecognized properties: {}" - .format(", ".join(conf_copy.keys()))) - - def get(self, url: str, query: Optional[dict] = None) -> Any: - raise NotImplementedError() - - def post(self, url: str, body: Optional[dict], **kwargs) -> Any: - raise NotImplementedError() - - def delete(self, url: str) -> Any: - raise NotImplementedError() - - def put(self, url: str, body: Optional[dict] = None) -> Any: - raise NotImplementedError() - - -class _RestClient(_BaseRestClient): - """ - HTTP client for Confluent Schema Registry. - - See SchemaRegistryClient for configuration details. - - Args: - conf (dict): Dictionary containing _RestClient configuration - """ - - def __init__(self, conf: dict): - super().__init__(conf) - - self.session = httpx.Client( - verify=self.verify, - cert=self.cert, - auth=self.auth, - proxy=self.proxy, - timeout=self.timeout - ) - - def handle_bearer_auth(self, headers: dict) -> None: - bearer_fields = self.bearer_field_provider.get_bearer_fields() - required_fields = ['bearer.auth.token', 'bearer.auth.identity.pool.id', 'bearer.auth.logical.cluster'] - - missing_fields = [] - for field in required_fields: - if field not in bearer_fields: - missing_fields.append(field) - - if missing_fields: - raise ValueError("Missing required bearer auth fields, needs to be set in config or custom function: {}" - .format(", ".join(missing_fields))) - - headers["Authorization"] = "Bearer {}".format(bearer_fields['bearer.auth.token']) - headers['Confluent-Identity-Pool-Id'] = bearer_fields['bearer.auth.identity.pool.id'] - headers['target-sr-cluster'] = bearer_fields['bearer.auth.logical.cluster'] - - def get(self, url: str, query: Optional[dict] = None) -> Any: - return self.send_request(url, method='GET', query=query) - - def post(self, url: str, body: Optional[dict], **kwargs) -> Any: - return self.send_request(url, method='POST', body=body) - - def delete(self, url: str) -> Any: - return self.send_request(url, method='DELETE') - - def put(self, url: str, body: Optional[dict] = None) -> Any: - return self.send_request(url, method='PUT', body=body) - - def send_request( - self, url: str, method: str, body: Optional[dict] = None, - query: Optional[dict] = None - ) -> Any: - """ - Sends HTTP request to the SchemaRegistry, trying each base URL in turn. - - All unsuccessful attempts will raise a SchemaRegistryError with the - response contents. In most cases this will be accompanied by a - Schema Registry supplied error code. - - In the event the response is malformed an error_code of -1 will be used. - - Args: - url (str): Request path - - method (str): HTTP method - - body (str): Request content - - query (dict): Query params to attach to the URL - - Returns: - dict: Schema Registry response content. - """ - - headers = {'Accept': "application/vnd.schemaregistry.v1+json," - " application/vnd.schemaregistry+json," - " application/json"} - - if body is not None: - body = json.dumps(body) - headers = {'Content-Length': str(len(body)), - 'Content-Type': "application/vnd.schemaregistry.v1+json"} - - if self.bearer_auth_credentials_source: - self.handle_bearer_auth(headers) - - response = None - for i, base_url in enumerate(self.base_urls): - try: - response = self.send_http_request( - base_url, url, method, headers, body, query) - - if is_success(response.status_code): - return response.json() - - if not is_retriable(response.status_code) or i == len(self.base_urls) - 1: - break - except Exception as e: - if i == len(self.base_urls) - 1: - # Raise the exception since we have no more urls to try - raise e - - try: - raise SchemaRegistryError(response.status_code, - response.json().get('error_code'), - response.json().get('message')) - # Schema Registry may return malformed output when it hits unexpected errors - except (ValueError, KeyError, AttributeError): - raise SchemaRegistryError(response.status_code, - -1, - "Unknown Schema Registry Error: " - + str(response.content)) - - def send_http_request( - self, base_url: str, url: str, method: str, headers: Optional[dict], - body: Optional[str] = None, query: Optional[dict] = None - ) -> Response: - """ - Sends HTTP request to the SchemaRegistry. - - All unsuccessful attempts will raise a SchemaRegistryError with the - response contents. In most cases this will be accompanied by a - Schema Registry supplied error code. - - In the event the response is malformed an error_code of -1 will be used. - - Args: - base_url (str): Schema Registry base URL - - url (str): Request path - - method (str): HTTP method - - headers (dict): Headers - - body (str): Request content - - query (dict): Query params to attach to the URL - - Returns: - Response: Schema Registry response content. - """ - response = None - for i in range(self.max_retries + 1): - response = self.session.request( - method, url="/".join([base_url, url]), - headers=headers, content=body, params=query) - - if is_success(response.status_code): - return response - - if not is_retriable(response.status_code) or i >= self.max_retries: - return response - - time.sleep(full_jitter(self.retries_wait_ms, self.retries_max_wait_ms, i) / 1000) - return response - - -def is_success(status_code: int) -> bool: - return 200 <= status_code <= 299 - - -def is_retriable(status_code: int) -> bool: - return status_code in (408, 429, 500, 502, 503, 504) - - -def full_jitter(base_delay_ms: int, max_delay_ms: int, retries_attempted: int) -> float: - no_jitter_delay = base_delay_ms * (2.0 ** retries_attempted) - return random.random() * min(no_jitter_delay, max_delay_ms) - - -class _SchemaCache(object): - """ - Thread-safe cache for use with the Schema Registry Client. - - This cache may be used to retrieve schema ids, schemas or to check - known subject membership. - """ - - def __init__(self): - self.lock = Lock() - self.schema_id_index = defaultdict(dict) - self.schema_index = defaultdict(dict) - self.rs_id_index = defaultdict(dict) - self.rs_version_index = defaultdict(dict) - self.rs_schema_index = defaultdict(dict) - - def set_schema(self, subject: str, schema_id: int, schema: 'Schema'): - """ - Add a Schema identified by schema_id to the cache. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): Schema's registration id - - schema (Schema): Schema instance - """ - - with self.lock: - self.schema_id_index[subject][schema_id] = schema - self.schema_index[subject][schema] = schema_id - - def set_registered_schema(self, schema: 'Schema', registered_schema: 'RegisteredSchema'): - """ - Add a RegisteredSchema to the cache. - - Args: - registered_schema (RegisteredSchema): RegisteredSchema instance - """ - - subject = registered_schema.subject - schema_id = registered_schema.schema_id - version = registered_schema.version - with self.lock: - self.schema_id_index[subject][schema_id] = schema - self.schema_index[subject][schema] = schema_id - self.rs_id_index[subject][schema_id] = registered_schema - self.rs_version_index[subject][version] = registered_schema - self.rs_schema_index[subject][schema] = registered_schema - - def get_schema_by_id(self, subject: str, schema_id: int) -> Optional['Schema']: - """ - Get the schema instance associated with schema id from the cache. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): Id used to identify a schema - - Returns: - Schema: The schema if known; else None - """ - - with self.lock: - return self.schema_id_index.get(subject, {}).get(schema_id, None) - - def get_id_by_schema(self, subject: str, schema: 'Schema') -> Optional[int]: - """ - Get the schema id associated with schema instance from the cache. - - Args: - subject (str): The subject this schema is associated with - - schema (Schema): The schema - - Returns: - int: The schema id if known; else None - """ - - with self.lock: - return self.schema_index.get(subject, {}).get(schema, None) - - def get_registered_by_subject_schema(self, subject: str, schema: 'Schema') -> Optional['RegisteredSchema']: - """ - Get the schema associated with this schema registered under subject. - - Args: - subject (str): The subject this schema is associated with - - schema (Schema): The schema associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_schema_index.get(subject, {}).get(schema, None) - - def get_registered_by_subject_id(self, subject: str, schema_id: int) -> Optional['RegisteredSchema']: - """ - Get the schema associated with this id registered under subject. - - Args: - subject (str): The subject this schema is associated with - - schema_id (int): The schema id associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_id_index.get(subject, {}).get(schema_id, None) - - def get_registered_by_subject_version(self, subject: str, version: int) -> Optional['RegisteredSchema']: - """ - Get the schema associated with this version registered under subject. - - Args: - subject (str): The subject this schema is associated with - - version (int): The version associated with this schema - - Returns: - RegisteredSchema: The registered schema if known; else None - """ - - with self.lock: - return self.rs_version_index.get(subject, {}).get(version, None) - - def remove_by_subject(self, subject: str): - """ - Remove schemas with the given subject. - - Args: - subject (str): The subject - """ - - with self.lock: - if subject in self.schema_id_index: - del self.schema_id_index[subject] - if subject in self.schema_index: - del self.schema_index[subject] - if subject in self.rs_id_index: - del self.rs_id_index[subject] - if subject in self.rs_version_index: - del self.rs_version_index[subject] - if subject in self.rs_schema_index: - del self.rs_schema_index[subject] - - def remove_by_subject_version(self, subject: str, version: int): - """ - Remove schemas with the given subject. - - Args: - subject (str): The subject - - version (int) The version - """ - - with self.lock: - if subject in self.rs_id_index: - for schema_id, registered_schema in self.rs_id_index[subject].items(): - if registered_schema.version == version: - del self.rs_schema_index[subject][schema_id] - if subject in self.rs_schema_index: - for schema, registered_schema in self.rs_schema_index[subject].items(): - if registered_schema.version == version: - del self.rs_schema_index[subject][schema] - rs = None - if subject in self.rs_version_index: - if version in self.rs_version_index[subject]: - rs = self.rs_version_index[subject][version] - del self.rs_version_index[subject][version] - if rs is not None: - if subject in self.schema_id_index: - if rs.schema_id in self.schema_id_index[subject]: - del self.schema_id_index[subject][rs.schema_id] - if rs.schema in self.schema_index[subject]: - del self.schema_index[subject][rs.schema] - - def clear(self): - """ - Clear the cache. - """ - - with self.lock: - self.schema_id_index.clear() - self.schema_index.clear() - self.rs_id_index.clear() - self.rs_version_index.clear() - self.rs_schema_index.clear() - - -class SchemaRegistryClient(object): - """ - A Confluent Schema Registry client. - - Configuration properties (* indicates a required field): - - +------------------------------+------+-------------------------------------------------+ - | Property name | type | Description | - +==============================+======+=================================================+ - | ``url`` * | str | Comma-separated list of Schema Registry URLs. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to CA certificate file used | - | ``ssl.ca.location`` | str | to verify the Schema Registry's | - | | | private key. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to client's private key | - | | | (PEM) used for authentication. | - | ``ssl.key.location`` | str | | - | | | ``ssl.certificate.location`` must also be set. | - +------------------------------+------+-------------------------------------------------+ - | | | Path to client's public key (PEM) used for | - | | | authentication. | - | ``ssl.certificate.location`` | str | | - | | | May be set without ssl.key.location if the | - | | | private key is stored within the PEM as well. | - +------------------------------+------+-------------------------------------------------+ - | | | Client HTTP credentials in the form of | - | | | ``username:password``. | - | ``basic.auth.user.info`` | str | | - | | | By default userinfo is extracted from | - | | | the URL if present. | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``proxy`` | str | Proxy such as http://localhost:8030. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``timeout`` | int | Request timeout. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``cache.capacity`` | int | Cache capacity. Defaults to 1000. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``cache.latest.ttl.sec`` | int | TTL in seconds for caching the latest schema. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | | - | ``max.retries`` | int | Maximum retries for a request. Defaults to 2. | - | | | | - +------------------------------+------+-------------------------------------------------+ - | | | Maximum time to wait for the first retry. | - | | | When jitter is applied, the actual wait may | - | ``retries.wait.ms`` | int | be less. | - | | | | - | | | Defaults to 1000. | - +------------------------------+------+-------------------------------------------------+ - - Args: - conf (dict): Schema Registry client configuration. - - See Also: - `Confluent Schema Registry documentation `_ - """ # noqa: E501 - - def __init__(self, conf: dict): - self._conf = conf - self._rest_client = _RestClient(conf) - self._cache = _SchemaCache() - cache_capacity = self._rest_client.cache_capacity - cache_ttl = self._rest_client.cache_latest_ttl_sec - if cache_ttl is not None: - self._latest_version_cache = TTLCache(cache_capacity, cache_ttl) - self._latest_with_metadata_cache = TTLCache(cache_capacity, cache_ttl) - else: - self._latest_version_cache = LRUCache(cache_capacity) - self._latest_with_metadata_cache = LRUCache(cache_capacity) - - def __enter__(self): - return self - - def __exit__(self, *args): - if self._rest_client is not None: - self._rest_client.session.close() - - def config(self): - return self._conf - - def register_schema( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False - ) -> int: - """ - Registers a schema under ``subject_name``. - - Args: - subject_name (str): subject to register a schema under - schema (Schema): Schema instance to register - normalize_schemas (bool): Normalize schema before registering - - Returns: - int: Schema id - - Raises: - SchemaRegistryError: if Schema violates this subject's - Compatibility policy or is otherwise invalid. - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - registered_schema = self.register_schema_full_response(subject_name, schema, normalize_schemas) - return registered_schema.schema_id - - def register_schema_full_response( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False - ) -> 'RegisteredSchema': - """ - Registers a schema under ``subject_name``. - - Args: - subject_name (str): subject to register a schema under - schema (Schema): Schema instance to register - normalize_schemas (bool): Normalize schema before registering - - Returns: - int: Schema id - - Raises: - SchemaRegistryError: if Schema violates this subject's - Compatibility policy or is otherwise invalid. - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - schema_id = self._cache.get_id_by_schema(subject_name, schema) - if schema_id is not None: - return RegisteredSchema(schema_id, schema, subject_name, None) - - request = schema.to_dict() - - response = self._rest_client.post( - 'subjects/{}/versions?normalize={}'.format(_urlencode(subject_name), normalize_schemas), - body=request) - - registered_schema = RegisteredSchema.from_dict(response) - - # The registered schema may not be fully populated - self._cache.set_schema(subject_name, registered_schema.schema_id, schema) - - return registered_schema - - def get_schema( - self, schema_id: int, subject_name: Optional[str] = None, fmt: Optional[str] = None - ) -> 'Schema': - """ - Fetches the schema associated with ``schema_id`` from the - Schema Registry. The result is cached so subsequent attempts will not - require an additional round-trip to the Schema Registry. - - Args: - schema_id (int): Schema id - subject_name (str): Subject name the schema is registered under - fmt (str): Format of the schema - - Returns: - Schema: Schema instance identified by the ``schema_id`` - - Raises: - SchemaRegistryError: If schema can't be found. - - See Also: - `GET Schema API Reference `_ - """ # noqa: E501 - - schema = self._cache.get_schema_by_id(subject_name, schema_id) - if schema is not None: - return schema - - query = {'subject': subject_name} if subject_name is not None else None - if fmt is not None: - if query is not None: - query['format'] = fmt - else: - query = {'format': fmt} - response = self._rest_client.get('schemas/ids/{}'.format(schema_id), query) - - schema = Schema.from_dict(response) - - self._cache.set_schema(subject_name, schema_id, schema) - - return schema - - def lookup_schema( - self, subject_name: str, schema: 'Schema', - normalize_schemas: bool = False, deleted: bool = False - ) -> 'RegisteredSchema': - """ - Returns ``schema`` registration information for ``subject``. - - Args: - subject_name (str): Subject name the schema is registered under - schema (Schema): Schema instance. - normalize_schemas (bool): Normalize schema before registering - deleted (bool): Whether to include deleted schemas. - - Returns: - RegisteredSchema: Subject registration information for this schema. - - Raises: - SchemaRegistryError: If schema or subject can't be found - - See Also: - `POST Subject API Reference `_ - """ # noqa: E501 - - registered_schema = self._cache.get_registered_by_subject_schema(subject_name, schema) - if registered_schema is not None: - return registered_schema - - request = schema.to_dict() - - response = self._rest_client.post('subjects/{}?normalize={}&deleted={}' - .format(_urlencode(subject_name), normalize_schemas, deleted), - body=request) - - result = RegisteredSchema.from_dict(response) - - # Ensure the schema matches the input - registered_schema = RegisteredSchema( - schema_id=result.schema_id, - subject=result.subject, - version=result.version, - schema=schema, - ) - - self._cache.set_registered_schema(schema, registered_schema) - - return registered_schema - - def get_subjects(self) -> List[str]: - """ - List all subjects registered with the Schema Registry - - Returns: - list(str): Registered subject names - - Raises: - SchemaRegistryError: if subjects can't be found - - See Also: - `GET subjects API Reference `_ - """ # noqa: E501 - - return self._rest_client.get('subjects') - - def delete_subject(self, subject_name: str, permanent: bool = False) -> List[int]: - """ - Deletes the specified subject and its associated compatibility level if - registered. It is recommended to use this API only when a topic needs - to be recycled or in development environments. - - Args: - subject_name (str): subject name - permanent (bool): True for a hard delete, False (default) for a soft delete - - Returns: - list(int): Versions deleted under this subject - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `DELETE Subject API Reference `_ - """ # noqa: E501 - - if permanent: - versions = self._rest_client.delete('subjects/{}?permanent=true' - .format(_urlencode(subject_name))) - self._cache.remove_by_subject(subject_name) - else: - versions = self._rest_client.delete('subjects/{}' - .format(_urlencode(subject_name))) - - return versions - - def get_latest_version( - self, subject_name: str, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves latest registered version for subject - - Args: - subject_name (str): Subject name. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - - See Also: - `GET Subject Version API Reference `_ - """ # noqa: E501 - - registered_schema = self._latest_version_cache.get(subject_name, None) - if registered_schema is not None: - return registered_schema - - query = {'format': fmt} if fmt is not None else None - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - 'latest'), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._latest_version_cache[subject_name] = registered_schema - - return registered_schema - - def get_latest_with_metadata( - self, subject_name: str, metadata: Dict[str, str], - deleted: bool = False, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves latest registered version for subject with the given metadata - - Args: - subject_name (str): Subject name. - metadata (dict): The key-value pairs for the metadata. - deleted (bool): Whether to include deleted schemas. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - """ # noqa: E501 - - cache_key = (subject_name, frozenset(metadata.items()), deleted) - registered_schema = self._latest_with_metadata_cache.get(cache_key, None) - if registered_schema is not None: - return registered_schema - - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - keys = metadata.keys() - if keys: - query['key'] = [_urlencode(key) for key in keys] - query['value'] = [_urlencode(metadata[key]) for key in keys] - - response = self._rest_client.get('subjects/{}/metadata' - .format(_urlencode(subject_name)), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._latest_with_metadata_cache[cache_key] = registered_schema - - return registered_schema - - def get_version( - self, subject_name: str, version: int, - deleted: bool = False, fmt: Optional[str] = None - ) -> 'RegisteredSchema': - """ - Retrieves a specific schema registered under ``subject_name``. - - Args: - subject_name (str): Subject name. - version (int): version number. Defaults to latest version. - deleted (bool): Whether to include deleted schemas. - fmt (str): Format of the schema - - Returns: - RegisteredSchema: Registration information for this version. - - Raises: - SchemaRegistryError: if the version can't be found or is invalid. - - See Also: - `GET Subject Version API Reference `_ - """ # noqa: E501 - - registered_schema = self._cache.get_registered_by_subject_version(subject_name, version) - if registered_schema is not None: - return registered_schema - - query = {'deleted': deleted, 'format': fmt} if fmt is not None else {'deleted': deleted} - response = self._rest_client.get('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version), query) - - registered_schema = RegisteredSchema.from_dict(response) - - self._cache.set_registered_schema(registered_schema.schema, registered_schema) - - return registered_schema - - def get_versions(self, subject_name: str) -> List[int]: - """ - Get a list of all versions registered with this subject. - - Args: - subject_name (str): Subject name. - - Returns: - list(int): Registered versions - - Raises: - SchemaRegistryError: If subject can't be found - - See Also: - `GET Subject Versions API Reference `_ - """ # noqa: E501 - - return self._rest_client.get('subjects/{}/versions'.format(_urlencode(subject_name))) - - def delete_version(self, subject_name: str, version: int, permanent: bool = False) -> int: - """ - Deletes a specific version registered to ``subject_name``. - - Args: - subject_name (str) Subject name - - version (int): Version number - - permanent (bool): True for a hard delete, False (default) for a soft delete - - Returns: - int: Version number which was deleted - - Raises: - SchemaRegistryError: if the subject or version cannot be found. - - See Also: - `Delete Subject Version API Reference `_ - """ # noqa: E501 - - if permanent: - response = self._rest_client.delete('subjects/{}/versions/{}?permanent=true' - .format(_urlencode(subject_name), - version)) - self._cache.remove_by_subject_version(subject_name, version) - else: - response = self._rest_client.delete('subjects/{}/versions/{}' - .format(_urlencode(subject_name), - version)) - - return response - - def set_compatibility(self, subject_name: Optional[str] = None, level: Optional[str] = None) -> str: - """ - Update global or subject level compatibility level. - - Args: - level (str): Compatibility level. See API reference for a list of - valid values. - - subject_name (str, optional): Subject to update. Sets global compatibility - level policy if not set. - - Returns: - str: The newly configured compatibility level. - - Raises: - SchemaRegistryError: If the compatibility level is invalid. - - See Also: - `PUT Subject Compatibility API Reference `_ - """ # noqa: E501 - - if level is None: - raise ValueError("level must be set") - - if subject_name is None: - return self._rest_client.put('config', - body={'compatibility': level.upper()}) - - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body={'compatibility': level.upper()}) - - def get_compatibility(self, subject_name: Optional[str] = None) -> str: - """ - Get the current compatibility level. - - Args: - subject_name (str, optional): Subject name. Returns global policy - if left unset. - - Returns: - str: Compatibility level for the subject if set, otherwise the global compatibility level. - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `GET Subject Compatibility API Reference `_ - """ # noqa: E501 - - if subject_name is not None: - url = 'config/{}'.format(_urlencode(subject_name)) - else: - url = 'config' - - result = self._rest_client.get(url) - return result['compatibilityLevel'] - - def test_compatibility( - self, subject_name: str, schema: 'Schema', - version: Union[int, str] = "latest" - ) -> bool: - """Test the compatibility of a candidate schema for a given subject and version - - Args: - subject_name (str): Subject name the schema is registered under - schema (Schema): Schema instance. - version (int or str, optional): Version number, or the string "latest". Defaults to "latest". - - Returns: - bool: True if the schema is compatible with the specified version - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `POST Test Compatibility API Reference `_ - """ # noqa: E501 - - request = schema.to_dict() - - response = self._rest_client.post( - 'compatibility/subjects/{}/versions/{}'.format(_urlencode(subject_name), version), body=request - ) - - return response['is_compatible'] - - def set_config( - self, subject_name: Optional[str] = None, - config: Optional['ServerConfig'] = None - ) -> 'ServerConfig': - """ - Update global or subject config. - - Args: - config (ServerConfig): Config. See API reference for a list of - valid values. - - subject_name (str, optional): Subject to update. Sets global config - if not set. - - Returns: - str: The newly configured config. - - Raises: - SchemaRegistryError: If the config is invalid. - - See Also: - `PUT Subject Config API Reference `_ - """ # noqa: E501 - - if config is None: - raise ValueError("config must be set") - - if subject_name is None: - return self._rest_client.put('config', - body=config.to_dict()) - - return self._rest_client.put('config/{}' - .format(_urlencode(subject_name)), - body=config.to_dict()) - - def get_config(self, subject_name: Optional[str] = None) -> 'ServerConfig': - """ - Get the current config. - - Args: - subject_name (str, optional): Subject name. Returns global config - if left unset. - - Returns: - ServerConfig: Config for the subject if set, otherwise the global config. - - Raises: - SchemaRegistryError: if the request was unsuccessful. - - See Also: - `GET Subject Config API Reference `_ - """ # noqa: E501 - - if subject_name is not None: - url = 'config/{}'.format(_urlencode(subject_name)) - else: - url = 'config' - - result = self._rest_client.get(url) - return ServerConfig.from_dict(result) - - def clear_latest_caches(self): - self._latest_version_cache.clear() - self._latest_with_metadata_cache.clear() - - def clear_caches(self): - self._latest_version_cache.clear() - self._latest_with_metadata_cache.clear() - self._cache.clear() - - @staticmethod - def new_client(conf: dict) -> 'SchemaRegistryClient': - from .mock_schema_registry_client import MockSchemaRegistryClient - url = conf.get("url") - if url.startswith("mock://"): - return MockSchemaRegistryClient(conf) - return SchemaRegistryClient(conf) - - -T = TypeVar("T") - - -class RuleKind(str, Enum): - CONDITION = "CONDITION" - TRANSFORM = "TRANSFORM" - - def __str__(self) -> str: - return str(self.value) - - -class RuleMode(str, Enum): - UPGRADE = "UPGRADE" - DOWNGRADE = "DOWNGRADE" - UPDOWN = "UPDOWN" - READ = "READ" - WRITE = "WRITE" - WRITEREAD = "WRITEREAD" - - def __str__(self) -> str: - return str(self.value) - - -@_attrs_define -class RuleParams: - params: Dict[str, str] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - field_dict.update(self.params) - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - rule_params = cls(params=d) - - return rule_params - - def __hash__(self): - return hash(frozenset(self.params.items())) - - -@_attrs_define(frozen=True) -class Rule: - name: Optional[str] - doc: Optional[str] - kind: Optional[RuleKind] - mode: Optional[RuleMode] - type: Optional[str] - tags: Optional[List[str]] = _attrs_field(hash=False) - params: Optional[RuleParams] - expr: Optional[str] - on_success: Optional[str] - on_failure: Optional[str] - disabled: Optional[bool] - - def to_dict(self) -> Dict[str, Any]: - name = self.name - - doc = self.doc - - kind_str: Optional[str] = None - if self.kind is not None: - kind_str = self.kind.value - - mode_str: Optional[str] = None - if self.mode is not None: - mode_str = self.mode.value - - rule_type = self.type - - tags = self.tags - - _params: Optional[Dict[str, Any]] = None - if self.params is not None: - _params = self.params.to_dict() - - expr = self.expr - - on_success = self.on_success - - on_failure = self.on_failure - - disabled = self.disabled - - field_dict: Dict[str, Any] = {} - field_dict.update({}) - if name is not None: - field_dict["name"] = name - if doc is not None: - field_dict["doc"] = doc - if kind_str is not None: - field_dict["kind"] = kind_str - if mode_str is not None: - field_dict["mode"] = mode_str - if type is not None: - field_dict["type"] = rule_type - if tags is not None: - field_dict["tags"] = tags - if _params is not None: - field_dict["params"] = _params - if expr is not None: - field_dict["expr"] = expr - if on_success is not None: - field_dict["onSuccess"] = on_success - if on_failure is not None: - field_dict["onFailure"] = on_failure - if disabled is not None: - field_dict["disabled"] = disabled - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - name = d.pop("name", None) - - doc = d.pop("doc", None) - - _kind = d.pop("kind", None) - kind: Optional[RuleKind] = None - if _kind is not None: - kind = RuleKind(_kind) - - _mode = d.pop("mode", None) - mode: Optional[RuleMode] = None - if _mode is not None: - mode = RuleMode(_mode) - - rule_type = d.pop("type", None) - - tags = cast(List[str], d.pop("tags", None)) - - _params: Optional[Dict[str, Any]] = d.pop("params", None) - params: Optional[RuleParams] = None - if _params is not None: - params = RuleParams.from_dict(_params) - - expr = d.pop("expr", None) - - on_success = d.pop("onSuccess", None) - - on_failure = d.pop("onFailure", None) - - disabled = d.pop("disabled", None) - - rule = cls( - name=name, - doc=doc, - kind=kind, - mode=mode, - type=rule_type, - tags=tags, - params=params, - expr=expr, - on_success=on_success, - on_failure=on_failure, - disabled=disabled, - ) - - return rule - - -@_attrs_define -class RuleSet: - migration_rules: Optional[List["Rule"]] = _attrs_field(hash=False) - domain_rules: Optional[List["Rule"]] = _attrs_field(hash=False) - - def to_dict(self) -> Dict[str, Any]: - _migration_rules: Optional[List[Dict[str, Any]]] = None - if self.migration_rules is not None: - _migration_rules = [] - for migration_rules_item_data in self.migration_rules: - migration_rules_item = migration_rules_item_data.to_dict() - _migration_rules.append(migration_rules_item) - - _domain_rules: Optional[List[Dict[str, Any]]] = None - if self.domain_rules is not None: - _domain_rules = [] - for domain_rules_item_data in self.domain_rules: - domain_rules_item = domain_rules_item_data.to_dict() - _domain_rules.append(domain_rules_item) - - field_dict: Dict[str, Any] = {} - field_dict.update({}) - if _migration_rules is not None: - field_dict["migrationRules"] = _migration_rules - if _domain_rules is not None: - field_dict["domainRules"] = _domain_rules - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - migration_rules = [] - _migration_rules = d.pop("migrationRules", None) - for migration_rules_item_data in _migration_rules or []: - migration_rules_item = Rule.from_dict(migration_rules_item_data) - migration_rules.append(migration_rules_item) - - domain_rules = [] - _domain_rules = d.pop("domainRules", None) - for domain_rules_item_data in _domain_rules or []: - domain_rules_item = Rule.from_dict(domain_rules_item_data) - domain_rules.append(domain_rules_item) - - rule_set = cls( - migration_rules=migration_rules, - domain_rules=domain_rules, - ) - - return rule_set - - def __hash__(self): - return hash(frozenset((self.migration_rules or []) + (self.domain_rules or []))) - - -@_attrs_define -class MetadataTags: - tags: Dict[str, List[str]] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - for prop_name, prop in self.tags.items(): - field_dict[prop_name] = prop - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - tags = {} - for prop_name, prop_dict in d.items(): - tag = cast(List[str], prop_dict) - - tags[prop_name] = tag - - metadata_tags = cls(tags=tags) - - return metadata_tags - - def __hash__(self): - return hash(frozenset(self.tags.items())) - - -@_attrs_define -class MetadataProperties: - properties: Dict[str, str] = _attrs_field(factory=dict, hash=False) - - def to_dict(self) -> Dict[str, Any]: - field_dict: Dict[str, Any] = {} - field_dict.update(self.properties) - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - metadata_properties = cls(properties=d) - - return metadata_properties - - def __hash__(self): - return hash(frozenset(self.properties.items())) - - -@_attrs_define(frozen=True) -class Metadata: - tags: Optional[MetadataTags] - properties: Optional[MetadataProperties] - sensitive: Optional[List[str]] = _attrs_field(hash=False) - - def to_dict(self) -> Dict[str, Any]: - _tags: Optional[Dict[str, Any]] = None - if self.tags is not None: - _tags = self.tags.to_dict() - - _properties: Optional[Dict[str, Any]] = None - if self.properties is not None: - _properties = self.properties.to_dict() - - sensitive: Optional[List[str]] = None - if self.sensitive is not None: - sensitive = [] - for sensitive_item in self.sensitive: - sensitive.append(sensitive_item) - - field_dict: Dict[str, Any] = {} - if _tags is not None: - field_dict["tags"] = _tags - if _properties is not None: - field_dict["properties"] = _properties - if sensitive is not None: - field_dict["sensitive"] = sensitive - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - _tags: Optional[Dict[str, Any]] = d.pop("tags", None) - tags: Optional[MetadataTags] = None - if _tags is not None: - tags = MetadataTags.from_dict(_tags) - - _properties: Optional[Dict[str, Any]] = d.pop("properties", None) - properties: Optional[MetadataProperties] = None - if _properties is not None: - properties = MetadataProperties.from_dict(_properties) - - sensitive = [] - _sensitive = d.pop("sensitive", None) - for sensitive_item in _sensitive or []: - sensitive.append(sensitive_item) - - metadata = cls( - tags=tags, - properties=properties, - sensitive=sensitive, - ) - - return metadata - - -@_attrs_define(frozen=True) -class SchemaReference: - name: Optional[str] - subject: Optional[str] - version: Optional[int] - - def to_dict(self) -> Dict[str, Any]: - name = self.name - - subject = self.subject - - version = self.version - - field_dict: Dict[str, Any] = {} - if name is not None: - field_dict["name"] = name - if subject is not None: - field_dict["subject"] = subject - if version is not None: - field_dict["version"] = version - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - name = d.pop("name", None) - - subject = d.pop("subject", None) - - version = d.pop("version", None) - - schema_reference = cls( - name=name, - subject=subject, - version=version, - ) - - return schema_reference - - -class ConfigCompatibilityLevel(str, Enum): - BACKWARD = "BACKWARD" - BACKWARD_TRANSITIVE = "BACKWARD_TRANSITIVE" - FORWARD = "FORWARD" - FORWARD_TRANSITIVE = "FORWARD_TRANSITIVE" - FULL = "FULL" - FULL_TRANSITIVE = "FULL_TRANSITIVE" - NONE = "NONE" - - def __str__(self) -> str: - return str(self.value) - - -@_attrs_define -class ServerConfig: - compatibility: Optional[ConfigCompatibilityLevel] = None - compatibility_level: Optional[ConfigCompatibilityLevel] = None - compatibility_group: Optional[str] = None - default_metadata: Optional[Metadata] = None - override_metadata: Optional[Metadata] = None - default_rule_set: Optional[RuleSet] = None - override_rule_set: Optional[RuleSet] = None - - def to_dict(self) -> Dict[str, Any]: - _compatibility: Optional[str] = None - if self.compatibility is not None: - _compatibility = self.compatibility.value - - _compatibility_level: Optional[str] = None - if self.compatibility_level is not None: - _compatibility_level = self.compatibility_level.value - - compatibility_group = self.compatibility_group - - _default_metadata: Optional[Dict[str, Any]] - if isinstance(self.default_metadata, Metadata): - _default_metadata = self.default_metadata.to_dict() - else: - _default_metadata = self.default_metadata - - _override_metadata: Optional[Dict[str, Any]] - if isinstance(self.override_metadata, Metadata): - _override_metadata = self.override_metadata.to_dict() - else: - _override_metadata = self.override_metadata - - _default_rule_set: Optional[Dict[str, Any]] - if isinstance(self.default_rule_set, RuleSet): - _default_rule_set = self.default_rule_set.to_dict() - else: - _default_rule_set = self.default_rule_set - - _override_rule_set: Optional[Dict[str, Any]] - if isinstance(self.override_rule_set, RuleSet): - _override_rule_set = self.override_rule_set.to_dict() - else: - _override_rule_set = self.override_rule_set - - field_dict: Dict[str, Any] = {} - if _compatibility is not None: - field_dict["compatibility"] = _compatibility - if _compatibility_level is not None: - field_dict["compatibilityLevel"] = _compatibility_level - if compatibility_group is not None: - field_dict["compatibilityGroup"] = compatibility_group - if _default_metadata is not None: - field_dict["defaultMetadata"] = _default_metadata - if _override_metadata is not None: - field_dict["overrideMetadata"] = _override_metadata - if _default_rule_set is not None: - field_dict["defaultRuleSet"] = _default_rule_set - if _override_rule_set is not None: - field_dict["overrideRuleSet"] = _override_rule_set - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - _compatibility = d.pop("compatibility", None) - compatibility: Optional[ConfigCompatibilityLevel] - if _compatibility is None: - compatibility = None - else: - compatibility = ConfigCompatibilityLevel(_compatibility) - - _compatibility_level = d.pop("compatibilityLevel", None) - compatibility_level: Optional[ConfigCompatibilityLevel] - if _compatibility_level is None: - compatibility_level = None - else: - compatibility_level = ConfigCompatibilityLevel(_compatibility_level) - - compatibility_group = d.pop("compatibilityGroup", None) - - def _parse_default_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - default_metadata = _parse_default_metadata(d.pop("defaultMetadata", None)) - - def _parse_override_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - override_metadata = _parse_override_metadata(d.pop("overrideMetadata", None)) - - def _parse_default_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - default_rule_set = _parse_default_rule_set(d.pop("defaultRuleSet", None)) - - def _parse_override_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - override_rule_set = _parse_override_rule_set(d.pop("overrideRuleSet", None)) - - config = cls( - compatibility=compatibility, - compatibility_level=compatibility_level, - compatibility_group=compatibility_group, - default_metadata=default_metadata, - override_metadata=override_metadata, - default_rule_set=default_rule_set, - override_rule_set=override_rule_set, - ) - - return config - - -@_attrs_define(frozen=True, cache_hash=True) -class Schema: - """ - An unregistered schema. - """ - - schema_str: Optional[str] - schema_type: Optional[str] = "AVRO" - references: Optional[List[SchemaReference]] = _attrs_field(factory=list, hash=False) - metadata: Optional[Metadata] = None - rule_set: Optional[RuleSet] = None - - def to_dict(self) -> Dict[str, Any]: - schema = self.schema_str - - schema_type = self.schema_type - - _references: Optional[List[Dict[str, Any]]] = [] - if self.references is not None: - for references_item_data in self.references: - references_item = references_item_data.to_dict() - _references.append(references_item) - - _metadata: Optional[Dict[str, Any]] = None - if isinstance(self.metadata, Metadata): - _metadata = self.metadata.to_dict() - - _rule_set: Optional[Dict[str, Any]] = None - if isinstance(self.rule_set, RuleSet): - _rule_set = self.rule_set.to_dict() - - field_dict: Dict[str, Any] = {} - if schema is not None: - field_dict["schema"] = schema - if schema_type is not None: - field_dict["schemaType"] = schema_type - if _references is not None: - field_dict["references"] = _references - if _metadata is not None: - field_dict["metadata"] = _metadata - if _rule_set is not None: - field_dict["ruleSet"] = _rule_set - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - schema = d.pop("schema", None) - - schema_type = d.pop("schemaType", "AVRO") - - references = [] - _references = d.pop("references", None) - for references_item_data in _references or []: - references_item = SchemaReference.from_dict(references_item_data) - - references.append(references_item) - - def _parse_metadata(data: object) -> Optional[Metadata]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return Metadata.from_dict(data) - - metadata = _parse_metadata(d.pop("metadata", None)) - - def _parse_rule_set(data: object) -> Optional[RuleSet]: - if data is None: - return data - if not isinstance(data, dict): - raise TypeError() - return RuleSet.from_dict(data) - - rule_set = _parse_rule_set(d.pop("ruleSet", None)) - - schema = cls( - schema_str=schema, - schema_type=schema_type, - references=references, - metadata=metadata, - rule_set=rule_set, - ) - - return schema - - -@_attrs_define(frozen=True, cache_hash=True) -class RegisteredSchema: - """ - An registered schema. - """ - - schema_id: Optional[int] - schema: Optional[Schema] - subject: Optional[str] - version: Optional[int] - - def to_dict(self) -> Dict[str, Any]: - schema = self.schema - - schema_id = self.schema_id - - subject = self.subject - - version = self.version - - field_dict: Dict[str, Any] = {} - if schema is not None: - field_dict = schema.to_dict() - if schema_id is not None: - field_dict["id"] = schema_id - if subject is not None: - field_dict["subject"] = subject - if version is not None: - field_dict["version"] = version - - return field_dict - - @classmethod - def from_dict(cls: Type[T], src_dict: Dict[str, Any]) -> T: - d = src_dict.copy() - - schema = Schema.from_dict(d) - - schema_id = d.pop("id", None) - - subject = d.pop("subject", None) - version = d.pop("version", None) - schema = cls( - schema_id=schema_id, - schema=schema, - subject=subject, - version=version, - ) +from .common.schema_registry_client import * # noqa +from ._sync.schema_registry_client import * # noqa - return schema +from .error import SchemaRegistryError # noqa diff --git a/src/confluent_kafka/schema_registry/serde.py b/src/confluent_kafka/schema_registry/serde.py index 1cfb384e1..87037dc17 100644 --- a/src/confluent_kafka/schema_registry/serde.py +++ b/src/confluent_kafka/schema_registry/serde.py @@ -16,524 +16,5 @@ # limitations under the License. # -__all__ = ['BaseSerializer', - 'BaseDeserializer', - 'FieldContext', - 'FieldRuleExecutor', - 'FieldTransform', - 'FieldTransformer', - 'FieldType', - 'ParsedSchemaCache', - 'RuleAction', - 'RuleContext', - 'RuleConditionError', - 'RuleError', - 'RuleExecutor'] - -import abc -import logging -from enum import Enum -from threading import Lock -from typing import Callable, List, Optional, Set, Dict, Any, TypeVar - -from confluent_kafka.schema_registry import RegisteredSchema -from confluent_kafka.schema_registry.schema_registry_client import RuleMode, \ - Rule, RuleKind, Schema, RuleSet -from confluent_kafka.schema_registry.wildcard_matcher import wildcard_match -from confluent_kafka.serialization import Serializer, Deserializer, \ - SerializationContext, SerializationError - - -log = logging.getLogger(__name__) - - -class FieldType(str, Enum): - RECORD = "RECORD" - ENUM = "ENUM" - ARRAY = "ARRAY" - MAP = "MAP" - COMBINED = "COMBINED" - FIXED = "FIXED" - STRING = "STRING" - BYTES = "BYTES" - INT = "INT" - LONG = "LONG" - FLOAT = "FLOAT" - DOUBLE = "DOUBLE" - BOOLEAN = "BOOLEAN" - NULL = "NULL" - - -class FieldContext(object): - __slots__ = ['containing_message', 'full_name', 'name', 'field_type', 'tags'] - - def __init__( - self, containing_message: Any, full_name: str, name: str, - field_type: FieldType, tags: Set[str] - ): - self.containing_message = containing_message - self.full_name = full_name - self.name = name - self.field_type = field_type - self.tags = tags - - def is_primitive(self) -> bool: - return self.field_type in (FieldType.INT, FieldType.LONG, FieldType.FLOAT, - FieldType.DOUBLE, FieldType.BOOLEAN, FieldType.NULL, - FieldType.STRING, FieldType.BYTES) - - def type_name(self) -> str: - return self.field_type.name - - -class RuleContext(object): - __slots__ = ['ser_ctx', 'source', 'target', 'subject', 'rule_mode', 'rule', - 'index', 'rules', 'inline_tags', 'field_transformer', '_field_contexts'] - - def __init__( - self, ser_ctx: SerializationContext, source: Optional[Schema], - target: Optional[Schema], subject: str, rule_mode: RuleMode, rule: Rule, - index: int, rules: List[Rule], inline_tags: Optional[Dict[str, Set[str]]], field_transformer - ): - self.ser_ctx = ser_ctx - self.source = source - self.target = target - self.subject = subject - self.rule_mode = rule_mode - self.rule = rule - self.index = index - self.rules = rules - self.inline_tags = inline_tags - self.field_transformer = field_transformer - self._field_contexts: List[FieldContext] = [] - - def get_parameter(self, name: str) -> Optional[str]: - params = self.rule.params - if params is not None: - value = params.params.get(name) - if value is not None: - return value - if (self.target is not None - and self.target.metadata is not None - and self.target.metadata.properties is not None): - value = self.target.metadata.properties.properties.get(name) - if value is not None: - return value - return None - - def _get_inline_tags(self, name: str) -> Set[str]: - if self.inline_tags is None: - return set() - return self.inline_tags.get(name, set()) - - def current_field(self) -> Optional[FieldContext]: - if not self._field_contexts: - return None - return self._field_contexts[-1] - - def enter_field( - self, containing_message: Any, full_name: str, name: str, - field_type: FieldType, tags: Optional[Set[str]] - ) -> FieldContext: - all_tags = set(tags if tags is not None else self._get_inline_tags(full_name)) - all_tags.update(self.get_tags(full_name)) - field_context = FieldContext(containing_message, full_name, name, field_type, all_tags) - self._field_contexts.append(field_context) - return field_context - - def get_tags(self, full_name: str) -> Set[str]: - result = set() - if (self.target is not None - and self.target.metadata is not None - and self.target.metadata.tags is not None): - tags = self.target.metadata.tags.tags - for k, v in tags.items(): - if wildcard_match(full_name, k): - result.update(v) - return result - - def exit_field(self): - if self._field_contexts: - self._field_contexts.pop() - - -FieldTransform = Callable[[RuleContext, FieldContext, Any], Any] - - -FieldTransformer = Callable[[RuleContext, FieldTransform, Any], Any] - - -class RuleBase(metaclass=abc.ABCMeta): - def configure(self, client_conf: dict, rule_conf: dict): - pass - - @abc.abstractmethod - def type(self) -> str: - raise NotImplementedError() - - def close(self): - pass - - -class RuleExecutor(RuleBase): - @abc.abstractmethod - def transform(self, ctx: RuleContext, message: Any) -> Any: - raise NotImplementedError() - - -class FieldRuleExecutor(RuleExecutor): - @abc.abstractmethod - def new_transform(self, ctx: RuleContext) -> FieldTransform: - raise NotImplementedError() - - def transform(self, ctx: RuleContext, message: Any) -> Any: - # TODO preserve source - if ctx.rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): - for i in range(ctx.index): - other_rule = ctx.rules[i] - if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): - # ignore this transform if an earlier one has the same tag - return message - elif ctx.rule_mode == RuleMode.READ or ctx.rule_mode == RuleMode.DOWNGRADE: - for i in range(ctx.index + 1, len(ctx.rules)): - other_rule = ctx.rules[i] - if FieldRuleExecutor.are_transforms_with_same_tag(ctx.rule, other_rule): - # ignore this transform if a later one has the same tag - return message - return ctx.field_transformer(ctx, self.new_transform(ctx), message) - - @staticmethod - def are_transforms_with_same_tag(rule1: Rule, rule2: Rule) -> bool: - return (bool(rule1.tags) - and rule1.kind == RuleKind.TRANSFORM - and rule1.kind == rule2.kind - and rule1.mode == rule2.mode - and rule1.type == rule2.type - and rule1.tags == rule2.tags) - - -class RuleAction(RuleBase): - @abc.abstractmethod - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - raise NotImplementedError() - - -class ErrorAction(RuleAction): - def type(self) -> str: - return 'ERROR' - - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - if ex is None: - raise SerializationError() - else: - raise SerializationError() from ex - - -class NoneAction(RuleAction): - def type(self) -> str: - return 'NONE' - - def run(self, ctx: RuleContext, message: Any, ex: Optional[Exception]): - pass - - -class RuleError(Exception): - pass - - -class RuleConditionError(RuleError): - def __init__(self, rule: Rule): - super().__init__(RuleConditionError.error_message(rule)) - - @staticmethod - def error_message(rule: Rule) -> str: - if rule.doc: - return rule.doc - elif rule.expr: - return f"Rule expr failed: {rule.expr}" - else: - return f"Rule failed: {rule.name}" - - -class Migration(object): - __slots__ = ['rule_mode', 'source', 'target'] - - def __init__( - self, rule_mode: RuleMode, source: Optional[RegisteredSchema], - target: Optional[RegisteredSchema] - ): - self.rule_mode = rule_mode - self.source = source - self.target = target - - -class BaseSerde(object): - __slots__ = ['_use_schema_id', '_use_latest_version', '_use_latest_with_metadata', - '_registry', '_rule_registry', '_subject_name_func', - '_field_transformer'] - - def _get_reader_schema(self, subject: str, fmt: Optional[str] = None) -> Optional[RegisteredSchema]: - if self._use_schema_id is not None: - schema = self._registry.get_schema(self._use_schema_id, subject, fmt) - return self._registry.lookup_schema(subject, schema, False, True) - if self._use_latest_with_metadata is not None: - return self._registry.get_latest_with_metadata( - subject, self._use_latest_with_metadata, True, fmt) - if self._use_latest_version: - return self._registry.get_latest_version(subject, fmt) - return None - - def _execute_rules( - self, ser_ctx: SerializationContext, subject: str, - rule_mode: RuleMode, - source: Optional[Schema], target: Optional[Schema], - message: Any, inline_tags: Optional[Dict[str, Set[str]]], - field_transformer: Optional[FieldTransformer] - ) -> Any: - if message is None or target is None: - return message - rules: Optional[List[Rule]] = None - if rule_mode == RuleMode.UPGRADE: - if target is not None and target.rule_set is not None: - rules = target.rule_set.migration_rules - elif rule_mode == RuleMode.DOWNGRADE: - if source is not None and source.rule_set is not None: - rules = source.rule_set.migration_rules - rules = rules[:] if rules else [] - rules.reverse() - else: - if target is not None and target.rule_set is not None: - rules = target.rule_set.domain_rules - if rule_mode == RuleMode.READ: - # Execute read rules in reverse order for symmetry - rules = rules[:] if rules else [] - rules.reverse() - - if not rules: - return message - - for index in range(len(rules)): - rule = rules[index] - if self._is_disabled(rule): - continue - if rule.mode == RuleMode.WRITEREAD: - if rule_mode != RuleMode.READ and rule_mode != RuleMode.WRITE: - continue - elif rule.mode == RuleMode.UPDOWN: - if rule_mode != RuleMode.UPGRADE and rule_mode != RuleMode.DOWNGRADE: - continue - elif rule.mode != rule_mode: - continue - - ctx = RuleContext(ser_ctx, source, target, subject, rule_mode, rule, - index, rules, inline_tags, field_transformer) - rule_executor = self._rule_registry.get_executor(rule.type.upper()) - if rule_executor is None: - self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), message, - RuleError(f"Could not find rule executor of type {rule.type}"), - 'ERROR') - return message - try: - result = rule_executor.transform(ctx, message) - if rule.kind == RuleKind.CONDITION: - if not result: - raise RuleConditionError(rule) - elif rule.kind == RuleKind.TRANSFORM: - message = result - self._run_action( - ctx, rule_mode, rule, - self._get_on_failure(rule) if message is None else self._get_on_success(rule), - message, None, - 'ERROR' if message is None else 'NONE') - except SerializationError: - raise - except Exception as e: - self._run_action(ctx, rule_mode, rule, self._get_on_failure(rule), - message, e, 'ERROR') - return message - - def _get_on_success(self, rule: Rule) -> Optional[str]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.on_success is not None: - return override.on_success - return rule.on_success - - def _get_on_failure(self, rule: Rule) -> Optional[str]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.on_failure is not None: - return override.on_failure - return rule.on_failure - - def _is_disabled(self, rule: Rule) -> Optional[bool]: - override = self._rule_registry.get_override(rule.type) - if override is not None and override.disabled is not None: - return override.disabled - return rule.disabled - - def _run_action( - self, ctx: RuleContext, rule_mode: RuleMode, rule: Rule, - action: Optional[str], message: Any, - ex: Optional[Exception], default_action: str - ): - action_name = self._get_rule_action_name(rule, rule_mode, action) - if action_name is None: - action_name = default_action - rule_action = self._get_rule_action(ctx, action_name) - if rule_action is None: - log.error("Could not find rule action of type %s", action_name) - raise RuleError(f"Could not find rule action of type {action_name}") - try: - rule_action.run(ctx, message, ex) - except SerializationError: - raise - except Exception as e: - log.warning("Could not run post-rule action %s: %s", action_name, e) - - def _get_rule_action_name( - self, rule: Rule, rule_mode: RuleMode, action_name: Optional[str] - ) -> Optional[str]: - if action_name is None or action_name == "": - return None - if rule.mode in (RuleMode.WRITEREAD, RuleMode.UPDOWN) and ',' in action_name: - parts = action_name.split(',') - if rule_mode in (RuleMode.WRITE, RuleMode.UPGRADE): - return parts[0] - elif rule_mode in (RuleMode.READ, RuleMode.DOWNGRADE): - return parts[1] - return action_name - - def _get_rule_action(self, ctx: RuleContext, action_name: str) -> Optional[RuleAction]: - if action_name == 'ERROR': - return ErrorAction() - elif action_name == 'NONE': - return NoneAction() - return self._rule_registry.get_action(action_name) - - -class BaseSerializer(BaseSerde, Serializer): - __slots__ = ['_auto_register', '_normalize_schemas'] - - -class BaseDeserializer(BaseSerde, Deserializer): - __slots__ = [] - - def _has_rules(self, rule_set: RuleSet, mode: RuleMode) -> bool: - if rule_set is None: - return False - if mode in (RuleMode.UPGRADE, RuleMode.DOWNGRADE): - return any(rule.mode == mode or rule.mode == RuleMode.UPDOWN - for rule in rule_set.migration_rules or []) - elif mode == RuleMode.UPDOWN: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) - elif mode in (RuleMode.WRITE, RuleMode.READ): - return any(rule.mode == mode or rule.mode == RuleMode.WRITEREAD - for rule in rule_set.domain_rules or []) - elif mode == RuleMode.WRITEREAD: - return any(rule.mode == mode for rule in rule_set.migration_rules or []) - return False - - def _get_migrations( - self, subject: str, source_info: Schema, - target: RegisteredSchema, fmt: Optional[str] - ) -> List[Migration]: - source = self._registry.lookup_schema(subject, source_info, False, True) - migrations = [] - if source.version < target.version: - migration_mode = RuleMode.UPGRADE - first = source - last = target - elif source.version > target.version: - migration_mode = RuleMode.DOWNGRADE - first = target - last = source - else: - return migrations - previous: Optional[RegisteredSchema] = None - versions = self._get_schemas_between(subject, first, last, fmt) - for i in range(len(versions)): - version = versions[i] - if i == 0: - previous = version - continue - if version.schema.rule_set is not None and self._has_rules(version.schema.rule_set, migration_mode): - if migration_mode == RuleMode.UPGRADE: - migration = Migration(migration_mode, previous, version) - else: - migration = Migration(migration_mode, version, previous) - migrations.append(migration) - previous = version - if migration_mode == RuleMode.DOWNGRADE: - migrations.reverse() - return migrations - - def _get_schemas_between( - self, subject: str, first: RegisteredSchema, - last: RegisteredSchema, fmt: Optional[str] = None - ) -> List[RegisteredSchema]: - if last.version - first.version <= 1: - return [first, last] - version1 = first.version - version2 = last.version - result = [first] - for i in range(version1 + 1, version2): - result.append(self._registry.get_version(subject, i, True, fmt)) - result.append(last) - return result - - def _execute_migrations( - self, ser_ctx: SerializationContext, subject: str, - migrations: List[Migration], message: Any - ) -> Any: - for migration in migrations: - message = self._execute_rules(ser_ctx, subject, migration.rule_mode, - migration.source.schema, migration.target.schema, - message, None, None) - return message - - -T = TypeVar("T") - - -class ParsedSchemaCache(object): - """ - Thread-safe cache for parsed schemas - """ - - def __init__(self): - self.lock = Lock() - self.parsed_schemas = {} - - def set(self, schema: Schema, parsed_schema: T): - """ - Add a Schema identified by schema_id to the cache. - - Args: - schema (Schema): The schema - - parsed_schema (Any): The parsed schema - """ - - with self.lock: - self.parsed_schemas[schema] = parsed_schema - - def get_parsed_schema(self, schema: Schema) -> Optional[T]: - """ - Get the parsed schema associated with the schema - - Args: - schema (Schema): The schema - - Returns: - The parsed schema if known; else None - """ - - with self.lock: - return self.parsed_schemas.get(schema, None) - - def clear(self): - """ - Clear the cache. - """ - - with self.lock: - self.parsed_schemas.clear() +from .common.serde import * # noqa +from ._sync.serde import * # noqa diff --git a/tests/integration/schema_registry/_sync/__init__.py b/tests/integration/schema_registry/_sync/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/schema_registry/test_api_client.py b/tests/integration/schema_registry/_sync/test_api_client.py similarity index 100% rename from tests/integration/schema_registry/test_api_client.py rename to tests/integration/schema_registry/_sync/test_api_client.py diff --git a/tests/integration/schema_registry/test_avro_serializers.py b/tests/integration/schema_registry/_sync/test_avro_serializers.py similarity index 99% rename from tests/integration/schema_registry/test_avro_serializers.py rename to tests/integration/schema_registry/_sync/test_avro_serializers.py index 4140ad600..322637ae7 100644 --- a/tests/integration/schema_registry/test_avro_serializers.py +++ b/tests/integration/schema_registry/_sync/test_avro_serializers.py @@ -179,9 +179,11 @@ def _references_test_common(kafka_cluster, awarded_user, serializer_schema, dese producer = kafka_cluster.producer(value_serializer=value_serializer) producer.produce(topic, value=awarded_user, partition=0) + producer.flush() consumer = kafka_cluster.consumer(value_deserializer=value_deserializer) + consumer.assign([TopicPartition(topic, 0)]) msg = consumer.poll() diff --git a/tests/integration/schema_registry/test_json_serializers.py b/tests/integration/schema_registry/_sync/test_json_serializers.py similarity index 99% rename from tests/integration/schema_registry/test_json_serializers.py rename to tests/integration/schema_registry/_sync/test_json_serializers.py index 5b6700438..ae67c30f2 100644 --- a/tests/integration/schema_registry/test_json_serializers.py +++ b/tests/integration/schema_registry/_sync/test_json_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition from confluent_kafka.error import ConsumeError, ValueSerializationError -from confluent_kafka.schema_registry import SchemaReference, Schema +from confluent_kafka.schema_registry import SchemaReference, Schema, SchemaRegistryClient from confluent_kafka.schema_registry.json_schema import (JSONSerializer, JSONDeserializer) @@ -404,7 +404,7 @@ def test_json_record_deserialization_mismatch(kafka_cluster, load_file): consumer.poll() -def _register_referenced_schemas(sr, load_file): +def _register_referenced_schemas(sr: SchemaRegistryClient, load_file): sr.register_schema("product", Schema(load_file("product.json"), 'JSON')) sr.register_schema("customer", Schema(load_file("customer.json"), 'JSON')) sr.register_schema("order_details", Schema(load_file("order_details.json"), 'JSON', [ diff --git a/tests/integration/schema_registry/test_proto_serializers.py b/tests/integration/schema_registry/_sync/test_proto_serializers.py similarity index 96% rename from tests/integration/schema_registry/test_proto_serializers.py rename to tests/integration/schema_registry/_sync/test_proto_serializers.py index 16de4ea6b..54e458152 100644 --- a/tests/integration/schema_registry/test_proto_serializers.py +++ b/tests/integration/schema_registry/_sync/test_proto_serializers.py @@ -19,7 +19,7 @@ from confluent_kafka import TopicPartition, KafkaException, KafkaError from confluent_kafka.error import ConsumeError from confluent_kafka.schema_registry.protobuf import ProtobufSerializer, ProtobufDeserializer -from .data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ +from tests.integration.schema_registry.data.proto import metadata_proto_pb2, NestedTestProto_pb2, TestProto_pb2, \ PublicTestProto_pb2 from tests.integration.schema_registry.data.proto.DependencyTestProto_pb2 import DependencyMessage from tests.integration.schema_registry.data.proto.exampleProtoCriteo_pb2 import ClickCas @@ -92,7 +92,7 @@ def test_protobuf_reference_registration(kafka_cluster, pb2, expected_refs): producer.produce(topic, key=pb2(), partition=0) producer.flush() - registered_refs = sr.get_schema(serializer._schema_id).references + registered_refs = (sr.get_schema(serializer._schema_id)).references assert expected_refs.sort() == [ref.name for ref in registered_refs].sort() diff --git a/tests/schema_registry/test_bearer_field_provider.py b/tests/schema_registry/test_bearer_field_provider.py index d67804a12..a6dfc8eb0 100644 --- a/tests/schema_registry/test_bearer_field_provider.py +++ b/tests/schema_registry/test_bearer_field_provider.py @@ -77,8 +77,8 @@ def update_token2(): def test_generate_token_retry_logic(): oauth_client = _OAuthClient('id', 'secret', 'scope', 'endpoint', TEST_CLUSTER, TEST_POOL, 5, 1000, 20000) - with (patch("confluent_kafka.schema_registry.schema_registry_client.time.sleep") as mock_sleep, - patch("confluent_kafka.schema_registry.schema_registry_client.full_jitter") as mock_jitter): + with (patch("confluent_kafka.schema_registry._sync.schema_registry_client.time.sleep") as mock_sleep, + patch("confluent_kafka.schema_registry._sync.schema_registry_client.full_jitter") as mock_jitter): with pytest.raises(OAuthTokenError): oauth_client.generate_access_token()