Skip to content

Commit

Permalink
Merge pull request #13
Browse files Browse the repository at this point in the history
* fix: incorrect lock and panic in out of order cluster message

* fix: southbridge deadlock and race condition

* [p2pCluster] hackfix to improve message orders, by introducing delay …
  • Loading branch information
ehsannm authored Jul 21, 2024
1 parent 35991ea commit 4ff6257
Show file tree
Hide file tree
Showing 13 changed files with 417 additions and 232 deletions.
285 changes: 183 additions & 102 deletions kit/bridge_south.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"sync"
"time"

"github.com/clubpay/ronykit/kit/errors"
"github.com/clubpay/ronykit/kit/utils"
Expand Down Expand Up @@ -44,7 +45,7 @@ type southBridge struct {
l Logger

inProgressMtx utils.SpinLock
inProgress map[string]chan *envelopeCarrier
inProgress map[string]*clusterConn
msgFactories map[string]MessageFactoryFunc
}

Expand Down Expand Up @@ -72,31 +73,109 @@ func (sb *southBridge) OnMessage(data []byte) {
}

sb.wg.Add(1)

switch carrier.Kind {
case incomingCarrier:
go sb.onIncomingMessage(carrier)
default:
conn := sb.getConn(carrier.SessionID)
if conn != nil {
ctx := sb.acquireCtx(conn)
ctx.sb = sb
select {
case conn.carrierChan <- carrier:
default:
sb.eh(ctx, ErrWritingToClusterConnection)
}
sb.releaseCtx(ctx)
}
}

sb.wg.Done()
}

func (sb *southBridge) createSenderConn(
carrier *envelopeCarrier, timeout time.Duration, callbackFn func(*envelopeCarrier),
) *clusterConn {
rxCtx, cancelFn := context.WithCancel(context.Background())
if timeout > 0 {
rxCtx, cancelFn = context.WithTimeout(rxCtx, timeout)
}

conn := &clusterConn{
cb: sb.cb,
ctx: rxCtx,
cf: cancelFn,
callbackFn: callbackFn,
cluster: sb.cb,
originID: carrier.OriginID,
sessionID: carrier.SessionID,
serverID: sb.id,
kv: map[string]string{},
wf: sb.writeFunc,
carrierChan: make(chan *envelopeCarrier, 32),
}

sb.inProgressMtx.Lock()
sb.inProgress[carrier.SessionID] = conn
sb.inProgressMtx.Unlock()

go func(c *clusterConn) {
for {
select {
case <-c.ctx.Done():
return
case carrier, ok := <-c.carrierChan:
if !ok {
c.cf()

return
}
switch carrier.Kind {
default:
panic("invalid carrier kind")
case outgoingCarrier:
c.callbackFn(carrier)
case eofCarrier:
sb.inProgressMtx.Lock()
delete(sb.inProgress, c.sessionID)
sb.inProgressMtx.Unlock()

close(c.carrierChan)
}
}
}
}(conn)

return conn
}

func (sb *southBridge) createTargetConn(
carrier *envelopeCarrier,
) *clusterConn {
conn := &clusterConn{
cluster: sb.cb,
originID: carrier.OriginID,
sessionID: carrier.SessionID,
serverID: sb.id,
kv: map[string]string{},
wf: sb.writeFunc,
}
ctx := sb.acquireCtx(conn)
ctx.sb = sb

switch carrier.Kind {
case incomingCarrier:
sb.onIncomingMessage(ctx, carrier)
case outgoingCarrier:
sb.onOutgoingMessage(ctx, carrier)
case eofCarrier:
sb.onEOF(carrier)
}
return conn
}

sb.releaseCtx(ctx)
sb.wg.Done()
func (sb *southBridge) getConn(sessionID string) *clusterConn {
sb.inProgressMtx.Lock()
conn := sb.inProgress[sessionID]
sb.inProgressMtx.Unlock()

return conn
}

func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier) {
func (sb *southBridge) onIncomingMessage(carrier *envelopeCarrier) {
conn := sb.createTargetConn(carrier)
ctx := sb.acquireCtx(conn)
ctx.sb = sb
ctx.forwarded = true

msg := sb.msgFactories[carrier.Data.MsgType]()
Expand Down Expand Up @@ -142,48 +221,19 @@ func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier)
if err != nil {
sb.eh(ctx, err)
}
}

func (sb *southBridge) onOutgoingMessage(ctx *Context, carrier *envelopeCarrier) {
sb.inProgressMtx.Lock()
ch, ok := sb.inProgress[carrier.SessionID]
sb.inProgressMtx.Unlock()

if ok {
select {
case ch <- carrier:
default:
sb.eh(ctx, ErrWritingToClusterConnection)
}
}
}

func (sb *southBridge) onEOF(carrier *envelopeCarrier) {
sb.inProgressMtx.Lock()
ch, ok := sb.inProgress[carrier.SessionID]
delete(sb.inProgress, carrier.SessionID)
sb.inProgressMtx.Unlock()
if ok {
close(ch)
}
sb.releaseCtx(ctx)
}

