diff --git a/tests/v2/test_waku_enr.nim b/tests/v2/test_waku_enr.nim index 5da6aab400..86c0144418 100644 --- a/tests/v2/test_waku_enr.nim +++ b/tests/v2/test_waku_enr.nim @@ -1,10 +1,11 @@ {.used.} import - std/options, + std/[options, sequtils], stew/results, testutils/unittests import + ../../waku/v2/protocol/waku_message, ../../waku/v2/protocol/waku_enr, ./testlib/wakucore @@ -251,3 +252,172 @@ suite "Waku ENR - Multiaddresses": multiaddrs.contains(expectedAddr1) multiaddrs.contains(addr2) + +suite "Waku ENR - Relay static sharding": + + test "new relay shards field with single invalid index": + ## Given + let + shardCluster: uint16 = 22 + shardIndex: uint16 = 1024 + + ## When + expect Defect: + discard RelayShards.init(shardCluster, shardIndex) + + test "new relay shards field with single invalid index in list": + ## Given + let + shardCluster: uint16 = 22 + shardIndices: seq[uint16] = @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16, 1024u16] + + ## When + expect Defect: + discard RelayShards.init(shardCluster, shardIndices) + + test "new relay shards field with single valid index": + ## Given + let + shardCluster: uint16 = 22 + shardIndex: uint16 = 1 + + let topic = NsPubsubTopic.staticSharding(shardCluster, shardIndex) + + ## When + let shards = RelayShards.init(shardCluster, shardIndex) + + ## Then + check: + shards.cluster == shardCluster + shards.indices == @[1u16] + + let topics = shards.topics.mapIt($it) + check: + topics == @[$topic] + + check: + shards.contains(shardCluster, shardIndex) + not shards.contains(shardCluster, 33u16) + not shards.contains(20u16, 33u16) + + shards.contains(topic) + shards.contains("/waku/2/rs/22/1") + + test "new relay shards field with repeated but valid indices": + ## Given + let + shardCluster: uint16 = 22 + shardIndices: seq[uint16] = @[1u16, 2u16, 2u16, 3u16, 3u16, 3u16] + + ## When + let shards = RelayShards.init(shardCluster, shardIndices) + + ## Then + check: + shards.cluster == shardCluster + shards.indices == @[1u16, 2u16, 3u16] + + test "cannot decode relay shards from record if not present": + ## Given + let + enrSeqNum = 1u64 + enrPrivKey = generatesecp256k1key() + + let record = EnrBuilder.init(enrPrivKey, enrSeqNum).build().tryGet() + + ## When + let typedRecord = record.toTyped() + require typedRecord.isOk() + + let fieldOpt = typedRecord.value.relaySharding + + ## Then + check fieldOpt.isNone() + + test "encode and decode record with relay shards field (EnrBuilder ext - indices list)": + ## Given + let + enrSeqNum = 1u64 + enrPrivKey = generatesecp256k1key() + + let + shardCluster: uint16 = 22 + shardIndices: seq[uint16] = @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16] + + let shards = RelayShards.init(shardCluster, shardIndices) + + ## When + var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) + require builder.withWakuRelaySharding(shards).isOk() + + let recordRes = builder.build() + + ## Then + check recordRes.isOk() + let record = recordRes.tryGet() + + let typedRecord = record.toTyped() + require typedRecord.isOk() + + let shardsOpt = typedRecord.value.relaySharding + check: + shardsOpt.isSome() + shardsOpt.get() == shards + + test "encode and decode record with relay shards field (EnrBuilder ext - bit vector)": + ## Given + let + enrSeqNum = 1u64 + enrPrivKey = generatesecp256k1key() + + let shards = RelayShards.init(33, toSeq(0u16 ..< 64u16)) + + var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) + require builder.withWakuRelaySharding(shards).isOk() + + let recordRes = builder.build() + require recordRes.isOk() + + let record = recordRes.tryGet() + + ## When + let typedRecord = record.toTyped() + require typedRecord.isOk() + + let shardsOpt = typedRecord.value.relaySharding + + ## Then + check: + shardsOpt.isSome() + shardsOpt.get() == shards + + test "decode record with relay shards indices list and bit vector fields": + ## Given + let + enrSeqNum = 1u64 + enrPrivKey = generatesecp256k1key() + + let + shardsIndicesList = RelayShards.init(22, @[1u16, 1u16, 2u16, 3u16, 5u16, 8u16]) + shardsBitVector = RelayShards.init(33, @[13u16, 24u16, 37u16, 61u16, 98u16, 159u16]) + + + var builder = EnrBuilder.init(enrPrivKey, seqNum = enrSeqNum) + require builder.withWakuRelayShardingIndicesList(shardsIndicesList).isOk() + require builder.withWakuRelayShardingBitVector(shardsBitVector).isOk() + + let recordRes = builder.build() + require recordRes.isOk() + + let record = recordRes.tryGet() + + ## When + let typedRecord = record.toTyped() + require typedRecord.isOk() + + let shardsOpt = typedRecord.value.relaySharding + + ## Then + check: + shardsOpt.isSome() + shardsOpt.get() == shardsIndicesList diff --git a/waku/v2/protocol/waku_discv5.nim b/waku/v2/protocol/waku_discv5.nim index 5d42d2effe..6561f02363 100644 --- a/waku/v2/protocol/waku_discv5.nim +++ b/waku/v2/protocol/waku_discv5.nim @@ -6,9 +6,11 @@ else: import std/[strutils, options], stew/results, + stew/shims/net, chronos, chronicles, metrics, + libp2p/multiaddress, eth/keys, eth/p2p/discoveryv5/enr, eth/p2p/discoveryv5/node, diff --git a/waku/v2/protocol/waku_enr.nim b/waku/v2/protocol/waku_enr.nim index 1b78626d5d..ba97cd0664 100644 --- a/waku/v2/protocol/waku_enr.nim +++ b/waku/v2/protocol/waku_enr.nim @@ -1,197 +1,11 @@ -## Collection of utilities related to Waku's use of EIP-778 ENR -## Implemented according to the specified Waku v2 ENR usage -## More at https://rfc.vac.dev/spec/31/ - -when (NimMajor, NimMinor) < (1, 4): - {.push raises: [Defect].} -else: - {.push raises: [].} - import - std/[options, bitops, sequtils], - stew/[endians2, results], - stew/shims/net, - eth/keys, - libp2p/[multiaddress, multicodec], - libp2p/crypto/crypto -import - ../../common/enr - -export enr, crypto, multiaddress, net - -const - MultiaddrEnrField* = "multiaddrs" - CapabilitiesEnrField* = "waku2" - - -## Node capabilities - -type - ## 8-bit flag field to indicate Waku node capabilities. - ## Only the 4 LSBs are currently defined according - ## to RFC31 (https://rfc.vac.dev/spec/31/). - CapabilitiesBitfield* = distinct uint8 - - ## See: https://rfc.vac.dev/spec/31/#waku2-enr-key - ## each enum numbers maps to a bit (where 0 is the LSB) - Capabilities*{.pure.} = enum - Relay = 0, - Store = 1, - Filter = 2, - Lightpush = 3 - - -func init*(T: type CapabilitiesBitfield, lightpush, filter, store, relay: bool): T = - ## Creates an waku2 ENR flag bit field according to RFC 31 (https://rfc.vac.dev/spec/31/) - var bitfield: uint8 - if relay: bitfield.setBit(0) - if store: bitfield.setBit(1) - if filter: bitfield.setBit(2) - if lightpush: bitfield.setBit(3) - CapabilitiesBitfield(bitfield) - -func init*(T: type CapabilitiesBitfield, caps: varargs[Capabilities]): T = - ## Creates an waku2 ENR flag bit field according to RFC 31 (https://rfc.vac.dev/spec/31/) - var bitfield: uint8 - for cap in caps: - bitfield.setBit(ord(cap)) - CapabilitiesBitfield(bitfield) - -converter toCapabilitiesBitfield*(field: uint8): CapabilitiesBitfield = - CapabilitiesBitfield(field) - -proc supportsCapability*(bitfield: CapabilitiesBitfield, cap: Capabilities): bool = - testBit(bitfield.uint8, ord(cap)) - -func toCapabilities*(bitfield: CapabilitiesBitfield): seq[Capabilities] = - toSeq(Capabilities.low..Capabilities.high).filterIt(supportsCapability(bitfield, it)) - - -# ENR builder extension - -proc withWakuCapabilities*(builder: var EnrBuilder, caps: CapabilitiesBitfield) = - builder.addFieldPair(CapabilitiesEnrField, @[caps.uint8]) - -proc withWakuCapabilities*(builder: var EnrBuilder, caps: varargs[Capabilities]) = - withWakuCapabilities(builder, CapabilitiesBitfield.init(caps)) - -proc withWakuCapabilities*(builder: var EnrBuilder, caps: openArray[Capabilities]) = - withWakuCapabilities(builder, CapabilitiesBitfield.init(@caps)) - - -# ENR record accessors (e.g., Record, TypedRecord, etc.) - -func waku2*(record: TypedRecord): Option[CapabilitiesBitfield] = - let field = record.tryGet(CapabilitiesEnrField, seq[uint8]) - if field.isNone(): - return none(CapabilitiesBitfield) - - some(CapabilitiesBitfield(field.get()[0])) - -proc supportsCapability*(r: Record, cap: Capabilities): bool = - let recordRes = r.toTyped() - if recordRes.isErr(): - return false - - let bitfieldOpt = recordRes.value.waku2 - if bitfieldOpt.isNone(): - return false - - let bitfield = bitfieldOpt.get() - bitfield.supportsCapability(cap) - -proc getCapabilities*(r: Record): seq[Capabilities] = - let recordRes = r.toTyped() - if recordRes.isErr(): - return @[] - - let bitfieldOpt = recordRes.value.waku2 - if bitfieldOpt.isNone(): - return @[] - - let bitfield = bitfieldOpt.get() - bitfield.toCapabilities() - - -## Multiaddress - -func encodeMultiaddrs*(multiaddrs: seq[MultiAddress]): seq[byte] = - var buffer = newSeq[byte]() - for multiaddr in multiaddrs: - - let - raw = multiaddr.data.buffer # binary encoded multiaddr - size = raw.len.uint16.toBytes(Endianness.bigEndian) # size as Big Endian unsigned 16-bit integer - - buffer.add(concat(@size, raw)) - - buffer - -func readBytes(rawBytes: seq[byte], numBytes: int, pos: var int = 0): Result[seq[byte], cstring] = - ## Attempts to read `numBytes` from a sequence, from - ## position `pos`. Returns the requested slice or - ## an error if `rawBytes` boundary is exceeded. - ## - ## If successful, `pos` is advanced by `numBytes` - if rawBytes[pos..^1].len() < numBytes: - return err("insufficient bytes") - - let slicedSeq = rawBytes[pos.. MaxShardIndex: + raise newException(Defect, "invalid index") + + RelayShards(cluster: cluster, indices: @[index]) + +func init*(T: type RelayShards, cluster: uint16, indices: varargs[uint16]): T = + if toSeq(indices).anyIt(it > MaxShardIndex): + raise newException(Defect, "invalid index") + + let indicesSeq = deduplicate(@indices) + if indices.len < 1: + raise newException(Defect, "invalid index count") + + RelayShards(cluster: cluster, indices: indicesSeq) + +func init*(T: type RelayShards, cluster: uint16, indices: seq[uint16]): T = + if indices.anyIt(it > MaxShardIndex): + raise newException(Defect, "invalid index") + + let indicesSeq = deduplicate(indices) + if indices.len < 1: + raise newException(Defect, "invalid index count") + + RelayShards(cluster: cluster, indices: indicesSeq) + + +func contains*(rs: RelayShards, cluster, index: uint16): bool = + rs.cluster == cluster and rs.indices.contains(index) + +func contains*(rs: RelayShards, topic: NsPubsubTopic): bool = + if topic.kind != NsPubsubTopicKind.StaticSharding: + return false + + rs.contains(topic.cluster, topic.shard) + +func contains*(rs: RelayShards, topic: PubsubTopic|string): bool = + let parseRes = NsPubsubTopic.parse(topic) + if parseRes.isErr(): + return false + + rs.contains(parseRes.value) + + +# ENR builder extension + +func toIndicesList(rs: RelayShards): EnrResult[seq[byte]] = + if rs.indices.len > high(uint8).int: + return err("indices list too long") + + var res: seq[byte] + res.add(rs.cluster.toBytesBE()) + + res.add(rs.indices.len.uint8) + for index in rs.indices: + res.add(index.toBytesBE()) + + ok(res) + +func fromIndicesList(buf: seq[byte]): Result[RelayShards, string] = + if buf.len < 3: + return err("insufficient data: expected at least 3 bytes, got " & $buf.len & " bytes") + + let cluster = uint16.fromBytesBE(buf[0..1]) + let length = int(buf[2]) + + if buf.len != 3 + 2 * length: + return err("invalid data: `length` field is " & $length & " but " & $buf.len & " bytes were provided") + + var indices: seq[uint16] + for i in 0..= 64: + builder.withWakuRelayShardingBitVector(rs) + else: + builder.withWakuRelayShardingIndicesList(rs) + + +# ENR record accessors (e.g., Record, TypedRecord, etc.) + +proc relayShardingIndicesList*(record: TypedRecord): Option[RelayShards] = + let field = record.tryGet(ShardingIndicesListEnrField, seq[byte]) + if field.isNone(): + return none(RelayShards) + + let indexList = fromIndicesList(field.get()) + if indexList.isErr(): + debug "invalid sharding indices list", error = indexList.error + return none(RelayShards) + + some(indexList.value) + +proc relayShardingBitVector*(record: TypedRecord): Option[RelayShards] = + let field = record.tryGet(ShardingBitVectorEnrField, seq[byte]) + if field.isNone(): + return none(RelayShards) + + let bitVector = fromBitVector(field.get()) + if bitVector.isErr(): + debug "invalid sharding bit vector", error = bitVector.error + return none(RelayShards) + + some(bitVector.value) + +proc relaySharding*(record: TypedRecord): Option[RelayShards] = + let indexList = record.relayShardingIndicesList() + if indexList.isSome(): + return indexList + + record.relayShardingBitVector() + + +## Utils + +proc containsShard*(r: Record, cluster, index: uint16): bool = + if index > MaxShardIndex: + return false + + let recordRes = r.toTyped() + if recordRes.isErr(): + debug "invalid ENR record", error = recordRes.error + return false + + let rs = recordRes.value.relaySharding() + if rs.isNone(): + return false + + rs.get().contains(cluster, index) + +proc containsShard*(r: Record, topic: NsPubsubTopic): bool = + if topic.kind != NsPubsubTopicKind.StaticSharding: + return false + + containsShard(r, topic.cluster, topic.shard) + +func containsShard*(r: Record, topic: PubsubTopic|string): bool = + let parseRes = NsPubsubTopic.parse(topic) + if parseRes.isErr(): + debug "invalid static sharding topic", topic = topic, error = parseRes.error + return false + + containsShard(r, parseRes.value)