Skip to content

Commit

Permalink
refactor: websocket controller creates a separate connection per pair (
Browse files Browse the repository at this point in the history
…#1773)

* websocket controller refactor

* Fix unit tests

* fix lint

* pong handling
  • Loading branch information
rbajollari authored Feb 3, 2023
1 parent 075b493 commit 0aba4fd
Show file tree
Hide file tree
Showing 18 changed files with 258 additions and 215 deletions.
10 changes: 0 additions & 10 deletions price-feeder/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,6 @@ func ParseConfig(configPath string) (Config, error) {
}
}

gatePairs := []string{}
for base, providers := range pairs {
if _, ok := providers[provider.ProviderGate]; ok {
gatePairs = append(gatePairs, base)
}
}
if len(gatePairs) > 1 {
return cfg, fmt.Errorf("gate provider does not support multiple pairs: %v", gatePairs)
}

for _, deviation := range cfg.Deviations {
threshold, err := sdk.NewDecFromStr(deviation.Threshold)
if err != nil {
Expand Down
8 changes: 2 additions & 6 deletions price-feeder/oracle/oracle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ func (m mockProvider) GetCandlePrices(_ ...types.CurrencyPair) (map[string][]typ
return candles, nil
}

func (m mockProvider) SubscribeCurrencyPairs(_ ...types.CurrencyPair) error {
return nil
}
func (m mockProvider) SubscribeCurrencyPairs(_ ...types.CurrencyPair) {}

func (m mockProvider) GetAvailablePairs() (map[string]struct{}, error) {
return map[string]struct{}{}, nil
Expand All @@ -58,9 +56,7 @@ func (m failingProvider) GetCandlePrices(_ ...types.CurrencyPair) (map[string][]
return nil, fmt.Errorf("unable to get candle prices")
}

func (m failingProvider) SubscribeCurrencyPairs(_ ...types.CurrencyPair) error {
return nil
}
func (m failingProvider) SubscribeCurrencyPairs(_ ...types.CurrencyPair) {}

func (m failingProvider) GetAvailablePairs() (map[string]struct{}, error) {
return map[string]struct{}{}, nil
Expand Down
18 changes: 10 additions & 8 deletions price-feeder/oracle/provider/binance.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,15 @@ func NewBinanceProvider(

provider.wsc = NewWebsocketController(
ctx,
ProviderBinance,
endpoints.Name,
wsURL,
provider.getSubscriptionMsgs(pairs...),
provider.messageReceived,
disabledPingDuration,
websocket.PingMessage,
binanceLogger,
)
go provider.wsc.Start()
provider.wsc.StartConnections()

return provider, nil
}
Expand All @@ -156,7 +156,7 @@ func (p *BinanceProvider) getSubscriptionMsgs(cps ...types.CurrencyPair) []inter

// SubscribeCurrencyPairs sends the new subscription messages to the websocket
// and adds them to the providers subscribedPairs array
func (p *BinanceProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
func (p *BinanceProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -168,11 +168,13 @@ func (p *BinanceProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) erro
}

newSubscriptionMsgs := p.getSubscriptionMsgs(newPairs...)
if err := p.wsc.AddSubscriptionMsgs(newSubscriptionMsgs); err != nil {
return err
}
p.wsc.AddWebsocketConnection(
newSubscriptionMsgs,
p.messageReceived,
disabledPingDuration,
websocket.PingMessage,
)
p.setSubscribedPairs(newPairs...)
return nil
}

// GetTickerPrices returns the tickerPrices based on the provided pairs.
Expand Down Expand Up @@ -267,7 +269,7 @@ func (p *BinanceProvider) getCandlePrices(key string) ([]types.CandlePrice, erro
return candleList, nil
}

func (p *BinanceProvider) messageReceived(_ int, bz []byte) {
func (p *BinanceProvider) messageReceived(_ int, _ *WebsocketConnection, bz []byte) {
var (
tickerResp BinanceTicker
tickerErr error
Expand Down
18 changes: 10 additions & 8 deletions price-feeder/oracle/provider/bitget.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ func NewBitgetProvider(

provider.wsc = NewWebsocketController(
ctx,
ProviderBitget,
endpoints.Name,
wsURL,
provider.getSubscriptionMsgs(pairs...),
provider.messageReceived,
defaultPingDuration,
websocket.TextMessage,
bitgetLogger,
)
go provider.wsc.Start()
provider.wsc.StartConnections()

return provider, nil
}
Expand All @@ -163,7 +163,7 @@ func (p *BitgetProvider) getSubscriptionMsgs(cps ...types.CurrencyPair) []interf

// SubscribeCurrencyPairs sends the new subscription messages to the websocket
// and adds them to the providers subscribedPairs array
func (p *BitgetProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
func (p *BitgetProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -175,11 +175,13 @@ func (p *BitgetProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error
}

newSubscriptionMsgs := p.getSubscriptionMsgs(newPairs...)
if err := p.wsc.AddSubscriptionMsgs(newSubscriptionMsgs); err != nil {
return err
}
p.wsc.AddWebsocketConnection(
newSubscriptionMsgs,
p.messageReceived,
defaultPingDuration,
websocket.PingMessage,
)
p.setSubscribedPairs(newPairs...)
return nil
}

// GetTickerPrices returns the tickerPrices based on the provided pairs.
Expand Down Expand Up @@ -233,7 +235,7 @@ func (p *BitgetProvider) GetCandlePrices(pairs ...types.CurrencyPair) (map[strin
}

// messageReceived handles the received data from the Bitget websocket.
func (p *BitgetProvider) messageReceived(_ int, bz []byte) {
func (p *BitgetProvider) messageReceived(_ int, _ *WebsocketConnection, bz []byte) {
var (
tickerResp BitgetTicker
tickerErr error
Expand Down
18 changes: 10 additions & 8 deletions price-feeder/oracle/provider/coinbase.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,15 @@ func NewCoinbaseProvider(

provider.wsc = NewWebsocketController(
ctx,
ProviderCoinbase,
endpoints.Name,
wsURL,
provider.getSubscriptionMsgs(pairs...),
provider.messageReceived,
defaultPingDuration,
websocket.PingMessage,
coinbaseLogger,
)
go provider.wsc.Start()
provider.wsc.StartConnections()

return provider, nil
}
Expand All @@ -153,7 +153,7 @@ func (p *CoinbaseProvider) getSubscriptionMsgs(cps ...types.CurrencyPair) []inte

// SubscribeCurrencyPairs sends the new subscription messages to the websocket
// and adds them to the providers subscribedPairs array
func (p *CoinbaseProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
func (p *CoinbaseProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -165,11 +165,13 @@ func (p *CoinbaseProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) err
}

newSubscriptionMsgs := p.getSubscriptionMsgs(newPairs...)
if err := p.wsc.AddSubscriptionMsgs(newSubscriptionMsgs); err != nil {
return err
}
p.wsc.AddWebsocketConnection(
newSubscriptionMsgs,
p.messageReceived,
defaultPingDuration,
websocket.PingMessage,
)
p.setSubscribedPairs(newPairs...)
return nil
}

// GetTickerPrices returns the tickerPrices based on the provided pairs.
Expand Down Expand Up @@ -327,7 +329,7 @@ func (p *CoinbaseProvider) getTradePrices(key string) ([]CoinbaseTrade, error) {
return trades, nil
}

func (p *CoinbaseProvider) messageReceived(_ int, bz []byte) {
func (p *CoinbaseProvider) messageReceived(_ int, _ *WebsocketConnection, bz []byte) {
var coinbaseTrade CoinbaseTradeResponse
if err := json.Unmarshal(bz, &coinbaseTrade); err != nil {
p.logger.Error().Err(err).Msg("unable to unmarshal response")
Expand Down
26 changes: 14 additions & 12 deletions price-feeder/oracle/provider/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ func NewCryptoProvider(

provider.wsc = NewWebsocketController(
ctx,
ProviderCrypto,
endpoints.Name,
wsURL,
provider.getSubscriptionMsgs(pairs...),
provider.messageReceived,
disabledPingDuration,
websocket.PingMessage,
cryptoLogger,
)
go provider.wsc.Start()
provider.wsc.StartConnections()

return provider, nil
}
Expand All @@ -166,7 +166,7 @@ func (p *CryptoProvider) getSubscriptionMsgs(cps ...types.CurrencyPair) []interf

// SubscribeCurrencyPairs sends the new subscription messages to the websocket
// and adds them to the providers subscribedPairs array
func (p *CryptoProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
func (p *CryptoProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -178,11 +178,13 @@ func (p *CryptoProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error
}

newSubscriptionMsgs := p.getSubscriptionMsgs(newPairs...)
if err := p.wsc.AddSubscriptionMsgs(newSubscriptionMsgs); err != nil {
return err
}
p.wsc.AddWebsocketConnection(
newSubscriptionMsgs,
p.messageReceived,
disabledPingDuration,
websocket.PingMessage,
)
p.setSubscribedPairs(newPairs...)
return nil
}

// GetTickerPrices returns the tickerPrices based on the provided pairs.
Expand Down Expand Up @@ -272,7 +274,7 @@ func (p *CryptoProvider) getCandlePrices(key string) ([]types.CandlePrice, error
return candleList, nil
}

func (p *CryptoProvider) messageReceived(messageType int, bz []byte) {
func (p *CryptoProvider) messageReceived(messageType int, conn *WebsocketConnection, bz []byte) {
if messageType != websocket.TextMessage {
return
}
Expand All @@ -289,7 +291,7 @@ func (p *CryptoProvider) messageReceived(messageType int, bz []byte) {
// sometimes the message received is not a ticker or a candle response.
heartbeatErr = json.Unmarshal(bz, &heartbeatResp)
if heartbeatResp.Method == cryptoHeartbeatMethod {
p.pong(heartbeatResp)
p.pong(conn, heartbeatResp)
return
}

Expand Down Expand Up @@ -325,19 +327,19 @@ func (p *CryptoProvider) messageReceived(messageType int, bz []byte) {
Msg("Error on receive message")
}

// pong return a heartbeat message when a "ping" is received and reset the
// pongReceived return a heartbeat message when a "ping" is received and reset the
// recconnect ticker because the connection is alive. After connected to crypto.com's
// Websocket server, the server will send heartbeat periodically (30s interval).
// When client receives an heartbeat message, it must respond back with the
// public/respond-heartbeat method, using the same matching id,
// within 5 seconds, or the connection will break.
func (p *CryptoProvider) pong(heartbeatResp CryptoHeartbeatResponse) {
func (p *CryptoProvider) pong(conn *WebsocketConnection, heartbeatResp CryptoHeartbeatResponse) {
heartbeatReq := CryptoHeartbeatRequest{
ID: heartbeatResp.ID,
Method: cryptoHeartbeatReqMethod,
}

if err := p.wsc.SendJSON(heartbeatReq); err != nil {
if err := conn.SendJSON(heartbeatReq); err != nil {
p.logger.Err(err).Msg("could not send pong message back")
}
}
Expand Down
18 changes: 10 additions & 8 deletions price-feeder/oracle/provider/gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ func NewGateProvider(

provider.wsc = NewWebsocketController(
ctx,
ProviderGate,
endpoints.Name,
wsURL,
provider.getSubscriptionMsgs(pairs...),
provider.messageReceived,
defaultPingDuration,
websocket.PingMessage,
gateLogger,
)
go provider.wsc.Start()
provider.wsc.StartConnections()

return provider, nil
}
Expand All @@ -162,7 +162,7 @@ func (p *GateProvider) getSubscriptionMsgs(cps ...types.CurrencyPair) []interfac

// SubscribeCurrencyPairs sends the new subscription messages to the websocket
// and adds them to the providers subscribedPairs array
func (p *GateProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
func (p *GateProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) {
p.mtx.Lock()
defer p.mtx.Unlock()

Expand All @@ -174,11 +174,13 @@ func (p *GateProvider) SubscribeCurrencyPairs(cps ...types.CurrencyPair) error {
}

newSubscriptionMsgs := p.getSubscriptionMsgs(newPairs...)
if err := p.wsc.AddSubscriptionMsgs(newSubscriptionMsgs); err != nil {
return err
}
p.wsc.AddWebsocketConnection(
newSubscriptionMsgs,
p.messageReceived,
defaultPingDuration,
websocket.PingMessage,
)
p.setSubscribedPairs(newPairs...)
return nil
}

// GetTickerPrices returns the tickerPrices based on the provided pairs.
Expand Down Expand Up @@ -273,7 +275,7 @@ func (p *GateProvider) getTickerPrice(cp types.CurrencyPair) (types.TickerPrice,
)
}

func (p *GateProvider) messageReceived(_ int, bz []byte) {
func (p *GateProvider) messageReceived(_ int, _ *WebsocketConnection, bz []byte) {
var (
gateEvent GateEvent
gateErr error
Expand Down
Loading

0 comments on commit 0aba4fd

Please sign in to comment.