Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MPP post-restart HTLC clean-up #1224

Merged
merged 2 commits into from
Nov 28, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import fr.acinq.eclair.io.{Authenticator, Server, Switchboard}
import fr.acinq.eclair.payment.receive.PaymentHandler
import fr.acinq.eclair.payment.send.{Autoprobe, PaymentInitiator}
import fr.acinq.eclair.payment.Auditor
import fr.acinq.eclair.payment.relay.Relayer
import fr.acinq.eclair.payment.relay.{CommandBuffer, Relayer}
import fr.acinq.eclair.router._
import fr.acinq.eclair.tor.TorProtocolHandler.OnionServiceVersion
import fr.acinq.eclair.tor.{Controller, TorProtocolHandler}
Expand Down Expand Up @@ -280,9 +280,10 @@ class Setup(datadir: File,
if (config.hasPath("backup-notify-script")) Some(config.getString("backup-notify-script")) else None
), "backuphandler", SupervisorStrategy.Resume))
audit = system.actorOf(SimpleSupervisor.props(Auditor.props(nodeParams), "auditor", SupervisorStrategy.Resume))
paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams), "payment-handler", SupervisorStrategy.Resume))
register = system.actorOf(SimpleSupervisor.props(Props(new Register), "register", SupervisorStrategy.Resume))
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, register, paymentHandler), "relayer", SupervisorStrategy.Resume))
commandBuffer = system.actorOf(SimpleSupervisor.props(Props(new CommandBuffer(nodeParams, register)), "command-buffer", SupervisorStrategy.Resume))
paymentHandler = system.actorOf(SimpleSupervisor.props(PaymentHandler.props(nodeParams, commandBuffer), "payment-handler", SupervisorStrategy.Resume))
relayer = system.actorOf(SimpleSupervisor.props(Relayer.props(nodeParams, register, commandBuffer, paymentHandler), "relayer", SupervisorStrategy.Resume))
authenticator = system.actorOf(SimpleSupervisor.props(Authenticator.props(nodeParams), "authenticator", SupervisorStrategy.Resume))
switchboard = system.actorOf(SimpleSupervisor.props(Switchboard.props(nodeParams, authenticator, watcher, router, relayer, paymentHandler, wallet), "switchboard", SupervisorStrategy.Resume))
server = system.actorOf(SimpleSupervisor.props(Server.props(nodeParams, authenticator, serverBindingAddress, Some(tcpBound)), "server", SupervisorStrategy.Restart))
Expand All @@ -295,6 +296,7 @@ class Setup(datadir: File,
watcher = watcher,
paymentHandler = paymentHandler,
register = register,
commandBuffer = commandBuffer,
relayer = relayer,
router = router,
switchboard = switchboard,
Expand Down Expand Up @@ -359,6 +361,7 @@ case class Kit(nodeParams: NodeParams,
watcher: ActorRef,
paymentHandler: ActorRef,
register: ActorRef,
commandBuffer: ActorRef,
relayer: ActorRef,
router: ActorRef,
switchboard: ActorRef,
Expand Down
13 changes: 1 addition & 12 deletions eclair-core/src/main/scala/fr/acinq/eclair/channel/Channel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import fr.acinq.bitcoin.{ByteVector32, OutPoint, Satoshi, Script, ScriptFlags, T
import fr.acinq.eclair._
import fr.acinq.eclair.blockchain._
import fr.acinq.eclair.channel.Helpers.{Closing, Funding}
import fr.acinq.eclair.crypto.{ShaChain, Sphinx}
import fr.acinq.eclair.crypto.ShaChain
import fr.acinq.eclair.io.Peer
import fr.acinq.eclair.payment._
import fr.acinq.eclair.payment.relay.{CommandBuffer, Origin, Relayer}
Expand Down Expand Up @@ -2201,17 +2201,6 @@ class Channel(val nodeParams: NodeParams, val wallet: EclairWallet, remoteNodeId
case _ => ()
}

// let's now fail all pending htlc for which we are the final payee
val htlcsToFail = commitments1.remoteCommit.spec.htlcs.collect {
case DirectedHtlc(OUT, add) if Sphinx.PaymentPacket.peel(nodeParams.privateKey, add.paymentHash, add.onionRoutingPacket).fold(
_ => true, // we also fail htlcs which onion we can't decode (message won't be precise)
p => p.isLastPacket
) => add
}

log.debug(s"failing htlcs=${htlcsToFail.map(Commitments.msg2String(_)).mkString(",")}")
htlcsToFail.foreach(add => self ! CMD_FAIL_HTLC(add.id, Right(TemporaryNodeFailure), commit = true))

// have I something to sign?
if (Commitments.localHasChanges(commitments1)) {
self ! CMD_SIGN
Expand Down
54 changes: 34 additions & 20 deletions eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import fr.acinq.eclair.NodeParams
import fr.acinq.eclair.blockchain.EclairWallet
import fr.acinq.eclair.channel.Helpers.Closing
import fr.acinq.eclair.channel._
import fr.acinq.eclair.db.PendingRelayDb
import fr.acinq.eclair.payment.relay.{CommandBuffer, Origin}
import fr.acinq.eclair.db.{IncomingPayment, IncomingPaymentStatus, IncomingPaymentsDb, PendingRelayDb}
import fr.acinq.eclair.payment.IncomingPacket
import fr.acinq.eclair.payment.relay.Origin
import fr.acinq.eclair.router.Rebroadcast
import fr.acinq.eclair.transactions.{IN, OUT}
import fr.acinq.eclair.wire.{TemporaryNodeFailure, UpdateAddHtlc}
Expand Down Expand Up @@ -59,7 +59,7 @@ class Switchboard(nodeParams: NodeParams, authenticator: ActorRef, watcher: Acto
})
val peers = nodeParams.db.peers.listPeers()

checkBrokenHtlcsLink(channels, nodeParams.privateKey, nodeParams.globalFeatures) match {
checkBrokenHtlcsLink(channels, nodeParams.db.payments, nodeParams.privateKey, nodeParams.globalFeatures) match {
case Nil => ()
case brokenHtlcs =>
val brokenHtlcKiller = context.system.actorOf(Props[HtlcReaper], name = "htlc-reaper")
Expand Down Expand Up @@ -152,15 +152,16 @@ object Switchboard {
def peerActorName(remoteNodeId: PublicKey): String = s"peer-$remoteNodeId"

/**
* If we have stopped eclair while it was forwarding HTLCs, it is possible that we are in a state were an incoming HTLC
* was committed by both sides, but we didn't have time to send and/or sign the corresponding HTLC to the downstream node.
* If we have stopped eclair while it was handling HTLCs, it is possible that we are in a state were an incoming HTLC
* was committed by both sides, but we didn't have time to send and/or sign the corresponding HTLC to the downstream
* node (if we're an intermediate node) or didn't have time to fail/fulfill the payment (if we're the recipient).
*
* In that case, if we do nothing, the incoming HTLC will eventually expire and we won't lose money, but the channel will
* get closed, which is a major inconvenience.
* In that case, if we do nothing, the incoming HTLC will eventually expire and we won't lose money, but the channel
* will get closed, which is a major inconvenience.
*
* This check will detect this and will allow us to fast-fail HTLCs and thus preserve channels.
* This check will detect this and will allow us to fast-settle HTLCs and thus preserve channels.
*/
def checkBrokenHtlcsLink(channels: Seq[HasCommitments], privateKey: PrivateKey, features: ByteVector)(implicit log: LoggingAdapter): Seq[UpdateAddHtlc] = {
def checkBrokenHtlcsLink(channels: Seq[HasCommitments], paymentsDb: IncomingPaymentsDb, privateKey: PrivateKey, features: ByteVector)(implicit log: LoggingAdapter): Seq[(UpdateAddHtlc, Option[ByteVector32])] = {
// We are interested in incoming HTLCs, that have been *cross-signed* (otherwise they wouldn't have been relayed).
// They signed it first, so the HTLC will first appear in our commitment tx, and later on in their commitment when
// we subsequently sign it. That's why we need to look in *their* commitment with direction=OUT.
Expand All @@ -169,7 +170,13 @@ object Switchboard {
.filter(_.direction == OUT)
.map(_.add)
.map(IncomingPacket.decrypt(_, privateKey, features))
.collect { case Right(IncomingPacket.ChannelRelayPacket(add, _, _)) => add } // we only consider htlcs that are relayed, not the ones for which we are the final node
.collect {
case Right(IncomingPacket.ChannelRelayPacket(add, _, _)) => (add, None) // we consider all relayed htlcs
case Right(IncomingPacket.FinalPacket(add, _)) => paymentsDb.getIncomingPayment(add.paymentHash) match {
case Some(IncomingPayment(_, preimage, _, IncomingPaymentStatus.Received(_, _))) => (add, Some(preimage)) // incoming payment that succeeded
case _ => (add, None) // incoming payment that didn't succeed
}
}

// TODO: @t-bast: will need to update this to take into account trampoline-relayed (and thoroughly test).

Expand All @@ -179,7 +186,9 @@ object Switchboard {
.collect { case r: Origin.Relayed => r }
.toSet

val htlcs_broken = htlcs_in.filterNot(htlc_in => relayed_out.exists(r => r.originChannelId == htlc_in.channelId && r.originHtlcId == htlc_in.id))
val htlcs_broken = htlcs_in.filterNot {
case (htlc_in, _) => relayed_out.exists(r => r.originChannelId == htlc_in.channelId && r.originHtlcId == htlc_in.id)
}

log.info(s"htlcs_in=${htlcs_in.size} htlcs_out=${relayed_out.size} htlcs_broken=${htlcs_broken.size}")

Expand Down Expand Up @@ -229,25 +238,30 @@ class HtlcReaper extends Actor with ActorLogging {
context.system.eventStream.subscribe(self, classOf[ChannelStateChanged])

override def receive: Receive = {
case initialHtlcs: Seq[UpdateAddHtlc]@unchecked => context become main(initialHtlcs)
case initialHtlcs: Seq[(UpdateAddHtlc, Option[ByteVector32])]@unchecked => context become main(initialHtlcs)
}

def main(htlcs: Seq[UpdateAddHtlc]): Receive = {
def main(htlcs: Seq[(UpdateAddHtlc, Option[ByteVector32])]): Receive = {
case ChannelStateChanged(channel, _, _, WAIT_FOR_INIT_INTERNAL | OFFLINE | SYNCING, NORMAL | SHUTDOWN | CLOSING, data: HasCommitments) =>
val acked = htlcs
.filter(_.channelId == data.channelId) // only consider htlcs related to this channel
.filter(_._1.channelId == data.channelId) // only consider htlcs related to this channel
.filter {
case htlc if Commitments.getHtlcCrossSigned(data.commitments, IN, htlc.id).isDefined =>
// this htlc is cross signed in the current commitment, we can fail it
log.info(s"failing broken htlc=$htlc")
channel ! CMD_FAIL_HTLC(htlc.id, Right(TemporaryNodeFailure), commit = true)
case (htlc, preimage) if Commitments.getHtlcCrossSigned(data.commitments, IN, htlc.id).isDefined =>
// this htlc is cross signed in the current commitment, we can settle it
preimage match {
case Some(preimage) =>
log.info(s"fulfilling broken htlc=$htlc")
channel ! CMD_FULFILL_HTLC(htlc.id, preimage, commit = true)
case None =>
log.info(s"failing broken htlc=$htlc")
channel ! CMD_FAIL_HTLC(htlc.id, Right(TemporaryNodeFailure), commit = true)
}
false // the channel may very well be disconnected before we sign (=ack) the fail, so we keep it for now
case _ =>
true // the htlc has already been failed, we can forget about it now
}
acked.foreach(htlc => log.info(s"forgetting htlc id=${htlc.id} channelId=${htlc.channelId}"))
acked.foreach { case (htlc, _) => log.info(s"forgetting htlc id=${htlc.id} channelId=${htlc.channelId}") }
context become main(htlcs diff acked)
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import fr.acinq.bitcoin.{ByteVector32, Crypto}
import fr.acinq.eclair.channel.{CMD_FAIL_HTLC, CMD_FULFILL_HTLC, Channel}
import fr.acinq.eclair.db.{IncomingPayment, IncomingPaymentStatus, IncomingPaymentsDb}
import fr.acinq.eclair.payment.PaymentRequest.ExtraHop
import fr.acinq.eclair.payment.relay.CommandBuffer
import fr.acinq.eclair.payment.{IncomingPacket, PaymentReceived, PaymentRequest}
import fr.acinq.eclair.wire._
import fr.acinq.eclair.{CltvExpiry, Features, MilliSatoshi, NodeParams, randomBytes32}
Expand All @@ -34,8 +35,7 @@ import scala.util.{Failure, Success, Try}
*
* Created by PM on 17/06/2016.
*/
class MultiPartHandler(nodeParams: NodeParams,
db: IncomingPaymentsDb) extends ReceiveHandler {
class MultiPartHandler(nodeParams: NodeParams, db: IncomingPaymentsDb, commandBuffer: ActorRef) extends ReceiveHandler {

import MultiPartHandler._

Expand Down Expand Up @@ -79,53 +79,59 @@ class MultiPartHandler(nodeParams: NodeParams,
case p: IncomingPacket.FinalPacket if doHandle(p) => db.getIncomingPayment(p.add.paymentHash) match {
case Some(record) => validatePayment(p, record, nodeParams.currentBlockHeight) match {
case Some(cmdFail) =>
ctx.sender ! cmdFail
commandBuffer ! CommandBuffer.CommandSend(p.add.channelId, p.add.id, cmdFail)
case None =>
log.info(s"received payment for paymentHash=${p.add.paymentHash} amount=${p.add.amountMsat} totalAmount=${p.payload.totalAmount}")
pendingPayments.get(p.add.paymentHash) match {
case Some((_, handler)) =>
handler forward MultiPartPaymentFSM.MultiPartHtlc(p.payload.totalAmount, p.add)
handler ! MultiPartPaymentFSM.MultiPartHtlc(p.payload.totalAmount, p.add)
case None =>
val handler = ctx.actorOf(MultiPartPaymentFSM.props(nodeParams, p.add.paymentHash, p.payload.totalAmount, ctx.self))
handler forward MultiPartPaymentFSM.MultiPartHtlc(p.payload.totalAmount, p.add)
handler ! MultiPartPaymentFSM.MultiPartHtlc(p.payload.totalAmount, p.add)
pendingPayments = pendingPayments + (p.add.paymentHash -> (record.paymentPreimage, handler))
}
}
case None =>
ctx.sender ! CMD_FAIL_HTLC(p.add.id, Right(IncorrectOrUnknownPaymentDetails(p.payload.totalAmount, nodeParams.currentBlockHeight)), commit = true)
val cmdFail = CMD_FAIL_HTLC(p.add.id, Right(IncorrectOrUnknownPaymentDetails(p.payload.totalAmount, nodeParams.currentBlockHeight)), commit = true)
commandBuffer ! CommandBuffer.CommandSend(p.add.channelId, p.add.id, cmdFail)
}

case MultiPartPaymentFSM.MultiPartHtlcFailed(paymentHash, failure, parts) =>
log.warning(s"payment with paymentHash=$paymentHash paidAmount=${parts.map(_.payment.amount).sum} failed ($failure)")
pendingPayments.get(paymentHash).foreach { case (_, handler: ActorRef) => handler ! PoisonPill }
parts.foreach(p => p.sender ! CMD_FAIL_HTLC(p.htlcId, Right(failure), commit = true))
parts.foreach(p => commandBuffer ! CommandBuffer.CommandSend(p.payment.fromChannelId, p.htlcId, CMD_FAIL_HTLC(p.htlcId, Right(failure), commit = true)))
pendingPayments = pendingPayments - paymentHash

case MultiPartPaymentFSM.MultiPartHtlcSucceeded(paymentHash, parts) =>
val received = PaymentReceived(paymentHash, parts.map(_.payment))
log.info(s"received complete payment for paymentHash=$paymentHash amount=${received.amount}")
// The first thing we do is store the payment. This allows us to reconcile pending HTLCs after a restart.
db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)
pendingPayments.get(paymentHash).foreach {
case (preimage: ByteVector32, handler: ActorRef) =>
handler ! PoisonPill
parts.foreach(p => p.sender ! CMD_FULFILL_HTLC(p.htlcId, preimage, commit = true))
parts.foreach(p => commandBuffer ! CommandBuffer.CommandSend(p.payment.fromChannelId, p.htlcId, CMD_FULFILL_HTLC(p.htlcId, preimage, commit = true)))
}
db.receiveIncomingPayment(paymentHash, received.amount, received.timestamp)
ctx.system.eventStream.publish(received)
pendingPayments = pendingPayments - paymentHash
onSuccess(received)

case MultiPartPaymentFSM.ExtraHtlcReceived(paymentHash, p, failure) => failure match {
case Some(failure) => p.sender ! CMD_FAIL_HTLC(p.htlcId, Right(failure), commit = true)
case Some(failure) => commandBuffer ! CommandBuffer.CommandSend(p.payment.fromChannelId, p.htlcId, CMD_FAIL_HTLC(p.htlcId, Right(failure), commit = true))
// NB: this case shouldn't happen unless the sender violated the spec, so it's ok that we take a slightly more
// expensive code path by fetching the preimage from DB.
case None => db.getIncomingPayment(paymentHash).foreach(record => {
p.sender ! CMD_FULFILL_HTLC(p.htlcId, record.paymentPreimage, commit = true)
commandBuffer ! CommandBuffer.CommandSend(p.payment.fromChannelId, p.htlcId, CMD_FULFILL_HTLC(p.htlcId, record.paymentPreimage, commit = true))
db.receiveIncomingPayment(paymentHash, p.payment.amount, p.payment.timestamp)
ctx.system.eventStream.publish(PaymentReceived(paymentHash, p.payment :: Nil))
})
}

case GetPendingPayments => ctx.sender ! PendingPayments(pendingPayments.keySet)

case ack: CommandBuffer.CommandAck => commandBuffer forward ack

case "ok" => // ignoring responses from channels
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot

case Event(MultiPartHtlc(totalAmount2, htlc), d: WaitingForHtlc) =>
require(htlc.paymentHash == paymentHash, s"invalid payment hash (expected $paymentHash, received ${htlc.paymentHash}")
val pp = PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId), sender)
val pp = PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId))
val updatedParts = d.parts :+ pp
if (totalAmount != totalAmount2) {
log.warning(s"multi-part payment total amount mismatch: previously $totalAmount, now $totalAmount2")
Expand All @@ -67,7 +67,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot
case Event(MultiPartHtlc(_, htlc), _) =>
require(htlc.paymentHash == paymentHash, s"invalid payment hash (expected $paymentHash, received ${htlc.paymentHash}")
log.info(s"received extraneous htlc for payment hash $paymentHash")
parent ! ExtraHtlcReceived(paymentHash, PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId), sender), None)
parent ! ExtraHtlcReceived(paymentHash, PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId)), None)
stay
}

Expand All @@ -76,7 +76,7 @@ class MultiPartPaymentFSM(nodeParams: NodeParams, paymentHash: ByteVector32, tot
// The LocalPaymentHandler will create a new instance of MultiPartPaymentHandler to handle a new attempt.
case Event(MultiPartHtlc(_, htlc), PaymentFailed(failure, _)) =>
require(htlc.paymentHash == paymentHash, s"invalid payment hash (expected $paymentHash, received ${htlc.paymentHash}")
parent ! ExtraHtlcReceived(paymentHash, PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId), sender), Some(failure))
parent ! ExtraHtlcReceived(paymentHash, PendingPayment(htlc.id, PartialPayment(htlc.amountMsat, htlc.channelId)), Some(failure))
stay
}

Expand Down Expand Up @@ -121,7 +121,7 @@ object MultiPartPaymentFSM {

// @formatter:off
/** A payment that we're currently holding until we decide to fulfill or fail it. */
case class PendingPayment(htlcId: Long, payment: PartialPayment, sender: ActorRef)
case class PendingPayment(htlcId: Long, payment: PartialPayment)
/** An incoming partial payment. */
case class MultiPartHtlc(totalAmount: MilliSatoshi, htlc: UpdateAddHtlc)
/** We successfully received all parts of the payment. */
Expand Down
Loading