Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make htlc_maximum_msat mandatory in channel updates #2361

Merged
merged 2 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class PgAuditDb(implicit ds: DataSource) extends AuditDb with Logging {
statement.setLong(4, u.channelUpdate.feeProportionalMillionths)
statement.setLong(5, u.channelUpdate.cltvExpiryDelta.toInt)
statement.setLong(6, u.channelUpdate.htlcMinimumMsat.toLong)
statement.setLong(7, u.channelUpdate.htlcMaximumMsat.map(_.toLong).getOrElse(-1))
statement.setLong(7, u.channelUpdate.htlcMaximumMsat.toLong)
statement.setTimestamp(8, Timestamp.from(Instant.now()))
statement.executeUpdate()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class SqliteAuditDb(val sqlite: Connection) extends AuditDb with Logging {
statement.setLong(4, u.channelUpdate.feeProportionalMillionths)
statement.setLong(5, u.channelUpdate.cltvExpiryDelta.toInt)
statement.setLong(6, u.channelUpdate.htlcMinimumMsat.toLong)
statement.setLong(7, u.channelUpdate.htlcMaximumMsat.map(_.toLong).getOrElse(-1))
statement.setLong(7, u.channelUpdate.htlcMaximumMsat.toLong)
statement.setLong(8, TimestampMilli.now().toLong)
statement.executeUpdate()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object Announcements {
def nodeAnnouncementWitnessEncode(timestamp: TimestampSecond, nodeId: PublicKey, rgbColor: Color, alias: String, features: Features[Feature], addresses: List[NodeAddress], tlvStream: TlvStream[NodeAnnouncementTlv]): ByteVector =
sha256(sha256(serializationResult(LightningMessageCodecs.nodeAnnouncementWitnessCodec.encode(features :: timestamp :: nodeId :: rgbColor :: alias :: addresses :: tlvStream :: HNil))))

def channelUpdateWitnessEncode(chainHash: ByteVector32, shortChannelId: ShortChannelId, timestamp: TimestampSecond, channelFlags: ChannelUpdate.ChannelFlags, cltvExpiryDelta: CltvExpiryDelta, htlcMinimumMsat: MilliSatoshi, feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long, htlcMaximumMsat: Option[MilliSatoshi], tlvStream: TlvStream[ChannelUpdateTlv]): ByteVector =
def channelUpdateWitnessEncode(chainHash: ByteVector32, shortChannelId: ShortChannelId, timestamp: TimestampSecond, channelFlags: ChannelUpdate.ChannelFlags, cltvExpiryDelta: CltvExpiryDelta, htlcMinimumMsat: MilliSatoshi, feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long, htlcMaximumMsat: MilliSatoshi, tlvStream: TlvStream[ChannelUpdateTlv]): ByteVector =
sha256(sha256(serializationResult(LightningMessageCodecs.channelUpdateWitnessCodec.encode(chainHash :: shortChannelId :: timestamp :: channelFlags :: cltvExpiryDelta :: htlcMinimumMsat :: feeBaseMsat :: feeProportionalMillionths :: htlcMaximumMsat :: tlvStream :: HNil))))

def generateChannelAnnouncementWitness(chainHash: ByteVector32, shortChannelId: RealShortChannelId, localNodeId: PublicKey, remoteNodeId: PublicKey, localFundingKey: PublicKey, remoteFundingKey: PublicKey, features: Features[Feature]): ByteVector =
Expand Down Expand Up @@ -116,8 +116,7 @@ object Announcements {

def makeChannelUpdate(chainHash: ByteVector32, nodeSecret: PrivateKey, remoteNodeId: PublicKey, shortChannelId: ShortChannelId, cltvExpiryDelta: CltvExpiryDelta, htlcMinimumMsat: MilliSatoshi, feeBaseMsat: MilliSatoshi, feeProportionalMillionths: Long, htlcMaximumMsat: MilliSatoshi, enable: Boolean = true, timestamp: TimestampSecond = TimestampSecond.now()): ChannelUpdate = {
val channelFlags = ChannelUpdate.ChannelFlags(isNode1 = isNode1(nodeSecret.publicKey, remoteNodeId), isEnabled = enable)
val htlcMaximumMsatOpt = Some(htlcMaximumMsat)
val witness = channelUpdateWitnessEncode(chainHash, shortChannelId, timestamp, channelFlags, cltvExpiryDelta, htlcMinimumMsat, feeBaseMsat, feeProportionalMillionths, htlcMaximumMsatOpt, TlvStream.empty)
val witness = channelUpdateWitnessEncode(chainHash, shortChannelId, timestamp, channelFlags, cltvExpiryDelta, htlcMinimumMsat, feeBaseMsat, feeProportionalMillionths, htlcMaximumMsat, TlvStream.empty)
val sig = Crypto.sign(witness, nodeSecret)
ChannelUpdate(
signature = sig,
Expand All @@ -129,7 +128,7 @@ object Announcements {
htlcMinimumMsat = htlcMinimumMsat,
feeBaseMsat = feeBaseMsat,
feeProportionalMillionths = feeProportionalMillionths,
htlcMaximumMsat = htlcMaximumMsatOpt
htlcMaximumMsat = htlcMaximumMsat
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ object Router {
override def cltvExpiryDelta: CltvExpiryDelta = channelUpdate.cltvExpiryDelta
override def relayFees: Relayer.RelayFees = channelUpdate.relayFees
override def htlcMinimum: MilliSatoshi = channelUpdate.htlcMinimumMsat
override def htlcMaximum_opt: Option[MilliSatoshi] = channelUpdate.htlcMaximumMsat
override def htlcMaximum_opt: Option[MilliSatoshi] = Some(channelUpdate.htlcMaximumMsat)
}
/** We learnt about this channel from hints in an invoice */
case class FromHint(extraHop: Invoice.ExtraEdge) extends ChannelRelayParams {
Expand Down
3 changes: 1 addition & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Sync.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ import akka.actor.{ActorContext, ActorRef}
import akka.event.LoggingAdapter
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.RealShortChannelId
import fr.acinq.eclair.crypto.TransportHandler
import fr.acinq.eclair.router.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.router.Router._
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{BlockHeight, ShortChannelId, TimestampSecond, TimestampSecondLong, serializationResult}
import fr.acinq.eclair.{BlockHeight, RealShortChannelId, TimestampSecond, TimestampSecondLong, serializationResult}
import scodec.bits.ByteVector
import shapeless.HNil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ import fr.acinq.bitcoin.scalacompat.ScriptWitness
import fr.acinq.eclair.wire.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.{Feature, Features, InitFeature, KamonExt}
import scodec.bits.{BitVector, ByteVector}
import scodec.bits.{BitVector, ByteVector, HexStringSyntax}
import scodec.codecs._
import scodec.{Attempt, Codec}
import shapeless._

/**
* Created by PM on 15/11/2016.
Expand Down Expand Up @@ -321,10 +320,6 @@ object LightningMessageCodecs {
("signature" | bytes64) ::
nodeAnnouncementWitnessCodec).as[NodeAnnouncement]

private case class MessageFlags(optionChannelHtlcMax: Boolean)
pm47 marked this conversation as resolved.
Show resolved Hide resolved

private val messageFlagsCodec = ("messageFlags" | (ignore(7) :: bool)).as[MessageFlags]

val reverseBool: Codec[Boolean] = bool.xmap[Boolean](b => !b, b => !b)

/** BOLT 7 defines a 'disable' bit and a 'direction' bit, but it's easier to understand if we take the reverse. */
Expand All @@ -333,36 +328,26 @@ object LightningMessageCodecs {
val channelUpdateChecksumCodec =
("chainHash" | bytes32) ::
("shortChannelId" | shortchannelid) ::
(messageFlagsCodec >>:~ { messageFlags =>
channelFlagsCodec ::
("cltvExpiryDelta" | cltvExpiryDelta) ::
("htlcMinimumMsat" | millisatoshi) ::
("feeBaseMsat" | millisatoshi32) ::
("feeProportionalMillionths" | uint32) ::
("htlcMaximumMsat" | conditional(messageFlags.optionChannelHtlcMax, millisatoshi))
}).derive[MessageFlags].from {
// The purpose of this is to tell scodec how to derive the message flags from the data, so we can remove that field
// from the codec definition and the case class, making it purely a serialization detail.
// see: https://github.com/scodec/scodec/blob/series/1.11.x/unitTests/src/test/scala/scodec/examples/ProductsExample.scala#L108-L127
case _ :: _ :: _ :: _ :: _ :: htlcMaximumMsat_opt :: HNil => MessageFlags(optionChannelHtlcMax = htlcMaximumMsat_opt.isDefined)
}
("messageFlags" | constant(hex"01")) :~>:
channelFlagsCodec ::
("cltvExpiryDelta" | cltvExpiryDelta) ::
("htlcMinimumMsat" | millisatoshi) ::
("feeBaseMsat" | millisatoshi32) ::
("feeProportionalMillionths" | uint32) ::
("htlcMaximumMsat" | millisatoshi)

val channelUpdateWitnessCodec =
(("chainHash" | bytes32) ::
("chainHash" | bytes32) ::
("shortChannelId" | shortchannelid) ::
("timestamp" | timestampSecond) ::
(messageFlagsCodec >>:~ { messageFlags =>
channelFlagsCodec ::
("cltvExpiryDelta" | cltvExpiryDelta) ::
("htlcMinimumMsat" | millisatoshi) ::
("feeBaseMsat" | millisatoshi32) ::
("feeProportionalMillionths" | uint32) ::
("htlcMaximumMsat" | conditional(messageFlags.optionChannelHtlcMax, millisatoshi)) ::
("tlvStream" | ChannelUpdateTlv.channelUpdateTlvCodec)
})).derive[MessageFlags].from {
// same comment above
case _ :: _ :: _ :: _ :: _ :: _ :: _ :: _ :: htlcMaximumMsat_opt :: _ :: HNil => MessageFlags(optionChannelHtlcMax = htlcMaximumMsat_opt.isDefined)
}
("messageFlags" | constant(hex"01")) :~>:
channelFlagsCodec ::
("cltvExpiryDelta" | cltvExpiryDelta) ::
("htlcMinimumMsat" | millisatoshi) ::
("feeBaseMsat" | millisatoshi32) ::
("feeProportionalMillionths" | uint32) ::
("htlcMaximumMsat" | millisatoshi) ::
("tlvStream" | ChannelUpdateTlv.channelUpdateTlvCodec)

val channelUpdateCodec: Codec[ChannelUpdate] = (
("signature" | bytes64) ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import com.google.common.base.Charsets
import com.google.common.net.InetAddresses
import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, ByteVector64, Satoshi, ScriptWitness, Transaction}
import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Feature, Features, InitFeature, MilliSatoshi, RealShortChannelId, ShortChannelId, TimestampSecond, UInt64}
import fr.acinq.eclair.blockchain.fee.FeeratePerKw
import fr.acinq.eclair.channel.{ChannelFlags, ChannelType}
import fr.acinq.eclair.payment.relay.Relayer
import fr.acinq.eclair.wire.protocol.ChannelReadyTlv.ShortChannelIdTlv
import fr.acinq.eclair.{Alias, BlockHeight, CltvExpiry, CltvExpiryDelta, Feature, Features, InitFeature, MilliSatoshi, RealShortChannelId, ShortChannelId, TimestampSecond, UInt64}
import scodec.bits.ByteVector

import java.net.{Inet4Address, Inet6Address, InetAddress}
Expand Down Expand Up @@ -227,8 +227,8 @@ case class FundingSigned(channelId: ByteVector32,
tlvStream: TlvStream[FundingSignedTlv] = TlvStream.empty) extends ChannelMessage with HasChannelId

case class ChannelReady(channelId: ByteVector32,
nextPerCommitmentPoint: PublicKey,
tlvStream: TlvStream[ChannelReadyTlv] = TlvStream.empty) extends ChannelMessage with HasChannelId {
nextPerCommitmentPoint: PublicKey,
tlvStream: TlvStream[ChannelReadyTlv] = TlvStream.empty) extends ChannelMessage with HasChannelId {
val alias_opt: Option[Alias] = tlvStream.get[ShortChannelIdTlv].map(_.alias)
}

Expand Down Expand Up @@ -260,7 +260,7 @@ object UpdateAddHtlc {
paymentHash: ByteVector32,
cltvExpiry: CltvExpiry,
onionRoutingPacket: OnionRoutingPacket,
blinding_opt: Option[PublicKey]):UpdateAddHtlc = {
blinding_opt: Option[PublicKey]): UpdateAddHtlc = {
val tlvs = Seq(blinding_opt.map(UpdateAddHtlcTlv.BlindingPoint)).flatten
UpdateAddHtlc(channelId, id, amountMsat, paymentHash, cltvExpiry, onionRoutingPacket, TlvStream[UpdateAddHtlcTlv](tlvs))
}
Expand Down Expand Up @@ -385,10 +385,9 @@ case class ChannelUpdate(signature: ByteVector64,
htlcMinimumMsat: MilliSatoshi,
feeBaseMsat: MilliSatoshi,
feeProportionalMillionths: Long,
htlcMaximumMsat: Option[MilliSatoshi],
htlcMaximumMsat: MilliSatoshi,
tlvStream: TlvStream[ChannelUpdateTlv] = TlvStream.empty) extends RoutingMessage with AnnouncementMessage with HasTimestamp with HasChainHash {

def messageFlags: Byte = if (htlcMaximumMsat.isDefined) 1 else 0
def messageFlags: Byte = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this to construct ChannelDisabled error messages, which should really only use channelFlags.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it looks a bit weird here if it's just a constant, but if we reintroduce more message flags (such as in #2362) it makes a bit more sense. Let's see if #2362 is accepted at the spec level and we'll see.


def toStringShort: String = s"cltvExpiryDelta=$cltvExpiryDelta,feeBase=$feeBaseMsat,feeProportionalMillionths=$feeProportionalMillionths"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class PaymentsDbSpec extends AnyFunSuite {
object PaymentsDbSpec {
val (alicePriv, bobPriv, carolPriv, davePriv) = (randomKey(), randomKey(), randomKey(), randomKey())
val (alice, bob, carol, dave) = (alicePriv.publicKey, bobPriv.publicKey, carolPriv.publicKey, davePriv.publicKey)
val hop_ab = ChannelHop(ShortChannelId(42), alice, bob, ChannelRelayParams.FromAnnouncement(ChannelUpdate(randomBytes64(), randomBytes32(), ShortChannelId(42), 1 unixsec, ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 1 msat, 1, None)))
val hop_ab = ChannelHop(ShortChannelId(42), alice, bob, ChannelRelayParams.FromAnnouncement(ChannelUpdate(randomBytes64(), randomBytes32(), ShortChannelId(42), 1 unixsec, ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 1 msat, 1, 500_000_000 msat)))
val hop_bc = NodeHop(bob, carol, CltvExpiryDelta(14), 1 msat)
val (preimage1, preimage2, preimage3, preimage4) = (randomBytes32(), randomBytes32(), randomBytes32(), randomBytes32())
val (paymentHash1, paymentHash2, paymentHash3, paymentHash4) = (Crypto.sha256(preimage1), Crypto.sha256(preimage2), Crypto.sha256(preimage3), Crypto.sha256(preimage4))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
router.expectMsgType[RouteRequest]
router.send(payFsm, RouteResponse(Seq(Route(500000 msat, hop_ab_1 :: hop_be :: Nil), Route(501000 msat, hop_ac_1 :: hop_ce :: Nil))))
val childPayments = childPayFsm.expectMsgType[SendPaymentToRoute] :: childPayFsm.expectMsgType[SendPaymentToRoute] :: Nil
childPayments.map(_.finalPayload.asInstanceOf[PaymentOnion.FinalTlvPayload]).foreach(p => {
assert(p.records.get[OnionPaymentPayloadTlv.TrampolineOnion] == Some(trampolineTlv))
childPayments.map(_.finalPayload).foreach(p => {
assert(p.records.get[OnionPaymentPayloadTlv.TrampolineOnion].contains(trampolineTlv))
assert(p.records.unknown.toSeq == Seq(userCustomTlv))
})

Expand Down Expand Up @@ -352,7 +352,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
// B doesn't have enough liquidity on this channel.
// NB: we need a channel update with a valid signature, otherwise we'll ignore the node instead of this specific channel.
val channelUpdateBE = hop_be.params.asInstanceOf[ChannelRelayParams.FromAnnouncement].channelUpdate
val channelUpdateBE1 = Announcements.makeChannelUpdate(channelUpdateBE.chainHash, priv_b, e, channelUpdateBE.shortChannelId, channelUpdateBE.cltvExpiryDelta, channelUpdateBE.htlcMinimumMsat, channelUpdateBE.feeBaseMsat, channelUpdateBE.feeProportionalMillionths, channelUpdateBE.htlcMaximumMsat.get)
val channelUpdateBE1 = Announcements.makeChannelUpdate(channelUpdateBE.chainHash, priv_b, e, channelUpdateBE.shortChannelId, channelUpdateBE.cltvExpiryDelta, channelUpdateBE.htlcMinimumMsat, channelUpdateBE.feeBaseMsat, channelUpdateBE.feeProportionalMillionths, channelUpdateBE.htlcMaximumMsat)
val childId = payFsm.stateData.asInstanceOf[PaymentProgress].pending.keys.head
childPayFsm.send(payFsm, PaymentFailed(childId, paymentHash, Seq(RemoteFailure(route.amount, route.hops, Sphinx.DecryptedFailurePacket(b, TemporaryChannelFailure(channelUpdateBE1))))))
// We update the routing hints accordingly before requesting a new route and ignore the channel.
Expand Down Expand Up @@ -381,7 +381,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
BasicEdge(a, b, ShortChannelId(1), 10 msat, 0, CltvExpiryDelta(12)), BasicEdge(b, c, ShortChannelId(2), 15 msat, 150, CltvExpiryDelta(48)),
BasicEdge(a, c, ShortChannelId(3), 1 msat, 10, CltvExpiryDelta(144))
)
assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall{case (e1, e2) => e1 == e2})
assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall { case (e1, e2) => e1 == e2 })
}
{
val failures = Seq(
Expand All @@ -394,7 +394,7 @@ class MultiPartPaymentLifecycleSpec extends TestKitBaseClass with FixtureAnyFunS
BasicEdge(a, b, ShortChannelId(1), 23 msat, 23, CltvExpiryDelta(23)), BasicEdge(b, c, ShortChannelId(2), 21 msat, 21, CltvExpiryDelta(21)),
BasicEdge(a, c, ShortChannelId(3), 22 msat, 22, CltvExpiryDelta(22))
)
assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall{case (e1, e2) => e1 == e2})
assert(extraEdges1.zip(PaymentFailure.updateExtraEdges(failures, extraEdges)).forall { case (e1, e2) => e1 == e2 })
}
}

Expand Down Expand Up @@ -691,7 +691,7 @@ object MultiPartPaymentLifecycleSpec {
val channelId_ce = ShortChannelId(13)
val channelId_ad = ShortChannelId(21)
val channelId_de = ShortChannelId(22)
val defaultChannelUpdate = ChannelUpdate(randomBytes64(), Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0 unixsec, ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 100 msat, 0, Some(2000000 msat))
val defaultChannelUpdate = ChannelUpdate(randomBytes64(), Block.RegtestGenesisBlock.hash, ShortChannelId(0), 0 unixsec, ChannelUpdate.ChannelFlags.DUMMY, CltvExpiryDelta(12), 1 msat, 100 msat, 0, 2_000_000 msat)
val channelUpdate_ab_1 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_1)
val channelUpdate_ab_2 = defaultChannelUpdate.copy(shortChannelId = channelId_ab_2)
val channelUpdate_be = defaultChannelUpdate.copy(shortChannelId = channelId_be)
Expand Down
Loading