diff --git a/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java new file mode 100644 index 0000000000000..a514380bee46e --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.serialization; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * Serialization schema that extracts a specific field from a {@link Row} and returns it as a byte + * array. + * + *

The field is required to be of type {@code byte[]}. This schema is particularly useful when + * using Flink with Kafka, where you may want to use one Row field as the message key and another as + * the value and perform the conversion to bytes explicitly in user code. + * + *

Example usage with Kafka: + * + *

{@code
+ * KafkaSink sink = KafkaSink.builder()
+ *     .setBootstrapServers(bootstrapServers)
+ *     .setRecordSerializer(
+ *         KafkaRecordSerializationSchema.builder()
+ *             .setTopic("my-topic")
+ *             .setKeySerializationSchema(new RowFieldExtractorSchema(0))    // field 0 as key
+ *             .setValueSerializationSchema(new RowFieldExtractorSchema(1))  // field 1 as value
+ *             .build())
+ *     .build();
+ * }
+ */ +@PublicEvolving +public class RowFieldExtractorSchema implements SerializationSchema { + + private static final long serialVersionUID = 1L; + + /** The index of the field to extract from the Row. */ + private final int fieldIndex; + + /** + * Creates a new RowFieldExtractorSchema that extracts the field at the specified index. + * + * @param fieldIndex the zero-based index of the field to extract + * @throws IllegalArgumentException if fieldIndex is negative + */ + public RowFieldExtractorSchema(int fieldIndex) { + checkArgument(fieldIndex >= 0, "Field index must be non-negative, got: %s", fieldIndex); + this.fieldIndex = fieldIndex; + } + + /** + * Gets the field index being extracted. + * + * @return the field index + */ + @VisibleForTesting + public int getFieldIndex() { + return fieldIndex; + } + + @Override + public byte[] serialize(@Nullable Row element) { + if (element == null) { + return new byte[0]; + } + + checkArgument( + fieldIndex < element.getArity(), + "Cannot access field %s in Row with arity %s", + fieldIndex, + element.getArity()); + + Object field = element.getField(fieldIndex); + if (field == null) { + return new byte[0]; + } + + if (!(field instanceof byte[])) { + throw new IllegalArgumentException( + String.format( + "Field at index %s must be of type byte[], but was %s", + fieldIndex, field.getClass().getName())); + } + + return (byte[]) field; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowFieldExtractorSchema that = (RowFieldExtractorSchema) o; + return fieldIndex == that.fieldIndex; + } + + @Override + public int hashCode() { + return fieldIndex; + } +} diff --git a/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java new file mode 100644 index 0000000000000..8b3e36aae71db --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.serialization; + +import org.apache.flink.types.Row; +import org.apache.flink.util.InstantiationUtil; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RowFieldExtractorSchema}. */ +class RowFieldExtractorSchemaTest { + + @Test + void testSerializeByteArrayField() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + byte[] value = "test-value".getBytes(StandardCharsets.UTF_8); + Row row = Row.of(value, 123); + + byte[] result = schema.serialize(row); + + assertThat(result).isEqualTo(value); + } + + @Test + void testSerializeNonByteArrayFieldThrowsException() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(1); + Row row = Row.of("key", 42); // field 1 is Integer, not byte[] + + assertThatThrownBy(() -> schema.serialize(row)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of type byte[]"); + } + + @Test + void testSerializeNullRow() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + + byte[] result = schema.serialize(null); + + assertThat(result).isEmpty(); + } + + @Test + void testSerializeNullField() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + Row row = Row.of(null, "value"); + + byte[] result = schema.serialize(row); + + assertThat(result).isEmpty(); + } + + @Test + void testSerializeOutOfBoundsIndex() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(5); + Row row = Row.of("field0", "field1"); + + assertThatThrownBy(() -> schema.serialize(row)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Cannot access field 5 in Row with arity 2"); + } + + @Test + void testNegativeFieldIndexThrowsException() { + assertThatThrownBy(() -> new RowFieldExtractorSchema(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Field index must be non-negative"); + } + + @Test + void testSerializability() throws IOException, ClassNotFoundException { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(3); + + RowFieldExtractorSchema deserialized = + InstantiationUtil.deserializeObject( + InstantiationUtil.serializeObject(schema), getClass().getClassLoader()); + + assertThat(deserialized.getFieldIndex()).isEqualTo(3); + } + + @Test + void testEquals() { + RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema3 = new RowFieldExtractorSchema(2); + + assertThat(schema1).isEqualTo(schema2); + assertThat(schema1).isNotEqualTo(schema3); + } + + @Test + void testHashCode() { + RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1); + + assertThat(schema1.hashCode()).isEqualTo(schema2.hashCode()); + } +} diff --git a/flink-python/pyflink/common/serialization.py b/flink-python/pyflink/common/serialization.py index 99dc540c47a60..6025bd76c9a24 100644 --- a/flink-python/pyflink/common/serialization.py +++ b/flink-python/pyflink/common/serialization.py @@ -25,7 +25,8 @@ 'SimpleStringSchema', 'ByteArraySchema', 'Encoder', - 'BulkWriterFactory' + 'BulkWriterFactory', + 'RowFieldExtractorSchema', ] @@ -35,6 +36,7 @@ class SerializationSchema(object): into a different serialized representation. Most data sinks (for example Apache Kafka) require the data to be handed to them in a specific format (for example as byte strings). """ + def __init__(self, j_serialization_schema=None): self._j_serialization_schema = j_serialization_schema @@ -48,6 +50,7 @@ class DeserializationSchema(object): In addition, the DeserializationSchema describes the produced type which lets Flink create internal serializers and structures to handle the type. """ + def __init__(self, j_deserialization_schema=None): self._j_deserialization_schema = j_deserialization_schema @@ -126,3 +129,53 @@ def __init__(self, j_bulk_writer_factory, row_type): def get_row_type(self): return self._row_type + + +class RowFieldExtractorSchema(SerializationSchema): + """ + Serialization schema that extracts a specific field from a Row and returns it as a + byte array. The field at the specified index MUST be of type bytes (byte array). + This schema is particularly useful when using Flink with Kafka, where you may want to use a + specific field as the message key for partition routing. + The field being extracted must already be a byte array. Users are responsible for + converting their data to bytes before passing it to this schema. + + Example usage with Kafka: + >>> from pyflink.common.serialization import RowFieldExtractorSchema + >>> from pyflink.datastream.connectors.kafka import KafkaSink, \ + KafkaRecordSerializationSchema + >>> + >>> # User must convert data to bytes beforehand + >>> # For example: Row.of(b"key-bytes", b"value-bytes") + >>> + >>> sink = KafkaSink.builder() \\ + ... .set_bootstrap_servers("localhost:9092") \\ + ... .set_record_serializer( + ... KafkaRecordSerializationSchema.builder() + ... .set_topic("my-topic") + ... .set_key_serialization_schema(RowFieldExtractorSchema(0)) + # Field 0 (must be bytes) as key + ... .set_value_serialization_schema(RowFieldExtractorSchema(1)) + # Field 1 (must be bytes) as value + ... .build() + ... ) \\ + ... .build() + + :param field_index: The zero-based index of the field to extract from the Row. + The field at this index must be of type bytes. + """ + + def __init__(self, field_index: int): + """ + Creates a new RowFieldExtractorSchema that extracts the field at the specified index. + + :param field_index: The zero-based index of the field to extract (must be non-negative). + :raises ValueError: If field_index is negative. + """ + if field_index < 0: + raise ValueError(f"Field index must be non-negative, got: {field_index}") + gateway = get_gateway() + j_row_field_extractor_schema = gateway.jvm.org.apache.flink.api.common.serialization \ + .RowFieldExtractorSchema(field_index) + super(RowFieldExtractorSchema, self).__init__( + j_serialization_schema=j_row_field_extractor_schema) diff --git a/flink-python/pyflink/common/tests/test_serialization_schemas.py b/flink-python/pyflink/common/tests/test_serialization_schemas.py index bbbe8b8f123af..7f10da4d8415b 100644 --- a/flink-python/pyflink/common/tests/test_serialization_schemas.py +++ b/flink-python/pyflink/common/tests/test_serialization_schemas.py @@ -15,8 +15,13 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -from pyflink.common.serialization import SimpleStringSchema, ByteArraySchema +from pyflink.common.serialization import ( + SimpleStringSchema, + ByteArraySchema, + RowFieldExtractorSchema, +) from pyflink.testing.test_case_utils import PyFlinkTestCase +from pyflink.java_gateway import get_gateway class SimpleStringSchemaTests(PyFlinkTestCase): @@ -39,3 +44,110 @@ def test_simple_byte_schema(self): simple_byte_schema._j_serialization_schema.serialize(expected_bytes)) self.assertEqual(expected_bytes, simple_byte_schema._j_deserialization_schema .deserialize(expected_bytes)) + + +class RowFieldExtractorSchemaTests(PyFlinkTestCase): + """Tests for RowFieldExtractorSchema.""" + + def test_row_field_extractor_schema_creation(self): + """Test RowFieldExtractorSchema can be created with valid index.""" + schema = RowFieldExtractorSchema(0) + self.assertIsNotNone(schema._j_serialization_schema) + + def test_serialize_byte_array_field(self): + """Test serializing a byte array field from a Row.""" + schema = RowFieldExtractorSchema(0) + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + + # Set byte array field + test_bytes = "test-value".encode('utf-8') + j_row.setField(0, test_bytes) + j_row.setField(1, "other-data".encode('utf-8')) + + result = schema._j_serialization_schema.serialize(j_row) + self.assertEqual(test_bytes, bytes(result)) + + def test_serialize_second_field(self): + """Test serializing byte array from second field of a Row.""" + schema = RowFieldExtractorSchema(1) + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + + test_bytes = "field-1-value".encode('utf-8') + j_row.setField(0, "field-0".encode('utf-8')) + j_row.setField(1, test_bytes) + + result = schema._j_serialization_schema.serialize(j_row) + self.assertEqual(test_bytes, bytes(result)) + + def test_serialize_null_row(self): + """Test serializing null Row returns empty byte array.""" + schema = RowFieldExtractorSchema(0) + result = schema._j_serialization_schema.serialize(None) + self.assertEqual(0, len(result)) + + def test_serialize_null_field(self): + """Test serializing Row with null field returns empty byte array.""" + schema = RowFieldExtractorSchema(0) + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + j_row.setField(0, None) # null field + j_row.setField(1, "value".encode('utf-8')) + + result = schema._j_serialization_schema.serialize(j_row) + self.assertEqual(0, len(result)) + + def test_serialize_non_byte_array_raises_error(self): + """Test that non-byte-array field raises IllegalArgumentException.""" + schema = RowFieldExtractorSchema(0) + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + + # set a string instead of byte array + j_row.setField(0, "not-bytes") + j_row.setField(1, "other") + + with self.assertRaises(Exception): + schema._j_serialization_schema.serialize(j_row) + # Should get IllegalArgumentException from Java + + def test_negative_field_index_raises_error(self): + """Test that negative field index raises ValueError.""" + with self.assertRaises(ValueError) as context: + RowFieldExtractorSchema(-1) + self.assertIn("Field index must be non-negative", str(context.exception)) + + def test_get_field_index(self): + """Test that getFieldIndex returns correct value.""" + schema = RowFieldExtractorSchema(3) + field_index = schema._j_serialization_schema.getFieldIndex() + self.assertEqual(3, field_index) + + def test_multiple_schemas_with_different_indices(self): + """Test creating multiple schemas with different field indices.""" + schema0 = RowFieldExtractorSchema(0) + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(2) + + self.assertEqual(0, schema0._j_serialization_schema.getFieldIndex()) + self.assertEqual(1, schema1._j_serialization_schema.getFieldIndex()) + self.assertEqual(2, schema2._j_serialization_schema.getFieldIndex()) + + def test_schema_equals(self): + """Test that schemas with same field index are considered equal.""" + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(1) + schema3 = RowFieldExtractorSchema(2) + + self.assertTrue(schema1._j_serialization_schema.equals(schema2._j_serialization_schema)) + self.assertFalse(schema1._j_serialization_schema.equals(schema3._j_serialization_schema)) + + def test_schema_hash_code(self): + """Test that schemas with same field index have same hash code.""" + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(1) + + hash1 = schema1._j_serialization_schema.hashCode() + hash2 = schema2._j_serialization_schema.hashCode() + self.assertEqual(hash1, hash2) diff --git a/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py b/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py new file mode 100644 index 0000000000000..53a74f3edbb67 --- /dev/null +++ b/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py @@ -0,0 +1,176 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Example demonstrating RowFieldExtractorSchema usage with Kafka. + +This example shows how to use RowFieldExtractorSchema to serialize specific +Row fields as Kafka message keys and values. The schema requires fields to +be byte arrays, giving you full control over serialization. + +Requirements: + - Kafka running on localhost:9092 + - Topic 'row-extractor-example' created + - Kafka connector JAR in classpath + +Usage: + python kafka_row_field_extractor_example.py +""" + +import json +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.datastream.connectors.kafka import ( + KafkaSink, + KafkaRecordSerializationSchema, + DeliveryGuarantee +) +from pyflink.common.serialization import RowFieldExtractorSchema +from pyflink.common.typeinfo import Types +from pyflink.common import Row + + +def serialize_to_json_bytes(data): + """Helper function to serialize data to JSON byte array.""" + return json.dumps(data).encode('utf-8') + + +def serialize_key_bytes(key): + """Helper function to serialize a key to byte array.""" + return key.encode('utf-8') + + +def create_sample_data(): + """ + Create sample e-commerce events. + Each event has: user_id (key) and event_data (value). + """ + events = [ + { + "user_id": "user-001", + "event": {"type": "purchase", "item": "laptop", "price": 999.99} + }, + { + "user_id": "user-002", + "event": {"type": "view", "item": "phone", "timestamp": "2024-01-07T10:30:00"} + }, + { + "user_id": "user-001", + "event": {"type": "add_to_cart", "item": "mouse", "quantity": 2} + }, + { + "user_id": "user-003", + "event": {"type": "purchase", "item": "keyboard", "price": 79.99} + }, + ] + + # Convert to Rows with byte array fields + rows = [] + for event in events: + row = Row( + serialize_key_bytes(event["user_id"]), # Field 0: user_id as bytes (Kafka key) + serialize_to_json_bytes(event["event"]) # Field 1: event as JSON bytes (Kafka value) + ) + rows.append(row) + + return rows + + +def kafka_row_field_extractor_example(): + """ + Demonstrate RowFieldExtractorSchema with Kafka Sink. + + This example: + 1. Creates sample e-commerce events + 2. Converts user_id to bytes (for Kafka key) + 3. Converts event data to JSON bytes (for Kafka value) + 4. Uses RowFieldExtractorSchema to extract and send to Kafka + """ + + # Create execution environment + env = StreamExecutionEnvironment.get_execution_environment() + env.set_parallelism(1) + + # Generate sample data + data = create_sample_data() + + # Create DataStream with proper type information + # Row has 2 fields: both are byte arrays + ds = env.from_collection( + data, + type_info=Types.ROW([ + Types.PRIMITIVE_ARRAY(Types.BYTE()), # Field 0: user_id (key) + Types.PRIMITIVE_ARRAY(Types.BYTE()) # Field 1: event_data (value) + ]) + ) + + # Optional: Print what we're sending (for debugging) + def print_event(row): + user_id = row[0].decode('utf-8') + event_data = json.loads(row[1].decode('utf-8')) + print(f"Sending: User={user_id}, Event={event_data}") + return row + + ds = ds.map( + print_event, + output_type=Types.ROW([ + Types.PRIMITIVE_ARRAY(Types.BYTE()), + Types.PRIMITIVE_ARRAY(Types.BYTE()) + ]) + ) + + # Create Kafka Sink with RowFieldExtractorSchema + kafka_sink = KafkaSink.builder() \ + .set_bootstrap_servers("localhost:9092") \ + .set_record_serializer( + KafkaRecordSerializationSchema.builder() + .set_topic("row-extractor-example") + # Extract field 0 (user_id) as Kafka key for partitioning + .set_key_serialization_schema(RowFieldExtractorSchema(0)) + # Extract field 1 (event_data) as Kafka value + .set_value_serialization_schema(RowFieldExtractorSchema(1)) + .build() + ) \ + .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \ + .build() + + # Send to Kafka + ds.sink_to(kafka_sink) + + # Execute + env.execute("Kafka RowFieldExtractorSchema Example") + + +if __name__ == '__main__': + print("=" * 70) + print("Kafka RowFieldExtractorSchema Example") + print("=" * 70) + print("\nMake sure:") + print(" 1. Kafka is running on localhost:9092") + print(" 2. Topic 'row-extractor-example' exists") + print(" 3. Kafka connector JAR is in classpath") + print("\nTo create topic:") + print(" kafka-topics.sh --create --topic row-extractor-example \\") + print(" --bootstrap-server localhost:9092 --partitions 3") + print("\nTo consume messages:") + print(" kafka-console-consumer.sh --bootstrap-server localhost:9092 \\") + print(" --topic row-extractor-example --from-beginning \\") + print(" --property print.key=true --property key.separator=' => '") + print("\n" + "=" * 70 + "\n") + + # Run the example + kafka_row_field_extractor_example()