Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
- extract Start at the top-level
- register only once to EventStream
- handle peer missing in map
- rename classes for clarity
- remove unnecessary comments
- use receiveMessagePartial
- simplify tests
  • Loading branch information
t-bast committed Nov 22, 2022
1 parent a1bd1b4 commit ac5118b
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,19 @@ import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.{ActorRef, Behavior, SupervisorStrategy}
import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.BlockHeight
import fr.acinq.eclair.Logs.LogCategory
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.io.PeerReadyNotifier.NotifyWhenPeerReady
import fr.acinq.eclair.io.{PeerReadyNotifier, Switchboard}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer.Command
import fr.acinq.eclair.{BlockHeight, Logs}

import scala.concurrent.duration.Duration

/**
* This actor waits for an async payment receiver to become ready to receive a payment or for a block timeout to expire.
* If the receiver of the payment is a connected peer, spawn a PeerReadyNotifier actor.
* TODO: If the receiver is not a connected peer, wait for a `ReceiverReady` onion message containing the specified paymentHash.
*/

object AsyncPaymentTriggerer {
// @formatter:off
sealed trait Command
Expand All @@ -50,73 +49,69 @@ object AsyncPaymentTriggerer {
// @formatter:on

def apply(): Behavior[Command] = Behaviors.setup { context =>
new AsyncPaymentTriggerer(context).initializing()
Behaviors.withMdc(Logs.mdc(category_opt = Some(LogCategory.PAYMENT))) {
Behaviors.receiveMessagePartial {
case Start(switchboard) => new AsyncPaymentTriggerer(switchboard, context).start()
}
}
}
}

private class AsyncPaymentTriggerer(context: ActorContext[Command]) {
private class AsyncPaymentTriggerer(switchboard: ActorRef[Switchboard.GetPeerInfo], context: ActorContext[Command]) {

import AsyncPaymentTriggerer._

case class Watcher(replyTo: ActorRef[Result], timeout: BlockHeight, paymentHash: ByteVector32) {
case class Payment(replyTo: ActorRef[Result], timeout: BlockHeight, paymentHash: ByteVector32) {
def expired(currentBlockHeight: BlockHeight): Boolean = timeout <= currentBlockHeight
}
case class AsyncPaymentTrigger(notifier: ActorRef[PeerReadyNotifier.Command], watchers: Set[Watcher]) {
def update(currentBlockHeight: BlockHeight): Option[AsyncPaymentTrigger] = {
// notify watchers that timeout occurred before offline peer reconnected
val expiredWatchers = watchers.filter(_.expired(currentBlockHeight))
expiredWatchers.foreach(e => e.replyTo ! AsyncPaymentTimeout)
// remove timed out watchers from set
val updatedWatchers: Set[Watcher] = watchers.removedAll(expiredWatchers)
if (updatedWatchers.isEmpty) {
// stop notifier for offline peer when all watchers time out

case class PeerPayments(notifier: ActorRef[PeerReadyNotifier.Command], pendingPayments: Set[Payment]) {
def update(currentBlockHeight: BlockHeight): Option[PeerPayments] = {
val expiredPayments = pendingPayments.filter(_.expired(currentBlockHeight))
expiredPayments.foreach(e => e.replyTo ! AsyncPaymentTimeout)
val pendingPayments1 = pendingPayments.removedAll(expiredPayments)
if (pendingPayments1.isEmpty) {
context.stop(notifier)
None
} else {
Some(AsyncPaymentTrigger(notifier, updatedWatchers))
Some(PeerPayments(notifier, pendingPayments1))
}
}
def trigger(): Unit = watchers.foreach(e => e.replyTo ! AsyncPaymentTriggered)
}

private def initializing(): Behavior[Command] = {
Behaviors.receiveMessage[Command] {
case Start(switchboard) => watching(switchboard, Map())
case m => context.log.error(s"received unhandled message ${m.getClass.getSimpleName} before Start received.")
Behaviors.same
}
def trigger(): Unit = pendingPayments.foreach(e => e.replyTo ! AsyncPaymentTriggered)
}

private def watching(switchboard: ActorRef[Switchboard.GetPeerInfo], triggers: Map[PublicKey, AsyncPaymentTrigger]): Behavior[Command] = {
val peerReadyResultAdapter = context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult)
def start(): Behavior[Command] = {
context.system.eventStream ! EventStream.Subscribe(context.messageAdapter[CurrentBlockHeight](WrappedCurrentBlockHeight))
watching(Map.empty)
}

Behaviors.receiveMessage[Command] {
private def watching(peers: Map[PublicKey, PeerPayments]): Behavior[Command] = {
Behaviors.receiveMessagePartial {
case Watch(replyTo, remoteNodeId, paymentHash, timeout) =>
triggers.get(remoteNodeId) match {
peers.get(remoteNodeId) match {
case None =>
// add a new trigger
val notifier = context.spawn(Behaviors.supervise(PeerReadyNotifier(remoteNodeId, switchboard, Left(Duration.Inf)))
.onFailure(SupervisorStrategy.restart), s"peer-ready-notifier-$remoteNodeId-$timeout")
notifier ! NotifyWhenPeerReady(peerReadyResultAdapter)
val newTrigger = AsyncPaymentTrigger(notifier, Set(Watcher(replyTo, timeout, paymentHash)))
watching(switchboard, triggers + (remoteNodeId -> newTrigger))
case Some(trigger) =>
// add a new watcher to an existing trigger
val updatedTrigger = AsyncPaymentTrigger(trigger.notifier, trigger.watchers + Watcher(replyTo, timeout, paymentHash))
watching(switchboard, triggers + (remoteNodeId -> updatedTrigger))
val notifier = context.spawn(
Behaviors.supervise(PeerReadyNotifier(remoteNodeId, switchboard, Left(Duration.Inf))).onFailure(SupervisorStrategy.restart),
s"peer-ready-notifier-$remoteNodeId",
)
notifier ! NotifyWhenPeerReady(context.messageAdapter[PeerReadyNotifier.Result](WrappedPeerReadyResult))
val peer = PeerPayments(notifier, Set(Payment(replyTo, timeout, paymentHash)))
watching(peers + (remoteNodeId -> peer))
case Some(peer) =>
val peer1 = PeerPayments(peer.notifier, peer.pendingPayments + Payment(replyTo, timeout, paymentHash))
watching(peers + (remoteNodeId -> peer1))
}
case WrappedCurrentBlockHeight(CurrentBlockHeight(currentBlockHeight)) =>
// update watchers, and remove triggers with no more active watchers
val newTriggers = triggers.collect(m => m._2.update(currentBlockHeight) match {
case Some(t) => m._1 -> t
})
watching(switchboard, newTriggers)
val peers1 = peers.flatMap {
case (remoteNodeId, peer) => peer.update(currentBlockHeight).map(peer1 => remoteNodeId -> peer1)
}
watching(peers1)
case WrappedPeerReadyResult(PeerReadyNotifier.PeerReady(remoteNodeId, _)) =>
// notify watcher that destination peer is ready to receive async payments; PeerReadyNotifier will stop itself
triggers(remoteNodeId).trigger()
watching(switchboard, triggers - remoteNodeId)
case m => context.log.error(s"received unhandled message ${m.getClass.getSimpleName} after Start received.")
Behaviors.same
peers.get(remoteNodeId).foreach(_.trigger())
watching(peers - remoteNodeId)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ import fr.acinq.eclair.router.Router.RouteParams
import fr.acinq.eclair.router.{BalanceTooLow, RouteNotFound}
import fr.acinq.eclair.wire.protocol.PaymentOnion.{FinalPayload, IntermediatePayload}
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{BlockHeight, CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, UInt64, nodeFee, randomBytes32}
import fr.acinq.eclair.{CltvExpiry, Features, Logs, MilliSatoshi, NodeParams, UInt64, nodeFee, randomBytes32}

import java.util.UUID
import scala.collection.immutable.Queue
Expand Down Expand Up @@ -215,23 +215,17 @@ class NodeRelay private(nodeParams: NodeParams,

private def waitForTrigger(upstream: Upstream.Trampoline, nextPayload: IntermediatePayload.NodeRelay.Standard, nextPacket: OnionRoutingPacket): Behavior[Command] = {
context.log.info(s"waiting for async payment to trigger before relaying trampoline payment (amountIn=${upstream.amountIn} expiryIn=${upstream.expiryIn} amountOut=${nextPayload.amountToForward} expiryOut=${nextPayload.outgoingCltv}, asyncPaymentsParams=${nodeParams.relayParams.asyncPaymentsParams})")
// a trigger must be received before waiting more than `holdTimeoutBlocks`
val timeoutBlock: BlockHeight = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
// a trigger must be received `cancelSafetyBeforeTimeoutBlocks` before the incoming payment cltv expiry
val safetyBlock: BlockHeight = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
val timeoutBlock = nodeParams.currentBlockHeight + nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks
val safetyBlock = (upstream.expiryIn - nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout).blockHeight
// wait for notification until which ever occurs first: the hold timeout block or the safety block
val notifierTimeout: BlockHeight = Seq(timeoutBlock, safetyBlock).min
val notifierTimeout = Seq(timeoutBlock, safetyBlock).min
val peerReadyResultAdapter = context.messageAdapter[AsyncPaymentTriggerer.Result](WrappedPeerReadyResult)

triggerer ! AsyncPaymentTriggerer.Watch(peerReadyResultAdapter, nextPayload.outgoingNodeId, paymentHash, notifierTimeout)
context.system.eventStream ! EventStream.Publish(WaitingToRelayPayment(nextPayload.outgoingNodeId, paymentHash))
Behaviors.receiveMessagePartial {
case WrappedPeerReadyResult(AsyncPaymentTriggerer.AsyncPaymentTimeout) =>
if (safetyBlock < timeoutBlock) {
context.log.warn(s"rejecting async payment; was not triggered ${nodeParams.relayParams.asyncPaymentsParams.cancelSafetyBeforeTimeout} safety blocks before upstream cltv expiry of ${upstream.expiryIn}")
} else {
context.log.warn(s"rejecting async payment; was not triggered after waiting ${nodeParams.relayParams.asyncPaymentsParams.holdTimeoutBlocks} blocks")
}
context.log.warn("rejecting async payment; was not triggered before block {}", notifierTimeout)
rejectPayment(upstream, Some(TemporaryNodeFailure)) // TODO: replace failure type when async payment spec is finalized
stopping()
case CancelAsyncPayment =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,29 @@ import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.bitcoin.scalacompat.Crypto.PublicKey
import fr.acinq.eclair.blockchain.CurrentBlockHeight
import fr.acinq.eclair.channel.{CMD_GET_CHANNEL_STATE, NEGOTIATING, RES_GET_CHANNEL_STATE}
import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._
import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard}
import fr.acinq.eclair.io.Switchboard.GetPeerInfo
import fr.acinq.eclair.io.{Peer, PeerConnected, Switchboard}
import fr.acinq.eclair.payment.relay.AsyncPaymentTriggerer._
import fr.acinq.eclair.{BlockHeight, TestConstants, randomKey}
import org.scalatest.Outcome
import org.scalatest.funsuite.FixtureAnyFunSuiteLike

import scala.concurrent.duration.DurationInt

class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.load("application")) with FixtureAnyFunSuiteLike {

case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], channelProbes: Seq[TestProbe[CMD_GET_CHANNEL_STATE]], probe: TestProbe[Result], triggerer: ActorRef[Command]) {
def channels: Set[akka.actor.ActorRef] = channelProbes.map(_.ref.toClassic).toSet
case class FixtureParam(remoteNodeId: PublicKey, switchboard: TestProbe[Switchboard.GetPeerInfo], channel: TestProbe[CMD_GET_CHANNEL_STATE], probe: TestProbe[Result], triggerer: ActorRef[Command]) {
def channels: Set[akka.actor.ActorRef] = Set(channel.ref.toClassic)
}

override def withFixture(test: OneArgTest): Outcome = {
val remoteNodeId = TestConstants.Alice.nodeParams.nodeId
val switchboard = TestProbe[Switchboard.GetPeerInfo]("switchboard")
val channelProbes = Seq(TestProbe[CMD_GET_CHANNEL_STATE]("channel1"), TestProbe[CMD_GET_CHANNEL_STATE]("channel2"))
val channel = TestProbe[CMD_GET_CHANNEL_STATE]("channel")
val probe = TestProbe[Result]()
val triggerer = testKit.spawn(AsyncPaymentTriggerer())
triggerer ! Start(switchboard.ref)
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, channelProbes, probe, triggerer)))
withFixture(test.toNoArgTest(FixtureParam(remoteNodeId, switchboard, channel, probe, triggerer)))
}

test("remote node does not connect before timeout") { f =>
Expand All @@ -50,7 +50,7 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.

// Only get the timeout message once.
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(111)))
probe.expectNoMessage()
probe.expectNoMessage(100 millis)
}

test("duplicate watches should emit only one trigger") { f =>
Expand All @@ -60,7 +60,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId)
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
switchboard.expectNoMessage()

// We trigger one timeout messages when we reach the timeout
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100)))
Expand All @@ -72,7 +71,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
assert(switchboard.expectMessageType[GetPeerInfo].remoteNodeId == remoteNodeId)
triggerer ! Watch(probe2.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
switchboard.expectNoMessage()

// We get two timeout messages when we reach the timeout
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100)))
Expand All @@ -86,24 +84,21 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, channels)
channelProbes.head.expectNoMessage(100 millis)

// An unrelated peer connects.
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, randomKey().publicKey, null))
switchboard.expectNoMessage(100 millis)
probe.expectNoMessage()
probe.expectNoMessage(100 millis)

