Skip to content

Commit

Permalink
Merge pull request #8683 from lightningnetwork/funding-tapcript
Browse files Browse the repository at this point in the history
[1/?]: multi: add ability to fund+use musig2 channels that commit to a tapscript root
  • Loading branch information
guggero authored Apr 30, 2024
2 parents 7fb2333 + 26ce8ee commit bb44793
Show file tree
Hide file tree
Showing 18 changed files with 438 additions and 150 deletions.
265 changes: 170 additions & 95 deletions channeldb/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,27 +225,138 @@ const (
// A tlv type definition used to serialize an outpoint's indexStatus
// for use in the outpoint index.
indexStatusType tlv.Type = 0
)

// A tlv type definition used to serialize and deserialize a KeyLocator
// from the database.
keyLocType tlv.Type = 1
// chanAuxData houses the auxiliary data that is stored for each channel in a
// TLV stream within the root bucket. This is stored as a TLV stream appended
// to the existing hard-coded fields in the channel's root bucket.
type chanAuxData struct {
// revokeKeyLoc is the key locator for the revocation key.
revokeKeyLoc tlv.RecordT[tlv.TlvType1, keyLocRecord]

// A tlv type used to serialize and deserialize the
// `InitialLocalBalance` field.
initialLocalBalanceType tlv.Type = 2
// initialLocalBalance is the initial local balance of the channel.
initialLocalBalance tlv.RecordT[tlv.TlvType2, uint64]

// A tlv type used to serialize and deserialize the
// `InitialRemoteBalance` field.
initialRemoteBalanceType tlv.Type = 3
// initialRemoteBalance is the initial remote balance of the channel.
initialRemoteBalance tlv.RecordT[tlv.TlvType3, uint64]

// A tlv type definition used to serialize and deserialize the
// confirmed ShortChannelID for a zero-conf channel.
realScidType tlv.Type = 4
// realScid is the real short channel ID of the channel corresponding to
// the on-chain outpoint.
realScid tlv.RecordT[tlv.TlvType4, lnwire.ShortChannelID]

// A tlv type definition used to serialize and deserialize the
// Memo for the channel channel.
channelMemoType tlv.Type = 5
)
// memo is an optional text field that gives context to the user about
// the channel.
memo tlv.OptionalRecordT[tlv.TlvType5, []byte]

// tapscriptRoot is the optional Tapscript root the channel funding
// output commits to.
tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte]
}

// encode serializes the chanAuxData to the given io.Writer.
func (c *chanAuxData) encode(w io.Writer) error {
tlvRecords := []tlv.Record{
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
}
c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record())
})
c.tapscriptRoot.WhenSome(
func(root tlv.RecordT[tlv.TlvType6, [32]byte]) {
tlvRecords = append(tlvRecords, root.Record())
},
)

// Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...)
if err != nil {
return err
}

return tlvStream.Encode(w)
}

// decode deserializes the chanAuxData from the given io.Reader.
func (c *chanAuxData) decode(r io.Reader) error {
memo := c.memo.Zero()
tapscriptRoot := c.tapscriptRoot.Zero()

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
c.revokeKeyLoc.Record(),
c.initialLocalBalance.Record(),
c.initialRemoteBalance.Record(),
c.realScid.Record(),
memo.Record(),
tapscriptRoot.Record(),
)
if err != nil {
return err
}

tlvs, err := tlvStream.DecodeWithParsedTypes(r)
if err != nil {
return err
}

if _, ok := tlvs[memo.TlvType()]; ok {
c.memo = tlv.SomeRecordT(memo)
}
if _, ok := tlvs[tapscriptRoot.TlvType()]; ok {
c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot)
}

return nil
}

// toOpeChan converts the chanAuxData to an OpenChannel by setting the relevant
// fields in the OpenChannel struct.
func (c *chanAuxData) toOpenChan(o *OpenChannel) {
o.RevocationKeyLocator = c.revokeKeyLoc.Val.KeyLocator
o.InitialLocalBalance = lnwire.MilliSatoshi(c.initialLocalBalance.Val)
o.InitialRemoteBalance = lnwire.MilliSatoshi(c.initialRemoteBalance.Val)
o.confirmedScid = c.realScid.Val
c.memo.WhenSomeV(func(memo []byte) {
o.Memo = memo
})
c.tapscriptRoot.WhenSomeV(func(h [32]byte) {
o.TapscriptRoot = fn.Some[chainhash.Hash](h)
})
}

