Skip to content

Commit

Permalink
Remove invalid channel updates from DB at startup (#2379)
Browse files Browse the repository at this point in the history
Following #2361, we reject channel updates that don't contain the
`htlc_maximum_msat` field. However, the network DB may contain such
channel updates, that we need to remove when starting up.
  • Loading branch information
t-bast authored Aug 11, 2022
1 parent 2f590a8 commit e8dda28
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait NetworkDb {

def updateChannel(u: ChannelUpdate): Unit

def removeChannel(shortChannelId: ShortChannelId) = removeChannels(Set(shortChannelId)): Unit
def removeChannel(shortChannelId: ShortChannelId): Unit = removeChannels(Set(shortChannelId))

def removeChannels(shortChannelIds: Iterable[ShortChannelId]): Unit

Expand Down
14 changes: 11 additions & 3 deletions eclair-core/src/main/scala/fr/acinq/eclair/db/pg/PgNetworkDb.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
package fr.acinq.eclair.db.pg

import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.RealShortChannelId
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import fr.acinq.eclair.{RealShortChannelId, ShortChannelId}
import grizzled.slf4j.Logging
import scodec.bits.BitVector

Expand Down Expand Up @@ -80,7 +79,16 @@ class PgNetworkDb(implicit ds: DataSource) extends NetworkDb with Logging {
if (v < 4) {
migration34(statement)
}
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(CURRENT_VERSION) =>
// We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat).
statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM network.public_channels").map(rs => {
val shortChannelId = rs.getLong("short_channel_id")
val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful)
val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful)
(shortChannelId, validChannelUpdate1 && validChannelUpdate2)
}).collect {
case (scid, false) => statement.executeUpdate(s"DELETE FROM network.public_channels WHERE short_channel_id=$scid")
}
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
setVersion(statement, DB_NAME, CURRENT_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
package fr.acinq.eclair.db.sqlite

import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto, Satoshi}
import fr.acinq.eclair.ShortChannelId
import fr.acinq.eclair.RealShortChannelId
import fr.acinq.eclair.db.Monitoring.Metrics.withMetrics
import fr.acinq.eclair.db.Monitoring.Tags.DbBackends
import fr.acinq.eclair.db.NetworkDb
import fr.acinq.eclair.router.Router.PublicChannel
import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncementCodec, channelUpdateCodec, nodeAnnouncementCodec}
import fr.acinq.eclair.wire.protocol.{ChannelAnnouncement, ChannelUpdate, NodeAnnouncement}
import fr.acinq.eclair.{RealShortChannelId, ShortChannelId}
import grizzled.slf4j.Logging

import java.sql.{Connection, Statement}
Expand Down Expand Up @@ -61,7 +60,16 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
case Some(v@1) =>
logger.warn(s"migrating db $DB_NAME, found version=$v current=$CURRENT_VERSION")
migration12(statement)
case Some(CURRENT_VERSION) => () // table is up-to-date, nothing to do
case Some(CURRENT_VERSION) =>
// We clean up channels that contain an invalid channel update (e.g. missing htlc_maximum_msat).
statement.executeQuery("SELECT short_channel_id, channel_update_1, channel_update_2 FROM channels").map(rs => {
val shortChannelId = rs.getLong("short_channel_id")
val validChannelUpdate1 = rs.getBitVectorOpt("channel_update_1").forall(channelUpdateCodec.decode(_).isSuccessful)
val validChannelUpdate2 = rs.getBitVectorOpt("channel_update_2").forall(channelUpdateCodec.decode(_).isSuccessful)
(shortChannelId, validChannelUpdate1 && validChannelUpdate2)
}).collect {
case (scid, false) => statement.executeUpdate(s"DELETE FROM channels WHERE short_channel_id=$scid")
}
case Some(unknownVersion) => throw new RuntimeException(s"Unknown version of DB $DB_NAME found, version=$unknownVersion")
}
setVersion(statement, DB_NAME, CURRENT_VERSION)
Expand Down Expand Up @@ -129,12 +137,12 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
using(sqlite.createStatement()) { statement =>
statement.executeQuery("SELECT channel_announcement, txid, capacity_sat, channel_update_1, channel_update_2 FROM channels")
.foldLeft(SortedMap.empty[RealShortChannelId, PublicChannel]) { (m, rs) =>
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
val ann = channelAnnouncementCodec.decode(rs.getBitVectorOpt("channel_announcement").get).require.value
val txId = ByteVector32.fromValidHex(rs.getString("txid"))
val capacity = rs.getLong("capacity_sat")
val channel_update_1_opt = rs.getBitVectorOpt("channel_update_1").map(channelUpdateCodec.decode(_).require.value)
val channel_update_2_opt = rs.getBitVectorOpt("channel_update_2").map(channelUpdateCodec.decode(_).require.value)
m + (ann.shortChannelId -> PublicChannel(ann, txId, Satoshi(capacity), channel_update_1_opt, channel_update_2_opt, None))
}
}
}
Expand Down Expand Up @@ -166,7 +174,7 @@ class SqliteNetworkDb(val sqlite: Connection) extends NetworkDb with Logging {
}

