Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send remote address in init #1973

Merged
merged 2 commits into from
Jan 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class TransportHandler[T: ClassTag](keyPair: KeyPair, rs: Option[ByteVector], co

val wireLog = new BusLogging(context.system.eventStream, "", classOf[Diagnostics], context.system.asInstanceOf[ExtendedActorSystem].logFilter) with DiagnosticLoggingAdapter

def diag(message: T, direction: String) = {
def diag(message: T, direction: String): Unit = {
require(direction == "IN" || direction == "OUT")
val channelId_opt = Logs.channelId(message)
wireLog.mdc(Logs.mdc(LogCategory(message), remoteNodeId_opt, channelId_opt))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,10 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A
d.transport ! TransportHandler.Listener(self)
Metrics.PeerConnectionsConnecting.withTag(Tags.ConnectionState, Tags.ConnectionStates.Initializing).increment()
log.info(s"using features=$localFeatures")
val localInit = protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil)))
val localInit = IPAddress(d.pendingAuth.address.getAddress, d.pendingAuth.address.getPort) match {
case Some(remoteAddress) if !d.pendingAuth.outgoing && NodeAddress.isPublicIPAddress(remoteAddress) => protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil), InitTlv.RemoteAddress(remoteAddress)))
case _ => protocol.Init(localFeatures, TlvStream(InitTlv.Networks(chainHash :: Nil)))
}
d.transport ! localInit
startSingleTimer(INIT_TIMER, InitTimeout, conf.initTimeout)
goto(INITIALIZING) using InitializingData(chainHash, d.pendingAuth, d.remoteNodeId, d.transport, peer, localInit, doSync, d.isPersistent)
Expand All @@ -117,6 +120,7 @@ class PeerConnection(keyPair: KeyPair, conf: PeerConnection.Conf, switchboard: A
d.transport ! TransportHandler.ReadAck(remoteInit)

log.info(s"peer is using features=${remoteInit.features}, networks=${remoteInit.networks.mkString(",")}")
remoteInit.remoteAddress_opt.foreach(address => log.info("peer reports that our IP address is {} (public={})", address.socketAddress.toString, NodeAddress.isPublicIPAddress(address)))

val featureGraphErr_opt = Features.validateFeatureGraph(remoteInit.features)
if (remoteInit.networks.nonEmpty && remoteInit.networks.intersect(d.localInit.networks).isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ sealed trait HtlcSettlementMessage extends UpdateMessage { def id: Long } // <-

case class Init(features: Features, tlvStream: TlvStream[InitTlv] = TlvStream.empty) extends SetupMessage {
val networks = tlvStream.get[InitTlv.Networks].map(_.chainHashes).getOrElse(Nil)
val remoteAddress_opt = tlvStream.get[InitTlv.RemoteAddress].map(_.address)
}

case class Warning(channelId: ByteVector32, data: ByteVector, tlvStream: TlvStream[WarningTlv] = TlvStream.empty) extends SetupMessage with HasChannelId {
Expand Down Expand Up @@ -215,28 +216,48 @@ case class Color(r: Byte, g: Byte, b: Byte) {
// @formatter:off
sealed trait NodeAddress { def socketAddress: InetSocketAddress }
sealed trait OnionAddress extends NodeAddress
sealed trait IPAddress extends NodeAddress
// @formatter:on

object NodeAddress {
/**
* Creates a NodeAddress from a host and port.
*
* Note that non-onion hosts will be resolved.
*
* We don't attempt to resolve onion addresses (it will be done by the tor proxy), so we just recognize them based on
* the .onion TLD and rely on their length to separate v2/v3.
*/
* Creates a NodeAddress from a host and port.
*
* Note that non-onion hosts will be resolved.
*
* We don't attempt to resolve onion addresses (it will be done by the tor proxy), so we just recognize them based on
* the .onion TLD and rely on their length to separate v2/v3.
*/
def fromParts(host: String, port: Int): Try[NodeAddress] = Try {
host match {
case _ if host.endsWith(".onion") && host.length == 22 => Tor2(host.dropRight(6), port)
case _ if host.endsWith(".onion") && host.length == 62 => Tor3(host.dropRight(6), port)
case _ => InetAddress.getByName(host) match {
case a: Inet4Address => IPv4(a, port)
case a: Inet6Address => IPv6(a, port)
}
case _ => IPAddress(InetAddress.getByName(host), port).get
}
}

private def isPrivate(address: InetAddress): Boolean = address.isAnyLocalAddress || address.isLoopbackAddress || address.isLinkLocalAddress || address.isSiteLocalAddress
pm47 marked this conversation as resolved.
Show resolved Hide resolved

def isPublicIPAddress(address: NodeAddress): Boolean = {
address match {
case IPv4(ipv4, _) if !isPrivate(ipv4) => true
case IPv6(ipv6, _) if !isPrivate(ipv6) => true
case _ => false
}
}
}
case class IPv4(ipv4: Inet4Address, port: Int) extends NodeAddress { override def socketAddress = new InetSocketAddress(ipv4, port) }
case class IPv6(ipv6: Inet6Address, port: Int) extends NodeAddress { override def socketAddress = new InetSocketAddress(ipv6, port) }

object IPAddress {
def apply(inetAddress: InetAddress, port: Int): Option[IPAddress] = inetAddress match {
case address: Inet4Address => Some(IPv4(address, port))
case address: Inet6Address => Some(IPv6(address, port))
case _ => None
}
}

// @formatter:off
case class IPv4(ipv4: Inet4Address, port: Int) extends IPAddress { override def socketAddress = new InetSocketAddress(ipv4, port) }
case class IPv6(ipv6: Inet6Address, port: Int) extends IPAddress { override def socketAddress = new InetSocketAddress(ipv6, port) }
case class Tor2(tor2: String, port: Int) extends OnionAddress { override def socketAddress = InetSocketAddress.createUnresolved(tor2 + ".onion", port) }
case class Tor3(tor3: String, port: Int) extends OnionAddress { override def socketAddress = InetSocketAddress.createUnresolved(tor3 + ".onion", port) }
// @formatter:on
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,24 @@ object InitTlv {
/** The chains the node is interested in. */
case class Networks(chainHashes: List[ByteVector32]) extends InitTlv

/**
* When receiving an incoming connection, we can send back the public address our peer is connecting from.
* This lets our peer discover if its public IP has changed from within its local network.
*/
case class RemoteAddress(address: NodeAddress) extends InitTlv

}

object InitTlvCodecs {

import InitTlv._

private val networks: Codec[Networks] = variableSizeBytesLong(varintoverflow, list(bytes32)).as[Networks]
private val remoteAddress: Codec[RemoteAddress] = variableSizeBytesLong(varintoverflow, nodeaddress).as[RemoteAddress]

val initTlvCodec = tlvStream(discriminated[InitTlv].by(varint)
.typecase(UInt64(1), networks)
.typecase(UInt64(3), remoteAddress)
)

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
connect(nodeParams, remoteNodeId, switchboard, router, connection, transport, peerConnection, peer)
}

test("send incoming connection's remote address in init") { f =>
import f._
val probe = TestProbe()
val incomingConnection = PeerConnection.PendingAuth(connection.ref, None, fakeIPAddress.socketAddress, origin_opt = None, transport_opt = Some(transport.ref), isPersistent = true)
assert(!incomingConnection.outgoing)
probe.send(peerConnection, incomingConnection)
transport.send(peerConnection, TransportHandler.HandshakeCompleted(remoteNodeId))
switchboard.expectMsg(PeerConnection.Authenticated(peerConnection, remoteNodeId))
probe.send(peerConnection, PeerConnection.InitializeConnection(peer.ref, nodeParams.chainHash, nodeParams.features, doSync = false))
transport.expectMsgType[TransportHandler.Listener]
val localInit = transport.expectMsgType[protocol.Init]
assert(localInit.remoteAddress_opt === Some(fakeIPAddress))
}

test("handle connection closed during authentication") { f =>
import f._
val probe = TestProbe()
Expand Down Expand Up @@ -459,5 +473,23 @@ class PeerConnectionSpec extends TestKitBaseClass with FixtureAnyFunSuiteLike wi
peer.send(peerConnection, message)
transport.expectMsg(message)
}

test("filter private IP addresses") { _ =>
val testCases = Seq(
NodeAddress.fromParts("127.0.0.1", 9735).get -> false,
NodeAddress.fromParts("0.0.0.0", 9735).get -> false,
NodeAddress.fromParts("192.168.0.1", 9735).get -> false,
NodeAddress.fromParts("140.82.121.3", 9735).get -> true,
NodeAddress.fromParts("0000:0000:0000:0000:0000:0000:0000:0001", 9735).get -> false,
NodeAddress.fromParts("b643:8bb1:c1f9:0556:487c:0acb:2ba3:3cc2", 9735).get -> true,
NodeAddress.fromParts("hsmithsxurybd7uh.onion", 9735).get -> false,
NodeAddress.fromParts("iq7zhmhck54vcax2vlrdcavq2m32wao7ekh6jyeglmnuuvv3js57r4id.onion", 9735).get -> false,
pm47 marked this conversation as resolved.
Show resolved Hide resolved
)
for ((address, expected) <- testCases) {
val isPublicIP = NodeAddress.isPublicIPAddress(address)
assert(isPublicIP === expected)
}
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.scalatest.funsuite.AnyFunSuite
import scodec.DecodeResult
import scodec.bits.{BinStringSyntax, ByteVector, HexStringSyntax}

import java.net.{Inet4Address, InetAddress}
import java.net.{Inet4Address, Inet6Address, InetAddress}

/**
* Created by PM on 31/05/2016.
Expand All @@ -49,31 +49,36 @@ class LightningMessageCodecsSpec extends AnyFunSuite {
def publicKey(fill: Byte) = PrivateKey(ByteVector.fill(32)(fill)).publicKey

test("encode/decode init message") {
case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], valid: Boolean, reEncoded: Option[ByteVector] = None)
case class TestCase(encoded: ByteVector, rawFeatures: ByteVector, networks: List[ByteVector32], address: Option[IPAddress], valid: Boolean, reEncoded: Option[ByteVector] = None)
val chainHash1 = ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101")
val chainHash2 = ByteVector32(hex"0202020202020202020202020202020202020202020202020202020202020202")
val remoteAddress1 = IPv4(InetAddress.getByAddress(Array[Byte](140.toByte, 82.toByte, 121.toByte, 3.toByte)).asInstanceOf[Inet4Address], 9735)
val remoteAddress2 = IPv6(InetAddress.getByAddress(hex"b643 8bb1 c1f9 0556 487c 0acb 2ba3 3cc2".toArray).asInstanceOf[Inet6Address], 9736)
val testCases = Seq(
TestCase(hex"0000 0000", hex"", Nil, valid = true), // no features
TestCase(hex"0000 0002088a", hex"088a", Nil, valid = true), // no global features
TestCase(hex"00020200 0000", hex"0200", Nil, valid = true, Some(hex"0000 00020200")), // no local features
TestCase(hex"00020200 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size
TestCase(hex"00020200 0003020002", hex"020202", Nil, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes
TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size
TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes
TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, valid = true), // unknown odd records
TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, valid = false), // unknown even records
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, valid = false), // invalid tlv stream
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), valid = true), // single network
TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), valid = true), // multiple networks
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010103012a", hex"088a", List(chainHash1), valid = true), // network and unknown odd records
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101010102012a", hex"088a", Nil, valid = false) // network and unknown even records
TestCase(hex"0000 0000", hex"", Nil, None, valid = true), // no features
TestCase(hex"0000 0002088a", hex"088a", Nil, None, valid = true), // no global features
TestCase(hex"00020200 0000", hex"0200", Nil, None, valid = true, Some(hex"0000 00020200")), // no local features
TestCase(hex"00020200 0002088a", hex"0a8a", Nil, None, valid = true, Some(hex"0000 00020a8a")), // local and global - no conflict - same size
TestCase(hex"00020200 0003020002", hex"020202", Nil, None, valid = true, Some(hex"0000 0003020202")), // local and global - no conflict - different sizes
TestCase(hex"00020a02 0002088a", hex"0a8a", Nil, None, valid = true, Some(hex"0000 00020a8a")), // local and global - conflict - same size
TestCase(hex"00022200 000302aaa2", hex"02aaa2", Nil, None, valid = true, Some(hex"0000 000302aaa2")), // local and global - conflict - different sizes
TestCase(hex"0000 0002088a 03012a05022aa2", hex"088a", Nil, None, valid = true), // unknown odd records
TestCase(hex"0000 0002088a 03012a04022aa2", hex"088a", Nil, None, valid = false), // unknown even records
TestCase(hex"0000 0002088a 0120010101010101010101010101010101010101010101010101010101010101", hex"088a", Nil, None, valid = false), // invalid tlv stream
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101", hex"088a", List(chainHash1), None, valid = true), // single network
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 0307018c5279032607", hex"088a", List(chainHash1), Some(remoteAddress1), valid = true), // single network and IPv4 address
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 031302b6438bb1c1f90556487c0acb2ba33cc22608", hex"088a", List(chainHash1), Some(remoteAddress2), valid = true), // single network and IPv6 address
TestCase(hex"0000 0002088a 014001010101010101010101010101010101010101010101010101010101010101010202020202020202020202020202020202020202020202020202020202020202", hex"088a", List(chainHash1, chainHash2), None, valid = true), // multiple networks
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 c9012a", hex"088a", List(chainHash1), None, valid = true), // network and unknown odd records
TestCase(hex"0000 0002088a 01200101010101010101010101010101010101010101010101010101010101010101 02012a", hex"088a", Nil, None, valid = false) // network and unknown even records
)

for (testCase <- testCases) {
if (testCase.valid) {
val init = initCodec.decode(testCase.encoded.bits).require.value
assert(init.features.toByteVector === testCase.rawFeatures)
assert(init.networks === testCase.networks)
assert(init.remoteAddress_opt === testCase.address)
val encoded = initCodec.encode(init).require
assert(encoded.bytes === testCase.reEncoded.getOrElse(testCase.encoded))
assert(initCodec.decode(encoded).require.value === init)
Expand Down