// newChanAuxDataFromChan creates a new chanAuxData from the given channel.
func newChanAuxDataFromChan(openChan *OpenChannel) *chanAuxData {
c := &chanAuxData{
revokeKeyLoc: tlv.NewRecordT[tlv.TlvType1](
keyLocRecord{openChan.RevocationKeyLocator},
),
initialLocalBalance: tlv.NewPrimitiveRecord[tlv.TlvType2](
uint64(openChan.InitialLocalBalance),
),
initialRemoteBalance: tlv.NewPrimitiveRecord[tlv.TlvType3](
uint64(openChan.InitialRemoteBalance),
),
realScid: tlv.NewRecordT[tlv.TlvType4](
openChan.confirmedScid,
),
}

if len(openChan.Memo) != 0 {
c.memo = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType5](openChan.Memo),
)
}
openChan.TapscriptRoot.WhenSome(func(h chainhash.Hash) {
c.tapscriptRoot = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h),
)
})

return c
}

// indexStatus is an enum-like type that describes what state the
// outpoint is in. Currently only two possible values.
Expand Down Expand Up @@ -324,6 +435,11 @@ const (
// SimpleTaprootFeatureBit indicates that the simple-taproot-chans
// feature bit was negotiated during the lifetime of the channel.
SimpleTaprootFeatureBit ChannelType = 1 << 10

// TapscriptRootBit indicates that this is a MuSig2 channel with a top
// level tapscript commitment. This MUST be set along with the
// SimpleTaprootFeatureBit.
TapscriptRootBit ChannelType = 1 << 11
)

// IsSingleFunder returns true if the channel type if one of the known single
Expand Down Expand Up @@ -394,6 +510,12 @@ func (c ChannelType) IsTaproot() bool {
return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit
}

// HasTapscriptRoot returns true if the channel is using a top level tapscript
// root commitment.
func (c ChannelType) HasTapscriptRoot() bool {
return c&TapscriptRootBit == TapscriptRootBit
}

// ChannelConstraints represents a set of constraints meant to allow a node to
// limit their exposure, enact flow control and ensure that all HTLCs are
// economically relevant. This struct will be mirrored for both sides of the
Expand Down Expand Up @@ -856,6 +978,10 @@ type OpenChannel struct {
// channel that will be useful to our future selves.
Memo []byte

// TapscriptRoot is an optional tapscript root used to derive the MuSig2
// funding output.
TapscriptRoot fn.Option[chainhash.Hash]

// TODO(roasbeef): eww
Db *ChannelStateDB

Expand Down Expand Up @@ -4007,32 +4133,9 @@ func putChanInfo(chanBucket kvdb.RwBucket, channel *OpenChannel) error {
return err
}

// Convert balance fields into uint64.
localBalance := uint64(channel.InitialLocalBalance)
remoteBalance := uint64(channel.InitialRemoteBalance)

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &channel.Memo),
)
if err != nil {
return err
}

