Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cluster message handling #13

Merged
merged 3 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 183 additions & 102 deletions kit/bridge_south.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"
"sync"
"time"

"github.com/clubpay/ronykit/kit/errors"
"github.com/clubpay/ronykit/kit/utils"
Expand Down Expand Up @@ -44,7 +45,7 @@ type southBridge struct {
l Logger

inProgressMtx utils.SpinLock
inProgress map[string]chan *envelopeCarrier
inProgress map[string]*clusterConn
msgFactories map[string]MessageFactoryFunc
}

Expand Down Expand Up @@ -72,31 +73,109 @@ func (sb *southBridge) OnMessage(data []byte) {
}

sb.wg.Add(1)

switch carrier.Kind {
case incomingCarrier:
go sb.onIncomingMessage(carrier)
default:
conn := sb.getConn(carrier.SessionID)
if conn != nil {
ctx := sb.acquireCtx(conn)
ctx.sb = sb
select {
case conn.carrierChan <- carrier:
default:
sb.eh(ctx, ErrWritingToClusterConnection)
}
sb.releaseCtx(ctx)
}
}

sb.wg.Done()
}

func (sb *southBridge) createSenderConn(
carrier *envelopeCarrier, timeout time.Duration, callbackFn func(*envelopeCarrier),
) *clusterConn {
rxCtx, cancelFn := context.WithCancel(context.Background())
if timeout > 0 {
rxCtx, cancelFn = context.WithTimeout(rxCtx, timeout)
}

conn := &clusterConn{
cb: sb.cb,
ctx: rxCtx,
cf: cancelFn,
callbackFn: callbackFn,
cluster: sb.cb,
originID: carrier.OriginID,
sessionID: carrier.SessionID,
serverID: sb.id,
kv: map[string]string{},
wf: sb.writeFunc,
carrierChan: make(chan *envelopeCarrier, 32),
}

sb.inProgressMtx.Lock()
sb.inProgress[carrier.SessionID] = conn
sb.inProgressMtx.Unlock()

go func(c *clusterConn) {
for {
select {
case <-c.ctx.Done():
return
case carrier, ok := <-c.carrierChan:
if !ok {
c.cf()

return
}
switch carrier.Kind {
default:
panic("invalid carrier kind")
case outgoingCarrier:
c.callbackFn(carrier)
case eofCarrier:
sb.inProgressMtx.Lock()
delete(sb.inProgress, c.sessionID)
sb.inProgressMtx.Unlock()

close(c.carrierChan)
}
}
}
}(conn)

return conn
}

func (sb *southBridge) createTargetConn(
carrier *envelopeCarrier,
) *clusterConn {
conn := &clusterConn{
cluster: sb.cb,
originID: carrier.OriginID,
sessionID: carrier.SessionID,
serverID: sb.id,
kv: map[string]string{},
wf: sb.writeFunc,
}
ctx := sb.acquireCtx(conn)
ctx.sb = sb

switch carrier.Kind {
case incomingCarrier:
sb.onIncomingMessage(ctx, carrier)
case outgoingCarrier:
sb.onOutgoingMessage(ctx, carrier)
case eofCarrier:
sb.onEOF(carrier)
}
return conn
}

sb.releaseCtx(ctx)
sb.wg.Done()
func (sb *southBridge) getConn(sessionID string) *clusterConn {
sb.inProgressMtx.Lock()
conn := sb.inProgress[sessionID]
sb.inProgressMtx.Unlock()

return conn
}

func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier) {
func (sb *southBridge) onIncomingMessage(carrier *envelopeCarrier) {
conn := sb.createTargetConn(carrier)
ctx := sb.acquireCtx(conn)
ctx.sb = sb
ctx.forwarded = true

msg := sb.msgFactories[carrier.Data.MsgType]()
Expand Down Expand Up @@ -142,48 +221,19 @@ func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier)
if err != nil {
sb.eh(ctx, err)
}
}

func (sb *southBridge) onOutgoingMessage(ctx *Context, carrier *envelopeCarrier) {
sb.inProgressMtx.Lock()
ch, ok := sb.inProgress[carrier.SessionID]
sb.inProgressMtx.Unlock()

if ok {
select {
case ch <- carrier:
default:
sb.eh(ctx, ErrWritingToClusterConnection)
}
}
}

func (sb *southBridge) onEOF(carrier *envelopeCarrier) {
sb.inProgressMtx.Lock()
ch, ok := sb.inProgress[carrier.SessionID]
delete(sb.inProgress, carrier.SessionID)
sb.inProgressMtx.Unlock()
if ok {
close(ch)
}
sb.releaseCtx(ctx)
}

