Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track HTLCs in rfq policies #1186

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions chain_bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/lightninglabs/taproot-assets/tapgarden"
"github.com/lightningnetwork/lnd/chainntnfs"
"github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
Expand Down Expand Up @@ -376,9 +377,17 @@ func (l *LndRouterClient) DeleteLocalAlias(ctx context.Context, alias,
return l.lnd.Router.XDeleteLocalChanAlias(ctx, alias, baseScid)
}

func (l *LndRouterClient) SubscribeHtlcEvents(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: requires Godoc comment (for the same method in the HtlcSubscriber interface as well).

ctx context.Context) (<-chan *routerrpc.HtlcEvent,
<-chan error, error) {

return l.lnd.Router.SubscribeHtlcEvents(ctx)
}

// Ensure LndRouterClient implements the rfq.HtlcInterceptor interface.
var _ rfq.HtlcInterceptor = (*LndRouterClient)(nil)
var _ rfq.ScidAliasManager = (*LndRouterClient)(nil)
var _ rfq.HtlcSubscriber = (*LndRouterClient)(nil)

// LndInvoicesClient is an LND invoices RPC client.
type LndInvoicesClient struct {
Expand Down
5 changes: 5 additions & 0 deletions rfq/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ type ManagerCfg struct {
// intercept and accept/reject HTLCs.
HtlcInterceptor HtlcInterceptor

// HtlcSubscriber is a subscriber that is used to retrieve live HTLC
// event updates.
HtlcSubscriber HtlcSubscriber

// PriceOracle is the price oracle that the RFQ manager will use to
// determine whether a quote is accepted or rejected.
PriceOracle PriceOracle
Expand Down Expand Up @@ -207,6 +211,7 @@ func (m *Manager) startSubsystems(ctx context.Context) error {
m.orderHandler, err = NewOrderHandler(OrderHandlerCfg{
CleanupInterval: CacheCleanupInterval,
HtlcInterceptor: m.cfg.HtlcInterceptor,
HtlcSubscriber: m.cfg.HtlcSubscriber,
AcceptHtlcEvents: m.acceptHtlcEvents,
})
if err != nil {
Expand Down
201 changes: 197 additions & 4 deletions rfq/order.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/lightninglabs/taproot-assets/rfqmath"
"github.com/lightninglabs/taproot-assets/rfqmsg"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnutils"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire"
Expand Down Expand Up @@ -70,6 +71,14 @@ type Policy interface {
// which the policy applies.
Scid() uint64

// TrackAcceptedHtlc makes the policy aware of this new accepted HTLC.
// This is important in cases where the set of existing HTLCs may affect
// whether the next compliance check passes.
TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc)

// UntrackHtlc stops tracking the uniquely identified htlc.
UntrackHtlc(htlcID string)

// GenerateInterceptorResponse generates an interceptor response for the
// HTLC interceptor from the policy.
GenerateInterceptorResponse(
Expand All @@ -94,9 +103,17 @@ type AssetSalePolicy struct {
// the policy.
MaxOutboundAssetAmount uint64

// CurrentAssetAmountMsat is the total amount that is held currently in
// accepted htlcs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use the same capitalization for HTLC everywhere? So s/htlcs/HTLCs here and throughout the file.

CurrentAmountMsat lnwire.MilliSatoshi

// AskAssetRate is the quote's asking asset unit to BTC conversion rate.
AskAssetRate rfqmath.BigIntFixedPoint

// htlcToAmt maps the unique htlc identifiers to the effective amount
// that they carry.
htlcToAmt lnutils.SyncMap[string, lnwire.MilliSatoshi]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So both the htlcToAmt and CurrentAmountMsat need to be in sync. If we use an atomic value for one and a sync map for the other, then it's still possible for a read from another goroutine to get a different value while we're updating things (e.g. the map was updated but the total not yet).

So I think because we have two separate values, we need a sync.RWMutex instead. But then we can use a normal map here at least.


// expiry is the policy's expiry unix timestamp after which the policy
// is no longer valid.
expiry uint64
Expand Down Expand Up @@ -151,7 +168,8 @@ func (c *AssetSalePolicy) CheckHtlcCompliance(
maxAssetAmount, c.AskAssetRate,
)

if htlc.AmountOutMsat > policyMaxOutMsat {
if (c.CurrentAmountMsat + htlc.AmountOutMsat) > policyMaxOutMsat {
// if htlc.AmountOutMsat > policyMaxOutMsat {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: commented out code.

return fmt.Errorf("htlc out amount is greater than the policy "+
"maximum (htlc_out_msat=%d, policy_max_out_msat=%d)",
htlc.AmountOutMsat, policyMaxOutMsat)
Expand All @@ -166,6 +184,29 @@ func (c *AssetSalePolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the
// acceptance of future htlcs.
func (c *AssetSalePolicy) TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc) {
c.CurrentAmountMsat += htlc.AmountOutMsat
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if these need to be atomic values? Since multiple goroutines could manipulate them concurrently.


htlcIDStr := htlcIdentifierStr(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use structs as map keys directly. And as long as there are no pointer values in a struct (meaning the struct can be directly compared with == without the pointers leading to side effects), things work out as expected. So we can just use the circuit key directly as the map key, no need to convert to string.

htlc.IncomingCircuitKey.ChanID.ToUint64(),
htlc.IncomingCircuitKey.HtlcID,
)

c.htlcToAmt.Store(htlcIDStr, htlc.AmountOutMsat)
}

// UntrackHtlc stops tracking the uniquely identified htlc.
func (c *AssetSalePolicy) UntrackHtlc(htlcIDStr string) {
amt, found := c.htlcToAmt.LoadAndDelete(htlcIDStr)
if !found {
return
}

c.CurrentAmountMsat -= amt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we definitely need a mutex here. Otherwise another goroutine might read this value while we're waiting on the LoadAndDelete above to complete.

}

// Expiry returns the policy's expiry time as a unix timestamp.
func (c *AssetSalePolicy) Expiry() uint64 {
return c.expiry
Expand Down Expand Up @@ -245,12 +286,20 @@ type AssetPurchasePolicy struct {
// AcceptedQuoteId is the ID of the accepted quote.
AcceptedQuoteId rfqmsg.ID

// CurrentAssetAmountMsat is the total amount that is held currently in
// accepted htlcs.
CurrentAmountMsat lnwire.MilliSatoshi

// BidAssetRate is the quote's asset to BTC conversion rate.
BidAssetRate rfqmath.BigIntFixedPoint

// PaymentMaxAmt is the maximum agreed BTC payment.
PaymentMaxAmt lnwire.MilliSatoshi

// htlcToAmt maps the unique htlc identifiers to the effective amount
// that they carry.
htlcToAmt lnutils.SyncMap[string, lnwire.MilliSatoshi]

// expiry is the policy's expiry unix timestamp in seconds after which
// the policy is no longer valid.
expiry uint64
Expand Down Expand Up @@ -321,7 +370,7 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance(

// Ensure that the outbound HTLC amount is less than the maximum agreed
// BTC payment.
if htlc.AmountOutMsat > c.PaymentMaxAmt {
if (c.CurrentAmountMsat + htlc.AmountOutMsat) > c.PaymentMaxAmt {
GeorgeTsagk marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("htlc out amount is more than the maximum "+
"agreed BTC payment (htlc_out_msat=%d, "+
"payment_max_amt=%d)", htlc.AmountOutMsat,
Expand All @@ -337,6 +386,31 @@ func (c *AssetPurchasePolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the
// acceptance of future htlcs.
func (c *AssetPurchasePolicy) TrackAcceptedHtlc(
htlc lndclient.InterceptedHtlc) {

c.CurrentAmountMsat += htlc.AmountOutMsat

htlcIDStr := htlcIdentifierStr(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can probably simplify the interface of this TrackAcceptedHtlc a bit by just passing in circuitKey and amount. Then it's IMO a bit more apparent what's happening (looks more like a map add and delete) and the policy doesn't need to know how the key is calculated.

htlc.IncomingCircuitKey.ChanID.ToUint64(),
htlc.IncomingCircuitKey.HtlcID,
)

c.htlcToAmt.Store(htlcIDStr, htlc.AmountOutMsat)
}

// UntrackHtlc stops tracking the uniquely identified htlc.
func (c *AssetPurchasePolicy) UntrackHtlc(htlcIDStr string) {
amt, found := c.htlcToAmt.LoadAndDelete(htlcIDStr)
if !found {
return
}

c.CurrentAmountMsat -= amt
}

// Expiry returns the policy's expiry time as a unix timestamp in seconds.
func (c *AssetPurchasePolicy) Expiry() uint64 {
return c.expiry
Expand Down Expand Up @@ -435,6 +509,25 @@ func (a *AssetForwardPolicy) CheckHtlcCompliance(
return nil
}

// TrackAcceptedHtlc accounts for the newly accepted htlc. This may affect the
// acceptance of future htlcs.
func (a *AssetForwardPolicy) TrackAcceptedHtlc(htlc lndclient.InterceptedHtlc) {
// Track accepted htlc in the incoming policy.
a.incomingPolicy.TrackAcceptedHtlc(htlc)

// Track accepted htlc in the outgoing policy.
a.outgoingPolicy.TrackAcceptedHtlc(htlc)
}

// UntrackHtlc stops tracking the uniquely identified htlc.
func (a *AssetForwardPolicy) UntrackHtlc(htlcIDStr string) {
// Untrack htlc in the incoming policy.
a.incomingPolicy.UntrackHtlc(htlcIDStr)

// Untrack htlc in the outgoing policy.
a.outgoingPolicy.UntrackHtlc(htlcIDStr)
}

// Expiry returns the policy's expiry time as a unix timestamp in seconds. The
// returned expiry time is the earliest expiry time of the incoming and outgoing
// policies.
Expand Down Expand Up @@ -513,6 +606,10 @@ type OrderHandlerCfg struct {

// AcceptHtlcEvents is a channel that receives accepted HTLCs.
AcceptHtlcEvents chan<- *AcceptHtlcEvent

// HtlcSubscriber is a subscriber that is used to retrieve live HTLC
// event updates.
HtlcSubscriber HtlcSubscriber
}

// OrderHandler orchestrates management of accepted quote bundles. It monitors
Expand All @@ -529,6 +626,12 @@ type OrderHandler struct {
// associated asset transaction policies.
policies lnutils.SyncMap[SerialisedScid, Policy]

// htlcToPolicy maps a unique htlc identifier encoded as a string, to
// the policy that applies to it. We need this map because for failed
// HTLCs we don't have the RFQ data available, so we need to cache this
// info.
htlcToPolicy lnutils.SyncMap[string, Policy]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, can use the circuit key as the map key directly. Here a sync map is enough and we don't need a mutex, as it's just a single map that needs to be synced.


// ContextGuard provides a wait group and main quit channel that can be
// used to create guarded contexts.
*fn.ContextGuard
Expand Down Expand Up @@ -592,6 +695,17 @@ func (h *OrderHandler) handleIncomingHtlc(_ context.Context,
}, nil
}

htlcIDStr := htlcIdentifierStr(
htlc.IncomingCircuitKey.ChanID.ToUint64(),
htlc.IncomingCircuitKey.HtlcID,
)

h.htlcToPolicy.Store(htlcIDStr, policy)

// The htlc passed the compliance checks, so now we keep track of the
// accepted htlc.
policy.TrackAcceptedHtlc(htlc)

log.Debug("HTLC complies with policy. Broadcasting accept event.")
h.cfg.AcceptHtlcEvents <- NewAcceptHtlcEvent(htlc, policy)

Expand Down Expand Up @@ -639,12 +753,64 @@ func (h *OrderHandler) mainEventLoop() {
}
}

// subscribeHtlcs subscribes the OrderHandler to HTLC events provided by the lnd
// RPC interface. We use this subscription to track HTLC forwarding failures,
// which we use to performn a live update of our policies.
func (h *OrderHandler) subscribeHtlcs(ctx context.Context) error {
events, chErr, err := h.cfg.HtlcSubscriber.SubscribeHtlcEvents(ctx)
if err != nil {
return err
}

for {
select {
case event := <-events:
// We only care about forwarding events.
if event.GetEventType() != routerrpc.HtlcEvent_FORWARD {
continue
}

// Retrieve the two instances that may be relevant.
failEvent := event.GetForwardFailEvent()
linkFail := event.GetLinkFailEvent()

// Craft the string representation of the unique htlc
// identifier. This is later on used to map to an rfq
// policy.
htlcIDStr := htlcIdentifierStr(
event.IncomingChannelId, event.IncomingHtlcId,
)

switch {
case failEvent != nil:
fallthrough
case linkFail != nil:
// Fetch the policy that is related to this
// htlc.
policy, found :=
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: format as:

				policy, found := h.htlcToPolicy.LoadAndDelete(
					htlcIDStr,
				)

instead.

h.htlcToPolicy.LoadAndDelete(htlcIDStr)

if !found {
continue
}

// Stop tracking this htlc as it failed.
policy.UntrackHtlc(htlcIDStr)
}

case err := <-chErr:
return err

case <-ctx.Done():
return ctx.Err()
}
}
}

// Start starts the service.
func (h *OrderHandler) Start() error {
var startErr error
h.startOnce.Do(func() {
log.Info("Starting subsystem: order handler")

// Start the main event loop in a separate goroutine.
h.Wg.Add(1)
go func() {
Expand All @@ -662,6 +828,20 @@ func (h *OrderHandler) Start() error {

h.mainEventLoop()
}()

// Start the HTLC event subscription loop.
h.Wg.Add(1)
go func() {
defer h.Wg.Done()

ctx, cancel := h.WithCtxQuitNoTimeout()
defer cancel()

err := h.subscribeHtlcs(ctx)
if err != nil {
log.Errorf("htlc subscriber error: %v", err)
}
}()
})

return startErr
Expand Down Expand Up @@ -843,3 +1023,16 @@ type HtlcInterceptor interface {
// to respond to HTLCs.
InterceptHtlcs(context.Context, lndclient.HtlcInterceptHandler) error
}

// HtlcSubscriber is an interface that contains the function necessary for
// retrieving live HTLC event updates.
type HtlcSubscriber interface {
SubscribeHtlcEvents(ctx context.Context) (<-chan *routerrpc.HtlcEvent,
<-chan error, error)
}

// htlcIdentifierStr is a deterministic method that blends the chanID and htlcID
// of an in-flight HTLC to create a string that uniquely identifies it.
func htlcIdentifierStr(chanID, htlcID uint64) string {
return fmt.Sprintf("%v:%v", chanID, htlcID)
}
GeorgeTsagk marked this conversation as resolved.
Show resolved Hide resolved
21 changes: 21 additions & 0 deletions rpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7034,6 +7034,27 @@ func (r *rpcServer) SendPayment(req *tchrpc.SendPaymentRequest,
case len(firstHopRecords) > 0:
// Continue below.

case req.Scid != 0:
var quote rfqmsg.SellAccept
for id, q := range r.cfg.RfqManager.PeerAcceptedSellQuotes() {
if id == rfqmsg.SerialisedScid(req.Scid) {
quote = q
break
}
}

htlc := rfqmsg.NewHtlc(nil, fn.Some(quote.ID))

// We'll now map the HTLC struct into a set of TLV records,
// which we can then encode into the expected map format.
htlcMapRecords, err := tlv.RecordsToMap(htlc.Records())
if err != nil {
return fmt.Errorf("unable to encode records as map: %w",
err)
}

pReq.FirstHopCustomRecords = htlcMapRecords

// The request wants to pay a specific invoice.
case pReq.PaymentRequest != "":
invoice, err := zpay32.Decode(
Expand Down
1 change: 1 addition & 0 deletions tapcfg/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ func genServerConfig(cfg *Config, cfgLogger btclog.Logger,
rfq.ManagerCfg{
PeerMessenger: msgTransportClient,
HtlcInterceptor: lndRouterClient,
HtlcSubscriber: lndRouterClient,
PriceOracle: priceOracle,
ChannelLister: walletAnchor,
AliasManager: lndRouterClient,
Expand Down
Loading