// The target peer connects.
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, remoteNodeId, null))
val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request2.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.CONNECTED, None, channels)
channelProbes.foreach(_.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING))
channel.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING)
probe.expectMessage(AsyncPaymentTriggered)

// Only get the trigger message once.
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, remoteNodeId, null))
switchboard.expectNoMessage()
probe.expectNoMessage()
probe.expectNoMessage(100 millis)
}

test("remote node connects after one watch timeout and before another") { f =>
Expand All @@ -112,7 +107,6 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, channels)
channelProbes.head.expectNoMessage(100 millis)

// Another async payment node relay watches the peer
val probe2 = TestProbe[Result]()
Expand All @@ -126,8 +120,8 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, remoteNodeId, null))
val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request2.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.CONNECTED, None, channels)
channelProbes.foreach(_.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING))
probe.expectNoMessage()
channel.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING)
probe.expectNoMessage(100 millis)
probe2.expectMessage(AsyncPaymentTriggered)
}

Expand All @@ -138,30 +132,29 @@ class AsyncPaymentTriggererSpec extends ScalaTestWithActorTestKit(ConfigFactory.
triggerer ! Watch(probe.ref, remoteNodeId, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(100))
val request1 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request1.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, channels)
channelProbes.head.expectNoMessage(100 millis)