if err := tlvStream.Encode(&w); err != nil {
return err
auxData := newChanAuxDataFromChan(channel)
if err := auxData.encode(&w); err != nil {
return fmt.Errorf("unable to encode aux data: %w", err)
}

if err := chanBucket.Put(chanInfoKey, w.Bytes()); err != nil {
Expand Down Expand Up @@ -4221,45 +4324,14 @@ func fetchChanInfo(chanBucket kvdb.RBucket, channel *OpenChannel) error {
}
}

// Create balance fields in uint64, and Memo field as byte slice.
var (
localBalance uint64
remoteBalance uint64
memo []byte
)

// Create the tlv stream.
tlvStream, err := tlv.NewStream(
// Write the RevocationKeyLocator as the first entry in a tlv
// stream.
MakeKeyLocRecord(
keyLocType, &channel.RevocationKeyLocator,
),
tlv.MakePrimitiveRecord(
initialLocalBalanceType, &localBalance,
),
tlv.MakePrimitiveRecord(
initialRemoteBalanceType, &remoteBalance,
),
MakeScidRecord(realScidType, &channel.confirmedScid),
tlv.MakePrimitiveRecord(channelMemoType, &memo),
)
if err != nil {
return err
var auxData chanAuxData
if err := auxData.decode(r); err != nil {
return fmt.Errorf("unable to decode aux data: %w", err)
}

if err := tlvStream.Decode(r); err != nil {
return err
}

// Attach the balance fields.
channel.InitialLocalBalance = lnwire.MilliSatoshi(localBalance)
channel.InitialRemoteBalance = lnwire.MilliSatoshi(remoteBalance)

// Attach the memo field if non-empty.
if len(memo) > 0 {
channel.Memo = memo
}
// Assign all the relevant fields from the aux data into the actual
// open channel.
auxData.toOpenChan(channel)

channel.Packager = NewChannelPackager(channel.ShortChannelID)

Expand Down Expand Up @@ -4417,6 +4489,25 @@ func deleteThawHeight(chanBucket kvdb.RwBucket) error {
return chanBucket.Delete(frozenChanKey)
}

// keyLocRecord is a wrapper struct around keychain.KeyLocator to implement the
// tlv.RecordProducer interface.
type keyLocRecord struct {
keychain.KeyLocator
}

// Record creates a Record out of a KeyLocator using the passed Type and the
// EKeyLocator and DKeyLocator functions. The size will always be 8 as
// KeyFamily is uint32 and the Index is uint32.
//
// NOTE: This is part of the tlv.RecordProducer interface.
func (k *keyLocRecord) Record() tlv.Record {
// Note that we set the type here as zero, as when used with a
// tlv.RecordT, the type param will be used as the type.
return tlv.MakeStaticRecord(
0, &k.KeyLocator, 8, EKeyLocator, DKeyLocator,
)
}

// EKeyLocator is an encoder for keychain.KeyLocator.
func EKeyLocator(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*keychain.KeyLocator); ok {
Expand Down Expand Up @@ -4445,22 +4536,6 @@ func DKeyLocator(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
return tlv.NewTypeForDecodingErr(val, "keychain.KeyLocator", l, 8)
}

// MakeKeyLocRecord creates a Record out of a KeyLocator using the passed
// Type and the EKeyLocator and DKeyLocator functions. The size will always be
// 8 as KeyFamily is uint32 and the Index is uint32.
func MakeKeyLocRecord(typ tlv.Type, keyLoc *keychain.KeyLocator) tlv.Record {
return tlv.MakeStaticRecord(typ, keyLoc, 8, EKeyLocator, DKeyLocator)
}

// MakeScidRecord creates a Record out of a ShortChannelID using the passed
// Type and the EShortChannelID and DShortChannelID functions. The size will
// always be 8 for the ShortChannelID.
func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
return tlv.MakeStaticRecord(
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
)
}

// ShutdownInfo contains various info about the shutdown initiation of a
// channel.
type ShutdownInfo struct {
Expand Down
8 changes: 7 additions & 1 deletion channeldb/channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock"
Expand Down Expand Up @@ -172,7 +173,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
}

// channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption {
return func(params *testChannelParams) {
params.channel.ShortChannelID = chanID
}
Expand Down Expand Up @@ -312,6 +313,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
uniqueOutputIndex.Add(1)
op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()}

var tapscriptRoot chainhash.Hash
copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32))

return &OpenChannel{
ChanType: SingleFunderBit | FrozenBit,
ChainHash: key,
Expand Down Expand Up @@ -354,6 +358,8 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
ThawHeight: uint32(defaultPendingHeight),
InitialLocalBalance: lnwire.MilliSatoshi(9000),
InitialRemoteBalance: lnwire.MilliSatoshi(3000),
Memo: []byte("test"),
TapscriptRoot: fn.Some(tapscriptRoot),
}
}

Expand Down
6 changes: 5 additions & 1 deletion contractcourt/chain_watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet"
)
Expand Down Expand Up @@ -301,8 +302,11 @@ func (c *chainWatcher) Start() error {
err error
)
if chanState.ChanType.IsTaproot() {
fundingOpts := fn.MapOptionZ(
chanState.TapscriptRoot, lnwallet.TapscriptRootToOpt,
)
c.fundingPkScript, _, err = input.GenTaprootFundingScript(
localKey, remoteKey, 0,
localKey, remoteKey, 0, fundingOpts...,
)
if err != nil {
return err
Expand Down
Loading

0 comments on commit bb44793

Please sign in to comment.