Skip to content

Commit

Permalink
Merge pull request #50 from tilakraj94/serializeenum
Browse files Browse the repository at this point in the history
Add Support for NATS Message Serialization/De-Serialization
  • Loading branch information
scottf authored Jan 29, 2025
2 parents f20aae7 + 8621e65 commit f4ece06
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@Internal
public class NatsSubjectCheckpointSerializer implements SimpleVersionedSerializer<Collection<NatsSubjectSplit>> {

public static final int CURRENT_VERSION = 1;
public static final int CURRENT_VERSION = 2;

@Override
public int getVersion() {
Expand All @@ -35,22 +35,33 @@ public byte[] serialize(Collection<NatsSubjectSplit> splits) throws IOException
final DataOutputSerializer out = new DataOutputSerializer(startSize);
out.writeInt(splits.size());
for (NatsSubjectSplit split : splits) {
NatsSubjectSplitSerializer.serializeV1(out, split);
NatsSubjectSplitSerializer.serializeV2(out, split);
}
return out.getCopyOfBuffer();
}

@Override
public Collection<NatsSubjectSplit> deserialize(int version, byte[] serialized) throws IOException {
if (version != CURRENT_VERSION) {
throw new IOException("Unrecognized version: " + version);
}
final DataInputDeserializer in = new DataInputDeserializer(serialized);
final int num = in.readInt();
final ArrayList<NatsSubjectSplit> result = new ArrayList<>(num);

if (version > 2 ) {
throw new IOException("Unrecognized version or corrupt state: " + version);
}

if (version == 1) {
for (int x = 0; x < num; x++) {
result.add(NatsSubjectSplitSerializer.deserializeV1(in));
}

return result;
}

for (int x = 0; x < num; x++) {
result.add(NatsSubjectSplitSerializer.deserializeV1(in));
result.add(NatsSubjectSplitSerializer.deserializeV2(in));
}

return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package io.synadia.flink.v0.source.split;

import io.nats.client.Message;
import io.nats.client.impl.Headers;
import io.nats.client.impl.NatsMessage;
import org.apache.flink.annotation.Internal;
import org.apache.flink.core.io.SimpleVersionedSerializer;
import org.apache.flink.core.memory.DataInputDeserializer;
Expand All @@ -11,6 +14,9 @@
import org.apache.flink.core.memory.DataOutputView;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;

/**
* Serializes and deserializes the {@link NatsSubjectSplit}. This class needs to handle
Expand All @@ -19,7 +25,7 @@
@Internal
public class NatsSubjectSplitSerializer implements SimpleVersionedSerializer<NatsSubjectSplit> {

public static final int CURRENT_VERSION = 1;
public static final int CURRENT_VERSION = 2;

@Override
public int getVersion() {
Expand All @@ -30,24 +36,167 @@ public int getVersion() {
public byte[] serialize(NatsSubjectSplit split) throws IOException {
final DataOutputSerializer out =
new DataOutputSerializer(split.splitId().length());
serializeV1(out, split);
serializeV2(out, split);
return out.getCopyOfBuffer();
}

public static void serializeV1(DataOutputView out, NatsSubjectSplit split) throws IOException {
out.writeUTF(split.splitId());
}

public static void serializeV2(DataOutputView out, NatsSubjectSplit split) throws IOException {
if (split.splitId() == null) {
throw new IOException("Split ID cannot be null");
}

out.writeUTF(split.splitId());
out.writeInt(split.getCurrentMessages().size());
for (Message message : split.getCurrentMessages()) {
serializeNatsMessage(out, message);
}
}

@Override
public NatsSubjectSplit deserialize(int version, byte[] serialized) throws IOException {
if (version != CURRENT_VERSION) {
throw new IOException("Unrecognized version: " + version);
}
final DataInputDeserializer in = new DataInputDeserializer(serialized);
return deserializeV1(in);

// check version
// handle older versions
if (version == 1) {
return deserializeV1(in);
} else if (version == 2) {
return deserializeV2(in);
} else {
throw new IOException("Unrecognized version or corrupted state: " + version);
}
}

static NatsSubjectSplit deserializeV1(DataInputView in) throws IOException {
return new NatsSubjectSplit(in.readUTF());
}

static NatsSubjectSplit deserializeV2(DataInputView in) throws IOException {
String subject = in.readUTF();
List<Message> messages = new ArrayList<>();
int numOfMessages = in.readInt();
for (int i = 0; i < numOfMessages; i++) {
messages.add(deserializeNatsMessage(in));
}

return new NatsSubjectSplit(subject, messages);
}

// Deserialize individual NATS Message
private static Message deserializeNatsMessage(DataInputView in) throws IOException {
// Deserialize subject
String subject = in.readBoolean() ? in.readUTF() : null;

Headers headers = in.readBoolean()? new Headers() : null;
if (headers != null) {
deserializeHeaders(in, headers);
}

// Deserialize replyTo
String replyTo = in.readBoolean() ? in.readUTF() : null;

// Deserialize data
int dataLength = in.readInt();
byte[] data = null;
if (dataLength != -1) {
data = new byte[dataLength];
in.readFully(data);
}

NatsMessage.Builder builder = NatsMessage.builder();
builder.subject(subject);

if (data != null) {
builder.data(data);
}

if (replyTo != null) {
builder.replyTo(replyTo);
}

if (headers != null) {
builder.headers(headers);
}

return builder.build();
}

// Serialize individual NATS Message
private static void serializeNatsMessage(DataOutputView out, Message message) throws IOException {
// Serialize subject
if (message.getSubject() == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeUTF(message.getSubject());
}

// serialize headers
if (message.getHeaders() == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
serializeHeaders(out, message);
}

// Serialize replyTo
if (message.getReplyTo() == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeUTF(message.getReplyTo());
}

// Serialize data (payload)
if (message.getData() == null) {
out.writeInt(-1);
} else {
out.writeInt(message.getData().length);
out.write(message.getData());
}
}

private static void serializeHeaders(DataOutputView out, Message message) throws IOException {
Set<String> keys = message.getHeaders().keySet();
out.writeInt(keys.size());

// serialize headers
for (String key : keys) {
out.writeUTF(key);

// serialize header value
List<String> values = message.getHeaders().get(key);
out.writeInt(values.size());

for (String value : values) {
out.writeUTF(value);
}
}
}

private static void deserializeHeaders(DataInputView in, Headers headers) throws IOException {
// Deserialize headers
int numOfKeys = in.readInt();

for (int i = 0; i < numOfKeys; i++) {
String key = in.readUTF();
List<String> values = new ArrayList<>();

int numOfValues = in.readInt();
for (int j = 0; j < numOfValues; j++) {

String value = in.readUTF();
values.add(value);
}

// add back
headers.add(key, values);
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package io.synadia.io.synadia.flink.v0;

import io.nats.client.Message;
import io.nats.client.impl.Headers;
import io.nats.client.impl.NatsMessage;
import io.synadia.flink.v0.enumerator.NatsSourceEnumeratorStateSerializer;
import io.synadia.flink.v0.enumerator.NatsSubjectSourceEnumeratorState;
Expand All @@ -16,34 +18,119 @@
import io.synadia.io.synadia.flink.WordCount;
import io.synadia.io.synadia.flink.WordCountDeserializer;
import io.synadia.io.synadia.flink.WordCountSerializer;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.*;
import java.util.stream.Stream;

import static io.synadia.flink.v0.source.split.NatsSubjectSplitSerializer.CURRENT_VERSION;
import static org.junit.jupiter.api.Assertions.*;

public class SerializersDeserializersTests extends TestBase {

@Test
public void testSourceSideSerialization() throws Exception {
@DisplayName("Test Serialization for subject")
@ParameterizedTest(name = "{2} | Subjects: {1}")
@MethodSource("provideSplitTestData")
void testSourceSideSerialization(int version, List<NatsSubjectSplit> splits, String description) throws Exception {
NatsSubjectSplitSerializer splitSerializer = new NatsSubjectSplitSerializer();
NatsSubjectCheckpointSerializer checkpointSerializer = new NatsSubjectCheckpointSerializer();

for (NatsSubjectSplit split : splits) {
byte[] serialized = splitSerializer.serialize(split);
NatsSubjectSplit deserializedSplit = splitSerializer.deserialize(version, serialized);
assertEquals(split.splitId(), deserializedSplit.splitId());

if (version == CURRENT_VERSION) {
for (int i = 0; i < split.getCurrentMessages().size(); i++) {
Message expectedMessage = split.getCurrentMessages().get(i);
Message actualMessage = deserializedSplit.getCurrentMessages().get(i);

assertEquals(expectedMessage.getSubject(), actualMessage.getSubject());
assertArrayEquals(expectedMessage.getData(), actualMessage.getData());

if (expectedMessage.getReplyTo() == null) {
assertNull(actualMessage.getReplyTo());
} else {
assertEquals(expectedMessage.getReplyTo(), actualMessage.getReplyTo());
}

if (expectedMessage.getHeaders() == null) {
assertNull(actualMessage.getHeaders());
} else {
assertEquals(expectedMessage.getHeaders().get("key1"), actualMessage.getHeaders().get("key1"));
assertEquals(expectedMessage.getHeaders().get("key2"), actualMessage.getHeaders().get("key2"));
}
}
}
}

byte[] serializedCheckpoint = checkpointSerializer.serialize(splits);
Collection<NatsSubjectSplit> deserializedCheckpoint = checkpointSerializer.deserialize(version, serializedCheckpoint);

assertEquals(splits, deserializedCheckpoint, "Checkpoint serialization failed");
}

private static Stream<Arguments> provideSplitTestData() {
return Stream.of(
// Standard cases with headers and replyTo
Arguments.of(CURRENT_VERSION, generateSplits(List.of("three", "four", "five"), false, false),
String.format("Version %d | Three splits", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("six", "seven", "eight", "nine"), false, false),
String.format("Version %d | Four splits", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("ten"), false, false),
String.format("Version %d | Single split", CURRENT_VERSION)),

// Cases without headers
Arguments.of(CURRENT_VERSION, generateSplits(List.of("three", "four", "five"), true, false),
String.format("Version %d | Three splits without headers", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("six", "seven", "eight", "nine"), true, false),
String.format("Version %d | Four splits without headers", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("ten"), true, false),
String.format("Version %d | Single split without headers", CURRENT_VERSION)),

// Cases without replyTo
Arguments.of(CURRENT_VERSION, generateSplits(List.of("three", "four", "five"), false, true),
String.format("Version %d | Three splits without replyTo", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("six", "seven", "eight", "nine"), false, true),
String.format("Version %d | Four splits without replyTo", CURRENT_VERSION)),
Arguments.of(CURRENT_VERSION, generateSplits(List.of("ten"), false, true),
String.format("Version %d | Single split without replyTo", CURRENT_VERSION))
);
}

private static List<NatsSubjectSplit> generateSplits(List<String> subjects, boolean headersNull, boolean replyToNull) {
List<NatsSubjectSplit> splits = new ArrayList<>();
String[] subjects = new String[]{"one", "two", "three", "four", "five"};
for (String subject : subjects) {
NatsSubjectSplit split = new NatsSubjectSplit(subject);
byte[] serialized = splitSerializer.serialize(split);
NatsSubjectSplit de = splitSerializer.deserialize(NatsSubjectSplitSerializer.CURRENT_VERSION, serialized);
assertEquals(subject, de.splitId());
splits.add(split);
List<Message> messages = generateMessages(subject, headersNull, replyToNull);
splits.add(new NatsSubjectSplit(subject, messages));
}
return splits;
}

private static List<Message> generateMessages(String subject, boolean headersNull, boolean replyToNull) {
List<Message> messages = new ArrayList<>();

NatsMessage.Builder builder = new NatsMessage.Builder();
if (!headersNull) {
Headers headers = new Headers();
headers.put("key1", "value1");
headers.put("key2", "value2");

builder.headers(headers);
}

if (!replyToNull) {
builder.replyTo("_inbox." + subject);
}

byte[] serialized = checkpointSerializer.serialize(splits);
Collection<NatsSubjectSplit> deserialized = checkpointSerializer.deserialize(NatsSubjectSplitSerializer.CURRENT_VERSION, serialized);
assertEquals(splits, deserialized);
messages.add(builder.subject(subject).data(subject.getBytes()).build());
return messages;
}

@Test
Expand Down

0 comments on commit f4ece06

Please sign in to comment.