func (sb *southBridge) sendMessage(sessionID string, targetID string, data []byte) (<-chan *envelopeCarrier, error) {
ch := make(chan *envelopeCarrier, 4)
sb.inProgressMtx.Lock()
sb.inProgress[sessionID] = ch
sb.inProgressMtx.Unlock()

err := sb.cb.Publish(targetID, data)
func (sb *southBridge) sendMessage(carrier *envelopeCarrier) error {
err := sb.cb.Publish(carrier.TargetID, carrier.ToJSON())
if err != nil {
sb.inProgressMtx.Lock()
delete(sb.inProgress, sessionID)
delete(sb.inProgress, carrier.SessionID)
sb.inProgressMtx.Unlock()

return nil, err
}

return ch, nil
return err
}

func (sb *southBridge) wrapWithCoordinator(c Contract) Contract {
Expand Down Expand Up @@ -216,53 +266,65 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc {
return
}

err = ctx.executeRemote(
executeRemoteArg{
Target: target,
In: newEnvelopeCarrier(
incomingCarrier,
utils.RandomID(32),
ctx.sb.id,
target,
).FillWithContext(ctx),
OutCallback: func(carrier *envelopeCarrier) {
if carrier.Data == nil {
return
}
f, ok := sb.msgFactories[carrier.Data.MsgType]
if !ok {
return
}

msg := f()
switch msg.(type) {
case RawMessage:
msg = RawMessage(carrier.Data.Msg)
default:
unmarshalEnvelopeCarrier(carrier.Data.Msg, msg)
}

for k, v := range carrier.Data.ConnHdr {
ctx.Conn().Set(k, v)
}

ctx.Out().
SetID(carrier.Data.EnvelopeID).
SetHdrMap(carrier.Data.Hdr).
SetMsg(msg).
Send()
},
},
)

ctx.Error(err)
carrier := newEnvelopeCarrier(
incomingCarrier,
utils.RandomID(32),
ctx.sb.id,
target,
).FillWithContext(ctx)

err = ctx.sb.sendMessage(carrier)
if err != nil {
ctx.Error(err)
ctx.StopExecution()

return
}

conn := sb.createSenderConn(carrier, ctx.rxt, sb.genCallback(ctx))
select {
case <-conn.Done():
ctx.Error(conn.Err())
case <-ctx.ctx.Done():
ctx.Error(ctx.ctx.Err())
}

// We should stop executing next handlers, since our request has been executed on
// a remote machine
ctx.StopExecution()
}
}

func (sb *southBridge) genCallback(ctx *Context) func(carrier *envelopeCarrier) {
return func(carrier *envelopeCarrier) {
if carrier.Data == nil {
return
}
f, ok := sb.msgFactories[carrier.Data.MsgType]
if !ok {
return
}

msg := f()
switch msg.(type) {
case RawMessage:
msg = RawMessage(carrier.Data.Msg)
default:
unmarshalEnvelopeCarrier(carrier.Data.Msg, msg)
}

for k, v := range carrier.Data.ConnHdr {
ctx.Conn().Set(k, v)
}

ctx.Out().
SetID(carrier.Data.EnvelopeID).
SetHdrMap(carrier.Data.Hdr).
SetMsg(msg).
Send()
}
}

func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error {
ec := newEnvelopeCarrier(
outgoingCarrier,
Expand All @@ -276,28 +338,33 @@ func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error {
sb.tp.Inject(e.ctx.ctx, ec.Data)
}

return c.cb.Publish(c.originID, ec.ToJSON())
return c.cluster.Publish(c.originID, ec.ToJSON())
}

type clusterConn struct {
sessionID string
originID string
serverID string
cb Cluster
var _ Conn = (*clusterConn)(nil)

id uint64
type clusterConn struct {
clientIP string
stream bool
kvMtx sync.Mutex
kv map[string]string

kvMtx sync.Mutex
kv map[string]string
wf func(c *clusterConn, e *Envelope) error
// target
serverID string
sessionID string
originID string
wf func(c *clusterConn, e *Envelope) error
cluster Cluster

// sender
ctx context.Context //nolint
cf context.CancelFunc
callbackFn func(carrier *envelopeCarrier)
carrierChan chan *envelopeCarrier
}

var _ Conn = (*clusterConn)(nil)

func (c *clusterConn) ConnID() uint64 {
return c.id
return 0
}

func (c *clusterConn) ClientIP() string {
Expand Down Expand Up @@ -348,6 +415,20 @@ func (c *clusterConn) Keys() []string {
return keys
}

func (c *clusterConn) Done() <-chan struct{} {
return c.ctx.Done()
}

func (c *clusterConn) Err() error {
return c.ctx.Err()
}

func (c *clusterConn) Cancel() {
if c.cf != nil {
c.cf()
}
}

var (
ErrSouthBridgeDisabled = errors.New("south bridge is disabled")
ErrWritingToClusterConnection = errors.New("writing to cluster connection is not possible")
Expand Down
Loading

0 comments on commit 4ff6257

Please sign in to comment.