func (sb *southBridge) sendMessage(sessionID string, targetID string, data []byte) (<-chan *envelopeCarrier, error) {
ch := make(chan *envelopeCarrier, 4)
sb.inProgressMtx.Lock()
sb.inProgress[sessionID] = ch
sb.inProgressMtx.Unlock()

err := sb.cb.Publish(targetID, data)
func (sb *southBridge) sendMessage(carrier *envelopeCarrier) error {
err := sb.cb.Publish(carrier.TargetID, carrier.ToJSON())
if err != nil {
sb.inProgressMtx.Lock()
delete(sb.inProgress, sessionID)
delete(sb.inProgress, carrier.SessionID)
sb.inProgressMtx.Unlock()

return nil, err
}

return ch, nil
return err
}

func (sb *southBridge) wrapWithCoordinator(c Contract) Contract {
Expand Down Expand Up @@ -216,53 +266,65 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc {
return
}

err = ctx.executeRemote(
executeRemoteArg{
Target: target,
In: newEnvelopeCarrier(
incomingCarrier,
utils.RandomID(32),
ctx.sb.id,
target,
).FillWithContext(ctx),
OutCallback: func(carrier *envelopeCarrier) {
if carrier.Data == nil {
return
}
f, ok := sb.msgFactories[carrier.Data.MsgType]
if !ok {
return
}

msg := f()
switch msg.(type) {
case RawMessage:
msg = RawMessage(carrier.Data.Msg)
default:
unmarshalEnvelopeCarrier(carrier.Data.Msg, msg)
}

for k, v := range carrier.Data.ConnHdr {
ctx.Conn().Set(k, v)
}

ctx.Out().
SetID(carrier.Data.EnvelopeID).
SetHdrMap(carrier.Data.Hdr).
SetMsg(msg).
Send()
},
},
)

ctx.Error(err)
carrier := newEnvelopeCarrier(
incomingCarrier,
utils.RandomID(32),
ctx.sb.id,
target,
).FillWithContext(ctx)

err = ctx.sb.sendMessage(carrier)
if err != nil {
ctx.Error(err)
ctx.StopExecution()

return
}

conn := sb.createSenderConn(carrier, ctx.rxt, sb.genCallback(ctx))
select {
case <-conn.Done():
ctx.Error(conn.Err())
case <-ctx.ctx.Done():
ctx.Error(ctx.ctx.Err())
}

// We should stop executing next handlers, since our request has been executed on
// a remote machine
ctx.StopExecution()
}
}

func (sb *southBridge) genCallback(ctx *Context) func(carrier *envelopeCarrier) {
return func(carrier *envelopeCarrier) {
if carrier.Data == nil {
return
}
f, ok := sb.msgFactories[carrier.Data.MsgType]
if !ok {
return
}

msg := f()
switch msg.(type) {
case RawMessage:
msg = RawMessage(carrier.Data.Msg)
default:
unmarshalEnvelopeCarrier(carrier.Data.Msg, msg)
}

for k, v := range carrier.Data.ConnHdr {
ctx.Conn().Set(k, v)
}

ctx.Out().
SetID(carrier.Data.EnvelopeID).
SetHdrMap(carrier.Data.Hdr).
SetMsg(msg).
Send()
}
}

func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error {
ec := newEnvelopeCarrier(
outgoingCarrier,
Expand All @@ -276,28 +338,33 @@ func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error {
sb.tp.Inject(e.ctx.ctx, ec.Data)
}

return c.cb.Publish(c.originID, ec.ToJSON())
return c.cluster.Publish(c.originID, ec.ToJSON())
}

type clusterConn struct {
sessionID string
originID string
serverID string
cb Cluster
var _ Conn = (*clusterConn)(nil)

id uint64
type clusterConn struct {
clientIP string
stream bool
kvMtx sync.Mutex
kv map[string]string

kvMtx sync.Mutex
kv map[string]string
wf func(c *clusterConn, e *Envelope) error
// target
serverID string
sessionID string
originID string
wf func(c *clusterConn, e *Envelope) error
cluster Cluster

// sender
ctx context.Context //nolint
cf context.CancelFunc
callbackFn func(carrier *envelopeCarrier)
carrierChan chan *envelopeCarrier
}

var _ Conn = (*clusterConn)(nil)

func (c *clusterConn) ConnID() uint64 {
return c.id
return 0
}

func (c *clusterConn) ClientIP() string {
Expand Down Expand Up @@ -348,6 +415,20 @@ func (c *clusterConn) Keys() []string {
return keys
}

func (c *clusterConn) Done() <-chan struct{} {
return c.ctx.Done()
}

func (c *clusterConn) Err() error {
return c.ctx.Err()
}

func (c *clusterConn) Cancel() {
if c.cf != nil {
c.cf()
}
}

var (
ErrSouthBridgeDisabled = errors.New("south bridge is disabled")
ErrWritingToClusterConnection = errors.New("writing to cluster connection is not possible")
Expand Down
Loading
Loading