diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index cadbbec31b..bef04f1e63 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -117,6 +117,11 @@ eclair { broadcast-interval = 60 seconds // see BOLT #7 init-timeout = 5 minutes + sync { + request-node-announcements = true // if true we will ask for node announcements when we receive channel ids that we don't know + encoding-type = zlib // encoding for short_channel_ids and timestamps in query channel sync messages; other possible value is "uncompressed" + } + // the values below will be used to perform route searching path-finding { max-route-length = 6 // max route length for the 'first pass', if none is found then a second pass is made with no limit diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 45a8d40b00..5b13c67c64 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -31,7 +31,7 @@ import fr.acinq.eclair.crypto.KeyManager import fr.acinq.eclair.db._ import fr.acinq.eclair.router.RouterConf import fr.acinq.eclair.tor.Socks5ProxyParams -import fr.acinq.eclair.wire.{Color, NodeAddress} +import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import scodec.bits.ByteVector import scala.collection.JavaConversions._ @@ -77,7 +77,6 @@ case class NodeParams(keyManager: KeyManager, routerConf: RouterConf, socksProxy_opt: Option[Socks5ProxyParams], maxPaymentAttempts: Int) { - val privateKey = keyManager.nodeKey.privateKey val nodeId = keyManager.nodeId } @@ -187,6 +186,16 @@ object NodeParams { claimMainBlockTarget = config.getInt("on-chain-fees.target-blocks.claim-main") ) + val feeBase = MilliSatoshi(config.getInt("fee-base-msat")) + // fee base is in msat but is encoded on 32 bits and not 64 in the BOLTs, which is why it has + // to be below 0x100000000 msat which is about 42 mbtc + require(feeBase <= MilliSatoshi(0xFFFFFFFFL), "fee-base-msat must be below 42 mbtc") + + val routerSyncEncodingType = config.getString("router.sync.encoding-type") match { + case "uncompressed" => EncodingType.UNCOMPRESSED + case "zlib" => EncodingType.COMPRESSED_ZLIB + } + NodeParams( keyManager = keyManager, alias = nodeAlias, @@ -210,7 +219,7 @@ object NodeParams { toRemoteDelayBlocks = CltvExpiryDelta(config.getInt("to-remote-delay-blocks")), maxToLocalDelayBlocks = CltvExpiryDelta(config.getInt("max-to-local-delay-blocks")), minDepthBlocks = config.getInt("mindepth-blocks"), - feeBase = MilliSatoshi(config.getInt("fee-base-msat")), + feeBase = feeBase, feeProportionalMillionth = config.getInt("fee-proportional-millionths"), reserveToFundingRatio = config.getDouble("reserve-to-funding-ratio"), maxReserveToFundingRatio = config.getDouble("max-reserve-to-funding-ratio"), @@ -231,6 +240,8 @@ object NodeParams { channelExcludeDuration = FiniteDuration(config.getDuration("router.channel-exclude-duration").getSeconds, TimeUnit.SECONDS), routerBroadcastInterval = FiniteDuration(config.getDuration("router.broadcast-interval").getSeconds, TimeUnit.SECONDS), randomizeRouteSelection = config.getBoolean("router.randomize-route-selection"), + requestNodeAnnouncements = config.getBoolean("router.sync.request-node-announcements"), + encodingType = routerSyncEncodingType, searchMaxRouteLength = config.getInt("router.path-finding.max-route-length"), searchMaxCltv = CltvExpiryDelta(config.getInt("router.path-finding.max-cltv")), searchMaxFeeBase = Satoshi(config.getLong("router.path-finding.fee-threshold-sat")), diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index 3f672d2e71..695afc06b9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -25,8 +25,7 @@ import akka.event.Logging.MDC import akka.util.Timeout import com.google.common.net.HostAndPort import fr.acinq.bitcoin.Crypto.PublicKey -import fr.acinq.bitcoin.{ByteVector32, DeterministicWallet, Protocol, Satoshi} -import fr.acinq.eclair +import fr.acinq.bitcoin.{Block, ByteVector32, DeterministicWallet, Protocol, Satoshi} import fr.acinq.eclair.blockchain.EclairWallet import fr.acinq.eclair.channel._ import fr.acinq.eclair.crypto.TransportHandler @@ -145,15 +144,23 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A if (remoteHasInitialRoutingSync) { if (remoteHasChannelRangeQueriesOptional || remoteHasChannelRangeQueriesMandatory) { // if they support channel queries we do nothing, they will send us their filters - log.info("{} has set initial routing sync and support channel range queries, we do nothing (they will send us a query)", remoteNodeId) + log.info("peer has set initial routing sync and supports channel range queries, we do nothing (they will send us a query)") } else { // "old" nodes, do as before + log.info("peer requested a full routing table dump") router ! GetRoutingState } } if (remoteHasChannelRangeQueriesOptional || remoteHasChannelRangeQueriesMandatory) { // if they support channel queries, always ask for their filter - router ! SendChannelQuery(remoteNodeId, d.transport) + // TODO: for now we do not activate extended queries on mainnet + val flags_opt = nodeParams.chainHash match { + case Block.RegtestGenesisBlock.hash | Block.TestnetGenesisBlock.hash => + Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) + case _ => None + } + log.info(s"sending sync channel range query with flags_opt=$flags_opt") + router ! SendChannelQuery(remoteNodeId, d.transport, flags_opt = flags_opt) } // let's bring existing/requested channels online diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/ChannelRangeQueries.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/ChannelRangeQueries.scala deleted file mode 100644 index b2fb131ac2..0000000000 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/ChannelRangeQueries.scala +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright 2019 ACINQ SAS - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package fr.acinq.eclair.router - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} -import java.nio.ByteOrder -import java.util.zip.{DeflaterOutputStream, GZIPInputStream, GZIPOutputStream, InflaterInputStream} - -import fr.acinq.bitcoin.Protocol -import fr.acinq.eclair.ShortChannelId -import scodec.bits.ByteVector - -import scala.annotation.tailrec -import scala.collection.SortedSet - -object ChannelRangeQueries { - - val UNCOMPRESSED_FORMAT = 0.toByte - val ZLIB_FORMAT = 1.toByte - - case class ShortChannelIdsBlock(val firstBlock: Long, val numBlocks: Long, shortChannelIds: ByteVector) - - /** - * Compressed a sequence of *sorted* short channel id. - * - * @param shortChannelIds must be sorted beforehand - * @return a sequence of short channel id blocks - */ - def encodeShortChannelIds(firstBlockIn: Long, numBlocksIn: Long, shortChannelIds: SortedSet[ShortChannelId], format: Byte, useGzip: Boolean = false): List[ShortChannelIdsBlock] = { - if (shortChannelIds.isEmpty) { - // special case: reply with an "empty" block - List(ShortChannelIdsBlock(firstBlockIn, numBlocksIn, ByteVector(0))) - } else { - // LN messages must fit in 65 Kb so we split ids into groups to make sure that the output message will be valid - val count = format match { - case UNCOMPRESSED_FORMAT => 7000 - case ZLIB_FORMAT => 12000 // TODO: do something less simplistic... - } - shortChannelIds.grouped(count).map(ids => { - val (firstBlock, numBlocks) = if (ids.isEmpty) (firstBlockIn, numBlocksIn) else { - val firstBlock: Long = ShortChannelId.coordinates(ids.head).blockHeight - val numBlocks: Long = ShortChannelId.coordinates(ids.last).blockHeight - firstBlock + 1 - (firstBlock, numBlocks) - } - val encoded = encodeShortChannelIdsSingle(ids, format, useGzip) - ShortChannelIdsBlock(firstBlock, numBlocks, encoded) - }).toList - } - } - - def encodeShortChannelIdsSingle(shortChannelIds: Iterable[ShortChannelId], format: Byte, useGzip: Boolean): ByteVector = { - val bos = new ByteArrayOutputStream() - bos.write(format) - format match { - case UNCOMPRESSED_FORMAT => - shortChannelIds.foreach(id => Protocol.writeUInt64(id.toLong, bos, ByteOrder.BIG_ENDIAN)) - case ZLIB_FORMAT => - val output = if (useGzip) new GZIPOutputStream(bos) else new DeflaterOutputStream(bos) - shortChannelIds.foreach(id => Protocol.writeUInt64(id.toLong, output, ByteOrder.BIG_ENDIAN)) - output.finish() - } - ByteVector.view(bos.toByteArray) - } - - /** - * Decompress a zipped sequence of sorted short channel ids. - * - * @param data - * @return a sorted set of short channel ids - */ - def decodeShortChannelIds(data: ByteVector): (Byte, SortedSet[ShortChannelId], Boolean) = { - val format = data.head - if (data.tail.isEmpty) (format, SortedSet.empty[ShortChannelId], false) else { - val buffer = new Array[Byte](8) - - // read 8 bytes from input - // zipped input stream often returns less bytes than what you want to read - @tailrec - def read8(input: InputStream, offset: Int = 0): Int = input.read(buffer, offset, 8 - offset) match { - case len if len <= 0 => len - case 8 => 8 - case len if offset + len == 8 => 8 - case len => read8(input, offset + len) - } - - // read until there's nothing left - @tailrec - def loop(input: InputStream, acc: SortedSet[ShortChannelId]): SortedSet[ShortChannelId] = { - val check = read8(input) - if (check <= 0) acc else loop(input, acc + ShortChannelId(Protocol.uint64(buffer, ByteOrder.BIG_ENDIAN))) - } - - def readAll(useGzip: Boolean) = { - val bis = new ByteArrayInputStream(data.tail.toArray) - val input = format match { - case UNCOMPRESSED_FORMAT => bis - case ZLIB_FORMAT if useGzip => new GZIPInputStream(bis) - case ZLIB_FORMAT => new InflaterInputStream(bis) - } - try { - (format, loop(input, SortedSet.empty[ShortChannelId]), useGzip) - } - finally { - input.close() - } - } - - try { - readAll(useGzip = false) - } - catch { - case _: Throwable if format == ZLIB_FORMAT => readAll(useGzip = true) - } - } - } -} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala index f9983d76fe..dd4eee4e01 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/router/Router.scala @@ -16,9 +16,12 @@ package fr.acinq.eclair.router +import java.util.zip.CRC32C + import akka.Done import akka.actor.{ActorRef, Props, Status} import akka.event.Logging.MDC +import akka.event.LoggingAdapter import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.Script.{pay2wsh, write} import fr.acinq.bitcoin.{ByteVector32, ByteVector64, Satoshi} @@ -33,7 +36,9 @@ import fr.acinq.eclair.router.Graph.GraphStructure.{DirectedGraph, GraphEdge} import fr.acinq.eclair.router.Graph.{RichWeight, WeightRatios} import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ +import shapeless.HNil +import scala.annotation.tailrec import scala.collection.immutable.{SortedMap, TreeMap} import scala.collection.{SortedSet, mutable} import scala.compat.Platform @@ -46,6 +51,8 @@ import scala.util.{Random, Try} case class RouterConf(randomizeRouteSelection: Boolean, channelExcludeDuration: FiniteDuration, routerBroadcastInterval: FiniteDuration, + requestNodeAnnouncements: Boolean, + encodingType: EncodingType, searchMaxFeeBase: Satoshi, searchMaxFeePct: Double, searchMaxRouteLength: Int, @@ -72,13 +79,15 @@ case class RouteResponse(hops: Seq[Hop], ignoreNodes: Set[PublicKey], ignoreChan } case class ExcludeChannel(desc: ChannelDesc) // this is used when we get a TemporaryChannelFailure, to give time for the channel to recover (note that exclusions are directed) case class LiftChannelExclusion(desc: ChannelDesc) -case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef) +case class SendChannelQuery(remoteNodeId: PublicKey, to: ActorRef, flags_opt: Option[QueryChannelRangeTlv]) case object GetRoutingState case class RoutingState(channels: Iterable[ChannelAnnouncement], updates: Iterable[ChannelUpdate], nodes: Iterable[NodeAnnouncement]) case class Stash(updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) case class Rebroadcast(channels: Map[ChannelAnnouncement, Set[ActorRef]], updates: Map[ChannelUpdate, Set[ActorRef]], nodes: Map[NodeAnnouncement, Set[ActorRef]]) -case class Sync(missing: SortedSet[ShortChannelId], totalMissingCount: Int) +case class ShortChannelIdAndFlag(shortChannelId: ShortChannelId, flag: Long) + +case class Sync(pending: List[RoutingMessage], total: Int) case class Data(nodes: Map[PublicKey, NodeAnnouncement], channels: SortedMap[ShortChannelId, ChannelAnnouncement], @@ -106,12 +115,15 @@ case object TickPruneStaleChannels * Created by PM on 24/05/2016. */ -class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) extends FSMDiagnosticActorLogging[State, Data] { +class Router(val nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) extends FSMDiagnosticActorLogging[State, Data] { import Router._ import ExecutionContext.Implicits.global + // we pass these to helpers classes so that they have the logging context + implicit def implicitLog: LoggingAdapter = log + context.system.eventStream.subscribe(self, classOf[LocalChannelUpdate]) context.system.eventStream.subscribe(self, classOf[LocalChannelDown]) @@ -421,11 +433,11 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom .recover { case t => sender ! Status.Failure(t) } stay - case Event(SendChannelQuery(remoteNodeId, remote), d) => + case Event(SendChannelQuery(remoteNodeId, remote, flags_opt), d) => // ask for everything // we currently send only one query_channel_range message per peer, when we just (re)connected to it, so we don't // have to worry about sending a new query_channel_range when another query is still in progress - val query = QueryChannelRange(nodeParams.chainHash, firstBlockNum = 0, numberOfBlocks = Int.MaxValue) + val query = QueryChannelRange(nodeParams.chainHash, firstBlockNum = 0L, numberOfBlocks = Int.MaxValue.toLong, TlvStream(flags_opt.toList)) log.info("sending query_channel_range={}", query) remote ! query @@ -495,79 +507,125 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom log.debug("received node announcement for nodeId={}", n.nodeId) stay using handle(n, sender, d) - case Event(PeerRoutingMessage(transport, _, routingMessage@QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks)), d) => + case Event(PeerRoutingMessage(transport, _, routingMessage@QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks, extendedQueryFlags_opt)), d) => sender ! TransportHandler.ReadAck(routingMessage) - log.info("received query_channel_range={}", routingMessage) - // sort channel ids and keep the ones which are in [firstBlockNum, firstBlockNum + numberOfBlocks] + log.info("received query_channel_range with firstBlockNum={} numberOfBlocks={} extendedQueryFlags_opt={}", firstBlockNum, numberOfBlocks, extendedQueryFlags_opt) + // keep channel ids that are in [firstBlockNum, firstBlockNum + numberOfBlocks] val shortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _, d.channels, d.updates)) - // TODO: we don't compress to be compatible with old mobile apps, switch to ZLIB ASAP - // Careful: when we remove GZIP support, eclair-wallet 0.3.0 will stop working i.e. channels to ACINQ nodes will not - // work anymore - val blocks = ChannelRangeQueries.encodeShortChannelIds(firstBlockNum, numberOfBlocks, shortChannelIds, ChannelRangeQueries.UNCOMPRESSED_FORMAT) - log.info("sending back reply_channel_range with {} items for range=({}, {})", shortChannelIds.size, firstBlockNum, numberOfBlocks) - // there could be several reply_channel_range messages for a single query - val replies = blocks.map(block => ReplyChannelRange(chainHash, block.firstBlock, block.numBlocks, 1, block.shortChannelIds)) - replies.foreach(reply => transport ! reply) + log.info("replying with {} items for range=({}, {})", shortChannelIds.size, firstBlockNum, numberOfBlocks) + split(shortChannelIds) + .foreach(chunk => { + val (timestamps, checksums) = routingMessage.queryFlags_opt match { + case Some(extension) if extension.wantChecksums | extension.wantTimestamps => + // we always compute timestamps and checksums even if we don't need both, overhead is negligible + val (timestamps, checksums) = chunk.shortChannelIds.map(getChannelDigestInfo(d.channels, d.updates)).unzip + val encodedTimestamps = if (extension.wantTimestamps) Some(ReplyChannelRangeTlv.EncodedTimestamps(nodeParams.routerConf.encodingType, timestamps)) else None + val encodedChecksums = if (extension.wantChecksums) Some(ReplyChannelRangeTlv.EncodedChecksums(checksums)) else None + (encodedTimestamps, encodedChecksums) + case _ => (None, None) + } + val reply = ReplyChannelRange(chainHash, chunk.firstBlock, chunk.numBlocks, + complete = 1, + shortChannelIds = EncodedShortChannelIds(nodeParams.routerConf.encodingType, chunk.shortChannelIds), + timestamps = timestamps, + checksums = checksums) + transport ! reply + }) stay - case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage@ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, _, data)), d) => + case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage@ReplyChannelRange(chainHash, _, _, _, shortChannelIds, _)), d) => sender ! TransportHandler.ReadAck(routingMessage) - val (format, theirShortChannelIds, useGzip) = ChannelRangeQueries.decodeShortChannelIds(data) - val ourShortChannelIds: SortedSet[ShortChannelId] = d.channels.keySet.filter(keep(firstBlockNum, numberOfBlocks, _, d.channels, d.updates)) - val missing: SortedSet[ShortChannelId] = theirShortChannelIds -- ourShortChannelIds - log.info("received reply_channel_range, we're missing {} channel announcements/updates, format={} useGzip={}", missing.size, format, useGzip) - val d1 = if (missing.nonEmpty) { - // they may send back several reply_channel_range messages for a single query_channel_range query, and we must not - // send another query_short_channel_ids query if they're still processing one - d.sync.get(remoteNodeId) match { - case None => - // we don't have a pending query with this peer - val (slice, rest) = missing.splitAt(SHORTID_WINDOW) - transport ! QueryShortChannelIds(chainHash, ChannelRangeQueries.encodeShortChannelIdsSingle(slice, format, useGzip)) - d.copy(sync = d.sync + (remoteNodeId -> Sync(rest, missing.size))) - case Some(sync) => - // we already have a pending query with this peer, add missing ids to our "sync" state - d.copy(sync = d.sync + (remoteNodeId -> Sync(sync.missing ++ missing, sync.totalMissingCount + missing.size))) + + @tailrec + def loop(ids: List[ShortChannelId], timestamps: List[ReplyChannelRangeTlv.Timestamps], checksums: List[ReplyChannelRangeTlv.Checksums], acc: List[ShortChannelIdAndFlag] = List.empty[ShortChannelIdAndFlag]): List[ShortChannelIdAndFlag] = { + ids match { + case Nil => acc.reverse + case head :: tail => + val flag = computeFlag(d.channels, d.updates)(head, timestamps.headOption, checksums.headOption, nodeParams.routerConf.requestNodeAnnouncements) + // 0 means nothing to query, just don't include it + val acc1 = if (flag != 0) ShortChannelIdAndFlag(head, flag) :: acc else acc + loop(tail, timestamps.drop(1), checksums.drop(1), acc1) } - } else d - context.system.eventStream.publish(syncProgress(d1)) - stay using d1 + } - case Event(PeerRoutingMessage(transport, _, routingMessage@QueryShortChannelIds(chainHash, data)), d) => + val timestamps_opt = routingMessage.timestamps_opt.map(_.timestamps).getOrElse(List.empty[ReplyChannelRangeTlv.Timestamps]) + val checksums_opt = routingMessage.checksums_opt.map(_.checksums).getOrElse(List.empty[ReplyChannelRangeTlv.Checksums]) + + val shortChannelIdAndFlags = loop(shortChannelIds.array, timestamps_opt, checksums_opt) + + val (channelCount, updatesCount) = shortChannelIdAndFlags.foldLeft((0, 0)) { + case ((c, u), ShortChannelIdAndFlag(_, flag)) => + val c1 = c + (if (QueryShortChannelIdsTlv.QueryFlagType.includeChannelAnnouncement(flag)) 1 else 0) + val u1 = u + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate1(flag)) 1 else 0) + (if (QueryShortChannelIdsTlv.QueryFlagType.includeUpdate2(flag)) 1 else 0) + (c1, u1) + } + log.info(s"received reply_channel_range with {} channels, we're missing {} channel announcements and {} updates, format={}", shortChannelIds.array.size, channelCount, updatesCount, shortChannelIds.encoding) + // we update our sync data to this node (there may be multiple channel range responses and we can only query one set of ids at a time) + val replies = shortChannelIdAndFlags + .grouped(SHORTID_WINDOW) + .map(chunk => QueryShortChannelIds(chainHash, + shortChannelIds = EncodedShortChannelIds(shortChannelIds.encoding, chunk.map(_.shortChannelId)), + if (routingMessage.timestamps_opt.isDefined || routingMessage.checksums_opt.isDefined) + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(shortChannelIds.encoding, chunk.map(_.flag))) + else + TlvStream.empty + )) + .toList + val (sync1, replynow_opt) = addToSync(d.sync, remoteNodeId, replies) + // we only send a rely right away if there were no pending requests + replynow_opt.foreach(transport ! _) + context.system.eventStream.publish(syncProgress(sync1)) + stay using d.copy(sync = sync1) + + case Event(PeerRoutingMessage(transport, _, routingMessage@QueryShortChannelIds(chainHash, shortChannelIds, queryFlags_opt)), d) => sender ! TransportHandler.ReadAck(routingMessage) - val (_, shortChannelIds, useGzip) = ChannelRangeQueries.decodeShortChannelIds(data) - log.info("received query_short_channel_ids for {} channel announcements, useGzip={}", shortChannelIds.size, useGzip) - shortChannelIds.foreach(shortChannelId => { - d.channels.get(shortChannelId) match { - case None => log.warning("received query for shortChannelId={} that we don't have", shortChannelId) - case Some(ca) => - transport ! ca - d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)).map(u => transport ! u) - d.updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)).map(u => transport ! u) + val flags = routingMessage.queryFlags_opt.map(_.array).getOrElse(List.empty[Long]) + + var channelCount = 0 + var updateCount = 0 + var nodeCount = 0 + + Router.processChannelQuery(d.nodes, d.channels, d.updates)( + shortChannelIds.array, + flags, + ca => { + channelCount = channelCount + 1 + transport ! ca + }, + cu => { + updateCount = updateCount + 1 + transport ! cu + }, + na => { + nodeCount = nodeCount + 1 + transport ! na } - }) + ) + log.info("received query_short_channel_ids with {} items, sent back {} channels and {} updates and {} nodes", shortChannelIds.array.size, channelCount, updateCount, nodeCount) transport ! ReplyShortChannelIdsEnd(chainHash, 1) stay - case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage@ReplyShortChannelIdsEnd(chainHash, complete)), d) => + case Event(PeerRoutingMessage(transport, remoteNodeId, routingMessage: ReplyShortChannelIdsEnd), d) => sender ! TransportHandler.ReadAck(routingMessage) - log.info("received reply_short_channel_ids_end={}", routingMessage) // have we more channels to ask this peer? - val d1 = d.sync.get(remoteNodeId) match { - case Some(sync) if sync.missing.nonEmpty => - log.info(s"asking {} for the next slice of short_channel_ids", remoteNodeId) - val (slice, rest) = sync.missing.splitAt(SHORTID_WINDOW) - transport ! QueryShortChannelIds(chainHash, ChannelRangeQueries.encodeShortChannelIdsSingle(slice, ChannelRangeQueries.UNCOMPRESSED_FORMAT, useGzip = false)) - d.copy(sync = d.sync + (remoteNodeId -> sync.copy(missing = rest))) - case Some(sync) if sync.missing.isEmpty => - // we received reply_short_channel_ids_end for our last query aand have not sent another one, we can now remove - // the remote peer from our map - d.copy(sync = d.sync - remoteNodeId) - case _ => - d + val sync1 = d.sync.get(remoteNodeId) match { + case Some(sync) => + sync.pending match { + case nextRequest +: rest => + log.info(s"asking for the next slice of short_channel_ids (remaining=${sync.pending.size}/${sync.total})") + transport ! nextRequest + d.sync + (remoteNodeId -> sync.copy(pending = rest)) + case Nil => + // we received reply_short_channel_ids_end for our last query and have not sent another one, we can now remove + // the remote peer from our map + log.info(s"sync complete (total=${sync.total})") + d.sync - remoteNodeId + } + case _ => d.sync } - context.system.eventStream.publish(syncProgress(d1)) - stay using d1 + context.system.eventStream.publish(syncProgress(sync1)) + stay using d.copy(sync = sync1) + } initialize() @@ -697,18 +755,20 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom // when we're sending updates to ourselves (transport_opt, remoteNodeId_opt) match { case (Some(transport), Some(remoteNodeId)) => + val query = QueryShortChannelIds(u.chainHash, EncodedShortChannelIds(nodeParams.routerConf.encodingType, List(u.shortChannelId)), TlvStream.empty) d.sync.get(remoteNodeId) match { case Some(sync) => // we already have a pending request to that node, let's add this channel to the list and we'll get it later - d.copy(sync = d.sync + (remoteNodeId -> sync.copy(missing = sync.missing + u.shortChannelId, totalMissingCount = sync.totalMissingCount + 1))) + // TODO: we only request channels with old style channel_query + d.copy(sync = d.sync + (remoteNodeId -> sync.copy(pending = sync.pending :+ query, total = sync.total + 1))) case None => // we send the query right away - transport ! QueryShortChannelIds(u.chainHash, ChannelRangeQueries.encodeShortChannelIdsSingle(Seq(u.shortChannelId), ChannelRangeQueries.UNCOMPRESSED_FORMAT, useGzip = false)) - d.copy(sync = d.sync + (remoteNodeId -> Sync(missing = SortedSet(u.shortChannelId), totalMissingCount = 1))) + transport ! query + d.copy(sync = d.sync + (remoteNodeId -> Sync(pending = Nil, total = 1))) } case _ => // we don't know which node this update came from (maybe it was stashed and the channel got pruned in the meantime or some other corner case). - // or we don't have a transport to send our query with. + // or we don't have a transport to send our query to. // anyway, that's not really a big deal because we have removed the channel from the pruned db so next time it shows up we will revalidate it d } @@ -718,13 +778,14 @@ class Router(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Prom } override def mdc(currentMessage: Any): MDC = currentMessage match { - case SendChannelQuery(remoteNodeId, _) => Logs.mdc(remoteNodeId_opt = Some(remoteNodeId)) + case SendChannelQuery(remoteNodeId, _, _) => Logs.mdc(remoteNodeId_opt = Some(remoteNodeId)) case PeerRoutingMessage(_, remoteNodeId, _) => Logs.mdc(remoteNodeId_opt = Some(remoteNodeId)) case _ => akka.event.Logging.emptyMDC } } object Router { + val SHORTID_WINDOW = 100 def props(nodeParams: NodeParams, watcher: ActorRef, initialized: Option[Promise[Done]] = None) = Props(new Router(nodeParams, watcher, initialized)) @@ -750,11 +811,19 @@ object Router { def hasChannels(nodeId: PublicKey, channels: Iterable[ChannelAnnouncement]): Boolean = channels.exists(c => isRelatedTo(c, nodeId)) - def isStale(u: ChannelUpdate): Boolean = { + def isStale(u: ChannelUpdate): Boolean = isStale(u.timestamp) + + def isStale(timestamp: Long): Boolean = { // BOLT 7: "nodes MAY prune channels should the timestamp of the latest channel_update be older than 2 weeks" // but we don't want to prune brand new channels for which we didn't yet receive a channel update val staleThresholdSeconds = (Platform.currentTime.milliseconds - 14.days).toSeconds - u.timestamp < staleThresholdSeconds + timestamp < staleThresholdSeconds + } + + def isAlmostStale(timestamp: Long): Boolean = { + // we define almost stale as 2 weeks minus 4 days + val staleThresholdSeconds = (Platform.currentTime.milliseconds - 10.days).toSeconds + timestamp < staleThresholdSeconds } /** @@ -793,12 +862,153 @@ object Router { height >= firstBlockNum && height <= (firstBlockNum + numberOfBlocks) } - def syncProgress(d: Data): SyncProgress = - if (d.sync.isEmpty) { + def shouldRequestUpdate(ourTimestamp: Long, ourChecksum: Long, theirTimestamp_opt: Option[Long], theirChecksum_opt: Option[Long]): Boolean = { + (theirTimestamp_opt, theirChecksum_opt) match { + case (Some(theirTimestamp), Some(theirChecksum)) => + // we request their channel_update if all those conditions are met: + // - it is more recent than ours + // - it is different from ours, or it is the same but ours is about to be stale + // - it is not stale + val theirsIsMoreRecent = ourTimestamp < theirTimestamp + val areDifferent = ourChecksum != theirChecksum + val oursIsAlmostStale = isAlmostStale(ourTimestamp) + val theirsIsStale = isStale(theirTimestamp) + theirsIsMoreRecent && (areDifferent || oursIsAlmostStale) && !theirsIsStale + case (Some(theirTimestamp), None) => + // if we only have their timestamp, we request their channel_update if theirs is more recent than ours + val theirsIsMoreRecent = ourTimestamp < theirTimestamp + val theirsIsStale = isStale(theirTimestamp) + theirsIsMoreRecent && !theirsIsStale + case (None, Some(theirChecksum)) => + // if we only have their checksum, we request their channel_update if it is different from ours + val areDifferent = theirChecksum != 0 && ourChecksum != theirChecksum + areDifferent + case (None, None) => + // if we have neither their timestamp nor their checksum we request their channel_update + true + } + } + + def computeFlag(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])( + shortChannelId: ShortChannelId, + theirTimestamps_opt: Option[ReplyChannelRangeTlv.Timestamps], + theirChecksums_opt: Option[ReplyChannelRangeTlv.Checksums], + includeNodeAnnouncements: Boolean): Long = { + import QueryShortChannelIdsTlv.QueryFlagType._ + + val flagsNodes = if (includeNodeAnnouncements) INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2 else 0 + + val flags = if (!channels.contains(shortChannelId)) { + INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 + } else { + // we already know this channel + val (ourTimestamps, ourChecksums) = Router.getChannelDigestInfo(channels, updates)(shortChannelId) + // if they don't provide timestamps or checksums, we set appropriate default values: + // - we assume their timestamp is more recent than ours by setting timestamp = Long.MaxValue + // - we assume their update is different from ours by setting checkum = Long.MaxValue (NB: our default value for checksum is 0) + val shouldRequestUpdate1 = shouldRequestUpdate(ourTimestamps.timestamp1, ourChecksums.checksum1, theirTimestamps_opt.map(_.timestamp1), theirChecksums_opt.map(_.checksum1)) + val shouldRequestUpdate2 = shouldRequestUpdate(ourTimestamps.timestamp2, ourChecksums.checksum2, theirTimestamps_opt.map(_.timestamp2), theirChecksums_opt.map(_.checksum2)) + val flagUpdate1 = if (shouldRequestUpdate1) INCLUDE_CHANNEL_UPDATE_1 else 0 + val flagUpdate2 = if (shouldRequestUpdate2) INCLUDE_CHANNEL_UPDATE_2 else 0 + flagUpdate1 | flagUpdate2 + } + + if (flags == 0) 0 else flags | flagsNodes + } + + /** + * Handle a query message, which includes a list of channel ids and flags. + * + * @param nodes node id -> node announcement + * @param channels channel id -> channel announcement + * @param updates channel description -> channel update + * @param ids list of channel ids + * @param flags list of query flags, either empty one flag per channel id + * @param onChannel called when a channel announcement matches (i.e. its bit is set in the query flag and we have it) + * @param onUpdate called when a channel update matches + * @param onNode called when a node announcement matches + * + */ + def processChannelQuery(nodes: Map[PublicKey, NodeAnnouncement], + channels: SortedMap[ShortChannelId, ChannelAnnouncement], + updates: Map[ChannelDesc, ChannelUpdate])( + ids: List[ShortChannelId], + flags: List[Long], + onChannel: ChannelAnnouncement => Unit, + onUpdate: ChannelUpdate => Unit, + onNode: NodeAnnouncement => Unit)(implicit log: LoggingAdapter): Unit = { + import QueryShortChannelIdsTlv.QueryFlagType + + // we loop over channel ids and query flag. We track node Ids for node announcement + // we've already sent to avoid sending them multiple times, as requested by the BOLTs + @tailrec + def loop(ids: List[ShortChannelId], flags: List[Long], numca: Int = 0, numcu: Int = 0, nodesSent: Set[PublicKey] = Set.empty[PublicKey]): (Int, Int, Int) = ids match { + case Nil => (numca, numcu, nodesSent.size) + case head :: tail if !channels.contains(head) => + log.warning("received query for shortChannelId={} that we don't have", head) + loop(tail, flags.drop(1), numca, numcu, nodesSent) + case head :: tail => + var numca1 = numca + var numcu1 = numcu + var sent1 = nodesSent + val ca = channels(head) + val flag_opt = flags.headOption + // no flag means send everything + + val includeChannel = flag_opt.forall(QueryFlagType.includeChannelAnnouncement) + val includeUpdate1 = flag_opt.forall(QueryFlagType.includeUpdate1) + val includeUpdate2 = flag_opt.forall(QueryFlagType.includeUpdate2) + val includeNode1 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement1) + val includeNode2 = flag_opt.forall(QueryFlagType.includeNodeAnnouncement2) + + if (includeChannel) { + onChannel(ca) + } + if (includeUpdate1) { + updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)).foreach { u => + onUpdate(u) + } + } + if (includeUpdate2) { + updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)).foreach { u => + onUpdate(u) + } + } + if (includeNode1 && !sent1.contains(ca.nodeId1)) { + nodes.get(ca.nodeId1).foreach { n => + onNode(n) + sent1 = sent1 + ca.nodeId1 + } + } + if (includeNode2 && !sent1.contains(ca.nodeId2)) { + nodes.get(ca.nodeId2).foreach { n => + onNode(n) + sent1 = sent1 + ca.nodeId2 + } + } + loop(tail, flags.drop(1), numca1, numcu1, sent1) + } + + loop(ids, flags) + } + + /** + * Returns overall progress on synchronization + * + * @param sync + * @return a sync progress indicator (1 means fully synced) + */ + def syncProgress(sync: Map[PublicKey, Sync]): SyncProgress = { + //NB: progress is in terms of requests, not individual channels + val (pending, total) = sync.foldLeft((0, 0)) { + case ((p, t), (_, sync)) => (p + sync.pending.size, t + sync.total) + } + if (total == 0) { SyncProgress(1) } else { - SyncProgress(1 - d.sync.values.map(_.missing.size).sum * 1.0 / d.sync.values.map(_.totalMissingCount).sum) + SyncProgress((total - pending) / (1.0 * total)) } + } /** * This method is used after a payment failed, and we want to exclude some nodes that we know are failing @@ -813,6 +1023,92 @@ object Router { desc } + /** + * + * @param channels id -> announcement map + * @param updates channel updates + * @param id short channel id + * @return the timestamp of the most recent update for this channel id, 0 if we don't have any + */ + def getTimestamp(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(id: ShortChannelId): Long = { + val ca = channels(id) + val opt1 = updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId1, ca.nodeId2)) + val opt2 = updates.get(ChannelDesc(ca.shortChannelId, ca.nodeId2, ca.nodeId1)) + val timestamp = (opt1, opt2) match { + case (Some(u1), Some(u2)) => Math.max(u1.timestamp, u2.timestamp) + case (Some(u1), None) => u1.timestamp + case (None, Some(u2)) => u2.timestamp + case (None, None) => 0L + } + timestamp + } + + def getChannelDigestInfo(channels: SortedMap[ShortChannelId, ChannelAnnouncement], updates: Map[ChannelDesc, ChannelUpdate])(shortChannelId: ShortChannelId): (ReplyChannelRangeTlv.Timestamps, ReplyChannelRangeTlv.Checksums) = { + val c = channels(shortChannelId) + val u1_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId1, c.nodeId2)) + val u2_opt = updates.get(ChannelDesc(c.shortChannelId, c.nodeId2, c.nodeId1)) + val timestamp1 = u1_opt.map(_.timestamp).getOrElse(0L) + val timestamp2 = u2_opt.map(_.timestamp).getOrElse(0L) + val checksum1 = u1_opt.map(getChecksum).getOrElse(0L) + val checksum2 = u2_opt.map(getChecksum).getOrElse(0L) + (ReplyChannelRangeTlv.Timestamps(timestamp1 = timestamp1, timestamp2 = timestamp2), ReplyChannelRangeTlv.Checksums(checksum1 = checksum1, checksum2 = checksum2)) + } + + def getChecksum(u: ChannelUpdate): Long = { + import u._ + val data = serializationResult(LightningMessageCodecs.channelUpdateChecksumCodec.encode(chainHash :: shortChannelId :: messageFlags :: channelFlags :: cltvExpiryDelta :: htlcMinimumMsat :: feeBaseMsat :: feeProportionalMillionths :: htlcMaximumMsat :: HNil)) + val checksum = new CRC32C() + checksum.update(data.toArray) + checksum.getValue + } + + case class ShortChannelIdsChunk(firstBlock: Long, numBlocks: Long, shortChannelIds: List[ShortChannelId]) + + /** + * Have to split ids because otherwise message could be too big + * there could be several reply_channel_range messages for a single query + * + * @param shortChannelIds + * @return + */ + def split(shortChannelIds: SortedSet[ShortChannelId]): List[ShortChannelIdsChunk] = { + // this algorithm can split blocks (meaning that we can in theory generate several replies with the same first_block/num_blocks + // and a different set of short_channel_ids) but it doesn't matter + val SPLIT_SIZE = 3500 // we can theoretically fit 4091 uncompressed channel ids in a single lightning message (max size 65 Kb) + if (shortChannelIds.isEmpty) { + List(ShortChannelIdsChunk(0, 0, List.empty)) + } else { + shortChannelIds + .grouped(SPLIT_SIZE) + .toList + .map { group => + // NB: group is never empty + val firstBlock: Long = ShortChannelId.coordinates(group.head).blockHeight.toLong + val numBlocks: Long = ShortChannelId.coordinates(group.last).blockHeight.toLong - firstBlock + 1 + ShortChannelIdsChunk(firstBlock, numBlocks, group.toList) + } + } + } + + def addToSync(syncMap: Map[PublicKey, Sync], remoteNodeId: PublicKey, pending: List[RoutingMessage]): (Map[PublicKey, Sync], Option[RoutingMessage]) = { + pending match { + case head +: rest => + // they may send back several reply_channel_range messages for a single query_channel_range query, and we must not + // send another query_short_channel_ids query if they're still processing one + syncMap.get(remoteNodeId) match { + case None => + // we don't have a pending query with this peer, let's send it + (syncMap + (remoteNodeId -> Sync(rest, pending.size)), Some(head)) + case Some(sync) => + // we already have a pending query with this peer, add missing ids to our "sync" state + (syncMap + (remoteNodeId -> Sync(sync.pending ++ pending, sync.total + pending.size)), None) + } + case Nil => + // there is nothing to send + (syncMap, None) + } + } + /** * https://github.com/lightningnetwork/lightning-rfc/blob/master/04-onion-routing.md#clarifications */ diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala index 898605c187..0fffae7eca 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/CommonCodecs.scala @@ -60,6 +60,10 @@ object CommonCodecs { val cltvExpiry: Codec[CltvExpiry] = uint32.xmapc(CltvExpiry)((_: CltvExpiry).toLong) val cltvExpiryDelta: Codec[CltvExpiryDelta] = uint16.xmapc(CltvExpiryDelta)((_: CltvExpiryDelta).toInt) + // this is needed because some millisatoshi values are encoded on 32 bits in the BOLTs + // this codec will fail if the amount does not fit on 32 bits + val millisatoshi32: Codec[MilliSatoshi] = uint32.xmapc(l => MilliSatoshi(l))(_.amount) + /** * We impose a minimal encoding on some values (such as varint and truncated int) to ensure that signed hashes can be * re-computed correctly. diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala index 4765487d68..5c65700c51 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageCodecs.scala @@ -16,8 +16,8 @@ package fr.acinq.eclair.wire +import fr.acinq.eclair.wire import fr.acinq.eclair.wire.CommonCodecs._ -import fr.acinq.eclair.{MilliSatoshi, wire} import scodec.Codec import scodec.codecs._ @@ -182,6 +182,18 @@ object LightningMessageCodecs { ("signature" | bytes64) :: nodeAnnouncementWitnessCodec).as[NodeAnnouncement] + val channelUpdateChecksumCodec = + ("chainHash" | bytes32) :: + ("shortChannelId" | shortchannelid) :: + (("messageFlags" | byte) >>:~ { messageFlags => + ("channelFlags" | byte) :: + ("cltvExpiryDelta" | cltvExpiryDelta) :: + ("htlcMinimumMsat" | millisatoshi) :: + ("feeBaseMsat" | millisatoshi32) :: + ("feeProportionalMillionths" | uint32) :: + ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, millisatoshi)) + }) + val channelUpdateWitnessCodec = ("chainHash" | bytes32) :: ("shortChannelId" | shortchannelid) :: @@ -190,7 +202,7 @@ object LightningMessageCodecs { ("channelFlags" | byte) :: ("cltvExpiryDelta" | cltvExpiryDelta) :: ("htlcMinimumMsat" | millisatoshi) :: - ("feeBaseMsat" | uint32.xmapc(l => MilliSatoshi(l))(_.amount)) :: + ("feeBaseMsat" | millisatoshi32) :: ("feeProportionalMillionths" | uint32) :: ("htlcMaximumMsat" | conditional((messageFlags & 1) != 0, millisatoshi)) :: ("unknownFields" | bytes) @@ -200,29 +212,43 @@ object LightningMessageCodecs { ("signature" | bytes64) :: channelUpdateWitnessCodec).as[ChannelUpdate] - val queryShortChannelIdsCodec: Codec[QueryShortChannelIds] = ( - ("chainHash" | bytes32) :: - ("data" | varsizebinarydata) + val encodedShortChannelIdsCodec: Codec[EncodedShortChannelIds] = + discriminated[EncodedShortChannelIds].by(byte) + .\(0) { case a@EncodedShortChannelIds(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(shortchannelid)).as[EncodedShortChannelIds]) + .\(1) { case a@EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(shortchannelid))).as[EncodedShortChannelIds]) + + val queryShortChannelIdsCodec: Codec[QueryShortChannelIds] = { + Codec( + ("chainHash" | bytes32) :: + ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: + ("tlvStream" | QueryShortChannelIdsTlv.codec) ).as[QueryShortChannelIds] + } val replyShortChanelIdsEndCodec: Codec[ReplyShortChannelIdsEnd] = ( ("chainHash" | bytes32) :: ("complete" | byte) ).as[ReplyShortChannelIdsEnd] - val queryChannelRangeCodec: Codec[QueryChannelRange] = ( - ("chainHash" | bytes32) :: - ("firstBlockNum" | uint32) :: - ("numberOfBlocks" | uint32) - ).as[QueryChannelRange] - - val replyChannelRangeCodec: Codec[ReplyChannelRange] = ( - ("chainHash" | bytes32) :: - ("firstBlockNum" | uint32) :: - ("numberOfBlocks" | uint32) :: - ("complete" | byte) :: - ("data" | varsizebinarydata) - ).as[ReplyChannelRange] + val queryChannelRangeCodec: Codec[QueryChannelRange] = { + Codec( + ("chainHash" | bytes32) :: + ("firstBlockNum" | uint32) :: + ("numberOfBlocks" | uint32) :: + ("tlvStream" | QueryChannelRangeTlv.codec) + ).as[QueryChannelRange] + } + + val replyChannelRangeCodec: Codec[ReplyChannelRange] = { + Codec( + ("chainHash" | bytes32) :: + ("firstBlockNum" | uint32) :: + ("numberOfBlocks" | uint32) :: + ("complete" | byte) :: + ("shortChannelIds" | variableSizeBytes(uint16, encodedShortChannelIdsCodec)) :: + ("tlvStream" | ReplyChannelRangeTlv.codec) + ).as[ReplyChannelRange] + } val gossipTimestampFilterCodec: Codec[GossipTimestampFilter] = ( ("chainHash" | bytes32) :: diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala index 48bfb8c49e..a78fef82d8 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/LightningMessageTypes.scala @@ -225,21 +225,61 @@ case class ChannelUpdate(signature: ByteVector64, require(((messageFlags & 1) != 0) == htlcMaximumMsat.isDefined, "htlcMaximumMsat is not consistent with messageFlags") } +// @formatter:off +sealed trait EncodingType +object EncodingType { + case object UNCOMPRESSED extends EncodingType + case object COMPRESSED_ZLIB extends EncodingType +} +// @formatter:on + + +case class EncodedShortChannelIds(encoding: EncodingType, + array: List[ShortChannelId]) + + case class QueryShortChannelIds(chainHash: ByteVector32, - data: ByteVector) extends RoutingMessage with HasChainHash + shortChannelIds: EncodedShortChannelIds, + tlvStream: TlvStream[QueryShortChannelIdsTlv] = TlvStream.empty) extends RoutingMessage with HasChainHash { + val queryFlags_opt: Option[QueryShortChannelIdsTlv.EncodedQueryFlags] = tlvStream.get[QueryShortChannelIdsTlv.EncodedQueryFlags] +} + +case class ReplyShortChannelIdsEnd(chainHash: ByteVector32, + complete: Byte) extends RoutingMessage with HasChainHash + case class QueryChannelRange(chainHash: ByteVector32, firstBlockNum: Long, - numberOfBlocks: Long) extends RoutingMessage with HasChainHash + numberOfBlocks: Long, + tlvStream: TlvStream[QueryChannelRangeTlv] = TlvStream.empty) extends RoutingMessage { + val queryFlags_opt: Option[QueryChannelRangeTlv.QueryFlags] = tlvStream.get[QueryChannelRangeTlv.QueryFlags] +} case class ReplyChannelRange(chainHash: ByteVector32, firstBlockNum: Long, numberOfBlocks: Long, complete: Byte, - data: ByteVector) extends RoutingMessage with HasChainHash + shortChannelIds: EncodedShortChannelIds, + tlvStream: TlvStream[ReplyChannelRangeTlv] = TlvStream.empty) extends RoutingMessage { + val timestamps_opt: Option[ReplyChannelRangeTlv.EncodedTimestamps] = tlvStream.get[ReplyChannelRangeTlv.EncodedTimestamps] + + val checksums_opt: Option[ReplyChannelRangeTlv.EncodedChecksums] = tlvStream.get[ReplyChannelRangeTlv.EncodedChecksums] +} + +object ReplyChannelRange { + def apply(chainHash: ByteVector32, + firstBlockNum: Long, + numberOfBlocks: Long, + complete: Byte, + shortChannelIds: EncodedShortChannelIds, + timestamps: Option[ReplyChannelRangeTlv.EncodedTimestamps], + checksums: Option[ReplyChannelRangeTlv.EncodedChecksums]) = { + timestamps.foreach(ts => require(ts.timestamps.length == shortChannelIds.array.length)) + checksums.foreach(cs => require(cs.checksums.length == shortChannelIds.array.length)) + new ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, complete, shortChannelIds, TlvStream(timestamps.toList ::: checksums.toList)) + } +} -case class ReplyShortChannelIdsEnd(chainHash: ByteVector32, - complete: Byte) extends RoutingMessage with HasChainHash case class GossipTimestampFilter(chainHash: ByteVector32, firstTimestamp: Long, diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala new file mode 100644 index 0000000000..0dc5f57050 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryChannelRangeTlv.scala @@ -0,0 +1,37 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.wire.CommonCodecs.{shortchannelid, varint, varintoverflow} +import scodec.Codec +import scodec.codecs._ + +sealed trait QueryChannelRangeTlv extends Tlv + +object QueryChannelRangeTlv { + /** + * Optional query flag that is appended to QueryChannelRange + * @param flag bit 1 set means I want timestamps, bit 2 set means I want checksums + */ + case class QueryFlags(flag: Long) extends QueryChannelRangeTlv { + val wantTimestamps = QueryFlags.wantTimestamps(flag) + + val wantChecksums = QueryFlags.wantChecksums(flag) + } + + case object QueryFlags { + val WANT_TIMESTAMPS: Long = 1 + val WANT_CHECKSUMS: Long = 2 + val WANT_ALL: Long = (WANT_TIMESTAMPS | WANT_CHECKSUMS) + + def wantTimestamps(flag: Long) = (flag & WANT_TIMESTAMPS) != 0 + + def wantChecksums(flag: Long) = (flag & WANT_CHECKSUMS) != 0 + } + + val queryFlagsCodec: Codec[QueryFlags] = Codec(("flag" | varintoverflow)).as[QueryFlags] + + val codec: Codec[TlvStream[QueryChannelRangeTlv]] = TlvCodecs.tlvStream(discriminated.by(varint) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, queryFlagsCodec)) + ) + +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala new file mode 100644 index 0000000000..12f5ad96fa --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/QueryShortChannelIdsTlv.scala @@ -0,0 +1,46 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.UInt64 +import fr.acinq.eclair.wire.CommonCodecs.{varint, varintoverflow} +import scodec.Codec +import scodec.codecs.{byte, discriminated, list, provide, variableSizeBytesLong, zlib} + +sealed trait QueryShortChannelIdsTlv extends Tlv + +object QueryShortChannelIdsTlv { + + /** + * Optional TLV-based query message that can be appended to QueryShortChannelIds + * @param encoding 0 means uncompressed, 1 means compressed with zlib + * @param array array of query flags, each flags specifies the info we want for a given channel + */ + case class EncodedQueryFlags(encoding: EncodingType, array: List[Long]) extends QueryShortChannelIdsTlv + + case object QueryFlagType { + val INCLUDE_CHANNEL_ANNOUNCEMENT: Long = 1 + val INCLUDE_CHANNEL_UPDATE_1: Long = 2 + val INCLUDE_CHANNEL_UPDATE_2: Long = 4 + val INCLUDE_NODE_ANNOUNCEMENT_1: Long = 8 + val INCLUDE_NODE_ANNOUNCEMENT_2: Long = 16 + + def includeChannelAnnouncement(flag: Long) = (flag & INCLUDE_CHANNEL_ANNOUNCEMENT) != 0 + + def includeUpdate1(flag: Long) = (flag & INCLUDE_CHANNEL_UPDATE_1) != 0 + + def includeUpdate2(flag: Long) = (flag & INCLUDE_CHANNEL_UPDATE_2) != 0 + + def includeNodeAnnouncement1(flag: Long) = (flag & INCLUDE_NODE_ANNOUNCEMENT_1) != 0 + + def includeNodeAnnouncement2(flag: Long) = (flag & INCLUDE_NODE_ANNOUNCEMENT_2) != 0 + } + + val encodedQueryFlagsCodec: Codec[EncodedQueryFlags] = + discriminated[EncodedQueryFlags].by(byte) + .\(0) { case a@EncodedQueryFlags(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(varintoverflow)).as[EncodedQueryFlags]) + .\(1) { case a@EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(varintoverflow))).as[EncodedQueryFlags]) + + + val codec: Codec[TlvStream[QueryShortChannelIdsTlv]] = TlvCodecs.tlvStream(discriminated.by(varint) + .typecase(UInt64(1), variableSizeBytesLong(varintoverflow, encodedQueryFlagsCodec)) + ) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala new file mode 100644 index 0000000000..bde4605551 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/ReplyChannelRangeTlv.scala @@ -0,0 +1,64 @@ +package fr.acinq.eclair.wire + +import fr.acinq.eclair.{UInt64, wire} +import fr.acinq.eclair.wire.CommonCodecs.{varint, varintoverflow} +import scodec.Codec +import scodec.codecs._ + +sealed trait ReplyChannelRangeTlv extends Tlv + +object ReplyChannelRangeTlv { + + /** + * + * @param timestamp1 timestamp for node 1, or 0 + * @param timestamp2 timestamp for node 2, or 0 + */ + case class Timestamps(timestamp1: Long, timestamp2: Long) + + /** + * Optional timestamps TLV that can be appended to ReplyChannelRange + * + * @param encoding same convention as for short channel ids + * @param timestamps + */ + case class EncodedTimestamps(encoding: EncodingType, timestamps: List[Timestamps]) extends ReplyChannelRangeTlv + + /** + * + * @param checksum1 checksum for node 1, or 0 + * @param checksum2 checksum for node 2, or 0 + */ + case class Checksums(checksum1: Long, checksum2: Long) + + /** + * Optional checksums TLV that can be appended to ReplyChannelRange + * + * @param checksums + */ + case class EncodedChecksums(checksums: List[Checksums]) extends ReplyChannelRangeTlv + + val timestampsCodec: Codec[Timestamps] = ( + ("checksum1" | uint32) :: + ("checksum2" | uint32) + ).as[Timestamps] + + val encodedTimestampsCodec: Codec[EncodedTimestamps] = variableSizeBytesLong(varintoverflow, + discriminated[EncodedTimestamps].by(byte) + .\(0) { case a@EncodedTimestamps(EncodingType.UNCOMPRESSED, _) => a }((provide[EncodingType](EncodingType.UNCOMPRESSED) :: list(timestampsCodec)).as[EncodedTimestamps]) + .\(1) { case a@EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, _) => a }((provide[EncodingType](EncodingType.COMPRESSED_ZLIB) :: zlib(list(timestampsCodec))).as[EncodedTimestamps]) + ) + + val checksumsCodec: Codec[Checksums] = ( + ("checksum1" | uint32) :: + ("checksum2" | uint32) + ).as[Checksums] + + val encodedChecksumsCodec: Codec[EncodedChecksums] = variableSizeBytesLong(varintoverflow, list(checksumsCodec)).as[EncodedChecksums] + + val innerCodec = discriminated[ReplyChannelRangeTlv].by(varint) + .typecase(UInt64(1), encodedTimestampsCodec) + .typecase(UInt64(3), encodedChecksumsCodec) + + val codec: Codec[TlvStream[ReplyChannelRangeTlv]] = TlvCodecs.tlvStream(innerCodec) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala index e4eea124f8..b87b1ebcb6 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala @@ -19,6 +19,8 @@ package fr.acinq.eclair.wire import fr.acinq.eclair.UInt64 import scodec.bits.ByteVector +import scala.reflect.ClassTag + /** * Created by t-bast on 20/06/2019. */ @@ -45,9 +47,18 @@ case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv * @param unknown unknown tlv records. * @tparam T the stream namespace is a trait extending the top-level tlv trait. */ -case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) +case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) { + /** + * + * @tparam R input type parameter, must be a subtype of the main TLV type + * @return the TLV record of of type that matches the input type parameter if any (there can be at most one, since BOLTs specify + * that TLV records are supposed to be unique + */ + def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r } +} object TlvStream { + def empty[T <: Tlv] = TlvStream[T](Nil, Nil) def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil) diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index 188ec8794b..cc38cb8e6f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -26,7 +26,7 @@ import fr.acinq.eclair.crypto.LocalKeyManager import fr.acinq.eclair.db._ import fr.acinq.eclair.io.Peer import fr.acinq.eclair.router.RouterConf -import fr.acinq.eclair.wire.{Color, NodeAddress} +import fr.acinq.eclair.wire.{Color, EncodingType, NodeAddress} import scodec.bits.ByteVector import scala.concurrent.duration._ @@ -106,6 +106,8 @@ object TestConstants { randomizeRouteSelection = false, channelExcludeDuration = 60 seconds, routerBroadcastInterval = 5 seconds, + requestNodeAnnouncements = true, + encodingType = EncodingType.COMPRESSED_ZLIB, searchMaxFeeBase = Satoshi(21), searchMaxFeePct = 0.03, searchMaxCltv = CltvExpiryDelta(2016), @@ -176,6 +178,8 @@ object TestConstants { randomizeRouteSelection = false, channelExcludeDuration = 60 seconds, routerBroadcastInterval = 5 seconds, + requestNodeAnnouncements = true, + encodingType = EncodingType.UNCOMPRESSED, searchMaxFeeBase = Satoshi(21), searchMaxFeePct = 0.03, searchMaxCltv = CltvExpiryDelta(2016), diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala index d53d86ddba..e9f0414147 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala @@ -31,8 +31,8 @@ import fr.acinq.eclair.channel.{ChannelCreated, HasCommitments} import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer._ import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo -import fr.acinq.eclair.router.{ChannelRangeQueries, ChannelRangeQueriesSpec, Rebroadcast} -import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, Error, IPv4, NodeAddress, NodeAnnouncement, Ping, Pong} +import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec} +import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, Tlv, TlvStream} import org.scalatest.{Outcome, Tag} import scodec.bits.ByteVector @@ -43,7 +43,7 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { def ipv4FromInet4(address: InetSocketAddress) = IPv4.apply(address.getAddress.asInstanceOf[Inet4Address], address.getPort) val fakeIPAddress = NodeAddress.fromParts("1.2.3.4", 42000).get - val shortChannelIds = ChannelRangeQueriesSpec.shortChannelIds.take(100) + val shortChannelIds = RoutingSyncSpec.shortChannelIds.take(100) val fakeRoutingInfo = shortChannelIds.map(makeFakeRoutingInfo) val channels = fakeRoutingInfo.map(_._1).toList val updates = (fakeRoutingInfo.map(_._2) ++ fakeRoutingInfo.map(_._3)).toList @@ -336,7 +336,10 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods { val probe = TestProbe() connect(remoteNodeId, authenticator, watcher, router, relayer, connection, transport, peer) - val query = wire.QueryShortChannelIds(Alice.nodeParams.chainHash, ChannelRangeQueries.encodeShortChannelIdsSingle(Seq(ShortChannelId(42000)), ChannelRangeQueries.UNCOMPRESSED_FORMAT, useGzip = false)) + val query = QueryShortChannelIds( + Alice.nodeParams.chainHash, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42000))), + TlvStream.empty) // make sure that routing messages go through for (ann <- channels ++ updates) { diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala index 4aa74b7999..7428007a8f 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/ChannelRangeQueriesSpec.scala @@ -16,63 +16,112 @@ package fr.acinq.eclair.router -import fr.acinq.bitcoin.Block -import fr.acinq.eclair.ShortChannelId -import fr.acinq.eclair.wire.ReplyChannelRange +import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ +import fr.acinq.eclair.{MilliSatoshi, randomKey} import org.scalatest.FunSuite -import scala.collection.{SortedSet, immutable} +import scala.collection.immutable.SortedMap +import scala.compat.Platform class ChannelRangeQueriesSpec extends FunSuite { - import ChannelRangeQueriesSpec._ - - test("create `reply_channel_range` messages (uncompressed format)") { - val blocks = ChannelRangeQueries.encodeShortChannelIds(400000, 20000, shortChannelIds, ChannelRangeQueries.UNCOMPRESSED_FORMAT) - val replies = blocks.map(block => ReplyChannelRange(Block.RegtestGenesisBlock.blockId, block.firstBlock, block.numBlocks, 1, block.shortChannelIds)) - var decoded = Set.empty[ShortChannelId] - replies.foreach(reply => decoded = decoded ++ ChannelRangeQueries.decodeShortChannelIds(reply.data)._2) - assert(decoded == shortChannelIds) - } - test("create `reply_channel_range` messages (ZLIB format)") { - val blocks = ChannelRangeQueries.encodeShortChannelIds(400000, 20000, shortChannelIds, ChannelRangeQueries.ZLIB_FORMAT, useGzip = false) - val replies = blocks.map(block => ReplyChannelRange(Block.RegtestGenesisBlock.blockId, block.firstBlock, block.numBlocks, 1, block.shortChannelIds)) - var decoded = Set.empty[ShortChannelId] - replies.foreach(reply => decoded = decoded ++ { - val (ChannelRangeQueries.ZLIB_FORMAT, ids, false) = ChannelRangeQueries.decodeShortChannelIds(reply.data) - ids - }) - assert(decoded == shortChannelIds) - } + test("ask for update test") { + // they don't provide anything => we always ask for the update + assert(Router.shouldRequestUpdate(0, 0, None, None)) + assert(Router.shouldRequestUpdate(Int.MaxValue, 12345, None, None)) - test("create `reply_channel_range` messages (GZIP format)") { - val blocks = ChannelRangeQueries.encodeShortChannelIds(400000, 20000, shortChannelIds, ChannelRangeQueries.ZLIB_FORMAT, useGzip = true) - val replies = blocks.map(block => ReplyChannelRange(Block.RegtestGenesisBlock.blockId, block.firstBlock, block.numBlocks, 1, block.shortChannelIds)) - var decoded = Set.empty[ShortChannelId] - replies.foreach(reply => decoded = decoded ++ { - val (ChannelRangeQueries.ZLIB_FORMAT, ids, true) = ChannelRangeQueries.decodeShortChannelIds(reply.data) - ids - }) - assert(decoded == shortChannelIds) - } + // their update is older => don't ask + val now = Platform.currentTime / 1000 + assert(!Router.shouldRequestUpdate(now, 0, Some(now - 1), None)) + assert(!Router.shouldRequestUpdate(now, 0, Some(now - 1), Some(12345))) + assert(!Router.shouldRequestUpdate(now, 12344, Some(now - 1), None)) + assert(!Router.shouldRequestUpdate(now, 12344, Some(now - 1), Some(12345))) + + // their update is newer but stale => don't ask + val old = now - 4 * 2016 * 24 * 3600 + assert(!Router.shouldRequestUpdate(old - 1, 0, Some(old), None)) + assert(!Router.shouldRequestUpdate(old - 1, 0, Some(old), Some(12345))) + assert(!Router.shouldRequestUpdate(old - 1, 12344, Some(old), None)) + assert(!Router.shouldRequestUpdate(old - 1, 12344, Some(old), Some(12345))) + + // their update is newer but with the same checksum, and ours is stale or about to be => ask (we want to renew our update) + assert(Router.shouldRequestUpdate(old, 12345, Some(now), Some(12345))) + + // their update is newer but with the same checksum => don't ask + assert(!Router.shouldRequestUpdate(now - 1, 12345, Some(now), Some(12345))) - test("create empty `reply_channel_range` message") { - val blocks = ChannelRangeQueries.encodeShortChannelIds(400000, 20000, SortedSet.empty[ShortChannelId], ChannelRangeQueries.ZLIB_FORMAT, useGzip = false) - val replies = blocks.map(block => ReplyChannelRange(Block.RegtestGenesisBlock.blockId, block.firstBlock, block.numBlocks, 1, block.shortChannelIds)) - var decoded = Set.empty[ShortChannelId] - replies.foreach(reply => decoded = decoded ++ { - val (format, ids, false) = ChannelRangeQueries.decodeShortChannelIds(reply.data) - ids - }) - assert(decoded.isEmpty) + // their update is newer with a different checksum => always ask + assert(Router.shouldRequestUpdate(now - 1, 0, Some(now), None)) + assert(Router.shouldRequestUpdate(now - 1, 0, Some(now), Some(12345))) + assert(Router.shouldRequestUpdate(now - 1, 12344, Some(now), None)) + assert(Router.shouldRequestUpdate(now - 1, 12344, Some(now), Some(12345))) + + // they just provided a 0 checksum => don't ask + assert(!Router.shouldRequestUpdate(0, 0, None, Some(0))) + assert(!Router.shouldRequestUpdate(now, 1234, None, Some(0))) + + // they just provided a checksum that is the same as us => don't ask + assert(!Router.shouldRequestUpdate(now, 1234, None, Some(1234))) + + // they just provided a different checksum that is the same as us => ask + assert(Router.shouldRequestUpdate(now, 1234, None, Some(1235))) } -} -object ChannelRangeQueriesSpec { - lazy val shortChannelIds: immutable.SortedSet[ShortChannelId] = (for { - block <- 400000 to 420000 - txindex <- 0 to 5 - outputIndex <- 0 to 1 - } yield ShortChannelId(block, txindex, outputIndex)).foldLeft(SortedSet.empty[ShortChannelId])(_ + _) + test("compute flag tests") { + + val now = Platform.currentTime / 1000 + + val a = randomKey.publicKey + val b = randomKey.publicKey + val ab = RouteCalculationSpec.makeChannel(123466L, a, b) + val (ab1, uab1) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId1, ab.nodeId2, MilliSatoshi(0), 0, timestamp = now) + val (ab2, uab2) = RouteCalculationSpec.makeUpdateShort(ab.shortChannelId, ab.nodeId2, ab.nodeId1, MilliSatoshi(0), 0, timestamp = now) + + val c = randomKey.publicKey + val d = randomKey.publicKey + val cd = RouteCalculationSpec.makeChannel(451312L, c, d) + val (cd1, ucd1) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId1, cd.nodeId2, MilliSatoshi(0), 0, timestamp = now) + val (_, ucd2) = RouteCalculationSpec.makeUpdateShort(cd.shortChannelId, cd.nodeId2, cd.nodeId1, MilliSatoshi(0), 0, timestamp = now) + + val e = randomKey.publicKey + val f = randomKey.publicKey + val ef = RouteCalculationSpec.makeChannel(167514L, e, f) + + val channels = SortedMap( + ab.shortChannelId -> ab, + cd.shortChannelId -> cd + ) + + val updates = Map( + ab1 -> uab1, + ab2 -> uab2, + cd1 -> ucd1 + ) + + import fr.acinq.eclair.wire.QueryShortChannelIdsTlv.QueryFlagType._ + + assert(Router.getChannelDigestInfo(channels, updates)(ab.shortChannelId) == (Timestamps(now, now), Checksums(1697591108L, 1697591108L))) + + // no extended info but we know the channel: we ask for the updates + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + // same checksums, newer timestamps: we don't ask anything + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(1697591108L, 1697591108L)), true) === 0) + // different checksums, newer timestamps: we ask for the updates + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now)), Some(Checksums(154654604, 1697591108L)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now + 1)), Some(Checksums(1697591108L, 45664546)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now + 1, now + 1)), Some(Checksums(154654604, 45664546 + 6)), true) === (INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2| INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + // different checksums, older timestamps: we don't ask anything + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now)), Some(Checksums(154654604, 1697591108L)), true) === 0) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now, now - 1)), Some(Checksums(1697591108L, 45664546)), true) === 0) + assert(Router.computeFlag(channels, updates)(ab.shortChannelId, Some(Timestamps(now - 1, now - 1)), Some(Checksums(154654604, 45664546)), true) === 0) + + // missing channel update: we ask for it + assert(Router.computeFlag(channels, updates)(cd.shortChannelId, Some(Timestamps(now, now)), Some(Checksums(3297511804L, 3297511804L)), true) === (INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + + // unknown channel: we ask everything + assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None, None, false) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2)) + assert(Router.computeFlag(channels, updates)(ef.shortChannelId, None, None, true) === (INCLUDE_CHANNEL_ANNOUNCEMENT | INCLUDE_CHANNEL_UPDATE_1 | INCLUDE_CHANNEL_UPDATE_2 | INCLUDE_NODE_ANNOUNCEMENT_1 | INCLUDE_NODE_ANNOUNCEMENT_2)) + } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala index 958d0e4113..840bbea8d0 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouteCalculationSpec.scala @@ -965,12 +965,12 @@ object RouteCalculationSpec { makeUpdateShort(ShortChannelId(shortChannelId), nodeId1, nodeId2, feeBase, feeProportionalMillionth, minHtlc, maxHtlc, cltvDelta) } - def makeUpdateShort(shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionth: Int, minHtlc: MilliSatoshi = DEFAULT_AMOUNT_MSAT, maxHtlc: Option[MilliSatoshi] = None, cltvDelta: CltvExpiryDelta = CltvExpiryDelta(0)): (ChannelDesc, ChannelUpdate) = + def makeUpdateShort(shortChannelId: ShortChannelId, nodeId1: PublicKey, nodeId2: PublicKey, feeBase: MilliSatoshi, feeProportionalMillionth: Int, minHtlc: MilliSatoshi = DEFAULT_AMOUNT_MSAT, maxHtlc: Option[MilliSatoshi] = None, cltvDelta: CltvExpiryDelta = CltvExpiryDelta(0), timestamp: Long = 0): (ChannelDesc, ChannelUpdate) = ChannelDesc(shortChannelId, nodeId1, nodeId2) -> ChannelUpdate( signature = DUMMY_SIG, chainHash = Block.RegtestGenesisBlock.hash, shortChannelId = shortChannelId, - timestamp = 0L, + timestamp = timestamp, messageFlags = maxHtlc match { case Some(_) => 1 case None => 0 diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala index f5bddc0f68..23bae1499e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RouterSpec.scala @@ -33,7 +33,6 @@ import fr.acinq.eclair.wire.QueryShortChannelIds import fr.acinq.eclair.{CltvExpiryDelta, Globals, MilliSatoshi, ShortChannelId, randomKey} import scodec.bits._ -import scala.collection.SortedSet import scala.compat.Platform import scala.concurrent.duration._ @@ -276,6 +275,6 @@ class RouterSpec extends BaseRouterSpec { val transport = TestProbe() probe.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, update1)) val query = transport.expectMsgType[QueryShortChannelIds] - assert(ChannelRangeQueries.decodeShortChannelIds(query.data)._2 == SortedSet(channelId)) + assert(query.shortChannelIds.array == List(channelId)) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala index 05db9102d3..f0f7cfa1ec 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/router/RoutingSyncSpec.scala @@ -16,81 +16,234 @@ package fr.acinq.eclair.router -import akka.actor.ActorSystem +import akka.actor.{Actor, ActorSystem, Props} import akka.testkit.{TestFSMRef, TestKit, TestProbe} -import fr.acinq.bitcoin.Block +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.bitcoin.{Block, Satoshi, Script, Transaction, TxIn, TxOut} +import fr.acinq.eclair.TestConstants.{Alice, Bob} import fr.acinq.eclair._ +import fr.acinq.eclair.blockchain.{UtxoStatus, ValidateRequest, ValidateResult} import fr.acinq.eclair.crypto.TransportHandler import fr.acinq.eclair.io.Peer.PeerRoutingMessage import fr.acinq.eclair.router.Announcements.{makeChannelUpdate, makeNodeAnnouncement} import fr.acinq.eclair.router.BaseRouterSpec.channelAnnouncement +import fr.acinq.eclair.transactions.Scripts import fr.acinq.eclair.wire._ import org.scalatest.FunSuiteLike +import scala.collection.immutable.TreeMap +import scala.collection.{SortedSet, immutable, mutable} +import scala.compat.Platform import scala.concurrent.duration._ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { - import RoutingSyncSpec.makeFakeRoutingInfo + import RoutingSyncSpec._ - val shortChannelIds = ChannelRangeQueriesSpec.shortChannelIds.take(350) - val fakeRoutingInfo = shortChannelIds.map(makeFakeRoutingInfo).map(t => t._1.shortChannelId -> t).toMap + val fakeRoutingInfo: TreeMap[ShortChannelId, (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement)] = RoutingSyncSpec + .shortChannelIds + .take(4567) + .foldLeft(TreeMap.empty[ShortChannelId, (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement)]) { + case (m, shortChannelId) => m + (shortChannelId -> makeFakeRoutingInfo(shortChannelId)) + } - test("handle channel range queries") { - val params = TestConstants.Alice.nodeParams - val router = TestFSMRef(new Router(params, TestProbe().ref)) - val transport = TestProbe() + class YesWatcher extends Actor { + override def receive: Receive = { + case ValidateRequest(c) => + val pubkeyScript = Script.write(Script.pay2wsh(Scripts.multiSig2of2(c.bitcoinKey1, c.bitcoinKey2))) + val TxCoordinates(_, _, outputIndex) = ShortChannelId.coordinates(c.shortChannelId) + val fakeFundingTx = Transaction( + version = 2, + txIn = Seq.empty[TxIn], + txOut = List.fill(outputIndex + 1)(TxOut(Satoshi(0), pubkeyScript)), // quick and dirty way to be sure that the outputIndex'th output is of the expected format + lockTime = 0) + sender ! ValidateResult(c, Right(fakeFundingTx, UtxoStatus.Unspent)) + } + } + + case class BasicSyncResult(ranges: Int, queries: Int, channels: Int, updates: Int, nodes: Int) + + case class SyncResult(ranges: Seq[ReplyChannelRange], queries: Seq[QueryShortChannelIds], channels: Seq[ChannelAnnouncement], updates: Seq[ChannelUpdate], nodes: Seq[NodeAnnouncement]) { + def counts = BasicSyncResult(ranges.size, queries.size, channels.size, updates.size, nodes.size) + } + + def sync(src: TestFSMRef[State, Data, Router], tgt: TestFSMRef[State, Data, Router], extendedQueryFlags_opt: Option[QueryChannelRangeTlv]): SyncResult = { val sender = TestProbe() - sender.ignoreMsg { case _: TransportHandler.ReadAck => true } - val remoteNodeId = TestConstants.Bob.nodeParams.nodeId + val pipe = TestProbe() + pipe.ignoreMsg { + case _: TransportHandler.ReadAck => true + case _: GossipTimestampFilter => true + } + val srcId = src.underlyingActor.nodeParams.nodeId + val tgtId = tgt.underlyingActor.nodeParams.nodeId + sender.send(src, SendChannelQuery(tgtId, pipe.ref, extendedQueryFlags_opt)) + // src sends a query_channel_range to bob + val qcr = pipe.expectMsgType[QueryChannelRange] + pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, qcr)) + // this allows us to know when the last reply_channel_range has been set + pipe.send(tgt, 'data) + // tgt answers with reply_channel_ranges + val rcrs = pipe.receiveWhile() { + case rcr: ReplyChannelRange => rcr + } + pipe.expectMsgType[Data] + rcrs.foreach(rcr => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, rcr))) + // then src will now query announcements + var queries = Vector.empty[QueryShortChannelIds] + var channels = Vector.empty[ChannelAnnouncement] + var updates = Vector.empty[ChannelUpdate] + var nodes = Vector.empty[NodeAnnouncement] + while (src.stateData.sync.nonEmpty) { + // for each chunk, src sends a query_short_channel_id + val query = pipe.expectMsgType[QueryShortChannelIds] + pipe.send(tgt, PeerRoutingMessage(pipe.ref, srcId, query)) + queries = queries :+ query + val announcements = pipe.receiveWhile() { + case c: ChannelAnnouncement => + channels = channels :+ c + c + case u: ChannelUpdate => + updates = updates :+ u + u + case n: NodeAnnouncement => + nodes = nodes :+ n + n + } + // tgt replies with announcements + announcements.foreach(ann => pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, ann))) + // and tgt ends this chunk with a reply_short_channel_id_end + val rscie = pipe.expectMsgType[ReplyShortChannelIdsEnd] + pipe.send(src, PeerRoutingMessage(pipe.ref, tgtId, rscie)) + } + SyncResult(rcrs, queries, channels, updates, nodes) + } - // ask router to send a channel range query - sender.send(router, SendChannelQuery(remoteNodeId, sender.ref)) - val QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks) = sender.expectMsgType[QueryChannelRange] - sender.expectMsgType[GossipTimestampFilter] + test("sync with standard channel queries") { + val watcher = system.actorOf(Props(new YesWatcher())) + val alice = TestFSMRef(new Router(Alice.nodeParams, watcher)) + val bob = TestFSMRef(new Router(Bob.nodeParams, watcher)) + val charlieId = randomKey.publicKey + val sender = TestProbe() + val extendedQueryFlags_opt = None - // split our answer in 3 blocks - val List(block1) = ChannelRangeQueries.encodeShortChannelIds(firstBlockNum, numberOfBlocks, shortChannelIds.take(100), ChannelRangeQueries.UNCOMPRESSED_FORMAT) - val List(block2) = ChannelRangeQueries.encodeShortChannelIds(firstBlockNum, numberOfBlocks, shortChannelIds.drop(100).take(100), ChannelRangeQueries.UNCOMPRESSED_FORMAT) - val List(block3) = ChannelRangeQueries.encodeShortChannelIds(firstBlockNum, numberOfBlocks, shortChannelIds.drop(200).take(150), ChannelRangeQueries.UNCOMPRESSED_FORMAT) + // tell alice to sync with bob + assert(BasicSyncResult(ranges = 1, queries = 0, channels = 0, updates = 0, nodes = 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels) + awaitCond(alice.stateData.updates === bob.stateData.updates) + awaitCond(alice.stateData.nodes === bob.stateData.nodes) - // send first block - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ReplyChannelRange(chainHash, block1.firstBlock, block1.numBlocks, 1, block1.shortChannelIds))) - // router should ask for our first block of ids - val QueryShortChannelIds(_, data1) = transport.expectMsgType[QueryShortChannelIds] - val (_, shortChannelIds1, false) = ChannelRangeQueries.decodeShortChannelIds(data1) - assert(shortChannelIds1 == shortChannelIds.take(100)) - - // send second block - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ReplyChannelRange(chainHash, block2.firstBlock, block2.numBlocks, 1, block2.shortChannelIds))) - - // send the first 50 items - shortChannelIds1.take(50).foreach(id => { - val (ca, cu1, cu2, _, _) = fakeRoutingInfo(id) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ca)) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, cu1)) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, cu2)) - }) - - // send the last 50 items - shortChannelIds1.drop(50).foreach(id => { - val (ca, cu1, cu2, _, _) = fakeRoutingInfo(id) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ca)) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, cu1)) - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, cu2)) - }) - - // during that time, router should not have asked for more ids, it already has a pending query ! - transport.expectNoMsg(200 millis) - - // now send our ReplyShortChannelIdsEnd message - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ReplyShortChannelIdsEnd(chainHash, 1.toByte))) - - // router should ask for our second block of ids - val QueryShortChannelIds(_, data2) = transport.expectMsgType[QueryShortChannelIds] - val (_, shortChannelIds2, false) = ChannelRangeQueries.decodeShortChannelIds(data2) - assert(shortChannelIds2 == shortChannelIds.drop(100).take(100)) + // add some channels and updates to bob and resync + fakeRoutingInfo.take(40).values.foreach { + case (ca, cu1, cu2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, ca)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu1)) + // we don't send channel_update #2 + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + } + awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 40) + assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 40, nodes = 80) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels) + awaitCond(alice.stateData.updates === bob.stateData.updates) + + // add some updates to bob and resync + fakeRoutingInfo.take(40).values.foreach { + case (ca, cu1, cu2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu2)) + } + awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 80) + assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 80, nodes = 80) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels) + awaitCond(alice.stateData.updates === bob.stateData.updates) + + // add everything (duplicates will be ignored) + fakeRoutingInfo.values.foreach { + case (c, u1, u2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + } + awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && bob.stateData.updates.size === 2 * fakeRoutingInfo.size, max = 60 seconds) + assert(BasicSyncResult(ranges = 2, queries = 46, channels = fakeRoutingInfo.size, updates = 2 * fakeRoutingInfo.size, nodes = 2 * fakeRoutingInfo.size) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) + awaitCond(alice.stateData.updates === bob.stateData.updates) + } + + def syncWithExtendedQueries(requestNodeAnnouncements: Boolean) = { + val watcher = system.actorOf(Props(new YesWatcher())) + val alice = TestFSMRef(new Router(Alice.nodeParams.copy(routerConf = Alice.nodeParams.routerConf.copy(requestNodeAnnouncements = requestNodeAnnouncements)), watcher)) + val bob = TestFSMRef(new Router(Bob.nodeParams, watcher)) + val charlieId = randomKey.publicKey + val sender = TestProbe() + val extendedQueryFlags_opt = Some(QueryChannelRangeTlv.QueryFlags(QueryChannelRangeTlv.QueryFlags.WANT_ALL)) + + // tell alice to sync with bob + assert(BasicSyncResult(ranges = 1, queries = 0, channels = 0, updates = 0, nodes = 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels) + awaitCond(alice.stateData.updates === bob.stateData.updates) + + // add some channels and updates to bob and resync + fakeRoutingInfo.take(40).values.foreach { + case (ca, cu1, cu2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, ca)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu1)) + // we don't send channel_update #2 + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + } + awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 40) + assert(BasicSyncResult(ranges = 1, queries = 1, channels = 40, updates = 40, nodes = if (requestNodeAnnouncements) 80 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) + awaitCond(alice.stateData.updates === bob.stateData.updates) + if (requestNodeAnnouncements) awaitCond(alice.stateData.nodes === bob.stateData.nodes) + + // add some updates to bob and resync + fakeRoutingInfo.take(40).values.foreach { + case (ca, cu1, cu2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, cu2)) + } + awaitCond(bob.stateData.channels.size === 40 && bob.stateData.updates.size === 80) + assert(BasicSyncResult(ranges = 1, queries = 1, channels = 0, updates = 40, nodes = if (requestNodeAnnouncements) 80 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) + awaitCond(alice.stateData.updates === bob.stateData.updates) + + // add everything (duplicates will be ignored) + fakeRoutingInfo.values.foreach { + case (c, u1, u2, na1, na2) => + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, u2)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na1)) + sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, na2)) + } + awaitCond(bob.stateData.channels.size === fakeRoutingInfo.size && bob.stateData.updates.size === 2 * fakeRoutingInfo.size, max = 60 seconds) + assert(BasicSyncResult(ranges = 2, queries = 46, channels = fakeRoutingInfo.size - 40, updates = 2 * (fakeRoutingInfo.size - 40), nodes = if (requestNodeAnnouncements) 2 * (fakeRoutingInfo.size - 40) else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) + awaitCond(alice.stateData.updates === bob.stateData.updates) + + // bump random channel_updates + def touchUpdate(shortChannelId: Int, side: Boolean) = { + val (c, u1, u2, _, _) = fakeRoutingInfo.values.toList(shortChannelId) + makeNewerChannelUpdate(c, if (side) u1 else u2) + } + + val bumpedUpdates = (List(0, 42, 147, 153, 654, 834, 4301).map(touchUpdate(_, true)) ++ List(1, 42, 150, 200).map(touchUpdate(_, false))).toSet + bumpedUpdates.foreach(c => sender.send(bob, PeerRoutingMessage(sender.ref, charlieId, c))) + assert(BasicSyncResult(ranges = 2, queries = 2, channels = 0, updates = bumpedUpdates.size, nodes = if (requestNodeAnnouncements) 20 else 0) === sync(alice, bob, extendedQueryFlags_opt).counts) + awaitCond(alice.stateData.channels === bob.stateData.channels, max = 60 seconds) + awaitCond(alice.stateData.updates === bob.stateData.updates) + if (requestNodeAnnouncements) awaitCond(alice.stateData.nodes === bob.stateData.nodes) + } + + test("sync with extended channel queries (don't request node announcements)") { + syncWithExtendedQueries(false) + } + + test("sync with extended channel queries (request node announcements)") { + syncWithExtendedQueries(true) } test("reset sync state on reconnection") { @@ -102,38 +255,86 @@ class RoutingSyncSpec extends TestKit(ActorSystem("test")) with FunSuiteLike { val remoteNodeId = TestConstants.Bob.nodeParams.nodeId // ask router to send a channel range query - sender.send(router, SendChannelQuery(remoteNodeId, sender.ref)) - val QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks) = sender.expectMsgType[QueryChannelRange] + sender.send(router, SendChannelQuery(remoteNodeId, sender.ref, None)) + val QueryChannelRange(chainHash, firstBlockNum, numberOfBlocks, _) = sender.expectMsgType[QueryChannelRange] sender.expectMsgType[GossipTimestampFilter] - val List(block1) = ChannelRangeQueries.encodeShortChannelIds(firstBlockNum, numberOfBlocks, shortChannelIds.take(100), ChannelRangeQueries.UNCOMPRESSED_FORMAT) + val block1 = ReplyChannelRange(chainHash, firstBlockNum, numberOfBlocks, 1, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, fakeRoutingInfo.take(100).keys.toList), None, None) // send first block - sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, ReplyChannelRange(chainHash, block1.firstBlock, block1.numBlocks, 1, block1.shortChannelIds))) + sender.send(router, PeerRoutingMessage(transport.ref, remoteNodeId, block1)) // router should ask for our first block of ids - val QueryShortChannelIds(_, data1) = transport.expectMsgType[QueryShortChannelIds] - // router should think that it is mssing 100 channels + assert(transport.expectMsgType[QueryShortChannelIds] === QueryShortChannelIds(chainHash, block1.shortChannelIds, TlvStream.empty)) + // router should think that it is missing 100 channels, in one request val Some(sync) = router.stateData.sync.get(remoteNodeId) - assert(sync.totalMissingCount == 100) + assert(sync.total == 1) // simulate a re-connection - sender.send(router, SendChannelQuery(remoteNodeId, sender.ref)) + sender.send(router, SendChannelQuery(remoteNodeId, sender.ref, None)) sender.expectMsgType[QueryChannelRange] sender.expectMsgType[GossipTimestampFilter] assert(router.stateData.sync.get(remoteNodeId).isEmpty) } + + test("sync progress") { + + def req = QueryShortChannelIds(Block.RegtestGenesisBlock.hash, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(42))), TlvStream.empty) + + val nodeidA = randomKey.publicKey + val nodeidB = randomKey.publicKey + + val (sync1, _) = Router.addToSync(Map.empty, nodeidA, List(req, req, req, req)) + assert(Router.syncProgress(sync1) == SyncProgress(0.25D)) + + val (sync2, _) = Router.addToSync(sync1, nodeidB, List(req, req, req, req, req, req, req, req, req, req, req, req)) + assert(Router.syncProgress(sync2) == SyncProgress(0.125D)) + + // let's assume we made some progress + val sync3 = sync2 + .updated(nodeidA, sync2(nodeidA).copy(pending = List(req))) + .updated(nodeidB, sync2(nodeidB).copy(pending = List(req))) + assert(Router.syncProgress(sync3) == SyncProgress(0.875D)) + } } object RoutingSyncSpec { + + lazy val shortChannelIds: immutable.SortedSet[ShortChannelId] = (for { + block <- 400000 to 420000 + txindex <- 0 to 5 + outputIndex <- 0 to 1 + } yield ShortChannelId(block, txindex, outputIndex)).foldLeft(SortedSet.empty[ShortChannelId])(_ + _) + + // this map will store private keys so that we can sign new announcements at will + val pub2priv: mutable.Map[PublicKey, PrivateKey] = mutable.HashMap.empty + + val unused = randomKey + def makeFakeRoutingInfo(shortChannelId: ShortChannelId): (ChannelAnnouncement, ChannelUpdate, ChannelUpdate, NodeAnnouncement, NodeAnnouncement) = { - val (priv_a, priv_b, priv_funding_a, priv_funding_b) = (randomKey, randomKey, randomKey, randomKey) - val channelAnn_ab = channelAnnouncement(shortChannelId, priv_a, priv_b, priv_funding_a, priv_funding_b) - val TxCoordinates(blockHeight, _, _) = ShortChannelId.coordinates(shortChannelId) - val channelUpdate_ab = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_a, priv_b.publicKey, shortChannelId, CltvExpiryDelta(7), MilliSatoshi(0), feeBaseMsat = MilliSatoshi(766000), feeProportionalMillionths = 10, MilliSatoshi(500000000L), timestamp = blockHeight) - val channelUpdate_ba = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv_b, priv_a.publicKey, shortChannelId, CltvExpiryDelta(7), MilliSatoshi(0), feeBaseMsat = MilliSatoshi(766000), feeProportionalMillionths = 10, MilliSatoshi(500000000L), timestamp = blockHeight) - val nodeAnnouncement_a = makeNodeAnnouncement(priv_a, "a", Color(0, 0, 0), List()) - val nodeAnnouncement_b = makeNodeAnnouncement(priv_b, "b", Color(0, 0, 0), List()) - (channelAnn_ab, channelUpdate_ab, channelUpdate_ba, nodeAnnouncement_a, nodeAnnouncement_b) + val timestamp = Platform.currentTime / 1000 + val (priv1, priv2) = { + val (priv_a, priv_b) = (randomKey, randomKey) + if (Announcements.isNode1(priv_a.publicKey, priv_b.publicKey)) (priv_a, priv_b) else (priv_b, priv_a) + } + val priv_funding1 = unused + val priv_funding2 = unused + pub2priv += (priv1.publicKey -> priv1) + pub2priv += (priv2.publicKey -> priv2) + val channelAnn_12 = channelAnnouncement(shortChannelId, priv1, priv2, priv_funding1, priv_funding2) + val channelUpdate_12 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv1, priv2.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), MilliSatoshi(0), feeBaseMsat = MilliSatoshi(766000), feeProportionalMillionths = 10, MilliSatoshi(500000000L), timestamp = timestamp) + val channelUpdate_21 = makeChannelUpdate(Block.RegtestGenesisBlock.hash, priv2, priv1.publicKey, shortChannelId, cltvExpiryDelta = CltvExpiryDelta(7), MilliSatoshi(0), feeBaseMsat = MilliSatoshi(766000), feeProportionalMillionths = 10, MilliSatoshi(500000000L), timestamp = timestamp) + val nodeAnnouncement_1 = makeNodeAnnouncement(priv1, "", Color(0, 0, 0), List()) + val nodeAnnouncement_2 = makeNodeAnnouncement(priv2, "", Color(0, 0, 0), List()) + (channelAnn_12, channelUpdate_12, channelUpdate_21, nodeAnnouncement_1, nodeAnnouncement_2) + } + + def makeNewerChannelUpdate(channelAnnouncement: ChannelAnnouncement, channelUpdate: ChannelUpdate): ChannelUpdate = { + val (local, remote) = if (Announcements.isNode1(channelUpdate.channelFlags)) (channelAnnouncement.nodeId1, channelAnnouncement.nodeId2) else (channelAnnouncement.nodeId2, channelAnnouncement.nodeId1) + val priv = pub2priv(local) + makeChannelUpdate(channelUpdate.chainHash, priv, remote, channelUpdate.shortChannelId, + channelUpdate.cltvExpiryDelta, channelUpdate.htlcMinimumMsat, + channelUpdate.feeBaseMsat, channelUpdate.feeProportionalMillionths, + channelUpdate.htlcMinimumMsat, Announcements.isEnabled(channelUpdate.channelFlags), channelUpdate.timestamp + 5000) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala new file mode 100644 index 0000000000..672439717b --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/ExtendedQueriesCodecsSpec.scala @@ -0,0 +1,137 @@ +package fr.acinq.eclair.wire + +import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64} +import fr.acinq.eclair.router.Router +import fr.acinq.eclair.wire.LightningMessageCodecs._ +import fr.acinq.eclair.wire.ReplyChannelRangeTlv._ +import fr.acinq.eclair.{CltvExpiryDelta, MilliSatoshi, ShortChannelId, UInt64} +import org.scalatest.FunSuite +import scodec.bits.ByteVector + +class ExtendedQueriesCodecsSpec extends FunSuite { + test("encode query_short_channel_ids (no optional data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream.empty) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode query_short_channel_ids (with optional data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)))) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode query_short_channel_ids (with optional data including unknown data)") { + val query_short_channel_id = QueryShortChannelIds( + Block.RegtestGenesisBlock.blockId, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1.toByte, 2.toByte, 3.toByte, 4.toByte, 5.toByte)) :: Nil, + GenericTlv(UInt64(43), ByteVector.fromValidHex("deadbeef")) :: Nil + ) + ) + + val encoded = queryShortChannelIdsCodec.encode(query_short_channel_id).require + val decoded = queryShortChannelIdsCodec.decode(encoded).require.value + assert(decoded === query_short_channel_id) + } + + test("encode reply_channel_range (no optional data)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + None, None) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } + + test("encode reply_channel_range (with optional timestamps)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3)))), + None) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } + + test("encode reply_channel_range (with optional timestamps, checksums, and unknown data)") { + val replyChannelRange = ReplyChannelRange( + Block.RegtestGenesisBlock.blockId, + 1, 100, + 1.toByte, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + List( + EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3))), + EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) + ), + GenericTlv(UInt64(7), ByteVector.fromValidHex("deadbeef")) :: Nil + ) + ) + + val encoded = replyChannelRangeCodec.encode(replyChannelRange).require + val decoded = replyChannelRangeCodec.decode(encoded).require.value + assert(decoded === replyChannelRange) + } + + test("compute checksums correctly (CL test #1)") { + val update = ChannelUpdate( + chainHash = ByteVector32.fromValidHex("06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"), + signature = ByteVector64.fromValidHex("76df7e70c63cc2b63ef1c062b99c6d934a80ef2fd4dae9e1d86d277f47674af3255a97fa52ade7f129263f591ed784996eba6383135896cc117a438c80293282"), + shortChannelId = ShortChannelId("103x1x0"), + timestamp = 1565587763L, + messageFlags = 0, + channelFlags = 0, + cltvExpiryDelta = CltvExpiryDelta(144), + htlcMinimumMsat = MilliSatoshi(0), + htlcMaximumMsat = None, + feeBaseMsat = MilliSatoshi(1000), + feeProportionalMillionths = 10 + ) + val check = ByteVector.fromValidHex("010276df7e70c63cc2b63ef1c062b99c6d934a80ef2fd4dae9e1d86d277f47674af3255a97fa52ade7f129263f591ed784996eba6383135896cc117a438c8029328206226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00006700000100005d50f933000000900000000000000000000003e80000000a") + assert(LightningMessageCodecs.channelUpdateCodec.encode(update).require.bytes == check.drop(2)) + + val checksum = Router.getChecksum(update) + assert(checksum == 0x1112fa30L) + } + + test("compute checksums correctly (CL test #2)") { + val update = ChannelUpdate( + chainHash = ByteVector32.fromValidHex("06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f"), + signature = ByteVector64.fromValidHex("06737e9e18d3e4d0ab4066ccaecdcc10e648c5f1c5413f1610747e0d463fa7fa39c1b02ea2fd694275ecfefe4fe9631f24afd182ab75b805e16cd550941f858c"), + shortChannelId = ShortChannelId("109x1x0"), + timestamp = 1565587765L, + messageFlags = 1, + channelFlags = 0, + cltvExpiryDelta = CltvExpiryDelta(48), + htlcMinimumMsat = MilliSatoshi(0), + htlcMaximumMsat = Some(MilliSatoshi(100000)), + feeBaseMsat = MilliSatoshi(100), + feeProportionalMillionths = 11 + ) + val check = ByteVector.fromValidHex("010206737e9e18d3e4d0ab4066ccaecdcc10e648c5f1c5413f1610747e0d463fa7fa39c1b02ea2fd694275ecfefe4fe9631f24afd182ab75b805e16cd550941f858c06226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f00006d00000100005d50f935010000300000000000000000000000640000000b00000000000186a0") + assert(LightningMessageCodecs.channelUpdateCodec.encode(update).require.bytes == check.drop(2)) + + val checksum = Router.getChecksum(update) + assert(checksum == 0xf32ce968L) + } +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala index a63a15f025..06d48ff3f8 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/LightningMessageCodecsSpec.scala @@ -23,6 +23,7 @@ import fr.acinq.bitcoin.{Block, ByteVector32, ByteVector64, Satoshi} import fr.acinq.eclair._ import fr.acinq.eclair.router.Announcements import fr.acinq.eclair.wire.LightningMessageCodecs._ +import ReplyChannelRangeTlv._ import org.scalatest.FunSuite import scodec.bits.{ByteVector, HexStringSyntax} @@ -71,9 +72,18 @@ class LightningMessageCodecsSpec extends FunSuite { val channel_update = ChannelUpdate(randomBytes64, Block.RegtestGenesisBlock.hash, ShortChannelId(1), 2, 42, 0, CltvExpiryDelta(3), MilliSatoshi(4), MilliSatoshi(5), 6, None) val announcement_signatures = AnnouncementSignatures(randomBytes32, ShortChannelId(42), randomBytes64, randomBytes64) val gossip_timestamp_filter = GossipTimestampFilter(Block.RegtestGenesisBlock.blockId, 100000, 1500) - val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, randomBytes(7515)) - val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500) - val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, 1, randomBytes(3200)) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val unknownTlv = GenericTlv(UInt64(5), ByteVector.fromValidHex("deadbeef")) + val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, + 100000, + 1500, + TlvStream(QueryChannelRangeTlv.QueryFlags((QueryChannelRangeTlv.QueryFlags.WANT_ALL)) :: Nil, unknownTlv :: Nil)) + val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), + TlvStream( + EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(1, 1), Timestamps(2, 2), Timestamps(3, 3))) :: EncodedChecksums(List(Checksums(1, 1), Checksums(2, 2), Checksums(3, 3))) :: Nil, + unknownTlv :: Nil) + ) val ping = Ping(100, bin(10, 1)) val pong = Pong(bin(10, 1)) val channel_reestablish = ChannelReestablish(randomBytes32, 242842L, 42L) @@ -92,6 +102,125 @@ class LightningMessageCodecsSpec extends FunSuite { } } + test("non-reg encoding type") { + val refs = Map( + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4010400010204" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.UNCOMPRESSED, List(1, 2, 4)))), + hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001601789c636000833e08659309a65c971d0100126e02e3010c01789c6364620100000e0008" + -> QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + ) + + refs.forall { + case (bin, obj) => + lightningMessageCodec.decode(bin.toBitVector).require.value == obj && lightningMessageCodec.encode(obj).require == bin.toBitVector + } + } + + case class TestItem(msg: Any, hex: String) + + test("test vectors for extended channel queries ") { + import org.json4s.{CustomSerializer, ShortTypeHints} + import org.json4s.JsonAST.JString + import org.json4s.jackson.Serialization + import fr.acinq.eclair.api._ + + val query_channel_range = QueryChannelRange(Block.RegtestGenesisBlock.blockId, 100000, 1500, TlvStream.empty) + val query_channel_range_timestamps_checksums = QueryChannelRange(Block.RegtestGenesisBlock.blockId, + 35000, + 100, + TlvStream(QueryChannelRangeTlv.QueryFlags((QueryChannelRangeTlv.QueryFlags.WANT_ALL)))) + val reply_channel_range = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 756230, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), None, None) + val reply_channel_range_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 1600, 110, 1, + EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(265462))), None, None) + val reply_channel_range_timestamps_checksums = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, + EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(EncodingType.UNCOMPRESSED, List(Timestamps(164545, 948165), Timestamps(489645, 4786864), Timestamps(46456, 9788415)))), + Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) + val reply_channel_range_timestamps_checksums_zlib = ReplyChannelRange(Block.RegtestGenesisBlock.blockId, 122334, 1500, 1, + EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(12355), ShortChannelId(489686), ShortChannelId(4645313))), + Some(EncodedTimestamps(EncodingType.COMPRESSED_ZLIB, List(Timestamps(164545, 948165), Timestamps(489645, 4786864), Timestamps(46456, 9788415)))), + Some(EncodedChecksums(List(Checksums(1111, 2222), Checksums(3333, 4444), Checksums(5555, 6666))))) + val query_short_channel_id = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(142), ShortChannelId(15465), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(4564), ShortChannelId(178622), ShortChannelId(4564676))), TlvStream.empty) + val query_short_channel_id_flags = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.UNCOMPRESSED, List(ShortChannelId(12232), ShortChannelId(15556), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + val query_short_channel_id_flags_zlib = QueryShortChannelIds(Block.RegtestGenesisBlock.blockId, EncodedShortChannelIds(EncodingType.COMPRESSED_ZLIB, List(ShortChannelId(14200), ShortChannelId(46645), ShortChannelId(4564676))), TlvStream(QueryShortChannelIdsTlv.EncodedQueryFlags(EncodingType.COMPRESSED_ZLIB, List(1, 2, 4)))) + + + + val refs = Map( + query_channel_range -> hex"01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000186a0000005dc", + query_channel_range_timestamps_checksums -> hex"01070f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000088b800000064010103", + reply_channel_range -> hex"01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000b8a06000005dc01001900000000000000008e0000000000003c69000000000045a6c4", + reply_channel_range_zlib -> hex"01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206000006400000006e01001601789c636000833e08659309a65878be010010a9023a", + reply_channel_range_timestamps_checksums -> hex"01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e22060001ddde000005dc01001900000000000000304300000000000778d6000000000046e1c1011900000282c1000e77c5000778ad00490ab00000b57800955bff031800000457000008ae00000d050000115c000015b300001a0a", + reply_channel_range_timestamps_checksums_zlib -> hex"01080f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e22060001ddde000005dc01001801789c63600001036730c55e710d4cbb3d3c080017c303b1012201789c63606a3ac8c0577e9481bd622d8327d7060686ad150c53a3ff0300554707db031800000457000008ae00000d050000115c000015b300001a0a", + query_short_channel_id -> hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001900000000000000008e0000000000003c69000000000045a6c4", + query_short_channel_id_zlib -> hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001801789c63600001c12b608a69e73e30edbaec0800203b040e", + query_short_channel_id_flags -> hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e22060019000000000000002fc80000000000003cc4000000000045a6c4010c01789c6364620100000e0008", + query_short_channel_id_flags_zlib -> hex"01050f9188f13cb7b2c71f2a335e3a4fc328bf5beb436012afca590b1a11466e2206001801789c63600001f30a30c5b0cd144cb92e3b020017c6034a010c01789c6364620100000e0008" + ) + + val items = refs.map { case (obj, refbin) => + val bin = lightningMessageCodec.encode(obj).require + assert(refbin.bits === bin) + TestItem(obj, bin.toHex) + } + + // NB: uncomment this to update the test vectors + + /*class EncodingTypeSerializer extends CustomSerializer[EncodingType](format => ( { + null + }, { + case EncodingType.UNCOMPRESSED => JString("UNCOMPRESSED") + case EncodingType.COMPRESSED_ZLIB => JString("COMPRESSED_ZLIB") + })) + + class ExtendedQueryFlagsSerializer extends CustomSerializer[QueryChannelRangeTlv.QueryFlags](format => ( { + null + }, { + case QueryChannelRangeTlv.QueryFlags(flag) => + JString(((if (QueryChannelRangeTlv.QueryFlags.wantTimestamps(flag)) List("WANT_TIMESTAMPS") else List()) ::: (if (QueryChannelRangeTlv.QueryFlags.wantChecksums(flag)) List("WANT_CHECKSUMS") else List())) mkString (" | ")) + })) + + implicit val formats = org.json4s.DefaultFormats.withTypeHintFieldName("type") + + new EncodingTypeSerializer + + new ExtendedQueryFlagsSerializer + + new ByteVectorSerializer + + new ByteVector32Serializer + + new UInt64Serializer + + new MilliSatoshiSerializer + + new ShortChannelIdSerializer + + new StateSerializer + + new ShaChainSerializer + + new PublicKeySerializer + + new PrivateKeySerializer + + new TransactionSerializer + + new TransactionWithInputInfoSerializer + + new InetSocketAddressSerializer + + new OutPointSerializer + + new OutPointKeySerializer + + new InputInfoSerializer + + new ColorSerializer + + new RouteResponseSerializer + + new ThrowableSerializer + + new FailureMessageSerializer + + new NodeAddressSerializer + + new DirectionSerializer + + new PaymentRequestSerializer + + ShortTypeHints(List( + classOf[QueryChannelRange], + classOf[ReplyChannelRange], + classOf[QueryShortChannelIds])) + + val json = Serialization.writePretty(items) + println(json)*/ + } + test("decode channel_update with htlc_maximum_msat") { // this was generated by c-lightning val bin = hex"010258fff7d0e987e2cdd560e3bb5a046b4efe7b26c969c2f51da1dceec7bcb8ae1b634790503d5290c1a6c51d681cf8f4211d27ed33a257dcc1102862571bf1792306226e46111a0b59caaf126043eb5bbf28c34f3a5e332a1fc7b2b73cf188910f0005a100000200005bc75919010100060000000000000001000000010000000a000000003a699d00" diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala index e2c5dbce1b..c21033085e 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/TlvCodecsSpec.scala @@ -294,6 +294,12 @@ class TlvCodecsSpec extends FunSuite { } } + test("get optional TLV field") { + val stream = TlvStream[TestTlv](Seq(TestType254(42), TestType1(42)), Seq(GenericTlv(13, hex"2a"), GenericTlv(11, hex"2b"))) + assert(stream.get[TestType254] == Some(TestType254(42))) + assert(stream.get[TestType1] == Some(TestType1(42))) + assert(stream.get[TestType2] == None) + } } object TlvCodecsSpec {