-
Notifications
You must be signed in to change notification settings - Fork 54
/
sharding.nim
234 lines (164 loc) · 6.28 KB
/
sharding.nim
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
when (NimMajor, NimMinor) < (1, 4):
{.push raises: [Defect].}
else:
{.push raises: [].}
import
std/[options, bitops, sequtils],
stew/[endians2, results],
stew/shims/net,
chronicles,
eth/keys,
libp2p/[multiaddress, multicodec],
libp2p/crypto/crypto
import
../../common/enr,
../waku_core
logScope:
topics = "waku enr sharding"
const MaxShardIndex: uint16 = 1023
const
ShardingIndicesListEnrField* = "rs"
ShardingBitVectorEnrField* = "rsv"
type
RelayShards* = object
cluster: uint16
indices: seq[uint16]
func cluster*(rs: RelayShards): uint16 =
rs.cluster
func indices*(rs: RelayShards): seq[uint16] =
rs.indices
func topics*(rs: RelayShards): seq[NsPubsubTopic] =
rs.indices.mapIt(NsPubsubTopic.staticSharding(rs.cluster, it))
func init*(T: type RelayShards, cluster, index: uint16): T =
if index > MaxShardIndex:
raise newException(Defect, "invalid index")
RelayShards(cluster: cluster, indices: @[index])
func init*(T: type RelayShards, cluster: uint16, indices: varargs[uint16]): T =
if toSeq(indices).anyIt(it > MaxShardIndex):
raise newException(Defect, "invalid index")
let indicesSeq = deduplicate(@indices)
if indices.len < 1:
raise newException(Defect, "invalid index count")
RelayShards(cluster: cluster, indices: indicesSeq)
func init*(T: type RelayShards, cluster: uint16, indices: seq[uint16]): T =
if indices.anyIt(it > MaxShardIndex):
raise newException(Defect, "invalid index")
let indicesSeq = deduplicate(indices)
if indices.len < 1:
raise newException(Defect, "invalid index count")
RelayShards(cluster: cluster, indices: indicesSeq)
func contains*(rs: RelayShards, cluster, index: uint16): bool =
rs.cluster == cluster and rs.indices.contains(index)
func contains*(rs: RelayShards, topic: NsPubsubTopic): bool =
if topic.kind != NsPubsubTopicKind.StaticSharding:
return false
rs.contains(topic.cluster, topic.shard)
func contains*(rs: RelayShards, topic: PubsubTopic|string): bool =
let parseRes = NsPubsubTopic.parse(topic)
if parseRes.isErr():
return false
rs.contains(parseRes.value)
# ENR builder extension
func toIndicesList(rs: RelayShards): EnrResult[seq[byte]] =
if rs.indices.len > high(uint8).int:
return err("indices list too long")
var res: seq[byte]
res.add(rs.cluster.toBytesBE())
res.add(rs.indices.len.uint8)
for index in rs.indices:
res.add(index.toBytesBE())
ok(res)
func fromIndicesList(buf: seq[byte]): Result[RelayShards, string] =
if buf.len < 3:
return err("insufficient data: expected at least 3 bytes, got " & $buf.len & " bytes")
let cluster = uint16.fromBytesBE(buf[0..1])
let length = int(buf[2])
if buf.len != 3 + 2 * length:
return err("invalid data: `length` field is " & $length & " but " & $buf.len & " bytes were provided")
var indices: seq[uint16]
for i in 0..<length:
indices.add(uint16.fromBytesBE(buf[3 + 2*i ..< 5 + 2*i]))
ok(RelayShards(cluster: cluster, indices: indices))
func toBitVector(rs: RelayShards): seq[byte] =
## The value is comprised of a two-byte shard cluster index in network byte
## order concatenated with a 128-byte wide bit vector. The bit vector
## indicates which shards of the respective shard cluster the node is part
## of. The right-most bit in the bit vector represents shard 0, the left-most
## bit represents shard 1023.
var res: seq[byte]
res.add(rs.cluster.toBytesBE())
var vec = newSeq[byte](128)
for index in rs.indices:
vec[index div 8].setBit(index mod 8)
res.add(vec)
res
func fromBitVector(buf: seq[byte]): EnrResult[RelayShards] =
if buf.len != 130:
return err("invalid data: expected 130 bytes")
let cluster = uint16.fromBytesBE(buf[0..1])
var indices: seq[uint16]
for i in 0u16..<128u16:
for j in 0u16..<8u16:
if not buf[2 + i].testBit(j):
continue
indices.add(j + 8 * i)
ok(RelayShards(cluster: cluster, indices: indices))
func withWakuRelayShardingIndicesList*(builder: var EnrBuilder, rs: RelayShards): EnrResult[void] =
let value = ? rs.toIndicesList()
builder.addFieldPair(ShardingIndicesListEnrField, value)
ok()
func withWakuRelayShardingBitVector*(builder: var EnrBuilder, rs: RelayShards): EnrResult[void] =
let value = rs.toBitVector()
builder.addFieldPair(ShardingBitVectorEnrField, value)
ok()
func withWakuRelaySharding*(builder: var EnrBuilder, rs: RelayShards): EnrResult[void] =
if rs.indices.len >= 64:
builder.withWakuRelayShardingBitVector(rs)
else:
builder.withWakuRelayShardingIndicesList(rs)
# ENR record accessors (e.g., Record, TypedRecord, etc.)
proc relayShardingIndicesList*(record: TypedRecord): Option[RelayShards] =
let field = record.tryGet(ShardingIndicesListEnrField, seq[byte])
if field.isNone():
return none(RelayShards)
let indexList = fromIndicesList(field.get())
if indexList.isErr():
debug "invalid sharding indices list", error = indexList.error
return none(RelayShards)
some(indexList.value)
proc relayShardingBitVector*(record: TypedRecord): Option[RelayShards] =
let field = record.tryGet(ShardingBitVectorEnrField, seq[byte])
if field.isNone():
return none(RelayShards)
let bitVector = fromBitVector(field.get())
if bitVector.isErr():
debug "invalid sharding bit vector", error = bitVector.error
return none(RelayShards)
some(bitVector.value)
proc relaySharding*(record: TypedRecord): Option[RelayShards] =
let indexList = record.relayShardingIndicesList()
if indexList.isSome():
return indexList
record.relayShardingBitVector()
## Utils
proc containsShard*(r: Record, cluster, index: uint16): bool =
if index > MaxShardIndex:
return false
let recordRes = r.toTyped()
if recordRes.isErr():
debug "invalid ENR record", error = recordRes.error
return false
let rs = recordRes.value.relaySharding()
if rs.isNone():
return false
rs.get().contains(cluster, index)
proc containsShard*(r: Record, topic: NsPubsubTopic): bool =
if topic.kind != NsPubsubTopicKind.StaticSharding:
return false
containsShard(r, topic.cluster, topic.shard)
func containsShard*(r: Record, topic: PubsubTopic|string): bool =
let parseRes = NsPubsubTopic.parse(topic)
if parseRes.isErr():
debug "invalid static sharding topic", topic = topic, error = parseRes.error
return false
containsShard(r, parseRes.value)