From a5af280e7bf4898f96bca2878dc8cd2ff886bdd5 Mon Sep 17 00:00:00 2001 From: Nflrijal Date: Sat, 20 Dec 2025 16:19:42 +0530 Subject: [PATCH 1/3] [FLINK-38189][python] Add RowFieldExtractorSchema for Row field serialization This commit introduces RowFieldExtractorSchema, a new SerializationSchema that extracts and serializes a specific field from a Row object. This is particularly useful for Kafka scenarios where keys and values need separate serialization. Changes: - Add RowFieldExtractorSchema.java with field extraction logic - Add comprehensive unit tests for Java implementation - Add Python bindings in pyflink.common.serialization - Add Python unit tests and Kafka integration tests - Add documentation and examples This closes apache/flink#38189 --- .../RowFieldExtractorSchema.java | 122 +++++++++++++++++ .../RowFieldExtractorSchemaTest.java | 119 ++++++++++++++++ flink-python/pyflink/common/serialization.py | 47 ++++++- .../tests/test_serialization_schemas.py | 127 +++++++++++++++++- 4 files changed, 413 insertions(+), 2 deletions(-) create mode 100644 flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java create mode 100644 flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java diff --git a/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java new file mode 100644 index 0000000000000..9174b58650d29 --- /dev/null +++ b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.serialization; + +import org.apache.flink.annotation.PublicEvolving; +import org.apache.flink.annotation.VisibleForTesting; +import org.apache.flink.types.Row; + +import javax.annotation.Nullable; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; + +import static org.apache.flink.util.Preconditions.checkArgument; + +/** + * Serialization schema that extracts a specific field from a {@link Row} and serializes it as a + * UTF-8 encoded byte array. + * + *

This schema is particularly useful when using Flink with Kafka, where you may want to use a + * specific field as the message key for partition routing. + * + *

By default, the serializer uses "UTF-8" for string/byte conversion. + * + *

Example usage with Kafka: + * + *