override def removeFromPruned(shortChannelId: RealShortChannelId): Unit = withMetrics("network/remove-from-pruned", DbBackends.Sqlite) {
using(sqlite.prepareStatement(s"DELETE FROM pruned WHERE short_channel_id=?")) { statement =>
using(sqlite.prepareStatement("DELETE FROM pruned WHERE short_channel_id=?")) { statement =>
statement.setLong(1, shortChannelId.toLong)
statement.executeUpdate()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import fr.acinq.eclair.wire.protocol.LightningMessageCodecs.{channelAnnouncement
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiryDelta, Features, MilliSatoshiLong, RealShortChannelId, ShortChannelId, TestDatabases, randomBytes32, randomKey}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits.HexStringSyntax

import scala.collection.{SortedMap, mutable}
import scala.util.Random
Expand Down Expand Up @@ -63,7 +64,7 @@ class NetworkDbSpec extends AnyFunSuite {
assert(db.listNodes().toSet == Set.empty)
db.addNode(node_1)
db.addNode(node_1) // duplicate is ignored
assert(db.getNode(node_1.nodeId) == Some(node_1))
assert(db.getNode(node_1.nodeId).contains(node_1))
assert(db.listNodes().size == 1)
db.addNode(node_2)
db.addNode(node_3)
Expand Down Expand Up @@ -326,6 +327,40 @@ class NetworkDbSpec extends AnyFunSuite {
)
}

test("remove channel updates without htlc_maximum_msat") {
forAllDbs { dbs =>
val t1 = channelTestCases(0)
val t2 = channelTestCases(1)
val db1 = dbs.network
db1.addChannel(t1.channel, t1.txid, t2.capacity)
db1.addChannel(t2.channel, t2.txid, t2.capacity)
// The DB contains a channel update missing the `htlc_maximum_msat` field.
val channelUpdateWithoutHtlcMax = hex"12540b6a236e21932622d61432f52913d9442cc09a1057c386119a286153f8681c66d2a0f17d32505ba71bb37c8edcfa9c11e151b2b38dae98b825eff1c040b36fe28c0ab6f1b372c1a6a246ae63f74f931e8365e15a089c68d619000000000008850f00058e00015e6a782e0000009000000000000003e8000003e800000002"
dbs match {
case sqlite: TestSqliteDatabases =>
using(sqlite.connection.prepareStatement("UPDATE channels SET channel_update_1=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray)
statement.setLong(2, t1.shortChannelId.toLong)
statement.executeUpdate()
}
case pg: TestPgDatabases =>
using(pg.connection.prepareStatement("UPDATE network.public_channels SET channel_update_1=? WHERE short_channel_id=?")) { statement =>
statement.setBytes(1, channelUpdateWithoutHtlcMax.toArray)
statement.setLong(2, t1.shortChannelId.toLong)
statement.executeUpdate()
}
}
assertThrows[IllegalArgumentException](db1.listChannels())
// We restart eclair and automatically clean up invalid entries.
val db2 = dbs match {
case sqlite: TestSqliteDatabases => new SqliteNetworkDb(sqlite.connection)
case pg: TestPgDatabases => new PgNetworkDb()(pg.datasource)
}
val channels = db2.listChannels()
assert(channels.keySet == Set(t2.shortChannelId))
}
}

test("json column reset (postgres)") {
val dbs = TestPgDatabases()
val db = dbs.network
Expand Down

0 comments on commit e8dda28

Please sign in to comment.