Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-Warehime committed Feb 26, 2024
1 parent d8e5e90 commit e0a8c37
Show file tree
Hide file tree
Showing 8 changed files with 579 additions and 23 deletions.
2 changes: 1 addition & 1 deletion protocol/daemons/slinky/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func (c *Client) RunMarketPairFetcher(ctx context.Context, appFlags appflags.Fla
for {
select {
case <-ticker.C:
err := c.marketPairFetcher.FetchIdMappings(ctx)
err = c.marketPairFetcher.FetchIdMappings(ctx)
if err != nil {
c.logger.Error("Failed to run fetch id mappings for slinky daemon", "error", err)
c.ReportFailure(errors.Wrap(err, "failed to run FetchIdMappings for slinky daemon"))
Expand Down
123 changes: 123 additions & 0 deletions protocol/daemons/slinky/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package client_test

import (
"context"
"net"
"sync"
"testing"
"time"

"cosmossdk.io/log"
"github.com/skip-mev/slinky/service/servers/oracle/types"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"

appflags "github.com/dydxprotocol/v4-chain/protocol/app/flags"
daemonflags "github.com/dydxprotocol/v4-chain/protocol/daemons/flags"
daemonserver "github.com/dydxprotocol/v4-chain/protocol/daemons/server"
pricefeed_types "github.com/dydxprotocol/v4-chain/protocol/daemons/server/types/pricefeed"
"github.com/dydxprotocol/v4-chain/protocol/daemons/slinky/client"
daemontypes "github.com/dydxprotocol/v4-chain/protocol/daemons/types"
"github.com/dydxprotocol/v4-chain/protocol/mocks"
"github.com/dydxprotocol/v4-chain/protocol/testutil/appoptions"
pricetypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
)

func TestClientTestSuite(t *testing.T) {
suite.Run(t, &ClientTestSuite{})
}

type ClientTestSuite struct {
suite.Suite
daemonFlags daemonflags.DaemonFlags
appFlags appflags.Flags
daemonServer *daemonserver.Server
exchangePriceCache *pricefeed_types.MarketToExchangePrices
grpcServer *grpc.Server
pricesMockQueryServer *mocks.QueryServer
wg sync.WaitGroup
}

func (c *ClientTestSuite) SetupTest() {
// Setup daemon and grpc servers.
c.daemonFlags = daemonflags.GetDefaultDaemonFlags()
c.appFlags = appflags.GetFlagValuesFromOptions(appoptions.GetDefaultTestAppOptions("", nil))
c.grpcServer = grpc.NewServer()

// Configure and run daemon server.
c.daemonServer = daemonserver.NewServer(
log.NewNopLogger(),
c.grpcServer,
&daemontypes.FileHandlerImpl{},
c.daemonFlags.Shared.SocketAddress,
)

c.pricesMockQueryServer = &mocks.QueryServer{}
pricetypes.RegisterQueryServer(c.grpcServer, c.pricesMockQueryServer)
c.daemonServer.WithPriceFeedMarketToExchangePrices(
pricefeed_types.NewMarketToExchangePrices(5 * time.Second),
)

c.wg.Add(1)
go func() {
defer c.wg.Done()
c.daemonServer.Start()
}()

c.wg.Add(1)
go func() {
defer c.wg.Done()
ls, err := net.Listen("tcp", c.appFlags.GrpcAddress)
c.Require().NoError(err)
_ = c.grpcServer.Serve(ls)
}()
}

func (c *ClientTestSuite) TearDownTest() {
c.daemonServer.Stop()
c.grpcServer.Stop()
c.wg.Wait()
}

func (c *ClientTestSuite) TestClient() {
var cli *client.Client
slinky := mocks.NewOracleClient(c.T())
logger := log.NewTestLogger(c.T())

c.pricesMockQueryServer.On("AllMarketParams", mock.Anything, mock.Anything).
Return(
&pricetypes.QueryAllMarketParamsResponse{
MarketParams: []pricetypes.MarketParam{
{Id: 0, Pair: "FOO-BAR"},
{Id: 1, Pair: "BAR-FOO"},
}},
nil,
)

c.Run("services are all started and call their deps", func() {
slinky.On("Stop").Return(nil)
slinky.On("Start", mock.Anything).Return(nil).Once()
slinky.On("Prices", mock.Anything, mock.Anything).
Return(&types.QueryPricesResponse{
Prices: map[string]string{
"FOO/BAR": "100000000000",
},
Timestamp: time.Now(),
}, nil)
client.SlinkyPriceFetchDelay = time.Millisecond
client.SlinkyMarketParamFetchDelay = time.Millisecond
cli = client.StartNewClient(
context.Background(),
slinky,
&daemontypes.GrpcClientImpl{},
c.daemonFlags,
c.appFlags,
logger,
)
// Need to wait until a single cycle is done
time.Sleep(time.Millisecond * 20)
cli.Stop()
c.Require().NoError(cli.HealthCheck())
})
}
36 changes: 17 additions & 19 deletions protocol/daemons/slinky/client/market_pair_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ package client

import (
"context"
"cosmossdk.io/log"
"fmt"
"sync"

"cosmossdk.io/log"
"google.golang.org/grpc"

oracletypes "github.com/skip-mev/slinky/x/oracle/types"

appflags "github.com/dydxprotocol/v4-chain/protocol/app/flags"
daemontypes "github.com/dydxprotocol/v4-chain/protocol/daemons/types"
"github.com/dydxprotocol/v4-chain/protocol/lib/slinky"
pricetypes "github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
oracletypes "github.com/skip-mev/slinky/x/oracle/types"
"google.golang.org/grpc"
"sync"
)

// MarketPairFetcher is a lightweight process run in a goroutine by the slinky client.
Expand All @@ -26,9 +29,9 @@ type MarketPairFetcher interface {

// MarketPairFetcherImpl implements the MarketPairFetcher interface.
type MarketPairFetcherImpl struct {
logger log.Logger
queryConn *grpc.ClientConn
pricesQueryClient pricetypes.QueryClient
Logger log.Logger
QueryConn *grpc.ClientConn
PricesQueryClient pricetypes.QueryClient

// compatMappings stores a mapping between CurrencyPair and the corresponding market(param|price) ID
compatMappings map[oracletypes.CurrencyPair]uint32
Expand All @@ -37,7 +40,7 @@ type MarketPairFetcherImpl struct {

func NewMarketPairFetcher(logger log.Logger) MarketPairFetcher {
return &MarketPairFetcherImpl{
logger: logger,
Logger: logger,
compatMappings: make(map[oracletypes.CurrencyPair]uint32),
}
}
Expand All @@ -50,21 +53,21 @@ func (m *MarketPairFetcherImpl) Start(
// Create the query client connection
queryConn, err := grpcClient.NewTcpConnection(ctx, appFlags.GrpcAddress)
if err != nil {
m.logger.Error(
m.Logger.Error(
"Failed to establish gRPC connection",
"gRPC address", appFlags.GrpcAddress,
"error", err,
)
return err
}
m.pricesQueryClient = pricetypes.NewQueryClient(queryConn)
m.PricesQueryClient = pricetypes.NewQueryClient(queryConn)
return nil
}

// Stop closes all existing connections.
func (m *MarketPairFetcherImpl) Stop() {
if m.queryConn != nil {
_ = m.queryConn.Close()
if m.QueryConn != nil {
_ = m.QueryConn.Close()
}
}

Expand All @@ -85,22 +88,17 @@ func (m *MarketPairFetcherImpl) GetIDForPair(cp oracletypes.CurrencyPair) (uint3
// CurrencyPair and MarketParam ID.
func (m *MarketPairFetcherImpl) FetchIdMappings(ctx context.Context) error {
// fetch all market params
resp, err := m.pricesQueryClient.AllMarketParams(ctx, &pricetypes.QueryAllMarketParamsRequest{})
resp, err := m.PricesQueryClient.AllMarketParams(ctx, &pricetypes.QueryAllMarketParamsRequest{})
if err != nil {
return err
}
// Exit early if there are no changes
// This assumes there will not be an addition and a removal of markets in the same block
if len(resp.MarketParams) == len(m.compatMappings) {
return nil
}
var compatMappings = make(map[oracletypes.CurrencyPair]uint32, len(resp.MarketParams))
for _, mp := range resp.MarketParams {
cp, err := slinky.MarketPairToCurrencyPair(mp.Pair)
if err != nil {
return err
}
m.logger.Info("Mapped market to pair", "market id", mp.Id, "currency pair", cp.String())
m.Logger.Debug("Mapped market to pair", "market id", mp.Id, "currency pair", cp.String())
compatMappings[cp] = mp.Id
}
m.compatMu.Lock()
Expand Down
92 changes: 92 additions & 0 deletions protocol/daemons/slinky/client/market_pair_fetcher_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package client_test

import (
"context"
"fmt"
"testing"

"cosmossdk.io/log"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"

slinkytypes "github.com/skip-mev/slinky/x/oracle/types"

"github.com/dydxprotocol/v4-chain/protocol/daemons/slinky/client"
"github.com/dydxprotocol/v4-chain/protocol/mocks"
"github.com/dydxprotocol/v4-chain/protocol/x/prices/types"
)

func TestMarketPairFetcher(t *testing.T) {
logger := log.NewTestLogger(t)
queryClient := mocks.NewQueryClient(t)
fetcher := client.MarketPairFetcherImpl{
Logger: logger,
PricesQueryClient: queryClient,
}
asset0 := "FOO"
asset1 := "BAR"
pair0 := types.MarketParam{Id: 0, Pair: fmt.Sprintf("%s-%s", asset0, asset1)}
pair1 := types.MarketParam{Id: 1, Pair: fmt.Sprintf("%s-%s", asset1, asset0)}
invalidPair := types.MarketParam{Id: 2, Pair: "foobar"}

t.Run("caches and returns valid pairs", func(t *testing.T) {
queryClient.
On("AllMarketParams", mock.Anything, mock.Anything).
Return(
&types.QueryAllMarketParamsResponse{
MarketParams: []types.MarketParam{
pair0,
pair1,
}},
nil,
).Once()
err := fetcher.FetchIdMappings(context.Background())
require.NoError(t, err)
id, err := fetcher.GetIDForPair(slinkytypes.CurrencyPair{Base: asset0, Quote: asset1})
require.NoError(t, err)
require.Equal(t, pair0.Id, id)
id, err = fetcher.GetIDForPair(slinkytypes.CurrencyPair{Base: asset1, Quote: asset0})
require.NoError(t, err)
require.Equal(t, pair1.Id, id)
})

t.Run("errors on fetch non-cached pair", func(t *testing.T) {
queryClient.
On("AllMarketParams", mock.Anything, mock.Anything).
Return(
&types.QueryAllMarketParamsResponse{
MarketParams: []types.MarketParam{}},
nil,
).Once()
err := fetcher.FetchIdMappings(context.Background())
require.NoError(t, err)
_, err = fetcher.GetIDForPair(slinkytypes.CurrencyPair{Base: asset0, Quote: asset1})
require.Error(t, err, fmt.Errorf("pair %s/%s not found in compatMappings", asset0, asset1))
})

t.Run("fails on fetching invalid pairs", func(t *testing.T) {
queryClient.
On("AllMarketParams", mock.Anything, mock.Anything).
Return(
&types.QueryAllMarketParamsResponse{
MarketParams: []types.MarketParam{
invalidPair,
}},
nil,
).Once()
err := fetcher.FetchIdMappings(context.Background())
require.Error(t, err, "incorrectly formatted CurrencyPair: foobar")
})

t.Run("fails on prices query error", func(t *testing.T) {
queryClient.
On("AllMarketParams", mock.Anything, mock.Anything).
Return(
&types.QueryAllMarketParamsResponse{},
fmt.Errorf("test error"),
).Once()
err := fetcher.FetchIdMappings(context.Background())
require.Error(t, err, "test error")
})

}
5 changes: 2 additions & 3 deletions protocol/daemons/slinky/client/price_fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ package client
import (
"context"
"fmt"
"github.com/skip-mev/slinky/service/servers/oracle/types"
"strconv"

"cosmossdk.io/log"
"google.golang.org/grpc"

oracleclient "github.com/skip-mev/slinky/service/clients/oracle"
"github.com/skip-mev/slinky/service/servers/oracle/types"
oracletypes "github.com/skip-mev/slinky/x/oracle/types"

"github.com/dydxprotocol/v4-chain/protocol/daemons/pricefeed/api"
Expand Down Expand Up @@ -129,15 +129,14 @@ func (p *PriceFetcherImpl) FetchPrices(ctx context.Context) error {
// send the updates to the app's price-feed service -> these will then be piped to the
// x/prices indexPriceCache via the pricefeed service
if p.priceFeedServiceClient == nil {
p.logger.Error("nil price feed service client")
return fmt.Errorf("price feed service client was not initialized in slinky client")
}
if len(updates) == 0 {
p.logger.Info("Slinky returned 0 valid market price updates")
return nil
}
_, err = p.priceFeedServiceClient.UpdateMarketPrices(ctx, &api.UpdateMarketPricesRequest{MarketPriceUpdates: updates})
if err != nil {
p.logger.Error(err.Error())
return err
}
return nil
Expand Down
Loading

0 comments on commit e0a8c37

Please sign in to comment.