Skip to content

Commit

Permalink
Prevent loops and improve shortest paths perf (#1747)
Browse files Browse the repository at this point in the history
Reverse the flow of yen's k-shortest path: go backwards like
we do in dijkstra.

Better tracking of already explored spur paths which improves
performance (especially tail latency).
  • Loading branch information
thomash-acinq authored Mar 31, 2021
1 parent c6a76af commit 75cb777
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 41 deletions.
76 changes: 35 additions & 41 deletions eclair-core/src/main/scala/fr/acinq/eclair/router/Graph.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,49 +98,50 @@ object Graph {
boundaries: RichWeight => Boolean): Seq[WeightedPath] = {
// find the shortest path (k = 0)
val targetWeight = RichWeight(amount, 0, CltvExpiryDelta(0), 0)
val shortestPath = dijkstraShortestPath(graph, sourceNode, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr)
val shortestPath = dijkstraShortestPath(graph, sourceNode, targetNode, ignoredEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr)
if (shortestPath.isEmpty) {
return Seq.empty // if we can't even find a single path, avoid returning a Seq(Seq.empty)
}

case class PathWithSpur(p: WeightedPath, spurIndex: Int)
implicit object PathWithSpurComparator extends Ordering[PathWithSpur] {
override def compare(x: PathWithSpur, y: PathWithSpur): Int = y.p.weight.compare(x.p.weight)
}

var allSpurPathsFound = false
val shortestPaths = new mutable.Queue[WeightedPath]
shortestPaths.enqueue(WeightedPath(shortestPath, pathWeight(sourceNode, shortestPath, amount, currentBlockHeight, wr)))
val shortestPaths = new mutable.Queue[PathWithSpur]
shortestPaths.enqueue(PathWithSpur(WeightedPath(shortestPath, pathWeight(sourceNode, shortestPath, amount, currentBlockHeight, wr)), 0))
// stores the candidates for the k-th shortest path, sorted by path cost
val candidates = new mutable.PriorityQueue[WeightedPath]
val candidates = new mutable.PriorityQueue[PathWithSpur]

// main loop
for (k <- 1 until pathsToFind) {
if (!allSpurPathsFound) {
val prevShortestPath = shortestPaths(k - 1).path
// for every edge in the path, we will try to find a different path after that edge
for (i <- prevShortestPath.indices) {
// select the spur node as the i-th element of the previous shortest path
val spurNode = prevShortestPath(i).desc.a
// select the sub-path from the source to the spur node
val rootPathEdges = prevShortestPath.take(i)
val PathWithSpur(WeightedPath(prevShortestPath, _), spurIndex) = shortestPaths(k - 1)
// for every new edge in the path, we will try to find a different path before that edge
for (i <- spurIndex until prevShortestPath.length) {
// select the spur node as the i-th element from the target of the previous shortest path
val spurNode = prevShortestPath(prevShortestPath.length - 1 - i).desc.b
// select the sub-path from the spur node to the target
val rootPathEdges = prevShortestPath.takeRight(i)
// we ignore all the paths that we have already fully explored in previous iterations
// if for example the spur node is B, and we already found shortest paths starting with A-B-C and A-B-D,
// we want to ignore the B-C and B-D edges
// +-- C -- [...]
// |
// A -- B --+-- D -- [...]
// |
// +-- E -- [...]
val alreadyExploredEdges = shortestPaths.collect { case p if p.path.take(i) == rootPathEdges => p.path(i).desc }.toSet
// we also want to ignore any link that can lead back to the previous node (we only want to go forward)
val returningEdges = rootPathEdges.lastOption.map(last => graph.getEdgesBetween(last.desc.b, last.desc.a).map(_.desc).toSet).getOrElse(Set.empty)
// if for example the spur node is D, and we already found shortest paths ending with A->D->E and B->D->E,
// we want to ignore the A->D and B->D edges
// [...] --> A --+
// |
// [...] --> B --+--> D --> E
// |
// [...] --> C --+
val alreadyExploredEdges = shortestPaths.collect { case p if p.p.path.takeRight(i) == rootPathEdges => p.p.path(p.p.path.length - 1 - i).desc }.toSet
// we also want to ignore any vertex on the root path to prevent loops
val alreadyExploredVertices = rootPathEdges.map(_.desc.b).toSet
val rootPathWeight = pathWeight(sourceNode, rootPathEdges, amount, currentBlockHeight, wr)
// find the "spur" path, a sub-path going from the spur node to the target avoiding previously found sub-paths
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, targetNode, ignoredEdges ++ alreadyExploredEdges ++ returningEdges, ignoredVertices, extraEdges, targetWeight, boundaries, currentBlockHeight, wr)
val spurPath = dijkstraShortestPath(graph, sourceNode, spurNode, ignoredEdges ++ alreadyExploredEdges, ignoredVertices ++ alreadyExploredVertices, extraEdges, rootPathWeight, boundaries, currentBlockHeight, wr)
if (spurPath.nonEmpty) {
// candidate k-shortest path is made of the root path and the new spur path, but the cost of the spur
// path is likely higher than previous shortest paths, so we need to validate that the root path can
// relay the increased amount.
val completePath = rootPathEdges ++ spurPath
val completePath = spurPath ++ rootPathEdges
val candidatePath = WeightedPath(completePath, pathWeight(sourceNode, completePath, amount, currentBlockHeight, wr))
if (boundaries(candidatePath.weight) && !shortestPaths.contains(candidatePath) && !candidates.exists(_ == candidatePath) && validatePath(completePath, amount)) {
candidates.enqueue(candidatePath)
}
candidates.enqueue(PathWithSpur(candidatePath, i))
}
}
}
Expand All @@ -154,7 +155,7 @@ object Graph {
}
}

shortestPaths.toSeq
shortestPaths.map(_.p).toSeq
}

/**
Expand All @@ -163,8 +164,7 @@ object Graph {
* graph @param g is optimized for querying the incoming edges given a vertex.
*
* @param g the graph on which will be performed the search
* @param sender node sending the payment (may be different from sourceNode when calculating partial paths)
* @param sourceNode the starting node of the path we're looking for
* @param sourceNode the starting node of the path we're looking for (payer)
* @param targetNode the destination node of the path
* @param ignoredEdges channels that should be avoided
* @param ignoredVertices nodes that should be avoided
Expand All @@ -175,7 +175,6 @@ object Graph {
* @param wr ratios used to 'weight' edges when searching for the shortest path
*/
private def dijkstraShortestPath(g: DirectedGraph,
sender: PublicKey,
sourceNode: PublicKey,
targetNode: PublicKey,
ignoredEdges: Set[ChannelDesc],
Expand Down Expand Up @@ -222,7 +221,7 @@ object Graph {
val neighbor = edge.desc.a
// NB: this contains the amount (including fees) that will need to be sent to `neighbor`, but the amount that
// will be relayed through that edge is the one in `currentWeight`.
val neighborWeight = addEdgeWeight(sender, edge, current.weight, currentBlockHeight, wr)
val neighborWeight = addEdgeWeight(sourceNode, edge, current.weight, currentBlockHeight, wr)
val canRelayAmount = current.weight.cost <= edge.capacity &&
edge.balance_opt.forall(current.weight.cost <= _) &&
edge.update.htlcMaximumMsat.forall(current.weight.cost <= _) &&
Expand Down Expand Up @@ -295,19 +294,14 @@ object Graph {

/**
* Calculate the minimum amount that the start node needs to receive to be able to forward @amountWithFees to the end
* node. To avoid infinite loops caused by zero-fee edges, we use a lower bound fee of 1 msat.
* node.
*
* @param edge the edge we want to cross
* @param amountToForward the value that this edge will have to carry along
* @return the new amount updated with the necessary fees for this edge
*/
private def addEdgeFees(edge: GraphEdge, amountToForward: MilliSatoshi): MilliSatoshi = {
if (edgeHasZeroFee(edge)) amountToForward + nodeFee(baseFee = 1 msat, proportionalFee = 0, amountToForward)
else amountToForward + nodeFee(edge.update.feeBaseMsat, edge.update.feeProportionalMillionths, amountToForward)
}

private def edgeHasZeroFee(edge: GraphEdge): Boolean = {
edge.update.feeBaseMsat.toLong == 0 && edge.update.feeProportionalMillionths == 0
amountToForward + nodeFee(edge.update.feeBaseMsat, edge.update.feeProportionalMillionths, amountToForward)
}

/** Validate that all edges along the path can relay the amount with fees. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.scalatest.{ParallelTestExecution, Tag}
import scodec.bits._

import scala.collection.immutable.SortedMap
import scala.collection.mutable
import scala.util.{Failure, Random, Success}

/**
Expand Down Expand Up @@ -1545,6 +1546,122 @@ class RouteCalculationSpec extends AnyFunSuite with ParallelTestExecution {
}
}

test("loop trap") {
// +-----------------+
// | |
// | v
// A --> B --> C --> D --> E
// ^ |
// | |
// F <---+
val g = DirectedGraph(List(
makeEdge(1L, a, b, 1000 msat, 1000),
makeEdge(2L, b, c, 1000 msat, 1000),
makeEdge(3L, c, d, 1000 msat, 1000),
makeEdge(4L, d, e, 1000 msat, 1000),
makeEdge(5L, b, e, 1000 msat, 1000),
makeEdge(6L, c, f, 1000 msat, 1000),
makeEdge(7L, f, b, 1000 msat, 1000),
))

val Success(routes) = findRoute(g, a, e, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 3, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000)
assert(routes.length == 2)
val route1 :: route2 :: Nil = routes
assert(route2Ids(route1) === 1 :: 5 :: Nil)
assert(route2Ids(route2) === 1 :: 2 :: 3 :: 4 :: Nil)
}

test("reversed loop trap") {
// +-----------------+
// | |
// v |
// A <-- B <-- C <-- D <-- E
// | ^
// | |
// F ----+
val g = DirectedGraph(List(
makeEdge(1L, b, a, 1000 msat, 1000),
makeEdge(2L, c, b, 1000 msat, 1000),
makeEdge(3L, d, c, 1000 msat, 1000),
makeEdge(4L, e, d, 1000 msat, 1000),
makeEdge(5L, e, b, 1000 msat, 1000),
makeEdge(6L, f, c, 1000 msat, 1000),
makeEdge(7L, b, f, 1000 msat, 1000),
))

val Success(routes) = findRoute(g, e, a, DEFAULT_AMOUNT_MSAT, DEFAULT_MAX_FEE, numRoutes = 3, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000)
assert(routes.length == 2)
val route1 :: route2 :: Nil = routes
assert(route2Ids(route1) === 5 :: 1 :: Nil)
assert(route2Ids(route2) === 4 :: 3 :: 2 :: 1 :: Nil)
}

test("k-shortest paths must be distinct") {
// +----> N ---> N N ---> N ----+
// / \ / \ / \
// A +--+ (...) +--+ B
// \ / \ / \ /
// +----> N ---> N N ---> N ----+

def makeEdges(n: Int): Seq[GraphEdge] = {
val nodes = new Array[(PublicKey, PublicKey)](n)
for (i <- nodes.indices) {
nodes(i) = (randomKey.publicKey, randomKey.publicKey)
}
val q = new mutable.Queue[GraphEdge]
// One path is shorter to maximise the overlap between the n-shortest paths, they will all be like the shortest path with a single hop changed.
q.enqueue(makeEdge(1L, a, nodes(0)._1, 100 msat, 90))
q.enqueue(makeEdge(2L, a, nodes(0)._2, 100 msat, 100))
for (i <- 0 until (n - 1)) {
q.enqueue(makeEdge(4 * i + 3, nodes(i)._1, nodes(i + 1)._1, 100 msat, 90))
q.enqueue(makeEdge(4 * i + 4, nodes(i)._1, nodes(i + 1)._2, 100 msat, 90))
q.enqueue(makeEdge(4 * i + 5, nodes(i)._2, nodes(i + 1)._1, 100 msat, 100))
q.enqueue(makeEdge(4 * i + 6, nodes(i)._2, nodes(i + 1)._2, 100 msat, 100))
}
q.enqueue(makeEdge(4 * n, nodes(n - 1)._1, b, 100 msat, 90))
q.enqueue(makeEdge(4 * n + 1, nodes(n - 1)._2, b, 100 msat, 100))
q.toSeq
}

val g = DirectedGraph(makeEdges(10))

val Success(routes) = findRoute(g, a, b, DEFAULT_AMOUNT_MSAT, 100000000 msat, numRoutes = 10, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000)
assert(routes.distinct.length == 10)
}

test("all paths are shortest") {
// +----> N ---> N N ---> N ----+
// / \ / \ / \
// A +--+ (...) +--+ B
// \ / \ / \ /
// +----> N ---> N N ---> N ----+

def makeEdges(n: Int): Seq[GraphEdge] = {
val nodes = new Array[(PublicKey, PublicKey)](n)
for (i <- nodes.indices) {
nodes(i) = (randomKey.publicKey, randomKey.publicKey)
}
val q = new mutable.Queue[GraphEdge]
q.enqueue(makeEdge(1L, a, nodes(0)._1, 100 msat, 100))
q.enqueue(makeEdge(2L, a, nodes(0)._2, 100 msat, 100))
for (i <- 0 until (n - 1)) {
q.enqueue(makeEdge(4 * i + 3, nodes(i)._1, nodes(i + 1)._1, 100 msat, 100))
q.enqueue(makeEdge(4 * i + 4, nodes(i)._1, nodes(i + 1)._2, 100 msat, 100))
q.enqueue(makeEdge(4 * i + 5, nodes(i)._2, nodes(i + 1)._1, 100 msat, 100))
q.enqueue(makeEdge(4 * i + 6, nodes(i)._2, nodes(i + 1)._2, 100 msat, 100))
}
q.enqueue(makeEdge(4 * n, nodes(n - 1)._1, b, 100 msat, 100))
q.enqueue(makeEdge(4 * n + 1, nodes(n - 1)._2, b, 100 msat, 100))
q.toSeq
}

val g = DirectedGraph(makeEdges(10))

val Success(routes) = findRoute(g, a, b, DEFAULT_AMOUNT_MSAT, 100000000 msat, numRoutes = 10, routeParams = DEFAULT_ROUTE_PARAMS, currentBlockHeight = 400000)
assert(routes.distinct.length == 10)
val fees = routes.map(_.fee)
assert(fees.forall(_ == fees.head))
}
}

object RouteCalculationSpec {
Expand Down

0 comments on commit 75cb777

Please sign in to comment.