Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.serialization;

import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.types.Row;

import javax.annotation.Nullable;

import static org.apache.flink.util.Preconditions.checkArgument;

/**
* Serialization schema that extracts a specific field from a {@link Row} and returns it as a byte
* array.
*
* <p>The field is required to be of type {@code byte[]}. This schema is particularly useful when
* using Flink with Kafka, where you may want to use one Row field as the message key and another as
* the value and perform the conversion to bytes explicitly in user code.
*
* <p>Example usage with Kafka:
*
* <pre>{@code
* KafkaSink<Row> sink = KafkaSink.<Row>builder()
* .setBootstrapServers(bootstrapServers)
* .setRecordSerializer(
* KafkaRecordSerializationSchema.builder()
* .setTopic("my-topic")
* .setKeySerializationSchema(new RowFieldExtractorSchema(0)) // field 0 as key
* .setValueSerializationSchema(new RowFieldExtractorSchema(1)) // field 1 as value
* .build())
* .build();
* }</pre>
*/
@PublicEvolving
public class RowFieldExtractorSchema implements SerializationSchema<Row> {
Copy link
Contributor

@dianfu dianfu Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about moving this class to module flink-python? This class should be more useful for Python users.

Copy link
Contributor Author

@Nflrijal Nflrijal Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought of keeping it in flink-core for a few reasons:

  1. General-purpose functionality: While Python users will benefit from this, Java/Scala users can also use it for Kafka or other sinks that need field extraction.

  2. Consistency: Other serialization schemas like SimpleStringSchema and ByteArraySchema are in flink-core, so this fits the existing pattern.

  3. Python wrapper already exists: The PyFlink wrapper in flink-python/pyflink/common/serialization.py already makes it easily accessible to Python users, and they can import it naturally via from pyflink.common.serialization import RowFieldExtractorSchema.

  4. Separation of concerns: flink-core handles Java serialization logic, flink-python handles Python-Java bridging.

However, I'm happy to move it to flink-python if you think that better serves the community! What do you think @dianfu


private static final long serialVersionUID = 1L;

/** The index of the field to extract from the Row. */
private final int fieldIndex;

/**
* Creates a new RowFieldExtractorSchema that extracts the field at the specified index.
*
* @param fieldIndex the zero-based index of the field to extract
* @throws IllegalArgumentException if fieldIndex is negative
*/
public RowFieldExtractorSchema(int fieldIndex) {
checkArgument(fieldIndex >= 0, "Field index must be non-negative, got: %s", fieldIndex);
this.fieldIndex = fieldIndex;
}

/**
* Gets the field index being extracted.
*
* @return the field index
*/
@VisibleForTesting
public int getFieldIndex() {
return fieldIndex;
}

@Override
public byte[] serialize(@Nullable Row element) {
if (element == null) {
return new byte[0];
}

checkArgument(
fieldIndex < element.getArity(),
"Cannot access field %s in Row with arity %s",
fieldIndex,
element.getArity());

Object field = element.getField(fieldIndex);
if (field == null) {
return new byte[0];
}

if (!(field instanceof byte[])) {
throw new IllegalArgumentException(
String.format(
"Field at index %s must be of type byte[], but was %s",
fieldIndex, field.getClass().getName()));
}

return (byte[]) field;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
RowFieldExtractorSchema that = (RowFieldExtractorSchema) o;
return fieldIndex == that.fieldIndex;
}

@Override
public int hashCode() {
return fieldIndex;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.serialization;

import org.apache.flink.types.Row;
import org.apache.flink.util.InstantiationUtil;

import org.junit.jupiter.api.Test;

import java.io.IOException;
import java.nio.charset.StandardCharsets;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

/** Tests for {@link RowFieldExtractorSchema}. */
class RowFieldExtractorSchemaTest {

@Test
void testSerializeByteArrayField() {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0);
byte[] value = "test-value".getBytes(StandardCharsets.UTF_8);
Row row = Row.of(value, 123);

byte[] result = schema.serialize(row);

assertThat(result).isEqualTo(value);
}

@Test
void testSerializeNonByteArrayFieldThrowsException() {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(1);
Row row = Row.of("key", 42); // field 1 is Integer, not byte[]

assertThatThrownBy(() -> schema.serialize(row))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("must be of type byte[]");
}

@Test
void testSerializeNullRow() {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0);

byte[] result = schema.serialize(null);

assertThat(result).isEmpty();
}

@Test
void testSerializeNullField() {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(0);
Row row = Row.of(null, "value");

byte[] result = schema.serialize(row);

assertThat(result).isEmpty();
}

@Test
void testSerializeOutOfBoundsIndex() {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(5);
Row row = Row.of("field0", "field1");

assertThatThrownBy(() -> schema.serialize(row))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Cannot access field 5 in Row with arity 2");
}

@Test
void testNegativeFieldIndexThrowsException() {
assertThatThrownBy(() -> new RowFieldExtractorSchema(-1))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Field index must be non-negative");
}

@Test
void testSerializability() throws IOException, ClassNotFoundException {
RowFieldExtractorSchema schema = new RowFieldExtractorSchema(3);

RowFieldExtractorSchema deserialized =
InstantiationUtil.deserializeObject(
InstantiationUtil.serializeObject(schema), getClass().getClassLoader());

assertThat(deserialized.getFieldIndex()).isEqualTo(3);
}

@Test
void testEquals() {
RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1);
RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1);
RowFieldExtractorSchema schema3 = new RowFieldExtractorSchema(2);

