Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add networks to init message #1254

Merged
merged 4 commits into from
Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
transport ! TransportHandler.Listener(self)
context watch transport
val localInit = nodeParams.overrideFeatures.get(remoteNodeId) match {
case Some(f) => wire.Init(f)
case Some(f) => wire.Init(f, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil)))
case None =>
// Eclair-mobile thinks feature bit 15 (payment_secret) is gossip_queries_ex which creates issues, so we mask
// off basic_mpp and payment_secret. As long as they're provided in the invoice it's not an issue.
Expand All @@ -116,7 +116,7 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
// ... and leave the others untouched
case (value, _) => value
}).reverse.bytes.dropWhile(_ == 0)
wire.Init(tweakedFeatures)
wire.Init(tweakedFeatures, TlvStream(InitTlv.Networks(nodeParams.chainHash :: Nil)))
}
log.info(s"using features=${localInit.features.toBin}")
transport ! localInit
Expand Down Expand Up @@ -148,9 +148,19 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
case Event(remoteInit: wire.Init, d: InitializingData) =>
d.transport ! TransportHandler.ReadAck(remoteInit)

log.info(s"peer is using features=${remoteInit.features.toBin}")
log.info(s"peer is using features=${remoteInit.features.toBin}, network=${remoteInit.networks}")

if (Features.areSupported(remoteInit.features)) {
if (remoteInit.networks.nonEmpty && !remoteInit.networks.contains(nodeParams.chainHash)) {
log.warning(s"incompatible networks (${remoteInit.networks}), disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible networks")))
d.transport ! PoisonPill
stay
} else if (!Features.areSupported(remoteInit.features)) {
log.warning("incompatible features, disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features")))
d.transport ! PoisonPill
stay
} else {
d.origin_opt.foreach(origin => origin ! "connected")

def localHasFeature(f: Feature): Boolean = Features.hasFeature(d.localInit.features, f)
Expand Down Expand Up @@ -181,11 +191,6 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, authenticator: A
val rebroadcastDelay = Random.nextInt(nodeParams.routerConf.routerBroadcastInterval.toSeconds.toInt).seconds
log.info(s"rebroadcast will be delayed by $rebroadcastDelay")
goto(CONNECTED) using ConnectedData(d.address_opt, d.transport, d.localInit, remoteInit, d.channels.map { case (k: ChannelId, v) => (k, v) }, rebroadcastDelay) forMax (30 seconds) // forMax will trigger a StateTimeout
} else {
log.warning(s"incompatible features, disconnecting")
d.origin_opt.foreach(origin => origin ! Status.Failure(new RuntimeException("incompatible features")))
d.transport ! PoisonPill
stay
}

case Event(Authenticator.Authenticated(connection, _, _, _, _, origin_opt), _) =>
Expand Down
49 changes: 49 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/InitTlv.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright 2019 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package fr.acinq.eclair.wire

import fr.acinq.bitcoin.ByteVector32
import fr.acinq.eclair.UInt64
import fr.acinq.eclair.wire.CommonCodecs._
import scodec.Codec
import scodec.codecs.{discriminated, list, variableSizeBytesLong}

/**
* Created by t-bast on 13/12/2019.
*/

/** Tlv types used inside Init messages. */
sealed trait InitTlv extends Tlv

object InitTlv {

/** The chains the node is interested in. */
case class Networks(chainHashes: List[ByteVector32]) extends InitTlv

}

object InitTlvCodecs {

import InitTlv._

private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks]

val initTlvCodec = TlvCodecs.tlvStream(discriminated[InitTlv].by(varint)
.typecase(UInt64(1), networks)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ object LightningMessageCodecs {
},
{ features => (ByteVector.empty, features) })

val initCodec: Codec[Init] = combinedFeaturesCodec.as[Init]
val initCodec: Codec[Init] = (("features" | combinedFeaturesCodec) :: ("tlvStream" | InitTlvCodecs.initTlvCodec)).as[Init]