{@code
+ * KafkaSink sink = KafkaSink.builder()
+ *     .setBootstrapServers(bootstrapServers)
+ *     .setRecordSerializer(
+ *         KafkaRecordSerializationSchema.builder()
+ *             .setTopic("my-topic")
+ *             .setKeySerializationSchema(new RowFieldExtractorSchema(0))    // Use field 0 as key
+ *             .setValueSerializationSchema(new RowFieldExtractorSchema(1))  // Use field 1 as value
+ *             .build())
+ *     .build();
+ * }
+ */ +@PublicEvolving +public class RowFieldExtractorSchema implements SerializationSchema { + + private static final long serialVersionUID = 1L; + + /** The charset to use for string/byte conversion. */ + private static final Charset CHARSET = StandardCharsets.UTF_8; + + /** The index of the field to extract from the Row. */ + private final int fieldIndex; + + /** + * Creates a new RowFieldExtractorSchema that extracts the field at the specified index. + * + * @param fieldIndex the zero-based index of the field to extract + * @throws IllegalArgumentException if fieldIndex is negative + */ + public RowFieldExtractorSchema(int fieldIndex) { + checkArgument(fieldIndex >= 0, "Field index must be non-negative, got: %s", fieldIndex); + this.fieldIndex = fieldIndex; + } + + /** + * Gets the field index being extracted. + * + * @return the field index + */ + @VisibleForTesting + public int getFieldIndex() { + return fieldIndex; + } + + @Override + public byte[] serialize(@Nullable Row element) { + if (element == null) { + return new byte[0]; + } + + checkArgument( + fieldIndex < element.getArity(), + "Cannot access field %s in Row with arity %s", + fieldIndex, + element.getArity()); + + Object field = element.getField(fieldIndex); + if (field == null) { + return new byte[0]; + } + + return field.toString().getBytes(CHARSET); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + RowFieldExtractorSchema that = (RowFieldExtractorSchema) o; + return fieldIndex == that.fieldIndex; + } + + @Override + public int hashCode() { + return fieldIndex; + } +} diff --git a/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java new file mode 100644 index 0000000000000..63241a1d3543f --- /dev/null +++ b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.common.serialization; + +import org.apache.flink.types.Row; +import org.apache.flink.util.InstantiationUtil; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link RowFieldExtractorSchema}. */ +class RowFieldExtractorSchemaTest { + + @Test + void testSerializeStringField() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + Row row = Row.of("test-value", 123); + + byte[] result = schema.serialize(row); + + assertThat(new String(result, StandardCharsets.UTF_8)).isEqualTo("test-value"); + } + + @Test + void testSerializeIntegerField() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(1); + Row row = Row.of("key", 42); + + byte[] result = schema.serialize(row); + + assertThat(new String(result, StandardCharsets.UTF_8)).isEqualTo("42"); + } + + @Test + void testSerializeNullRow() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + + byte[] result = schema.serialize(null); + + assertThat(result).isEmpty(); + } + + @Test + void testSerializeNullField() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); + Row row = Row.of(null, "value"); + + byte[] result = schema.serialize(row); + + assertThat(result).isEmpty(); + } + + @Test + void testSerializeOutOfBoundsIndex() { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(5); + Row row = Row.of("field0", "field1"); + + assertThatThrownBy(() -> schema.serialize(row)) + .isInstanceOf( + IllegalArgumentException.class) // Changed from IndexOutOfBoundsException + .hasMessageContaining("Cannot access field 5 in Row with arity 2"); + } + + @Test + void testNegativeFieldIndexThrowsException() { + assertThatThrownBy(() -> new RowFieldExtractorSchema(-1)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Field index must be non-negative"); + } + + @Test + void testSerializability() throws IOException, ClassNotFoundException { + RowFieldExtractorSchema schema = new RowFieldExtractorSchema(3); + + RowFieldExtractorSchema deserialized = + InstantiationUtil.deserializeObject( + InstantiationUtil.serializeObject(schema), getClass().getClassLoader()); + + assertThat(deserialized.getFieldIndex()).isEqualTo(3); + } + + @Test + void testEquals() { + RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema3 = new RowFieldExtractorSchema(2); + + assertThat(schema1).isEqualTo(schema2); + assertThat(schema1).isNotEqualTo(schema3); + } + + @Test + void testHashCode() { + RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1); + RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1); + + assertThat(schema1.hashCode()).isEqualTo(schema2.hashCode()); + } +} diff --git a/flink-python/pyflink/common/serialization.py b/flink-python/pyflink/common/serialization.py index 99dc540c47a60..0ab86b83d1786 100644 --- a/flink-python/pyflink/common/serialization.py +++ b/flink-python/pyflink/common/serialization.py @@ -25,7 +25,8 @@ 'SimpleStringSchema', 'ByteArraySchema', 'Encoder', - 'BulkWriterFactory' + 'BulkWriterFactory', + 'RowFieldExtractorSchema', ] @@ -126,3 +127,47 @@ def __init__(self, j_bulk_writer_factory, row_type): def get_row_type(self): return self._row_type + + +class RowFieldExtractorSchema(SerializationSchema): + """ + Serialization schema that extracts a specific field from a Row and serializes it as a + UTF-8 encoded byte array. + + This schema is particularly useful when using Flink with Kafka, where you may want to use a + specific field as the message key for partition routing. + + Example usage with Kafka: + >>> from pyflink.common.serialization import RowFieldExtractorSchema + >>> from pyflink.datastream.connectors.kafka import KafkaSink, \ + KafkaRecordSerializationSchema + >>> sink = KafkaSink.builder() \\ + ... .set_bootstrap_servers("localhost:9092") \\ + ... .set_record_serializer( + ... KafkaRecordSerializationSchema.builder() + ... .set_topic("my-topic") + ... .set_key_serialization_schema(RowFieldExtractorSchema(0)) + # Field 0 as key + ... .set_value_serialization_schema(RowFieldExtractorSchema(1)) + # Field 1 as value + ... .build() + ... ) \\ + ... .build() + + :param field_index: The zero-based index of the field to extract from the Row. + """ + + def __init__(self, field_index: int): + """ + Creates a new RowFieldExtractorSchema that extracts the field at the specified index. + + :param field_index: The zero-based index of the field to extract (must be non-negative). + :raises ValueError: If field_index is negative. + """ + if field_index < 0: + raise ValueError(f"Field index must be non-negative, got: {field_index}") + gateway = get_gateway() + j_row_field_extractor_schema = gateway.jvm.org.apache.flink.api.common.serialization \ + .RowFieldExtractorSchema(field_index) + super(RowFieldExtractorSchema, self).__init__( + j_serialization_schema=j_row_field_extractor_schema) diff --git a/flink-python/pyflink/common/tests/test_serialization_schemas.py b/flink-python/pyflink/common/tests/test_serialization_schemas.py index bbbe8b8f123af..bb6e65f8b0478 100644 --- a/flink-python/pyflink/common/tests/test_serialization_schemas.py +++ b/flink-python/pyflink/common/tests/test_serialization_schemas.py @@ -15,8 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -from pyflink.common.serialization import SimpleStringSchema, ByteArraySchema +from pyflink.common.serialization import SimpleStringSchema, ByteArraySchema, RowFieldExtractorSchema from pyflink.testing.test_case_utils import PyFlinkTestCase +from pyflink.java_gateway import get_gateway class SimpleStringSchemaTests(PyFlinkTestCase): @@ -39,3 +40,127 @@ def test_simple_byte_schema(self): simple_byte_schema._j_serialization_schema.serialize(expected_bytes)) self.assertEqual(expected_bytes, simple_byte_schema._j_deserialization_schema .deserialize(expected_bytes)) + + +class RowFieldExtractorSchemaTests(PyFlinkTestCase): + """Tests for RowFieldExtractorSchema.""" + + def test_row_field_extractor_schema_creation(self): + """Test RowFieldExtractorSchema can be created with valid index.""" + schema = RowFieldExtractorSchema(0) + self.assertIsNotNone(schema._j_serialization_schema) + + def test_serialize_string_field(self): + """Test serializing a string field from a Row.""" + schema = RowFieldExtractorSchema(0) + + # Create a Java Row using constructor and setField + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) # 2 fields + j_row.setField(0, "test-value") + j_row.setField(1, 123) + + result = schema._j_serialization_schema.serialize(j_row) + expected = "test-value".encode('utf-8') + + self.assertEqual(expected, bytes(result)) + + def test_serialize_integer_field(self): + """Test serializing an integer field from a Row.""" + schema = RowFieldExtractorSchema(1) + + # Create a Java Row using constructor and setField + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + j_row.setField(0, "key") + j_row.setField(1, 42) + + result = schema._j_serialization_schema.serialize(j_row) + expected = "42".encode('utf-8') + + self.assertEqual(expected, bytes(result)) + + def test_serialize_null_row(self): + """Test serializing null Row returns empty byte array.""" + schema = RowFieldExtractorSchema(0) + result = schema._j_serialization_schema.serialize(None) + + self.assertEqual(0, len(result)) + + def test_serialize_null_field(self): + """Test serializing Row with null field returns empty byte array.""" + schema = RowFieldExtractorSchema(0) + + # Create a Java Row with null first field + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + j_row.setField(0, None) # null field + j_row.setField(1, "value") + + result = schema._j_serialization_schema.serialize(j_row) + + self.assertEqual(0, len(result)) + + def test_negative_field_index_raises_error(self): + """Test that negative field index raises ValueError.""" + with self.assertRaises(ValueError) as context: + RowFieldExtractorSchema(-1) + + self.assertIn("Field index must be non-negative", str(context.exception)) + + def test_get_field_index(self): + """Test that getFieldIndex returns correct value.""" + schema = RowFieldExtractorSchema(3) + field_index = schema._j_serialization_schema.getFieldIndex() + + self.assertEqual(3, field_index) + + def test_multiple_schemas_with_different_indices(self): + """Test creating multiple schemas with different field indices.""" + schema0 = RowFieldExtractorSchema(0) + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(2) + + self.assertEqual(0, schema0._j_serialization_schema.getFieldIndex()) + self.assertEqual(1, schema1._j_serialization_schema.getFieldIndex()) + self.assertEqual(2, schema2._j_serialization_schema.getFieldIndex()) + + def test_serialize_different_field_types(self): + """Test serializing different data types from Row fields.""" + gateway = get_gateway() + + # Test with string + schema_str = RowFieldExtractorSchema(0) + j_row_str = gateway.jvm.org.apache.flink.types.Row(2) + j_row_str.setField(0, "hello") + j_row_str.setField(1, 100) + result_str = schema_str._j_serialization_schema.serialize(j_row_str) + self.assertEqual("hello".encode('utf-8'), bytes(result_str)) + + # Test with integer + schema_int = RowFieldExtractorSchema(1) + j_row_int = gateway.jvm.org.apache.flink.types.Row(2) + j_row_int.setField(0, "key") + j_row_int.setField(1, 999) + result_int = schema_int._j_serialization_schema.serialize(j_row_int) + self.assertEqual("999".encode('utf-8'), bytes(result_int)) + + def test_schema_equals(self): + """Test that schemas with same field index are considered equal.""" + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(1) + schema3 = RowFieldExtractorSchema(2) + + # Test via Java equals method + self.assertTrue(schema1._j_serialization_schema.equals(schema2._j_serialization_schema)) + self.assertFalse(schema1._j_serialization_schema.equals(schema3._j_serialization_schema)) + + def test_schema_hash_code(self): + """Test that schemas with same field index have same hash code.""" + schema1 = RowFieldExtractorSchema(1) + schema2 = RowFieldExtractorSchema(1) + + hash1 = schema1._j_serialization_schema.hashCode() + hash2 = schema2._j_serialization_schema.hashCode() + + self.assertEqual(hash1, hash2) From b5248e406db22ffa27e45c7385f6162b143c72e1 Mon Sep 17 00:00:00 2001 From: Nflrijal Date: Sun, 21 Dec 2025 07:57:11 +0530 Subject: [PATCH 2/3] Formatting the test_serialization_schema --- .../tests/test_serialization_schemas.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/flink-python/pyflink/common/tests/test_serialization_schemas.py b/flink-python/pyflink/common/tests/test_serialization_schemas.py index bb6e65f8b0478..43a4ab62a997e 100644 --- a/flink-python/pyflink/common/tests/test_serialization_schemas.py +++ b/flink-python/pyflink/common/tests/test_serialization_schemas.py @@ -15,7 +15,11 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################ -from pyflink.common.serialization import SimpleStringSchema, ByteArraySchema, RowFieldExtractorSchema +from pyflink.common.serialization import ( + SimpleStringSchema, + ByteArraySchema, + RowFieldExtractorSchema, +) from pyflink.testing.test_case_utils import PyFlinkTestCase from pyflink.java_gateway import get_gateway @@ -53,66 +57,54 @@ def test_row_field_extractor_schema_creation(self): def test_serialize_string_field(self): """Test serializing a string field from a Row.""" schema = RowFieldExtractorSchema(0) - # Create a Java Row using constructor and setField gateway = get_gateway() j_row = gateway.jvm.org.apache.flink.types.Row(2) # 2 fields j_row.setField(0, "test-value") j_row.setField(1, 123) - result = schema._j_serialization_schema.serialize(j_row) expected = "test-value".encode('utf-8') - self.assertEqual(expected, bytes(result)) def test_serialize_integer_field(self): """Test serializing an integer field from a Row.""" schema = RowFieldExtractorSchema(1) - # Create a Java Row using constructor and setField gateway = get_gateway() j_row = gateway.jvm.org.apache.flink.types.Row(2) j_row.setField(0, "key") j_row.setField(1, 42) - result = schema._j_serialization_schema.serialize(j_row) expected = "42".encode('utf-8') - self.assertEqual(expected, bytes(result)) def test_serialize_null_row(self): """Test serializing null Row returns empty byte array.""" schema = RowFieldExtractorSchema(0) result = schema._j_serialization_schema.serialize(None) - self.assertEqual(0, len(result)) def test_serialize_null_field(self): """Test serializing Row with null field returns empty byte array.""" schema = RowFieldExtractorSchema(0) - # Create a Java Row with null first field gateway = get_gateway() j_row = gateway.jvm.org.apache.flink.types.Row(2) j_row.setField(0, None) # null field j_row.setField(1, "value") - result = schema._j_serialization_schema.serialize(j_row) - self.assertEqual(0, len(result)) def test_negative_field_index_raises_error(self): """Test that negative field index raises ValueError.""" with self.assertRaises(ValueError) as context: RowFieldExtractorSchema(-1) - self.assertIn("Field index must be non-negative", str(context.exception)) def test_get_field_index(self): """Test that getFieldIndex returns correct value.""" schema = RowFieldExtractorSchema(3) field_index = schema._j_serialization_schema.getFieldIndex() - self.assertEqual(3, field_index) def test_multiple_schemas_with_different_indices(self): @@ -120,7 +112,6 @@ def test_multiple_schemas_with_different_indices(self): schema0 = RowFieldExtractorSchema(0) schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(2) - self.assertEqual(0, schema0._j_serialization_schema.getFieldIndex()) self.assertEqual(1, schema1._j_serialization_schema.getFieldIndex()) self.assertEqual(2, schema2._j_serialization_schema.getFieldIndex()) @@ -128,7 +119,6 @@ def test_multiple_schemas_with_different_indices(self): def test_serialize_different_field_types(self): """Test serializing different data types from Row fields.""" gateway = get_gateway() - # Test with string schema_str = RowFieldExtractorSchema(0) j_row_str = gateway.jvm.org.apache.flink.types.Row(2) @@ -136,7 +126,6 @@ def test_serialize_different_field_types(self): j_row_str.setField(1, 100) result_str = schema_str._j_serialization_schema.serialize(j_row_str) self.assertEqual("hello".encode('utf-8'), bytes(result_str)) - # Test with integer schema_int = RowFieldExtractorSchema(1) j_row_int = gateway.jvm.org.apache.flink.types.Row(2) @@ -150,7 +139,6 @@ def test_schema_equals(self): schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(1) schema3 = RowFieldExtractorSchema(2) - # Test via Java equals method self.assertTrue(schema1._j_serialization_schema.equals(schema2._j_serialization_schema)) self.assertFalse(schema1._j_serialization_schema.equals(schema3._j_serialization_schema)) @@ -159,8 +147,6 @@ def test_schema_hash_code(self): """Test that schemas with same field index have same hash code.""" schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(1) - hash1 = schema1._j_serialization_schema.hashCode() hash2 = schema2._j_serialization_schema.hashCode() - self.assertEqual(hash1, hash2) From 39bca1940037ae9ec52da31d22f2c35701f57377 Mon Sep 17 00:00:00 2001 From: Noufal Rijal Date: Thu, 8 Jan 2026 00:03:25 +0530 Subject: [PATCH 3/3] [FLINK-38189][core][python] Add RowFieldExtractorSchema for Row field serialization Changes: - Add RowFieldExtractorSchema class with byte[] requirement - Add comprehensive tests for Java and PyFlink - Add PyFlink wrapper and documentation - Add example demonstrating usage with Kafka --- .../RowFieldExtractorSchema.java | 30 +-- .../RowFieldExtractorSchemaTest.java | 20 +- flink-python/pyflink/common/serialization.py | 18 +- .../tests/test_serialization_schemas.py | 73 ++++---- .../kafka_row_field_extractor_example.py | 176 ++++++++++++++++++ 5 files changed, 251 insertions(+), 66 deletions(-) create mode 100644 flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py diff --git a/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java index 9174b58650d29..a514380bee46e 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchema.java @@ -23,19 +23,15 @@ import javax.annotation.Nullable; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; - import static org.apache.flink.util.Preconditions.checkArgument; /** - * Serialization schema that extracts a specific field from a {@link Row} and serializes it as a - * UTF-8 encoded byte array. - * - *

This schema is particularly useful when using Flink with Kafka, where you may want to use a - * specific field as the message key for partition routing. + * Serialization schema that extracts a specific field from a {@link Row} and returns it as a byte + * array. * - *

By default, the serializer uses "UTF-8" for string/byte conversion. + *

The field is required to be of type {@code byte[]}. This schema is particularly useful when + * using Flink with Kafka, where you may want to use one Row field as the message key and another as + * the value and perform the conversion to bytes explicitly in user code. * *

Example usage with Kafka: * @@ -45,8 +41,8 @@ * .setRecordSerializer( * KafkaRecordSerializationSchema.builder() * .setTopic("my-topic") - * .setKeySerializationSchema(new RowFieldExtractorSchema(0)) // Use field 0 as key - * .setValueSerializationSchema(new RowFieldExtractorSchema(1)) // Use field 1 as value + * .setKeySerializationSchema(new RowFieldExtractorSchema(0)) // field 0 as key + * .setValueSerializationSchema(new RowFieldExtractorSchema(1)) // field 1 as value * .build()) * .build(); * } @@ -56,9 +52,6 @@ public class RowFieldExtractorSchema implements SerializationSchema { private static final long serialVersionUID = 1L; - /** The charset to use for string/byte conversion. */ - private static final Charset CHARSET = StandardCharsets.UTF_8; - /** The index of the field to extract from the Row. */ private final int fieldIndex; @@ -100,7 +93,14 @@ public byte[] serialize(@Nullable Row element) { return new byte[0]; } - return field.toString().getBytes(CHARSET); + if (!(field instanceof byte[])) { + throw new IllegalArgumentException( + String.format( + "Field at index %s must be of type byte[], but was %s", + fieldIndex, field.getClass().getName())); + } + + return (byte[]) field; } @Override diff --git a/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java index 63241a1d3543f..8b3e36aae71db 100644 --- a/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java +++ b/flink-core/src/test/java/org/apache/flink/api/common/serialization/RowFieldExtractorSchemaTest.java @@ -32,23 +32,24 @@ class RowFieldExtractorSchemaTest { @Test - void testSerializeStringField() { + void testSerializeByteArrayField() { RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0); - Row row = Row.of("test-value", 123); + byte[] value = "test-value".getBytes(StandardCharsets.UTF_8); + Row row = Row.of(value, 123); byte[] result = schema.serialize(row); - assertThat(new String(result, StandardCharsets.UTF_8)).isEqualTo("test-value"); + assertThat(result).isEqualTo(value); } @Test - void testSerializeIntegerField() { + void testSerializeNonByteArrayFieldThrowsException() { RowFieldExtractorSchema schema = new RowFieldExtractorSchema(1); - Row row = Row.of("key", 42); + Row row = Row.of("key", 42); // field 1 is Integer, not byte[] - byte[] result = schema.serialize(row); - - assertThat(new String(result, StandardCharsets.UTF_8)).isEqualTo("42"); + assertThatThrownBy(() -> schema.serialize(row)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("must be of type byte[]"); } @Test @@ -76,8 +77,7 @@ void testSerializeOutOfBoundsIndex() { Row row = Row.of("field0", "field1"); assertThatThrownBy(() -> schema.serialize(row)) - .isInstanceOf( - IllegalArgumentException.class) // Changed from IndexOutOfBoundsException + .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Cannot access field 5 in Row with arity 2"); } diff --git a/flink-python/pyflink/common/serialization.py b/flink-python/pyflink/common/serialization.py index 0ab86b83d1786..6025bd76c9a24 100644 --- a/flink-python/pyflink/common/serialization.py +++ b/flink-python/pyflink/common/serialization.py @@ -36,6 +36,7 @@ class SerializationSchema(object): into a different serialized representation. Most data sinks (for example Apache Kafka) require the data to be handed to them in a specific format (for example as byte strings). """ + def __init__(self, j_serialization_schema=None): self._j_serialization_schema = j_serialization_schema @@ -49,6 +50,7 @@ class DeserializationSchema(object): In addition, the DeserializationSchema describes the produced type which lets Flink create internal serializers and structures to handle the type. """ + def __init__(self, j_deserialization_schema=None): self._j_deserialization_schema = j_deserialization_schema @@ -131,30 +133,36 @@ def get_row_type(self): class RowFieldExtractorSchema(SerializationSchema): """ - Serialization schema that extracts a specific field from a Row and serializes it as a - UTF-8 encoded byte array. - + Serialization schema that extracts a specific field from a Row and returns it as a + byte array. The field at the specified index MUST be of type bytes (byte array). This schema is particularly useful when using Flink with Kafka, where you may want to use a specific field as the message key for partition routing. + The field being extracted must already be a byte array. Users are responsible for + converting their data to bytes before passing it to this schema. Example usage with Kafka: >>> from pyflink.common.serialization import RowFieldExtractorSchema >>> from pyflink.datastream.connectors.kafka import KafkaSink, \ KafkaRecordSerializationSchema + >>> + >>> # User must convert data to bytes beforehand + >>> # For example: Row.of(b"key-bytes", b"value-bytes") + >>> >>> sink = KafkaSink.builder() \\ ... .set_bootstrap_servers("localhost:9092") \\ ... .set_record_serializer( ... KafkaRecordSerializationSchema.builder() ... .set_topic("my-topic") ... .set_key_serialization_schema(RowFieldExtractorSchema(0)) - # Field 0 as key + # Field 0 (must be bytes) as key ... .set_value_serialization_schema(RowFieldExtractorSchema(1)) - # Field 1 as value + # Field 1 (must be bytes) as value ... .build() ... ) \\ ... .build() :param field_index: The zero-based index of the field to extract from the Row. + The field at this index must be of type bytes. """ def __init__(self, field_index: int): diff --git a/flink-python/pyflink/common/tests/test_serialization_schemas.py b/flink-python/pyflink/common/tests/test_serialization_schemas.py index 43a4ab62a997e..7f10da4d8415b 100644 --- a/flink-python/pyflink/common/tests/test_serialization_schemas.py +++ b/flink-python/pyflink/common/tests/test_serialization_schemas.py @@ -54,29 +54,32 @@ def test_row_field_extractor_schema_creation(self): schema = RowFieldExtractorSchema(0) self.assertIsNotNone(schema._j_serialization_schema) - def test_serialize_string_field(self): - """Test serializing a string field from a Row.""" + def test_serialize_byte_array_field(self): + """Test serializing a byte array field from a Row.""" schema = RowFieldExtractorSchema(0) - # Create a Java Row using constructor and setField gateway = get_gateway() - j_row = gateway.jvm.org.apache.flink.types.Row(2) # 2 fields - j_row.setField(0, "test-value") - j_row.setField(1, 123) + j_row = gateway.jvm.org.apache.flink.types.Row(2) + + # Set byte array field + test_bytes = "test-value".encode('utf-8') + j_row.setField(0, test_bytes) + j_row.setField(1, "other-data".encode('utf-8')) + result = schema._j_serialization_schema.serialize(j_row) - expected = "test-value".encode('utf-8') - self.assertEqual(expected, bytes(result)) + self.assertEqual(test_bytes, bytes(result)) - def test_serialize_integer_field(self): - """Test serializing an integer field from a Row.""" + def test_serialize_second_field(self): + """Test serializing byte array from second field of a Row.""" schema = RowFieldExtractorSchema(1) - # Create a Java Row using constructor and setField gateway = get_gateway() j_row = gateway.jvm.org.apache.flink.types.Row(2) - j_row.setField(0, "key") - j_row.setField(1, 42) + + test_bytes = "field-1-value".encode('utf-8') + j_row.setField(0, "field-0".encode('utf-8')) + j_row.setField(1, test_bytes) + result = schema._j_serialization_schema.serialize(j_row) - expected = "42".encode('utf-8') - self.assertEqual(expected, bytes(result)) + self.assertEqual(test_bytes, bytes(result)) def test_serialize_null_row(self): """Test serializing null Row returns empty byte array.""" @@ -87,14 +90,28 @@ def test_serialize_null_row(self): def test_serialize_null_field(self): """Test serializing Row with null field returns empty byte array.""" schema = RowFieldExtractorSchema(0) - # Create a Java Row with null first field gateway = get_gateway() j_row = gateway.jvm.org.apache.flink.types.Row(2) j_row.setField(0, None) # null field - j_row.setField(1, "value") + j_row.setField(1, "value".encode('utf-8')) + result = schema._j_serialization_schema.serialize(j_row) self.assertEqual(0, len(result)) + def test_serialize_non_byte_array_raises_error(self): + """Test that non-byte-array field raises IllegalArgumentException.""" + schema = RowFieldExtractorSchema(0) + gateway = get_gateway() + j_row = gateway.jvm.org.apache.flink.types.Row(2) + + # set a string instead of byte array + j_row.setField(0, "not-bytes") + j_row.setField(1, "other") + + with self.assertRaises(Exception): + schema._j_serialization_schema.serialize(j_row) + # Should get IllegalArgumentException from Java + def test_negative_field_index_raises_error(self): """Test that negative field index raises ValueError.""" with self.assertRaises(ValueError) as context: @@ -112,34 +129,17 @@ def test_multiple_schemas_with_different_indices(self): schema0 = RowFieldExtractorSchema(0) schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(2) + self.assertEqual(0, schema0._j_serialization_schema.getFieldIndex()) self.assertEqual(1, schema1._j_serialization_schema.getFieldIndex()) self.assertEqual(2, schema2._j_serialization_schema.getFieldIndex()) - def test_serialize_different_field_types(self): - """Test serializing different data types from Row fields.""" - gateway = get_gateway() - # Test with string - schema_str = RowFieldExtractorSchema(0) - j_row_str = gateway.jvm.org.apache.flink.types.Row(2) - j_row_str.setField(0, "hello") - j_row_str.setField(1, 100) - result_str = schema_str._j_serialization_schema.serialize(j_row_str) - self.assertEqual("hello".encode('utf-8'), bytes(result_str)) - # Test with integer - schema_int = RowFieldExtractorSchema(1) - j_row_int = gateway.jvm.org.apache.flink.types.Row(2) - j_row_int.setField(0, "key") - j_row_int.setField(1, 999) - result_int = schema_int._j_serialization_schema.serialize(j_row_int) - self.assertEqual("999".encode('utf-8'), bytes(result_int)) - def test_schema_equals(self): """Test that schemas with same field index are considered equal.""" schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(1) schema3 = RowFieldExtractorSchema(2) - # Test via Java equals method + self.assertTrue(schema1._j_serialization_schema.equals(schema2._j_serialization_schema)) self.assertFalse(schema1._j_serialization_schema.equals(schema3._j_serialization_schema)) @@ -147,6 +147,7 @@ def test_schema_hash_code(self): """Test that schemas with same field index have same hash code.""" schema1 = RowFieldExtractorSchema(1) schema2 = RowFieldExtractorSchema(1) + hash1 = schema1._j_serialization_schema.hashCode() hash2 = schema2._j_serialization_schema.hashCode() self.assertEqual(hash1, hash2) diff --git a/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py b/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py new file mode 100644 index 0000000000000..53a74f3edbb67 --- /dev/null +++ b/flink-python/pyflink/examples/datastream/connectors/kafka_row_field_extractor_example.py @@ -0,0 +1,176 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +""" +Example demonstrating RowFieldExtractorSchema usage with Kafka. + +This example shows how to use RowFieldExtractorSchema to serialize specific +Row fields as Kafka message keys and values. The schema requires fields to +be byte arrays, giving you full control over serialization. + +Requirements: + - Kafka running on localhost:9092 + - Topic 'row-extractor-example' created + - Kafka connector JAR in classpath + +Usage: + python kafka_row_field_extractor_example.py +""" + +import json +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.datastream.connectors.kafka import ( + KafkaSink, + KafkaRecordSerializationSchema, + DeliveryGuarantee +) +from pyflink.common.serialization import RowFieldExtractorSchema +from pyflink.common.typeinfo import Types +from pyflink.common import Row + + +def serialize_to_json_bytes(data): + """Helper function to serialize data to JSON byte array.""" + return json.dumps(data).encode('utf-8') + + +def serialize_key_bytes(key): + """Helper function to serialize a key to byte array.""" + return key.encode('utf-8') + + +def create_sample_data(): + """ + Create sample e-commerce events. + Each event has: user_id (key) and event_data (value). + """ + events = [ + { + "user_id": "user-001", + "event": {"type": "purchase", "item": "laptop", "price": 999.99} + }, + { + "user_id": "user-002", + "event": {"type": "view", "item": "phone", "timestamp": "2024-01-07T10:30:00"} + }, + { + "user_id": "user-001", + "event": {"type": "add_to_cart", "item": "mouse", "quantity": 2} + }, + { + "user_id": "user-003", + "event": {"type": "purchase", "item": "keyboard", "price": 79.99} + }, + ] + + # Convert to Rows with byte array fields + rows = [] + for event in events: + row = Row( + serialize_key_bytes(event["user_id"]), # Field 0: user_id as bytes (Kafka key) + serialize_to_json_bytes(event["event"]) # Field 1: event as JSON bytes (Kafka value) + ) + rows.append(row) + + return rows + + +def kafka_row_field_extractor_example(): + """ + Demonstrate RowFieldExtractorSchema with Kafka Sink. + + This example: + 1. Creates sample e-commerce events + 2. Converts user_id to bytes (for Kafka key) + 3. Converts event data to JSON bytes (for Kafka value) + 4. Uses RowFieldExtractorSchema to extract and send to Kafka + """ + + # Create execution environment + env = StreamExecutionEnvironment.get_execution_environment() + env.set_parallelism(1) + + # Generate sample data + data = create_sample_data() + + # Create DataStream with proper type information + # Row has 2 fields: both are byte arrays + ds = env.from_collection( + data, + type_info=Types.ROW([ + Types.PRIMITIVE_ARRAY(Types.BYTE()), # Field 0: user_id (key) + Types.PRIMITIVE_ARRAY(Types.BYTE()) # Field 1: event_data (value) + ]) + ) + + # Optional: Print what we're sending (for debugging) + def print_event(row): + user_id = row[0].decode('utf-8') + event_data = json.loads(row[1].decode('utf-8')) + print(f"Sending: User={user_id}, Event={event_data}") + return row + + ds = ds.map( + print_event, + output_type=Types.ROW([ + Types.PRIMITIVE_ARRAY(Types.BYTE()), + Types.PRIMITIVE_ARRAY(Types.BYTE()) + ]) + ) + + # Create Kafka Sink with RowFieldExtractorSchema + kafka_sink = KafkaSink.builder() \ + .set_bootstrap_servers("localhost:9092") \ + .set_record_serializer( + KafkaRecordSerializationSchema.builder() + .set_topic("row-extractor-example") + # Extract field 0 (user_id) as Kafka key for partitioning + .set_key_serialization_schema(RowFieldExtractorSchema(0)) + # Extract field 1 (event_data) as Kafka value + .set_value_serialization_schema(RowFieldExtractorSchema(1)) + .build() + ) \ + .set_delivery_guarantee(DeliveryGuarantee.AT_LEAST_ONCE) \ + .build() + + # Send to Kafka + ds.sink_to(kafka_sink) + + # Execute + env.execute("Kafka RowFieldExtractorSchema Example") + + +if __name__ == '__main__': + print("=" * 70) + print("Kafka RowFieldExtractorSchema Example") + print("=" * 70) + print("\nMake sure:") + print(" 1. Kafka is running on localhost:9092") + print(" 2. Topic 'row-extractor-example' exists") + print(" 3. Kafka connector JAR is in classpath") + print("\nTo create topic:") + print(" kafka-topics.sh --create --topic row-extractor-example \\") + print(" --bootstrap-server localhost:9092 --partitions 3") + print("\nTo consume messages:") + print(" kafka-console-consumer.sh --bootstrap-server localhost:9092 \\") + print(" --topic row-extractor-example --from-beginning \\") + print(" --property print.key=true --property key.separator=' => '") + print("\n" + "=" * 70 + "\n") + + # Run the example + kafka_row_field_extractor_example()