assertThat(schema1).isEqualTo(schema2);
assertThat(schema1).isNotEqualTo(schema3);
}

@Test
void testHashCode() {
RowFieldExtractorSchema schema1 = new RowFieldExtractorSchema(1);
RowFieldExtractorSchema schema2 = new RowFieldExtractorSchema(1);

assertThat(schema1.hashCode()).isEqualTo(schema2.hashCode());
}
}
55 changes: 54 additions & 1 deletion flink-python/pyflink/common/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
'SimpleStringSchema',
'ByteArraySchema',
'Encoder',
'BulkWriterFactory'
'BulkWriterFactory',
'RowFieldExtractorSchema',
]


Expand All @@ -35,6 +36,7 @@ class SerializationSchema(object):
into a different serialized representation. Most data sinks (for example Apache Kafka) require
the data to be handed to them in a specific format (for example as byte strings).
"""

def __init__(self, j_serialization_schema=None):
self._j_serialization_schema = j_serialization_schema

Expand All @@ -48,6 +50,7 @@ class DeserializationSchema(object):
In addition, the DeserializationSchema describes the produced type which lets Flink create
internal serializers and structures to handle the type.
"""

def __init__(self, j_deserialization_schema=None):
self._j_deserialization_schema = j_deserialization_schema

Expand Down Expand Up @@ -126,3 +129,53 @@ def __init__(self, j_bulk_writer_factory, row_type):

def get_row_type(self):
return self._row_type


class RowFieldExtractorSchema(SerializationSchema):
"""
Serialization schema that extracts a specific field from a Row and returns it as a
byte array. The field at the specified index MUST be of type bytes (byte array).
This schema is particularly useful when using Flink with Kafka, where you may want to use a
specific field as the message key for partition routing.
The field being extracted must already be a byte array. Users are responsible for
converting their data to bytes before passing it to this schema.

Example usage with Kafka:
>>> from pyflink.common.serialization import RowFieldExtractorSchema
>>> from pyflink.datastream.connectors.kafka import KafkaSink, \
KafkaRecordSerializationSchema
>>>
>>> # User must convert data to bytes beforehand
>>> # For example: Row.of(b"key-bytes", b"value-bytes")
>>>
>>> sink = KafkaSink.builder() \\
... .set_bootstrap_servers("localhost:9092") \\
... .set_record_serializer(
... KafkaRecordSerializationSchema.builder()
... .set_topic("my-topic")
... .set_key_serialization_schema(RowFieldExtractorSchema(0))
# Field 0 (must be bytes) as key
... .set_value_serialization_schema(RowFieldExtractorSchema(1))
# Field 1 (must be bytes) as value
... .build()
... ) \\
... .build()

:param field_index: The zero-based index of the field to extract from the Row.
The field at this index must be of type bytes.
"""

def __init__(self, field_index: int):
"""
Creates a new RowFieldExtractorSchema that extracts the field at the specified index.

:param field_index: The zero-based index of the field to extract (must be non-negative).
:raises ValueError: If field_index is negative.
"""
if field_index < 0:
raise ValueError(f"Field index must be non-negative, got: {field_index}")
gateway = get_gateway()
j_row_field_extractor_schema = gateway.jvm.org.apache.flink.api.common.serialization \
.RowFieldExtractorSchema(field_index)
super(RowFieldExtractorSchema, self).__init__(
j_serialization_schema=j_row_field_extractor_schema)
Loading