Skip to content

Commit

Permalink
Experimental support for fat errors
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Dec 7, 2022
1 parent a0b7a49 commit e7da96b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 3 deletions.
87 changes: 86 additions & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.bitcoin.scalacompat.{ByteVector32, Crypto}
import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.Attempt
import scodec.bits.ByteVector
import scodec.{Attempt, Codec}

import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration.DurationInt
import scala.util.{Failure, Success, Try}

/**
Expand Down Expand Up @@ -332,6 +334,89 @@ object Sphinx extends Logging {

}

case class InvalidFatErrorPacket(hopPayloads: Seq[(PublicKey, FatError.HopPayload)], failingNode: PublicKey)

object FatErrorPacket {

import FatError._

private val payloadAndPadLength = 256
private val hopPayloadLength = 9
private val maxNumHop = 27
private val codec: Codec[FatError] = fatErrorCodec(payloadAndPadLength, hopPayloadLength, maxNumHop)

def create(sharedSecret: ByteVector32, failure: FailureMessage): ByteVector = {
val failurePayload = FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).encode(failure).require.toByteVector
val hopPayload = HopPayload(ErrorSource, 0 millis)
val zeroPayloads = Seq.fill(maxNumHop)(ByteVector.fill(hopPayloadLength)(0))
val zeroHmacs = (maxNumHop.to(1, -1)).map(Seq.fill(_)(ByteVector32.Zeroes))
val plainError = codec.encode(FatError(failurePayload, zeroPayloads, zeroHmacs)).require.bytes
wrap(plainError, sharedSecret, hopPayload).get
}

private def computeHmacs(mac: Mac32, failurePayload: ByteVector, hopPayloads: Seq[ByteVector], hmacs: Seq[Seq[ByteVector32]], minNumHop: Int): Seq[ByteVector32] = {
val newHmacs = (minNumHop until maxNumHop).map(i => {
val y = maxNumHop - i
mac.mac(failurePayload ++
ByteVector.concat(hopPayloads.take(y)) ++
ByteVector.concat((0 until y - 1).map(j => hmacs(j)(i))))
})
newHmacs
}

def wrap(errorPacket: ByteVector, sharedSecret: ByteVector32, hopPayload: HopPayload): Try[ByteVector] = Try {
val um = generateKey("um", sharedSecret)
val error = codec.decode(errorPacket.bits).require.value
val hopPayloads = hopPayloadCodec.encode(hopPayload).require.bytes +: error.hopPayloads.dropRight(1)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, hopPayloads, error.hmacs, 0) +: error.hmacs.dropRight(1).map(_.drop(1))
val newError = codec.encode(FatError(error.failurePayload, hopPayloads, hmacs)).require.bytes
val key = generateKey("ammag", sharedSecret)
val stream = generateStream(key, newError.length.toInt)
newError xor stream
}

private def unwrap(errorPacket: ByteVector, sharedSecret: ByteVector32, minNumHop: Int): Try[(ByteVector, HopPayload)] = Try {
val key = generateKey("ammag", sharedSecret)
val stream = generateStream(key, errorPacket.length.toInt)
val error = codec.decode((errorPacket xor stream).bits).require.value
val um = generateKey("um", sharedSecret)
val shiftedHmacs = error.hmacs.tail.map(ByteVector32.Zeroes +: _) :+ Seq(ByteVector32.Zeroes)
val hmacs = computeHmacs(Hmac256(um), error.failurePayload, error.hopPayloads, shiftedHmacs, minNumHop)
require(hmacs == error.hmacs.head.drop(minNumHop), "Invalid HMAC")
val shiftedHopPayloads = error.hopPayloads.tail :+ ByteVector.fill(hopPayloadLength)(0)
val unwrapedError = FatError(error.failurePayload, shiftedHopPayloads, shiftedHmacs)
(codec.encode(unwrapedError).require.bytes,
hopPayloadCodec.decode(error.hopPayloads.head.bits).require.value)
}

def decrypt(errorPacket: ByteVector, sharedSecrets: Seq[(ByteVector32, PublicKey)]): Either[InvalidFatErrorPacket, DecryptedFailurePacket] = {
var packet = errorPacket
var minNumHop = 1
val hopPayloads = ArrayBuffer.empty[(PublicKey, HopPayload)]
for ((sharedSecret, nodeId) <- sharedSecrets) {
unwrap(packet, sharedSecret, minNumHop) match {
case Failure(_) => return Left(InvalidFatErrorPacket(hopPayloads.toSeq, nodeId))
case Success((unwrapedPacket, hopPayload)) =>
hopPayload.payloadType match {
case FatError.IntermediateHop =>
packet = unwrapedPacket
minNumHop += 1
hopPayloads += ((nodeId, hopPayload))
case FatError.ErrorSource =>
val failurePayload = codec.decode(unwrapedPacket.bits).require.value.failurePayload
FailureMessageCodecs.failureOnionPayload(payloadAndPadLength).decode(failurePayload.bits) match {
case Attempt.Successful(failureMessage) =>
return Right(DecryptedFailurePacket(nodeId, failureMessage.value))
case Attempt.Failure(_) =>
return Left(InvalidFatErrorPacket(hopPayloads.toSeq, nodeId))
}
}
}
}
Left(InvalidFatErrorPacket(hopPayloads.toSeq, sharedSecrets.last._2))
}
}

/**
* Route blinding is a lightweight technique to provide recipient anonymity by blinding an arbitrary amount of hops at
* the end of an onion path. It can be used for payments or onion messages.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ object FailureMessageCodecs {
}, (_: FailureMessage).code)
)

private def failureOnionPayload(payloadAndPadLength: Int): Codec[FailureMessage] = Codec(
def failureOnionPayload(payloadAndPadLength: Int): Codec[FailureMessage] = Codec(
encoder = f => variableSizeBytes(uint16, failureMessageCodec).encode(f).flatMap(bits => {
val payloadLength = bits.bytes.length - 2
val padLen = payloadAndPadLength - payloadLength
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright 2022 ACINQ SAS
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package fr.acinq.eclair.wire.protocol

import fr.acinq.bitcoin.scalacompat.ByteVector32
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import scodec.Codec
import scodec.bits.ByteVector
import scodec.codecs._

import scala.concurrent.duration.{DurationLong, FiniteDuration}


case class FatError(failurePayload: ByteVector, hopPayloads: Seq[ByteVector], hmacs: Seq[Seq[ByteVector32]])

object FatError {
// @formatter:off
sealed trait PayloadType
object IntermediateHop extends PayloadType
object ErrorSource extends PayloadType
// @formatter:on

def payloadTypeCodec: Codec[PayloadType] = mappedEnum(uint8, (IntermediateHop -> 0), (ErrorSource -> 1))

case class HopPayload(payloadType: PayloadType, holdTime: FiniteDuration)

def hopPayloadCodec: Codec[HopPayload] = (
("payload_type" | payloadTypeCodec) ::
("hold_time_ms" | uint64overflow.xmap[FiniteDuration](_.millis, _.toMillis))).as[HopPayload]

private def hmacsCodec(n: Int): Codec[Seq[Seq[ByteVector32]]] =
if (n == 0) {
provide(Nil)
}
else {
(listOfN(provide(n), bytes32).xmap[Seq[ByteVector32]](_.toSeq, _.toList) ::
hmacsCodec(n - 1)).as[(Seq[ByteVector32], Seq[Seq[ByteVector32]])]
.xmap(pair => pair._1 +: pair._2, seq => (seq.head, seq.tail))
}

def fatErrorCodec(payloadAndPadLength: Int = 256, hopPayloadLength: Int = 9, maxHop: Int = 27): Codec[FatError] = (
("failure_payload" | bytes(payloadAndPadLength + 4)) ::
("hop_payloads" | listOfN(provide(maxHop), bytes(hopPayloadLength)).xmap[Seq[ByteVector]](_.toSeq, _.toList)) ::
("hmacs" | hmacsCodec(maxHop))).as[FatError].complete
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import fr.acinq.bitcoin.scalacompat.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedRoute, BlindedRouteDetails}
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol._
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomKey}
import fr.acinq.eclair.{CltvExpiry, CltvExpiryDelta, MilliSatoshiLong, ShortChannelId, UInt64, randomBytes, randomKey}
import org.scalatest.funsuite.AnyFunSuite
import scodec.bits._

import scala.concurrent.duration.DurationInt
import scala.util.Success

/**
Expand Down Expand Up @@ -386,6 +387,63 @@ class SphinxSpec extends AnyFunSuite {
assert(failure == InvalidRealm)
}

test("decrypt fat error") {
val sharedSecrets = Seq(
hex"0101010101010101010101010101010101010101010101010101010101010101",
hex"0202020202020202020202020202020202020202020202020202020202020202",
hex"0303030303030303030303030303030303030303030303030303030303030303",
hex"0404040404040404040404040404040404040404040404040404040404040404",
hex"0505050505050505050505050505050505050505050505050505050505050505",
).map(ByteVector32(_))

val dummyHopPayload = FatError.HopPayload(FatError.IntermediateHop, 0 millis)

val expected = DecryptedFailurePacket(publicKeys(2), InvalidOnionKey(ByteVector32.One))

val packet1 = FatErrorPacket.create(sharedSecrets(2), expected.failureMessage)
assert(packet1.length == 12599)

val Right(decrypted1) = FatErrorPacket.decrypt(packet1, (2 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
assert(decrypted1 == expected)

val Success(packet2) = FatErrorPacket.wrap(packet1, sharedSecrets(1), dummyHopPayload)
assert(packet2.length == 12599)

val Right(decrypted2) = FatErrorPacket.decrypt(packet2, (1 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
assert(decrypted2 == expected)

val Success(packet3) = FatErrorPacket.wrap(packet2, sharedSecrets(0), dummyHopPayload)
assert(packet3.length == 12599)

val Right(decrypted3) = FatErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
assert(decrypted3 == expected)
}

test("decrypt fat error with random data") {
val sharedSecrets = Seq(
hex"0101010101010101010101010101010101010101010101010101010101010101",
hex"0202020202020202020202020202020202020202020202020202020202020202",
hex"0303030303030303030303030303030303030303030303030303030303030303",
hex"0404040404040404040404040404040404040404040404040404040404040404",
hex"0505050505050505050505050505050505050505050505050505050505050505",
).map(ByteVector32(_))

// publicKeys(2) creates an invalid random packet, or publicKeys(1) tries to shift blame by pretending to receive random data from publicKeys(2)
val packet1 = randomBytes(12599)

val hopPayload2 = FatError.HopPayload(FatError.IntermediateHop, 50 millis)
val Success(packet2) = FatErrorPacket.wrap(packet1, sharedSecrets(1), hopPayload2)
assert(packet2.length == 12599)

val hopPayload3 = FatError.HopPayload(FatError.IntermediateHop, 100 millis)
val Success(packet3) = FatErrorPacket.wrap(packet2, sharedSecrets(0), hopPayload3)
assert(packet3.length == 12599)

val Left(decryptionError) = FatErrorPacket.decrypt(packet3, (0 to 4).map(i => (sharedSecrets(i), publicKeys(i))))
val expected = InvalidFatErrorPacket(Seq((publicKeys(0), hopPayload3), (publicKeys(1), hopPayload2)), publicKeys(2))
assert(decryptionError == expected)
}

test("create blinded route (reference test vector)") {
val alice = PrivateKey(hex"4141414141414141414141414141414141414141414141414141414141414141")
val bob = PrivateKey(hex"4242424242424242424242424242424242424242424242424242424242424242")
Expand Down

0 comments on commit e7da96b

Please sign in to comment.