From ca2a9e1d4cb6d32119b6840ef0eff8c70be870b9 Mon Sep 17 00:00:00 2001 From: David Terpay <35130517+davidterpay@users.noreply.github.com> Date: Wed, 24 Jul 2024 15:34:11 -0400 Subject: [PATCH] feat: Refresh prices on heartbeats (#622) --- oracle/types/provider.go | 26 +++++++++++++++++ providers/base/fetch.go | 10 +++++++ providers/base/provider_test.go | 29 +++++++++++++++++++ providers/websockets/huobi/ws_data_handler.go | 6 +++- .../websockets/huobi/ws_data_handler_test.go | 16 ++++++++-- .../websockets/kucoin/ws_data_handler.go | 2 +- .../websockets/kucoin/ws_data_handler_test.go | 18 ++++++++++-- 7 files changed, 100 insertions(+), 7 deletions(-) diff --git a/oracle/types/provider.go b/oracle/types/provider.go index ecc1ecb29..bd0c4f4dc 100644 --- a/oracle/types/provider.go +++ b/oracle/types/provider.go @@ -2,8 +2,12 @@ package types import ( "fmt" + "math/big" "strings" "sync" + "time" + + providertypes "github.com/skip-mev/slinky/providers/types" ) type ( @@ -62,7 +66,9 @@ func (t DefaultProviderTicker) String() string { func NewProviderTickers(tickers ...ProviderTicker) ProviderTickers { cache := make(map[string]ProviderTicker) for _, ticker := range tickers { + cache[strings.ToLower(ticker.GetOffChainTicker())] = ticker cache[ticker.GetOffChainTicker()] = ticker + cache[strings.ToUpper(ticker.GetOffChainTicker())] = ticker } return ProviderTickers{ cache: cache, @@ -87,3 +93,23 @@ func (t *ProviderTickers) Add(ticker ProviderTicker) { t.cache[ticker.GetOffChainTicker()] = ticker t.cache[strings.ToUpper(ticker.GetOffChainTicker())] = ticker } + +// NoPriceChangeResponse is used to handle a message that indicates that the price has not changed. +// In particular, this will update the base provider with the ResponseCodeUnchanged code for all tickers. +func (t *ProviderTickers) NoPriceChangeResponse() PriceResponse { + resolved := make(ResolvedPrices) + seen := make(map[ProviderTicker]struct{}) + for _, ticker := range t.cache { + if _, ok := seen[ticker]; ok { + continue + } + + resolved[ticker] = NewPriceResultWithCode( + big.NewFloat(0), + time.Now().UTC(), + providertypes.ResponseCodeUnchanged, + ) + seen[ticker] = struct{}{} + } + return NewPriceResponse(resolved, nil) +} diff --git a/providers/base/fetch.go b/providers/base/fetch.go index 0d77e326e..2cc577755 100644 --- a/providers/base/fetch.go +++ b/providers/base/fetch.go @@ -211,6 +211,16 @@ func (p *Provider[K, V]) updateData(id K, result providertypes.ResolvedResult[V] current, ok := p.data[id] if !ok { + // Deal with the case where we have no received any updates but may have received a heartbeat. + if result.ResponseCode == providertypes.ResponseCodeUnchanged { + p.logger.Debug( + "result is unchanged but no current data", + zap.String("id", fmt.Sprint(id)), + zap.String("result", result.String()), + ) + return + } + p.data[id] = result return } diff --git a/providers/base/provider_test.go b/providers/base/provider_test.go index 3debedc64..fd3b9bec8 100644 --- a/providers/base/provider_test.go +++ b/providers/base/provider_test.go @@ -595,6 +595,35 @@ func TestWebSocketProvider(t *testing.T) { pairs[0]: big.NewInt(100), }, }, + { + name: "does not update the base provider if the result is unchanged but the cache has no entry for the id", + handler: func() wshandlers.WebSocketQueryHandler[slinkytypes.CurrencyPair, *big.Int] { + // First response is valid and sets the data. + resolved := map[slinkytypes.CurrencyPair]providertypes.ResolvedResult[*big.Int]{ + pairs[0]: { + Value: big.NewInt(100), + Timestamp: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + ResponseCode: providertypes.ResponseCodeUnchanged, + }, + } + + responses := []providertypes.GetResponse[slinkytypes.CurrencyPair, *big.Int]{ + providertypes.NewGetResponse(resolved, nil), + } + + return testutils.CreateWebSocketQueryHandlerWithGetResponses( + t, + time.Second, + logger, + responses, + ) + }, + pairs: []slinkytypes.CurrencyPair{ + pairs[0], + }, + cfg: wsCfg, + expectedPrices: map[slinkytypes.CurrencyPair]*big.Int{}, + }, } for _, tc := range testCases { diff --git a/providers/websockets/huobi/ws_data_handler.go b/providers/websockets/huobi/ws_data_handler.go index 68b96a327..271a74b21 100644 --- a/providers/websockets/huobi/ws_data_handler.go +++ b/providers/websockets/huobi/ws_data_handler.go @@ -89,7 +89,11 @@ func (h *WebSocketHandler) HandleMessage( if err := json.Unmarshal(uncompressed.Bytes(), &pingMessage); err == nil && pingMessage.Ping != 0 { h.logger.Debug("received ping message") updateMessage, err := NewPongMessage(pingMessage) - return resp, updateMessage, err + + // The receipt of a ping message means that the connection is still alive and that all market's corresponding + // to the tickers subscribed to are still being tracked. Therefore, the response can include a message to let + // the provider know that market prices are still valid. + return h.cache.NoPriceChangeResponse(), updateMessage, err } // attempt to unmarshal to subscription response message and check if field values are not nil diff --git a/providers/websockets/huobi/ws_data_handler_test.go b/providers/websockets/huobi/ws_data_handler_test.go index b96ccd676..c12c9d6f5 100644 --- a/providers/websockets/huobi/ws_data_handler_test.go +++ b/providers/websockets/huobi/ws_data_handler_test.go @@ -6,6 +6,8 @@ import ( "math/big" "testing" + providertypes "github.com/skip-mev/slinky/providers/types" + "github.com/klauspost/compress/gzip" "github.com/stretchr/testify/require" @@ -26,7 +28,7 @@ var ( logger = zap.NewExample() ) -func TestHandlerMessage(t *testing.T) { +func TestHandleMessage(t *testing.T) { testCases := []struct { name string msg func() []byte @@ -205,7 +207,16 @@ func TestHandlerMessage(t *testing.T) { return buf.Bytes() }, resp: types.NewPriceResponse( - types.ResolvedPrices{}, + types.ResolvedPrices{ + btcusdt: { + Value: big.NewFloat(0), + ResponseCode: providertypes.ResponseCodeUnchanged, + }, + ethusdt: { + Value: big.NewFloat(0), + ResponseCode: providertypes.ResponseCodeUnchanged, + }, + }, types.UnResolvedPrices{}, ), updateMessage: func() []handlers.WebsocketEncodedMessage { @@ -267,6 +278,7 @@ func TestHandlerMessage(t *testing.T) { for cp, result := range tc.resp.Resolved { require.Contains(t, resp.Resolved, cp) require.Equal(t, result.Value, resp.Resolved[cp].Value) + require.Equal(t, result.ResponseCode, resp.Resolved[cp].ResponseCode) } for cp := range tc.resp.UnResolved { diff --git a/providers/websockets/kucoin/ws_data_handler.go b/providers/websockets/kucoin/ws_data_handler.go index 0f15e02e2..fe066c182 100644 --- a/providers/websockets/kucoin/ws_data_handler.go +++ b/providers/websockets/kucoin/ws_data_handler.go @@ -81,7 +81,7 @@ func (h *WebSocketHandler) HandleMessage( return resp, nil, nil case PongMessage: h.logger.Debug("received pong message") - return resp, nil, nil + return h.cache.NoPriceChangeResponse(), nil, nil case AckMessage: h.logger.Debug("received ack message; markets were successfully subscribed to") return resp, nil, nil diff --git a/providers/websockets/kucoin/ws_data_handler_test.go b/providers/websockets/kucoin/ws_data_handler_test.go index abaebc2bb..d6e92a513 100644 --- a/providers/websockets/kucoin/ws_data_handler_test.go +++ b/providers/websockets/kucoin/ws_data_handler_test.go @@ -68,7 +68,18 @@ func TestHandleMessage(t *testing.T) { "type": "pong" }`) }, - resp: types.PriceResponse{}, + resp: types.PriceResponse{ + Resolved: types.ResolvedPrices{ + btcusdt: { + Value: big.NewFloat(0), + ResponseCode: providertypes.ResponseCodeUnchanged, + }, + ethusdt: { + Value: big.NewFloat(0), + ResponseCode: providertypes.ResponseCodeUnchanged, + }, + }, + }, updateMsg: func() []handlers.WebsocketEncodedMessage { return nil }, expectedErr: false, }, @@ -292,14 +303,15 @@ func TestHandleMessage(t *testing.T) { require.NoError(t, err) // The response should contain a single resolved price update. - require.LessOrEqual(t, len(resp.Resolved), 1) - require.LessOrEqual(t, len(resp.UnResolved), 1) + require.LessOrEqual(t, len(resp.Resolved), 2) + require.LessOrEqual(t, len(resp.UnResolved), 2) require.Equal(t, tc.updateMsg(), updateMsg) for cp, result := range tc.resp.Resolved { require.Contains(t, resp.Resolved, cp) require.Equal(t, result.Value.SetPrec(18), resp.Resolved[cp].Value.SetPrec(18)) + require.Equal(t, result.ResponseCode, resp.Resolved[cp].ResponseCode) } for cp := range tc.resp.UnResolved {