Skip to content

Commit

Permalink
Convert messageId from String to byte[] (see spec PR: libp2p/specs#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
Nashatyrev committed Nov 9, 2020
1 parent cbb56d5 commit 7d1d08e
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 49 deletions.
8 changes: 3 additions & 5 deletions src/main/kotlin/io/libp2p/pubsub/AbstractRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import io.libp2p.etc.types.copy
import io.libp2p.etc.types.createLRUMap
import io.libp2p.etc.types.forward
import io.libp2p.etc.types.lazyVarInit
import io.libp2p.etc.types.toHex
import io.libp2p.etc.util.P2PServiceSemiDuplex
import io.netty.channel.ChannelHandler
import io.netty.handler.codec.protobuf.ProtobufDecoder
Expand All @@ -28,8 +27,7 @@ import java.util.function.BiConsumer
import java.util.function.Consumer

class DefaultPubsubMessage(override val protobufMessage: Rpc.Message) : PubsubMessage {
override val messageId: MessageId =
protobufMessage.from.toByteArray().toHex() + protobufMessage.seqno.toByteArray().toHex()
override val messageId: MessageId = protobufMessage.from.toByteArray() + protobufMessage.seqno.toByteArray()

override fun equals(other: Any?) = protobufMessage == (other as? PubsubMessage)?.protobufMessage
override fun hashCode() = protobufMessage.hashCode()
Expand All @@ -48,8 +46,8 @@ abstract class AbstractRouter : P2PServiceSemiDuplex(), PubsubRouter, PubsubRout

override var messageFactory: PubsubMessageFactory = { DefaultPubsubMessage(it) }
var maxSeenMessagesLimit = 10000
protected open val seenMessages by lazy {
createLRUMap<PubsubMessage, Optional<ValidationResult>>(maxSeenMessagesLimit)
protected open val seenMessages: MutableMap<PubsubMessage, Optional<ValidationResult>> by lazy {
createLRUMap(maxSeenMessagesLimit)
}

private val peerTopics = MultiSet<PeerHandler, String>()
Expand Down
2 changes: 1 addition & 1 deletion src/main/kotlin/io/libp2p/pubsub/PubsubRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import java.util.concurrent.CompletableFuture
import java.util.concurrent.ScheduledExecutorService

typealias Topic = String
typealias MessageId = String
typealias MessageId = ByteArray
typealias PubsubMessageFactory = (Rpc.Message) -> PubsubMessage

interface PubsubMessage {
Expand Down
10 changes: 5 additions & 5 deletions src/main/kotlin/io/libp2p/pubsub/gossip/GossipRouter.kt
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ open class GossipRouter @JvmOverloads constructor(
TimeUnit.MILLISECONDS
)
}
override val seenMessages by lazy {
override val seenMessages: MutableMap<PubsubMessage, Optional<ValidationResult>> by lazy {
val t: Ticker = object : Ticker() {
// Ticker operates with nanos and handles overflows correctly
override fun read() = curTimeMillis() * 1_000_000
Expand Down Expand Up @@ -224,7 +224,7 @@ open class GossipRouter @JvmOverloads constructor(
return
}

val iWant = msg.messageIDsList - seenMessages.keys.map { it.messageId }
val iWant = msg.messageIDsList.map { it.toByteArray()} - seenMessages.keys.map { it.messageId }
val maxToAsk = min(iWant.size, params.maxIHaveLength - asked.get())
asked.addAndGet(maxToAsk)
iWant(peer, iWant.shuffled(random).subList(0, maxToAsk))
Expand All @@ -234,7 +234,7 @@ open class GossipRouter @JvmOverloads constructor(
val peerScore = score.score(peer)
if (peerScore < score.params.gossipThreshold) return
msg.messageIDsList
.mapNotNull { mCache.getMessageForPeer(peer.peerId, it) }
.mapNotNull { mCache.getMessageForPeer(peer.peerId, it.toByteArray()) }
.filter { it.sentCount < params.gossipRetransmission }
.map { it.msg }
.forEach { submitPublishMessage(peer, it) }
Expand Down Expand Up @@ -493,7 +493,7 @@ open class GossipRouter @JvmOverloads constructor(
peer,
Rpc.RPC.newBuilder().setControl(
Rpc.ControlMessage.newBuilder().addIwant(
Rpc.ControlIWant.newBuilder().addAllMessageIDs(messageIds)
Rpc.ControlIWant.newBuilder().addAllMessageIDs(messageIds.map { it.toProtobuf() })
)
).build()
)
Expand All @@ -505,7 +505,7 @@ open class GossipRouter @JvmOverloads constructor(
peer,
Rpc.RPC.newBuilder().setControl(
Rpc.ControlMessage.newBuilder().addIhave(
Rpc.ControlIHave.newBuilder().addAllMessageIDs(messageIds)
Rpc.ControlIHave.newBuilder().addAllMessageIDs(messageIds.map { it.toProtobuf() })
)
).build()
)
Expand Down
4 changes: 2 additions & 2 deletions src/main/proto/rpc.proto
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ message ControlMessage {

message ControlIHave {
optional string topicID = 1;
repeated string messageIDs = 2;
repeated bytes messageIDs = 2;
}

message ControlIWant {
repeated string messageIDs = 1;
repeated bytes messageIDs = 1;
}

message ControlGraft {
Expand Down
149 changes: 126 additions & 23 deletions src/test/java/io/libp2p/pubsub/GossipApiTest.java
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
package io.libp2p.pubsub;

import com.google.protobuf.ByteString;
import io.libp2p.core.pubsub.ValidationResult;
import io.libp2p.etc.util.P2PService;
import io.libp2p.pubsub.gossip.GossipParams;
import io.libp2p.pubsub.gossip.GossipParamsKt;
import io.libp2p.pubsub.gossip.GossipRouter;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import pubsub.pb.Rpc;

import java.nio.charset.StandardCharsets;
import java.util.AbstractMap;
import java.util.HashSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static io.libp2p.tools.StubsKt.peerHandlerStub;
import static org.assertj.core.api.Assertions.assertThat;

public class GossipApiTest {

@Test
Expand All @@ -38,16 +51,68 @@ public void createGossipTest() {
// });
//
// Assertions.assertEquals("Hey!", router.(Rpc.Message.getDefaultInstance()));
Assertions.assertEquals(10, router.getParams().getD());
Assertions.assertEquals(20, router.getParams().getDHigh());
Assertions.assertEquals(GossipParamsKt.defaultDScore(10), router.getParams().getDScore());
assertThat(router.getParams().getD()).isEqualTo(10);
assertThat(router.getParams().getDHigh()).isEqualTo(20);
assertThat(router.getParams().getDScore()).isEqualTo(GossipParamsKt.defaultDScore(10));
}

@Test
public void testFastMessageId() throws Exception {
GossipRouter router = new GossipRouter() {
private final MessageMap<Optional<ValidationResult>> seenMessages = new MessageMap<>();

@NotNull
@Override
protected Map<PubsubMessage, Optional<ValidationResult>> getSeenMessages() {
return seenMessages;
}
};
List<CustomMessage> createdMessages = new ArrayList<>();
router.setMessageFactory(m -> {
CustomMessage message = new CustomMessage(m);
createdMessages.add(message);
return message;
});
router.subscribe("topic");

BlockingQueue<PubsubMessage> messages = new LinkedBlockingQueue<>();
router.initHandler(m -> {
messages.add(m);
return CompletableFuture.completedFuture(ValidationResult.Valid);
});

P2PService.PeerHandler peerHandler = peerHandlerStub(router);

router.onInbound(peerHandler, newMessage("Hello-1"));
CustomMessage message1 = (CustomMessage) messages.poll(1, TimeUnit.SECONDS);

assertThat(message1).isNotNull();
assertThat(message1.canonicalId).isNotNull();
assertThat(createdMessages.size()).isEqualTo(1);
createdMessages.clear();

router.onInbound(peerHandler, newMessage("Hello-1"));
CustomMessage message2 = (CustomMessage) messages.poll(100, TimeUnit.MILLISECONDS);

assertThat(message2).isNull();
assertThat(createdMessages.size()).isEqualTo(1);
// assert that 'slow' canonicalId was not calculated and the message was filtered as seen by fastId
assertThat(createdMessages.get(0).canonicalId).isNull();
createdMessages.clear();
}

Rpc.RPC newMessage(String msg) {
return Rpc.RPC.newBuilder().addPublish(
Rpc.Message.newBuilder()
.addTopicIDs("topic")
.setData(ByteString.copyFrom("Hello-1", StandardCharsets.US_ASCII))
).build();
}

class CustomMessage implements PubsubMessage {
static class CustomMessage implements PubsubMessage {
final Rpc.Message message;
Function<Rpc.Message, Object> fastIdCalculator;
Function<Rpc.Message, String> canonicalIdCalculator;
String canonicalId = null;
Function<Rpc.Message, byte[]> canonicalIdCalculator = m -> ("canon-" + m.getData().toString()).getBytes();
byte[] canonicalId = null;

public CustomMessage(Rpc.Message message) {
this.message = message;
Expand All @@ -65,7 +130,7 @@ public Object fastMessageId() {

@NotNull
@Override
public String getMessageId() {
public byte[] getMessageId() {
if (canonicalId == null) {
canonicalId = canonicalIdCalculator.apply(getProtobufMessage());
}
Expand All @@ -74,36 +139,74 @@ public String getMessageId() {

@Override
public boolean equals(Object o) {
throw new UnsupportedOperationException();
// if (this == o) return true;
// if (o == null || getClass() != o.getClass()) return false;
// CustomMessage that = (CustomMessage) o;
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
CustomMessage that = (CustomMessage) o;
return message.equals(that.message);
}

@Override
public int hashCode() {
throw new UnsupportedOperationException();
// return Objects.hash(message, fastIdCalculator, canonicalIdCalculator, canonicalId);
return message.hashCode();
}
}

class MessageMap<V> extends AbstractMap<CustomMessage, V> {
Map<Object, String> fastToCanonicalId;
Map<String, Entry<CustomMessage, V>> canonicalIdToMsg;
static class MessageMap<V> extends AbstractMap<PubsubMessage, V> {
Map<Object, String> fastToCanonicalId = new HashMap<>();
Map<String, Entry<PubsubMessage, V>> canonicalIdToMsg = new HashMap<>();

@NotNull
@Override
public Set<Entry<CustomMessage, V>> entrySet() {
return new HashSet<>(canonicalIdToMsg.values());
public Set<Entry<PubsubMessage, V>> entrySet() {
return Set.copyOf(canonicalIdToMsg.values());
}

@Override
public V get(Object key) {
if (key instanceof CustomMessage) {
return get((CustomMessage) key);
} else {
throw new IllegalArgumentException();
}
}

public V get(CustomMessage key) {
String canonicalId = fastToCanonicalId.get(key.fastMessageId());
Entry<PubsubMessage, V> entry = canonicalIdToMsg.get(canonicalId != null ? canonicalId : key.getMessageId());
return entry == null ? null : entry.getValue();
}

@Override
public V put(PubsubMessage key, V value) {
if (key instanceof CustomMessage) {
return put((CustomMessage) key, value);
} else {
throw new IllegalArgumentException();
}
}
public V put(CustomMessage key, V value) {
fastToCanonicalId.put(key.fastMessageId(), key.getMessageId());
Entry<CustomMessage, V> oldVal = canonicalIdToMsg.put(key.getMessageId(), new SimpleEntry<>(key, value));
fastToCanonicalId.put(key.fastMessageId(), new String(key.getMessageId()));
Entry<PubsubMessage, V> oldVal =
canonicalIdToMsg.put(new String(key.getMessageId()), new SimpleEntry<>(key, value));
return oldVal == null ? null : oldVal.getValue();
}

@Override
public V remove(Object key) {
if (key instanceof CustomMessage) {
return remove((CustomMessage) key);
} else {
throw new IllegalArgumentException();
}
}

public V remove(CustomMessage key) {
String canonicalId = fastToCanonicalId.remove(key.fastMessageId());
Entry<PubsubMessage, V> entry =
canonicalIdToMsg.remove(canonicalId != null ? canonicalId : key.getMessageId());
return entry == null ? null : entry.getValue();
}

public boolean contains(CustomMessage msg) {
if (fastToCanonicalId.containsKey(msg.fastMessageId())) {
return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.libp2p.pubsub.gossip

import io.libp2p.etc.types.seconds
import io.libp2p.etc.types.toProtobuf
import io.libp2p.pubsub.DeterministicFuzz
import io.libp2p.pubsub.MockRouter
import io.libp2p.pubsub.PubsubRouterTest
Expand Down Expand Up @@ -123,7 +124,7 @@ class GossipPubsubRouterTest : PubsubRouterTest({
val msg1 = Rpc.RPC.newBuilder()
.setControl(
Rpc.ControlMessage.newBuilder().addIhave(
Rpc.ControlIHave.newBuilder().addMessageIDs("messageId")
Rpc.ControlIHave.newBuilder().addMessageIDs("messageId".toByteArray().toProtobuf())
)
).build()

Expand Down
Loading

0 comments on commit 7d1d08e

Please sign in to comment.