From d6173de3c86dd8533d2c31150e4f8d81f6225306 Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Sat, 20 Jul 2024 18:39:30 +0300 Subject: [PATCH 1/3] fix: incorrect lock and panic in out of order cluster message --- kit/bridge_south.go | 3 +- std/clusters/p2pcluster/cluster.go | 13 +++++ std/clusters/p2pcluster/go.mod | 8 +-- std/clusters/p2pcluster/go.sum | 16 +++--- testenv/.golangci.yml | 48 +++++++++++++++++ testenv/cluster_test.go | 86 ++++++++++++++++-------------- testenv/services/key_value.go | 6 --- 7 files changed, 121 insertions(+), 59 deletions(-) create mode 100644 testenv/.golangci.yml diff --git a/kit/bridge_south.go b/kit/bridge_south.go index 29bdda30..cfc16578 100644 --- a/kit/bridge_south.go +++ b/kit/bridge_south.go @@ -147,8 +147,6 @@ func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier) func (sb *southBridge) onOutgoingMessage(ctx *Context, carrier *envelopeCarrier) { sb.inProgressMtx.Lock() ch, ok := sb.inProgress[carrier.SessionID] - sb.inProgressMtx.Unlock() - if ok { select { case ch <- carrier: @@ -156,6 +154,7 @@ func (sb *southBridge) onOutgoingMessage(ctx *Context, carrier *envelopeCarrier) sb.eh(ctx, ErrWritingToClusterConnection) } } + sb.inProgressMtx.Unlock() } func (sb *southBridge) onEOF(carrier *envelopeCarrier) { diff --git a/std/clusters/p2pcluster/cluster.go b/std/clusters/p2pcluster/cluster.go index bef9dd6f..97a78a37 100644 --- a/std/clusters/p2pcluster/cluster.go +++ b/std/clusters/p2pcluster/cluster.go @@ -2,6 +2,7 @@ package p2pcluster import ( "context" + "errors" "fmt" "time" @@ -138,6 +139,12 @@ func (c *cluster) startBroadcast(ctx context.Context) error { for { msg, err := broadcastSub.Next(ctx) if err != nil { + if errors.Is(err, context.Canceled) { + c.log.Debugf("[p2pCluster] broadcast stopped") + + return + } + c.log.Errorf("[p2pCluster] failed to receive broadcast message: %v", err) time.Sleep(time.Second) @@ -173,6 +180,12 @@ func (c *cluster) startMyTopic(ctx context.Context) error { for { msg, err := sub.Next(ctx) if err != nil { + if errors.Is(err, context.Canceled) { + return + } + + c.log.Errorf("[p2pCluster] failed to receive message from topic[%s]: %v", topic, err) + continue } diff --git a/std/clusters/p2pcluster/go.mod b/std/clusters/p2pcluster/go.mod index 6f6b9823..a9704c03 100644 --- a/std/clusters/p2pcluster/go.mod +++ b/std/clusters/p2pcluster/go.mod @@ -4,7 +4,7 @@ go 1.21 require ( github.com/clubpay/ronykit/kit v0.16.6 - github.com/libp2p/go-libp2p v0.34.1 + github.com/libp2p/go-libp2p v0.35.4 github.com/libp2p/go-libp2p-pubsub v0.11.0 ) @@ -30,7 +30,7 @@ require ( github.com/google/gopacket v1.1.19 // indirect github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.1 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect github.com/huin/goupnp v1.3.0 // indirect github.com/ipfs/go-cid v0.4.1 // indirect @@ -73,7 +73,7 @@ require ( github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect github.com/pion/datachannel v1.5.6 // indirect github.com/pion/dtls/v2 v2.2.11 // indirect - github.com/pion/ice/v2 v2.3.24 // indirect + github.com/pion/ice/v2 v2.3.25 // indirect github.com/pion/interceptor v0.1.29 // indirect github.com/pion/logging v0.2.2 // indirect github.com/pion/mdns v0.0.12 // indirect @@ -101,7 +101,7 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/stretchr/testify v1.9.0 // indirect go.uber.org/dig v1.17.1 // indirect - go.uber.org/fx v1.21.1 // indirect + go.uber.org/fx v1.22.1 // indirect go.uber.org/mock v0.4.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.uber.org/zap v1.27.0 // indirect diff --git a/std/clusters/p2pcluster/go.sum b/std/clusters/p2pcluster/go.sum index 44c227db..ac2f1dcd 100644 --- a/std/clusters/p2pcluster/go.sum +++ b/std/clusters/p2pcluster/go.sum @@ -99,8 +99,8 @@ github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY= github.com/googleapis/gax-go/v2 v2.0.3/go.mod h1:LLvjysVCY1JZeum8Z6l8qUty8fiNwE08qbEPm1M08qg= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= -github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -142,8 +142,8 @@ github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6 github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg= github.com/libp2p/go-flow-metrics v0.1.0 h1:0iPhMI8PskQwzh57jB9WxIuIOQ0r+15PChFGkx3Q3WM= github.com/libp2p/go-flow-metrics v0.1.0/go.mod h1:4Xi8MX8wj5aWNDAZttg6UPmc0ZrnFNsMtpsYUClFtro= -github.com/libp2p/go-libp2p v0.34.1 h1:fxn9vyLo7vJcXQRNvdRbyPjbzuQgi2UiqC8hEbn8a18= -github.com/libp2p/go-libp2p v0.34.1/go.mod h1:snyJQix4ET6Tj+LeI0VPjjxTtdWpeOhYt5lEY0KirkQ= +github.com/libp2p/go-libp2p v0.35.4 h1:FDiBUYLkueFwsuNJUZaxKRdpKvBOWU64qQPL768bSeg= +github.com/libp2p/go-libp2p v0.35.4/go.mod h1:RKCDNt30IkFipGL0tl8wQW/3zVWEGFUZo8g2gAKxwjU= github.com/libp2p/go-libp2p-asn-util v0.4.1 h1:xqL7++IKD9TBFMgnLPZR6/6iYhawHKHl950SO9L6n94= github.com/libp2p/go-libp2p-asn-util v0.4.1/go.mod h1:d/NI6XZ9qxw67b4e+NgpQexCIiFYJjErASrYW4PFDN8= github.com/libp2p/go-libp2p-pubsub v0.11.0 h1:+JvS8Kty0OiyUiN0i8H5JbaCgjnJTRnTHe4rU88dLFc= @@ -233,8 +233,8 @@ github.com/pion/datachannel v1.5.6/go.mod h1:1eKT6Q85pRnr2mHiWHxJwO50SfZRtWHTsNI github.com/pion/dtls/v2 v2.2.7/go.mod h1:8WiMkebSHFD0T+dIU+UeBaoV7kDhOW5oDCzZ7WZ/F9s= github.com/pion/dtls/v2 v2.2.11 h1:9U/dpCYl1ySttROPWJgqWKEylUdT0fXp/xst6JwY5Ks= github.com/pion/dtls/v2 v2.2.11/go.mod h1:d9SYc9fch0CqK90mRk1dC7AkzzpwJj6u2GU3u+9pqFE= -github.com/pion/ice/v2 v2.3.24 h1:RYgzhH/u5lH0XO+ABatVKCtRd+4U1GEaCXSMjNr13tI= -github.com/pion/ice/v2 v2.3.24/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw= +github.com/pion/ice/v2 v2.3.25 h1:M5rJA07dqhi3nobJIg+uPtcVjFECTrhcR3n0ns8kDZs= +github.com/pion/ice/v2 v2.3.25/go.mod h1:KXJJcZK7E8WzrBEYnV4UtqEZsGeWfHxsNqhVcVvgjxw= github.com/pion/interceptor v0.1.29 h1:39fsnlP1U8gw2JzOFWdfCU82vHvhW9o0rZnZF56wF+M= github.com/pion/interceptor v0.1.29/go.mod h1:ri+LGNjRUc5xUNtDEPzfdkmSqISixVTBF/z/Zms/6T4= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= @@ -359,8 +359,8 @@ go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/dig v1.17.1 h1:Tga8Lz8PcYNsWsyHMZ1Vm0OQOUaJNDyvPImgbAu9YSc= go.uber.org/dig v1.17.1/go.mod h1:Us0rSJiThwCv2GteUN0Q7OKvU7n5J4dxZ9JKUXozFdE= -go.uber.org/fx v1.21.1 h1:RqBh3cYdzZS0uqwVeEjOX2p73dddLpym315myy/Bpb0= -go.uber.org/fx v1.21.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= +go.uber.org/fx v1.22.1 h1:nvvln7mwyT5s1q201YE29V/BFrGor6vMiDNpU/78Mys= +go.uber.org/fx v1.22.1/go.mod h1:HT2M7d7RHo+ebKGh9NRcrsrHHfpZ60nW3QRubMRfv48= go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= diff --git a/testenv/.golangci.yml b/testenv/.golangci.yml new file mode 100644 index 00000000..43584355 --- /dev/null +++ b/testenv/.golangci.yml @@ -0,0 +1,48 @@ +run: + tests: false +linters: + enable-all: true + disable: + - wsl + - gomnd + - gochecknoglobals + - paralleltest + - gochecknoinits + - funlen + - godot + - godox + - nonamedreturns + - testpackage + - tagliatelle + - exhaustruct + - wrapcheck + - varnamelen + - ireturn + - depguard + - perfsprint + - mnd + - execinquery + fast: false +linters-settings: + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + # report about assignment of errors to blank identifier: `num, _ := strconv.Atoi(numStr)`. + # Such cases aren't reported by default. + # Default: false + check-blank: true + depguard: + rules: + whilelist: + list-mode: lax + files: + - $all + - "!*_test.go" + deny: + - pkg: "io/ioutil" + desc: "io/ioutil is deprecated, use io instead" + + + diff --git a/testenv/cluster_test.go b/testenv/cluster_test.go index f46972a1..c1d8dca2 100644 --- a/testenv/cluster_test.go +++ b/testenv/cluster_test.go @@ -140,44 +140,52 @@ func kitWithCluster(t *testing.T, opt fx.Option) func(c C) { ), ) - time.Sleep(time.Second * 15) - - // Set Key to instance 1 - - resp := &services.KeyValue{} - err := stub.New("localhost:8082").REST(). - SetMethod("POST"). - DefaultResponseHandler( - func(ctx context.Context, r stub.RESTResponse) *stub.Error { - c.So(r.StatusCode(), ShouldEqual, http.StatusOK) - - return stub.WrapError(json.Unmarshal(r.GetBody(), resp)) - }, - ). - AutoRun(ctx, "/set-key", kit.JSON, &services.SetRequest{Key: "test", Value: "testValue"}). - Error() - c.So(err, ShouldBeNil) - c.So(resp.Key, ShouldEqual, "test") - c.So(resp.Value, ShouldEqual, "testValue") - - // Get Key from instance 2 - err = stub.New("localhost:8083").REST(). - SetMethod("GET"). - SetHeader("Conn-Hdr-In", "MyValue"). - SetHeader("Envelope-Hdr-In", "EnvelopeValue"). - DefaultResponseHandler( - func(ctx context.Context, r stub.RESTResponse) *stub.Error { - c.So(r.GetHeader("Conn-Hdr-Out"), ShouldEqual, "MyValue") - c.So(r.GetHeader("Envelope-Hdr-Out"), ShouldEqual, "EnvelopeValue") - c.So(r.StatusCode(), ShouldEqual, http.StatusOK) - - return stub.WrapError(json.Unmarshal(r.GetBody(), resp)) - }, - ). - AutoRun(ctx, "/get-key/{key}", kit.JSON, &services.GetRequest{Key: "test"}). - Error() - c.So(err, ShouldBeNil) - c.So(resp.Key, ShouldEqual, "test") - c.So(resp.Value, ShouldEqual, "testValue") + time.Sleep(time.Second * 5) + hosts := []string{"localhost:8082", "localhost:8083"} + for range 100 { + key := "K_" + utils.RandomID(10) + value := "V_" + utils.RandomID(10) + setHostIndex := utils.RandomInt(len(hosts)) + setHost := hosts[setHostIndex] + getHost := hosts[(setHostIndex+1)%len(hosts)] + // Set Key to instance 1 + resp := &services.KeyValue{} + err := stub.New(setHost).REST(). + SetMethod("POST"). + DefaultResponseHandler( + func(ctx context.Context, r stub.RESTResponse) *stub.Error { + c.So(r.StatusCode(), ShouldEqual, http.StatusOK) + + return stub.WrapError(json.Unmarshal(r.GetBody(), resp)) + }, + ). + AutoRun(ctx, "/set-key", kit.JSON, &services.SetRequest{Key: key, Value: value}). + Error() + c.So(err, ShouldBeNil) + c.So(resp.Key, ShouldEqual, key) + c.So(resp.Value, ShouldEqual, value) + + // Get Key from instance 2 + connHdrIn := utils.RandomID(12) + envelopeHdrIn := utils.RandomID(12) + err = stub.New(getHost).REST(). + SetMethod("GET"). + SetHeader("Conn-Hdr-In", connHdrIn). + SetHeader("Envelope-Hdr-In", envelopeHdrIn). + DefaultResponseHandler( + func(ctx context.Context, r stub.RESTResponse) *stub.Error { + c.So(r.GetHeader("Conn-Hdr-Out"), ShouldEqual, connHdrIn) + c.So(r.GetHeader("Envelope-Hdr-Out"), ShouldEqual, envelopeHdrIn) + c.So(r.StatusCode(), ShouldEqual, http.StatusOK) + + return stub.WrapError(json.Unmarshal(r.GetBody(), resp)) + }, + ). + AutoRun(ctx, "/get-key/{key}", kit.JSON, &services.GetRequest{Key: key}). + Error() + c.So(err, ShouldBeNil) + c.So(resp.Key, ShouldEqual, key) + c.So(resp.Value, ShouldEqual, value) + } } } diff --git a/testenv/services/key_value.go b/testenv/services/key_value.go index 940c0e28..977ca9d5 100644 --- a/testenv/services/key_value.go +++ b/testenv/services/key_value.go @@ -67,14 +67,8 @@ var SimpleKeyValueService kit.ServiceBuilder = desc.NewService("simpleKeyValueSe return } - ctx.Conn().Walk(func(key string, val string) bool { - fmt.Println("Conn:", key, val) - - return true - }) ctx.Conn().Set("Conn-Hdr-Out", ctx.Conn().Get("Conn-Hdr-In")) - ctx.Out(). SetHdr("Envelope-Hdr-Out", ctx.In().GetHdr("Envelope-Hdr-In")). SetMsg(&KeyValue{Key: req.Key, Value: value.(string)}). //nolint:forcetypeassert From 73381218981246e5233e4e01704ab1cabe30859b Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Sat, 20 Jul 2024 22:30:59 +0300 Subject: [PATCH 2/3] fix: southbridge deadlock and race condition --- kit/bridge_south.go | 279 +++++++++++++++++---------- kit/ctx.go | 48 ----- kit/edge.go | 2 +- kit/envelope_carrier.go | 8 +- std/clusters/p2pcluster/cluster.go | 90 +++++++-- std/clusters/rediscluster/cluster.go | 2 +- testenv/cluster_test.go | 2 +- 7 files changed, 262 insertions(+), 169 deletions(-) diff --git a/kit/bridge_south.go b/kit/bridge_south.go index cfc16578..8d9ee5e9 100644 --- a/kit/bridge_south.go +++ b/kit/bridge_south.go @@ -5,6 +5,7 @@ import ( "fmt" "reflect" "sync" + "time" "github.com/clubpay/ronykit/kit/errors" "github.com/clubpay/ronykit/kit/utils" @@ -44,7 +45,7 @@ type southBridge struct { l Logger inProgressMtx utils.SpinLock - inProgress map[string]chan *envelopeCarrier + inProgress map[string]*clusterConn msgFactories map[string]MessageFactoryFunc } @@ -72,31 +73,84 @@ func (sb *southBridge) OnMessage(data []byte) { } sb.wg.Add(1) + + switch carrier.Kind { + case incomingCarrier: + go sb.onIncomingMessage(carrier) + default: + conn := sb.getConn(carrier.SessionID) + if conn != nil { + ctx := sb.acquireCtx(conn) + ctx.sb = sb + select { + case conn.carrierChan <- carrier: + default: + sb.eh(ctx, ErrWritingToClusterConnection) + } + sb.releaseCtx(ctx) + } + } + + sb.wg.Done() +} + +func (sb *southBridge) createSenderConn( + carrier *envelopeCarrier, timeout time.Duration, callbackFn func(*envelopeCarrier), +) *clusterConn { + rxCtx, cancelFn := context.WithCancel(context.Background()) + if timeout > 0 { + rxCtx, cancelFn = context.WithTimeout(rxCtx, timeout) + } + conn := &clusterConn{ - cb: sb.cb, + ctx: rxCtx, + cf: cancelFn, + callbackFn: callbackFn, + cluster: sb.cb, + originID: carrier.OriginID, + sessionID: carrier.SessionID, + serverID: sb.id, + kv: map[string]string{}, + wf: sb.writeFunc, + carrierChan: make(chan *envelopeCarrier, 32), + } + + sb.inProgressMtx.Lock() + sb.inProgress[carrier.SessionID] = conn + sb.inProgressMtx.Unlock() + + go conn.run() + + return conn +} + +func (sb *southBridge) createTargetConn( + carrier *envelopeCarrier, +) *clusterConn { + conn := &clusterConn{ + cluster: sb.cb, originID: carrier.OriginID, sessionID: carrier.SessionID, serverID: sb.id, kv: map[string]string{}, wf: sb.writeFunc, } - ctx := sb.acquireCtx(conn) - ctx.sb = sb - switch carrier.Kind { - case incomingCarrier: - sb.onIncomingMessage(ctx, carrier) - case outgoingCarrier: - sb.onOutgoingMessage(ctx, carrier) - case eofCarrier: - sb.onEOF(carrier) - } + return conn +} - sb.releaseCtx(ctx) - sb.wg.Done() +func (sb *southBridge) getConn(sessionID string) *clusterConn { + sb.inProgressMtx.Lock() + conn := sb.inProgress[sessionID] + sb.inProgressMtx.Unlock() + + return conn } -func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier) { +func (sb *southBridge) onIncomingMessage(carrier *envelopeCarrier) { + conn := sb.createTargetConn(carrier) + ctx := sb.acquireCtx(conn) + ctx.sb = sb ctx.forwarded = true msg := sb.msgFactories[carrier.Data.MsgType]() @@ -142,47 +196,19 @@ func (sb *southBridge) onIncomingMessage(ctx *Context, carrier *envelopeCarrier) if err != nil { sb.eh(ctx, err) } -} -func (sb *southBridge) onOutgoingMessage(ctx *Context, carrier *envelopeCarrier) { - sb.inProgressMtx.Lock() - ch, ok := sb.inProgress[carrier.SessionID] - if ok { - select { - case ch <- carrier: - default: - sb.eh(ctx, ErrWritingToClusterConnection) - } - } - sb.inProgressMtx.Unlock() -} - -func (sb *southBridge) onEOF(carrier *envelopeCarrier) { - sb.inProgressMtx.Lock() - ch, ok := sb.inProgress[carrier.SessionID] - delete(sb.inProgress, carrier.SessionID) - sb.inProgressMtx.Unlock() - if ok { - close(ch) - } + sb.releaseCtx(ctx) } -func (sb *southBridge) sendMessage(sessionID string, targetID string, data []byte) (<-chan *envelopeCarrier, error) { - ch := make(chan *envelopeCarrier, 4) - sb.inProgressMtx.Lock() - sb.inProgress[sessionID] = ch - sb.inProgressMtx.Unlock() - - err := sb.cb.Publish(targetID, data) +func (sb *southBridge) sendMessage(carrier *envelopeCarrier) error { + err := sb.cb.Publish(carrier.TargetID, carrier.ToJSON()) if err != nil { sb.inProgressMtx.Lock() - delete(sb.inProgress, sessionID) + delete(sb.inProgress, carrier.SessionID) sb.inProgressMtx.Unlock() - - return nil, err } - return ch, nil + return err } func (sb *southBridge) wrapWithCoordinator(c Contract) Contract { @@ -215,46 +241,28 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc { return } - err = ctx.executeRemote( - executeRemoteArg{ - Target: target, - In: newEnvelopeCarrier( - incomingCarrier, - utils.RandomID(32), - ctx.sb.id, - target, - ).FillWithContext(ctx), - OutCallback: func(carrier *envelopeCarrier) { - if carrier.Data == nil { - return - } - f, ok := sb.msgFactories[carrier.Data.MsgType] - if !ok { - return - } - - msg := f() - switch msg.(type) { - case RawMessage: - msg = RawMessage(carrier.Data.Msg) - default: - unmarshalEnvelopeCarrier(carrier.Data.Msg, msg) - } - - for k, v := range carrier.Data.ConnHdr { - ctx.Conn().Set(k, v) - } - - ctx.Out(). - SetID(carrier.Data.EnvelopeID). - SetHdrMap(carrier.Data.Hdr). - SetMsg(msg). - Send() - }, - }, - ) - - ctx.Error(err) + carrier := newEnvelopeCarrier( + incomingCarrier, + utils.RandomID(32), + ctx.sb.id, + target, + ).FillWithContext(ctx) + + err = ctx.sb.sendMessage(carrier) + if err != nil { + ctx.Error(err) + ctx.StopExecution() + + return + } + + conn := sb.createSenderConn(carrier, ctx.rxt, sb.genCallback(ctx)) + select { + case <-conn.Done(): + ctx.Error(conn.Err()) + case <-ctx.ctx.Done(): + ctx.Error(ctx.ctx.Err()) + } // We should stop executing next handlers, since our request has been executed on // a remote machine @@ -262,6 +270,36 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc { } } +func (sb *southBridge) genCallback(ctx *Context) func(carrier *envelopeCarrier) { + return func(carrier *envelopeCarrier) { + if carrier.Data == nil { + return + } + f, ok := sb.msgFactories[carrier.Data.MsgType] + if !ok { + return + } + + msg := f() + switch msg.(type) { + case RawMessage: + msg = RawMessage(carrier.Data.Msg) + default: + unmarshalEnvelopeCarrier(carrier.Data.Msg, msg) + } + + for k, v := range carrier.Data.ConnHdr { + ctx.Conn().Set(k, v) + } + + ctx.Out(). + SetID(carrier.Data.EnvelopeID). + SetHdrMap(carrier.Data.Hdr). + SetMsg(msg). + Send() + } +} + func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error { ec := newEnvelopeCarrier( outgoingCarrier, @@ -275,28 +313,55 @@ func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error { sb.tp.Inject(e.ctx.ctx, ec.Data) } - return c.cb.Publish(c.originID, ec.ToJSON()) + return c.cluster.Publish(c.originID, ec.ToJSON()) } -type clusterConn struct { - sessionID string - originID string - serverID string - cb Cluster +var _ Conn = (*clusterConn)(nil) - id uint64 +type clusterConn struct { clientIP string stream bool + kvMtx sync.Mutex + kv map[string]string - kvMtx sync.Mutex - kv map[string]string - wf func(c *clusterConn, e *Envelope) error + // target + serverID string + sessionID string + originID string + wf func(c *clusterConn, e *Envelope) error + cluster Cluster + + // sender + ctx context.Context //nolint + cf context.CancelFunc + callbackFn func(carrier *envelopeCarrier) + carrierChan chan *envelopeCarrier } -var _ Conn = (*clusterConn)(nil) +func (c *clusterConn) run() { + for { + select { + case <-c.ctx.Done(): + case carrier, ok := <-c.carrierChan: + if !ok { + c.cf() + + return + } + switch carrier.Kind { + default: + panic("invalid carrier kind") + case outgoingCarrier: + c.callbackFn(carrier) + case eofCarrier: + close(c.carrierChan) + } + } + } +} func (c *clusterConn) ConnID() uint64 { - return c.id + return 0 } func (c *clusterConn) ClientIP() string { @@ -347,6 +412,20 @@ func (c *clusterConn) Keys() []string { return keys } +func (c *clusterConn) Done() <-chan struct{} { + return c.ctx.Done() +} + +func (c *clusterConn) Err() error { + return c.ctx.Err() +} + +func (c *clusterConn) Cancel() { + if c.cf != nil { + c.cf() + } +} + var ( ErrSouthBridgeDisabled = errors.New("south bridge is disabled") ErrWritingToClusterConnection = errors.New("writing to cluster connection is not possible") diff --git a/kit/ctx.go b/kit/ctx.go index 575cf914..dda4d3e1 100644 --- a/kit/ctx.go +++ b/kit/ctx.go @@ -85,53 +85,6 @@ func (ctx *Context) execute(arg ExecuteArg, c Contract) { ctx.Next() } -type executeRemoteArg struct { - Target string - In *envelopeCarrier - OutCallback func(carrier *envelopeCarrier) -} - -func (ctx *Context) executeRemote(arg executeRemoteArg) error { - if ctx.sb == nil { - return ErrSouthBridgeDisabled - } - - ch, err := ctx.sb.sendMessage( - arg.In.SessionID, - arg.Target, - arg.In.ToJSON(), - ) - if err != nil { - return err - } - - var ( - cancelFn context.CancelFunc - rxCtx = context.Background() - ) - - if ctx.rxt > 0 { - rxCtx, cancelFn = context.WithTimeout(rxCtx, ctx.rxt) - defer cancelFn() - } -LOOP: - for { - select { - case <-rxCtx.Done(): - return rxCtx.Err() - case <-ctx.ctx.Done(): - return ctx.ctx.Err() - case c, ok := <-ch: - if !ok { - break LOOP - } - arg.OutCallback(c) - } - } - - return nil -} - // Next sets the next handler which will be called after the current handler. /* Here's a brief explanation of the Next() method: @@ -148,7 +101,6 @@ LOOP: you allow the processing flow to continue and pass control to the subsequent middleware functions in the chain. */ - func (ctx *Context) Next() { ctx.handlerIndex++ for ctx.handlerIndex <= len(ctx.handlers) { diff --git a/kit/edge.go b/kit/edge.go index 31c6db32..db482c9d 100644 --- a/kit/edge.go +++ b/kit/edge.go @@ -131,7 +131,7 @@ func (s *EdgeServer) registerCluster(id string, cb Cluster) *EdgeServer { cb: cb, tp: s.t, inProgressMtx: utils.SpinLock{}, - inProgress: map[string]chan *envelopeCarrier{}, + inProgress: map[string]*clusterConn{}, msgFactories: map[string]MessageFactoryFunc{}, l: s.l, } diff --git a/kit/envelope_carrier.go b/kit/envelope_carrier.go index 5bccfdee..0640c113 100644 --- a/kit/envelope_carrier.go +++ b/kit/envelope_carrier.go @@ -15,16 +15,16 @@ const ( eofCarrier ) -// envelopeCarrier is a serializable message which is used by Cluster component of the +// envelopeCarrier is a serializable message which is used by the Cluster component of the // EdgeServer to send information from one instance to another instance. type envelopeCarrier struct { // SessionID is a unique identifier for each remote-execution session. SessionID string `json:"id"` - // Kind identifies what type of the data this carrier has + // Kind identifies the purpose of the message Kind carrierKind `json:"kind"` - // OriginID the instance's id of the sender of this message + // OriginID the sender's id of the message OriginID string `json:"originID"` - // TargetID the instance's id of the receiver of this message + // TargetID the receiver's id of the message TargetID string `json:"targetID"` Data *carrierData `json:"data"` } diff --git a/std/clusters/p2pcluster/cluster.go b/std/clusters/p2pcluster/cluster.go index 97a78a37..b6a7f36d 100644 --- a/std/clusters/p2pcluster/cluster.go +++ b/std/clusters/p2pcluster/cluster.go @@ -2,6 +2,7 @@ package p2pcluster import ( "context" + "encoding/binary" "errors" "fmt" "time" @@ -176,28 +177,84 @@ func (c *cluster) startMyTopic(ctx context.Context) error { return err } - go func() { - for { - msg, err := sub.Next(ctx) - if err != nil { - if errors.Is(err, context.Canceled) { - return - } + var ( + inputChan = make(chan psMsg, 1024) + orderedChan = make(chan psMsg, 1024) + ) - c.log.Errorf("[p2pCluster] failed to receive message from topic[%s]: %v", topic, err) + go c.receiveMessage(ctx, sub, inputChan, orderedChan) + go c.handleMessage(orderedChan) + go c.tryOrderMessage(inputChan, orderedChan) - continue + return nil +} + +func (c *cluster) receiveMessage( + ctx context.Context, sub *pubsub.Subscription, + inputChan, orderedChan chan psMsg, +) { + for { + msg, err := sub.Next(ctx) + if err != nil { + if errors.Is(err, context.Canceled) { + close(inputChan) + close(orderedChan) + + return } - if msg.ReceivedFrom == c.host.ID() { - continue + c.log.Errorf("[p2pCluster] failed to receive message from topic[%s]: %v", sub.Topic(), err) + + continue + } + + if msg.ReceivedFrom == c.host.ID() { + continue + } + + inputChan <- psMsg{m: msg} + } +} + +func (c *cluster) handleMessage(orderedChan chan psMsg) { + for msg := range orderedChan { + c.d.OnMessage(msg.m.GetData()) + } +} + +func (c *cluster) tryOrderMessage(inputChan, orderedChan chan psMsg) { + var lastSeq uint64 + debounceTime := time.Millisecond * 100 + for msg := range inputChan { + seq := binary.BigEndian.Uint64(msg.m.GetSeqno()) + if lastSeq == 0 || seq == lastSeq+1 { + orderedChan <- msg + lastSeq = seq + + continue + } + + if seq < lastSeq { + orderedChan <- msg + + continue + } + + if msg.debounced > 1 { + orderedChan <- msg + if seq > lastSeq { + lastSeq = seq } - go c.d.OnMessage(msg.GetData()) + continue } - }() - return nil + go func(msg psMsg) { + time.Sleep(time.Duration(msg.debounced) * debounceTime) + msg.debounced++ + inputChan <- msg + }(msg) + } } func (c *cluster) getTopic(id string) *pubsub.Topic { @@ -257,3 +314,8 @@ func (c *cluster) HandlePeerFound(pi peer.AddrInfo) { c.log.Errorf("[p2pCluster] failed to connect to peer(%s): %v", pi.String(), err) } } + +type psMsg struct { + m *pubsub.Message + debounced int +} diff --git a/std/clusters/rediscluster/cluster.go b/std/clusters/rediscluster/cluster.go index e266f22d..7704e5c5 100644 --- a/std/clusters/rediscluster/cluster.go +++ b/std/clusters/rediscluster/cluster.go @@ -86,7 +86,7 @@ func (c *cluster) Start(ctx context.Context) error { if !ok { return } - go c.d.OnMessage(utils.S2B(msg.Payload)) + c.d.OnMessage(utils.S2B(msg.Payload)) case <-runCtx.Done(): _ = c.ps.Close() diff --git a/testenv/cluster_test.go b/testenv/cluster_test.go index c1d8dca2..1db015df 100644 --- a/testenv/cluster_test.go +++ b/testenv/cluster_test.go @@ -142,7 +142,7 @@ func kitWithCluster(t *testing.T, opt fx.Option) func(c C) { time.Sleep(time.Second * 5) hosts := []string{"localhost:8082", "localhost:8083"} - for range 100 { + for range 500 { key := "K_" + utils.RandomID(10) value := "V_" + utils.RandomID(10) setHostIndex := utils.RandomInt(len(hosts)) From 794667ce3561fcb35eff1847338b08402a8f7d63 Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Sun, 21 Jul 2024 09:06:24 +0300 Subject: [PATCH 3/3] [p2pCluster] hackfix to improve message orders, by introducing delay and batch --- kit/bridge_south.go | 49 +++++++++--------- kit/utils/batch/batcher.go | 19 ++++--- kit/utils/batch/batcher_test.go | 35 +++++++++++++ std/clusters/p2pcluster/cluster.go | 81 +++++++++++++----------------- 4 files changed, 107 insertions(+), 77 deletions(-) diff --git a/kit/bridge_south.go b/kit/bridge_south.go index 8d9ee5e9..769fcaa3 100644 --- a/kit/bridge_south.go +++ b/kit/bridge_south.go @@ -119,7 +119,32 @@ func (sb *southBridge) createSenderConn( sb.inProgress[carrier.SessionID] = conn sb.inProgressMtx.Unlock() - go conn.run() + go func(c *clusterConn) { + for { + select { + case <-c.ctx.Done(): + return + case carrier, ok := <-c.carrierChan: + if !ok { + c.cf() + + return + } + switch carrier.Kind { + default: + panic("invalid carrier kind") + case outgoingCarrier: + c.callbackFn(carrier) + case eofCarrier: + sb.inProgressMtx.Lock() + delete(sb.inProgress, c.sessionID) + sb.inProgressMtx.Unlock() + + close(c.carrierChan) + } + } + } + }(conn) return conn } @@ -338,28 +363,6 @@ type clusterConn struct { carrierChan chan *envelopeCarrier } -func (c *clusterConn) run() { - for { - select { - case <-c.ctx.Done(): - case carrier, ok := <-c.carrierChan: - if !ok { - c.cf() - - return - } - switch carrier.Kind { - default: - panic("invalid carrier kind") - case outgoingCarrier: - c.callbackFn(carrier) - case eofCarrier: - close(c.carrierChan) - } - } - } -} - func (c *clusterConn) ConnID() uint64 { return 0 } diff --git a/kit/utils/batch/batcher.go b/kit/utils/batch/batcher.go index 186146a3..5e50b122 100644 --- a/kit/utils/batch/batcher.go +++ b/kit/utils/batch/batcher.go @@ -18,7 +18,7 @@ import ( type NA = struct{} -type Func[IN, OUT any] func(targetID string, entries []Entry[IN, OUT]) +type Func[IN, OUT any] func(tagID string, entries []Entry[IN, OUT]) type MultiBatcher[IN, OUT any] struct { cfg config @@ -66,7 +66,7 @@ func (fp *MultiBatcher[IN, OUT]) EnterAndWait(targetID string, entry Entry[IN, O } type Batcher[IN, OUT any] struct { - utils.SpinLock + spin utils.SpinLock readyWorkers int32 batchSize int32 @@ -76,6 +76,9 @@ type Batcher[IN, OUT any] struct { tagID string } +// NewBatcher construct a new Batcher with tagID. `tagID` is the value that will be passed to +// Func on every batch. This lets you define the same batch func with multiple Batcher objects; MultiBatcher +// is using `tagID` internally to handle different batches of entries in parallel. func NewBatcher[IN, OUT any](f Func[IN, OUT], tagID string, opt ...Option) *Batcher[IN, OUT] { cfg := defaultConfig for _, o := range opt { @@ -97,14 +100,14 @@ func newBatcher[IN, OUT any](f Func[IN, OUT], tagID string, cfg config) *Batcher } func (f *Batcher[IN, OUT]) startWorker() { - f.Lock() + f.spin.Lock() if atomic.AddInt32(&f.readyWorkers, -1) < 0 { atomic.AddInt32(&f.readyWorkers, 1) - f.Unlock() + f.spin.Unlock() return } - f.Unlock() + f.spin.Unlock() w := &worker[IN, OUT]{ f: f, @@ -156,15 +159,15 @@ func (w *worker[IN, OUT]) run() { continue } } - w.f.Lock() + w.f.spin.Lock() if len(el) == 0 { // clean up and shutdown the worker atomic.AddInt32(&w.f.readyWorkers, 1) - w.f.Unlock() + w.f.spin.Unlock() break } - w.f.Unlock() + w.f.spin.Unlock() w.f.flusherFunc(w.f.tagID, el) for idx := range el { el[idx].done() diff --git a/kit/utils/batch/batcher_test.go b/kit/utils/batch/batcher_test.go index d9bd3015..dccb659e 100644 --- a/kit/utils/batch/batcher_test.go +++ b/kit/utils/batch/batcher_test.go @@ -135,3 +135,38 @@ var _ = Describe("Flusher With Callback", func() { Expect(sum).To(Equal(total * (total - 1) / 2)) }) }) + +func ExampleBatcher() { + averageAll := func(targetID string, entries []batch.Entry[float64, float64]) { + var ( + sum float64 + n int + ) + for _, entry := range entries { + sum += entry.Value() + n++ + } + avg := sum / float64(n) + + for _, e := range entries { + e.Callback(avg) + } + } + b := batch.NewBatcher( + averageAll, "tag1", + batch.WithBatchSize(10), + batch.WithMinWaitTime(time.Second), + ) + wg := sync.WaitGroup{} + for i := 0.0; i < 10.0; i++ { + wg.Add(1) + go func(i float64) { + t := time.Now() + b.EnterAndWait( + batch.NewEntry(i, func(out float64) { fmt.Println("duration:", time.Now().Sub(t), "avg:", out) }), + ) + wg.Done() + }(i) + } + wg.Wait() +} diff --git a/std/clusters/p2pcluster/cluster.go b/std/clusters/p2pcluster/cluster.go index b6a7f36d..c5f0c6dd 100644 --- a/std/clusters/p2pcluster/cluster.go +++ b/std/clusters/p2pcluster/cluster.go @@ -1,14 +1,16 @@ package p2pcluster import ( + "bytes" "context" - "encoding/binary" "errors" "fmt" + "sort" "time" "github.com/clubpay/ronykit/kit" "github.com/clubpay/ronykit/kit/utils" + "github.com/clubpay/ronykit/kit/utils/batch" "github.com/libp2p/go-libp2p" pubsub "github.com/libp2p/go-libp2p-pubsub" "github.com/libp2p/go-libp2p/core/host" @@ -177,29 +179,27 @@ func (c *cluster) startMyTopic(ctx context.Context) error { return err } - var ( - inputChan = make(chan psMsg, 1024) - orderedChan = make(chan psMsg, 1024) + inputChan := make(chan *pubsub.Message, 1024) + batcher := batch.NewMulti( + genSortMessages(inputChan), + batch.WithMaxWorkers(1), + batch.WithMinWaitTime(time.Millisecond*25), ) - go c.receiveMessage(ctx, sub, inputChan, orderedChan) - go c.handleMessage(orderedChan) - go c.tryOrderMessage(inputChan, orderedChan) + go c.receiveMessage(ctx, sub, batcher) + go c.handleMessage(ctx, inputChan) return nil } func (c *cluster) receiveMessage( ctx context.Context, sub *pubsub.Subscription, - inputChan, orderedChan chan psMsg, + b *batch.MultiBatcher[*pubsub.Message, batch.NA], ) { for { msg, err := sub.Next(ctx) if err != nil { if errors.Is(err, context.Canceled) { - close(inputChan) - close(orderedChan) - return } @@ -212,48 +212,37 @@ func (c *cluster) receiveMessage( continue } - inputChan <- psMsg{m: msg} + b.Enter( + msg.GetFrom().String(), + batch.NewEntry[*pubsub.Message, batch.NA](msg, nil), + ) } } -func (c *cluster) handleMessage(orderedChan chan psMsg) { - for msg := range orderedChan { - c.d.OnMessage(msg.m.GetData()) +func (c *cluster) handleMessage(ctx context.Context, inputChan chan *pubsub.Message) { + for { + select { + case <-ctx.Done(): + return + case msg := <-inputChan: + c.d.OnMessage(msg.GetData()) + } } } -func (c *cluster) tryOrderMessage(inputChan, orderedChan chan psMsg) { - var lastSeq uint64 - debounceTime := time.Millisecond * 100 - for msg := range inputChan { - seq := binary.BigEndian.Uint64(msg.m.GetSeqno()) - if lastSeq == 0 || seq == lastSeq+1 { - orderedChan <- msg - lastSeq = seq - - continue +func genSortMessages( + inputChan chan *pubsub.Message, +) func(tagID string, entries []batch.Entry[*pubsub.Message, batch.NA]) { + return func(tagID string, entries []batch.Entry[*pubsub.Message, batch.NA]) { + sort.Slice( + entries, func(i, j int) bool { + return bytes.Compare(entries[i].Value().GetSeqno(), entries[j].Value().GetSeqno()) < 1 + }, + ) + + for _, entry := range entries { + inputChan <- entry.Value() } - - if seq < lastSeq { - orderedChan <- msg - - continue - } - - if msg.debounced > 1 { - orderedChan <- msg - if seq > lastSeq { - lastSeq = seq - } - - continue - } - - go func(msg psMsg) { - time.Sleep(time.Duration(msg.debounced) * debounceTime) - msg.debounced++ - inputChan <- msg - }(msg) } }