// watch another remote node
val remoteNodeId2 = TestConstants.Bob.nodeParams.nodeId
val probe2 = TestProbe[Result]()
triggerer ! Watch(probe2.ref, remoteNodeId2, paymentHash = ByteVector32.Zeroes, timeout = BlockHeight(101))
val request2 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
request2.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId, Peer.DISCONNECTED, None, channels)
channelProbes.head.expectNoMessage(100 millis)

// First remote node times out
system.eventStream ! EventStream.Publish(CurrentBlockHeight(BlockHeight(100)))
probe.expectMessage(AsyncPaymentTimeout)

// First remote node connects, but does not trigger expired watch
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, remoteNodeId, null))
switchboard.expectNoMessage()

// Second remote node connects and triggers watch
system.eventStream ! EventStream.Publish(PeerConnected(TestProbe().ref.toClassic, remoteNodeId2, null))
val request3 = switchboard.expectMessageType[Switchboard.GetPeerInfo]
assert(request3.remoteNodeId == remoteNodeId2)
request3.replyTo ! Peer.PeerInfo(TestProbe().ref.toClassic, remoteNodeId2, Peer.CONNECTED, None, channels)
channelProbes.foreach(_.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING))
probe.expectNoMessage()
channel.expectMessageType[CMD_GET_CHANNEL_STATE].replyTo ! RES_GET_CHANNEL_STATE(NEGOTIATING)
probe.expectNoMessage(100 millis)
probe2.expectMessage(AsyncPaymentTriggered)
}

}

0 comments on commit ac5118b

Please sign in to comment.