val errorCodec: Codec[Error] = (
("channelId" | bytes32) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ sealed trait HasChainHash extends LightningMessage { def chainHash: ByteVector32
sealed trait UpdateMessage extends HtlcMessage // <- not in the spec
// @formatter:on

case class Init(features: ByteVector) extends SetupMessage
case class Init(features: ByteVector, tlvs: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage {
val networks = tlvs.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil)
}

case class Error(channelId: ByteVector32, data: ByteVector) extends SetupMessage with HasChannelId {
def toAscii: String = if (fr.acinq.eclair.isAsciiPrintable(data)) new String(data.toArray, StandardCharsets.US_ASCII) else "n/a"
Expand Down
43 changes: 21 additions & 22 deletions eclair-core/src/main/scala/fr/acinq/eclair/wire/TlvTypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,40 @@ import scodec.bits.ByteVector
import scala.reflect.ClassTag

/**
* Created by t-bast on 20/06/2019.
*/
* Created by t-bast on 20/06/2019.
*/

trait Tlv

/**
* Generic tlv type we fallback to if we don't understand the incoming tlv.
*
* @param tag tlv tag.
* @param value tlv value (length is implicit, and encoded as a varint).
*/
* Generic tlv type we fallback to if we don't understand the incoming tlv.
*
* @param tag tlv tag.
* @param value tlv value (length is implicit, and encoded as a varint).
*/
case class GenericTlv(tag: UInt64, value: ByteVector) extends Tlv

/**
* A tlv stream is a collection of tlv records.
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
* That namespace is provided by a trait extending the top-level tlv trait.
*
* @param records known tlv records.
* @param unknown unknown tlv records.
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
*/
* A tlv stream is a collection of tlv records.
* A tlv stream is constrained to a specific tlv namespace that dictates how to parse the tlv records.
* That namespace is provided by a trait extending the top-level tlv trait.
*
* @param records known tlv records.
* @param unknown unknown tlv records.
* @tparam T the stream namespace is a trait extending the top-level tlv trait.
*/
case class TlvStream[T <: Tlv](records: Traversable[T], unknown: Traversable[GenericTlv] = Nil) {
/**
*
* @tparam R input type parameter, must be a subtype of the main TLV type
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
* that TLV records are supposed to be unique)
*/
*
* @tparam R input type parameter, must be a subtype of the main TLV type
* @return the TLV record of type that matches the input type parameter if any (there can be at most one, since BOLTs specify
* that TLV records are supposed to be unique)
*/
def get[R <: T : ClassTag]: Option[R] = records.collectFirst { case r: R => r }
}

object TlvStream {
def empty[T <: Tlv] = TlvStream[T](Nil, Nil)
def empty[T <: Tlv]: TlvStream[T] = TlvStream[T](Nil, Nil)

def apply[T <: Tlv](records: T*): TlvStream[T] = TlvStream(records, Nil)

}
19 changes: 17 additions & 2 deletions eclair-core/src/test/scala/fr/acinq/eclair/io/PeerSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.net.{Inet4Address, InetAddress, InetSocketAddress, ServerSocket}
import akka.actor.FSM.{CurrentState, SubscribeTransitionCallBack, Transition}
import akka.actor.{ActorRef, PoisonPill}
import akka.testkit.{TestFSMRef, TestProbe}
import fr.acinq.bitcoin.Block
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.TestConstants._
import fr.acinq.eclair._
Expand All @@ -31,7 +32,7 @@ import fr.acinq.eclair.crypto.TransportHandler
import fr.acinq.eclair.io.Peer._
import fr.acinq.eclair.router.RoutingSyncSpec.makeFakeRoutingInfo
import fr.acinq.eclair.router.{Rebroadcast, RoutingSyncSpec, SendChannelQuery}
import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream}
import fr.acinq.eclair.wire.{ChannelCodecsSpec, Color, EncodedShortChannelIds, EncodingType, Error, IPv4, InitTlv, LightningMessageCodecs, NodeAddress, NodeAnnouncement, Ping, Pong, QueryShortChannelIds, TlvStream}
import org.scalatest.{Outcome, Tag}
import scodec.bits.{ByteVector, _}

Expand Down Expand Up @@ -81,7 +82,8 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods {
probe.send(peer, Peer.Init(None, channels))
authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, fakeIPAddress.socketAddress, outgoing = true, None))
transport.expectMsgType[TransportHandler.Listener]
transport.expectMsgType[wire.Init]
val localInit = transport.expectMsgType[wire.Init]
assert(localInit.networks === List(Block.RegtestGenesisBlock.hash))
transport.send(peer, remoteInit)
transport.expectMsgType[TransportHandler.ReadAck]
if (expectSync) {
Expand Down Expand Up @@ -255,6 +257,19 @@ class PeerSpec extends TestkitBaseClass with StateTestsHelperMethods {
assert(init.features === sentFeatures.bytes)
}
}

test("disconnect if incompatible networks") { f =>
import f._
val probe = TestProbe()
probe.watch(transport.ref)
probe.send(peer, Peer.Init(None, Set.empty))
authenticator.send(peer, Authenticator.Authenticated(connection.ref, transport.ref, remoteNodeId, new InetSocketAddress("1.2.3.4", 42000), outgoing = true, None))
transport.expectMsgType[TransportHandler.Listener]
transport.expectMsgType[wire.Init]
transport.send(peer, wire.Init(Bob.nodeParams.features, TlvStream(InitTlv.Networks(Block.LivenetGenesisBlock.hash :: Block.SegnetGenesisBlock.hash :: Nil))))
transport.expectMsgType[TransportHandler.ReadAck]
probe.expectTerminated(transport.ref)
}

test("handle disconnect in status INITIALIZING") { f =>
import f._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,36 @@ class LightningMessageCodecsSpec extends FunSuite {
def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey

test("encode/decode init message") {
val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101")
val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202")
val testCases = Seq(
(hex"0000 0000", hex"", hex"0000 0000"), // no features
(hex"0000 0002088a", hex"088a", hex"0000 0002088a"), // no global features
(hex"00020200 0000", hex"0200", hex"0000 00020200"), // no local features
(hex"00020200 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - no conflict - same size
(hex"00020200 0003020002", hex"020202", hex"0000 0003020202"), // local and global - no conflict - different sizes
(hex"00020a02 0002088a", hex"0a8a", hex"0000 00020a8a"), // local and global - conflict - same size
(hex"00022200 000302aaa2", hex"02aaa2", hex"0000 000302aaa2") // local and global - conflict - different sizes
(hex"0000 0000", hex"", Nil, true, None), // no features
(hex"0000 0002088a", hex"088a", Nil, true, None), // no global features
(hex"00020200 0000", hex"0200", Nil, true, Some(hex"0000 00020200")), // no local features
(hex"00020200 0002088a", hex"0a8a", Nil, true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size
(hex"00020200 0003020002", hex"020202", Nil, true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes
(hex"00020a02 0002088a", hex"0a8a", Nil, true, Some(hex"0000 00020a8a")), // local and global - conflict - same size
(hex"00022200 000302aaa2", hex"02aaa2", Nil, true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes
(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, true, None), // unknown odd records
(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, false, None), // unknown even records
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, false, None), // invalid tlv stream
(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), true, None), // single network
(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), true, None), // multiple networks
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), true, None), // network and unknown odd records
(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, false, None) // network and unknown even records
)

for ((bin, features, encoded) <- testCases) {
val init = initCodec.decode(bin.bits).require.value
assert(init.features === features)
assert(initCodec.encode(init).require.bytes === encoded)
assert(initCodec.decode(encoded.bits).require.value === init)
for ((bin, features, networks, valid, encodedOverride) <- testCases) {
if (valid) {
val init = initCodec.decode(bin.bits).require.value
assert(init.features === features)
assert(init.networks === networks)
val encoded = initCodec.encode(init).require
assert(encoded.bytes === encodedOverride.getOrElse(bin))
assert(initCodec.decode(encoded).require.value === init)
} else {
assert(initCodec.decode(bin.bits).isFailure)
}
}
}

Expand Down