Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize(exit): server graceful shutdown logic to avoid EOF when idle connections receive new requests after being closed #1681

Merged
merged 2 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions pkg/remote/trans/default_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
"net"
"runtime/debug"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/endpoint"
"github.com/cloudwego/kitex/pkg/kerrors"
Expand Down Expand Up @@ -49,13 +50,14 @@ func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.
}

type svrTransHandler struct {
opt *remote.ServerOption
svcSearcher remote.ServiceSearcher
targetSvcInfo *serviceinfo.ServiceInfo
inkHdlFunc endpoint.Endpoint
codec remote.Codec
transPipe *remote.TransPipeline
ext Extension
opt *remote.ServerOption
svcSearcher remote.ServiceSearcher
targetSvcInfo *serviceinfo.ServiceInfo
inkHdlFunc endpoint.Endpoint
codec remote.Codec
transPipe *remote.TransPipeline
ext Extension
inGracefulShutdown uint32
}

// Write implements the remote.ServerTransHandler interface.
Expand Down Expand Up @@ -115,13 +117,22 @@ func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, recvMsg remot
}

func (t *svrTransHandler) newCtxWithRPCInfo(ctx context.Context, conn net.Conn) (context.Context, rpcinfo.RPCInfo) {
var ri rpcinfo.RPCInfo
if rpcinfo.PoolEnabled() { // reuse per-connection rpcinfo
return ctx, rpcinfo.GetRPCInfo(ctx)
ri = rpcinfo.GetRPCInfo(ctx)
// delayed reinitialize for faster response
} else {
// new rpcinfo if reuse is disabled
ri = t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
ctx = rpcinfo.NewCtxWithRPCInfo(ctx, ri)
}
// new rpcinfo if reuse is disabled
ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
return rpcinfo.NewCtxWithRPCInfo(ctx, ri), ri
if atomic.LoadUint32(&t.inGracefulShutdown) == 1 {
// If server is in graceful shutdown status, mark connection reset flag to all responses to let client close the connections.
if ei := rpcinfo.AsTaggable(ri.To()); ei != nil {
ei.SetTag(rpcinfo.ConnResetTag, "1")
}
}
return ctx, ri
}

// OnRead implements the remote.ServerTransHandler interface.
Expand Down Expand Up @@ -343,6 +354,11 @@ func (t *svrTransHandler) finishProfiler(ctx context.Context) {
t.opt.Profiler.Untag(ctx)
}

func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
atomic.StoreUint32(&t.inGracefulShutdown, 1)
return nil
}

func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) {
rAddr := conn.RemoteAddr()
if ri == nil {
Expand Down
9 changes: 9 additions & 0 deletions pkg/remote/trans/netpoll/trans_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"runtime/debug"
"sync"
"syscall"
"time"

"github.com/cloudwego/netpoll"

Expand Down Expand Up @@ -119,6 +120,14 @@ func (ts *transServer) Shutdown() (err error) {
if err != nil {
klog.Warnf("KITEX: server graceful shutdown error: %v", err)
}
// 3. wait some time to receive requests before closing idle conns
/*
When the netpoll eventloop shutdown, all idle connections will be closed.
At this time, these connections may just receive requests, and then the peer side will report an EOF error.
To reduce such cases, wait for some time to try to receive these requests as much as possible,
so that the closing of connections can be controlled by the upper-layer protocol and the EOF problem can be reduced.
*/
time.Sleep(100 * time.Millisecond)
}
}
if ts.evl != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/retry/failure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestFixedBackOff_Wait(t *testing.T) {
bk.Wait(1)
waitTime := time.Since(startTime)
test.Assert(t, time.Millisecond*fix <= waitTime)
test.Assert(t, waitTime < time.Millisecond*(fix+5))
test.Assert(t, waitTime < time.Millisecond*(fix*2))
}

func TestFixedBackOff_String(t *testing.T) {
Expand Down
8 changes: 8 additions & 0 deletions pkg/transmeta/ttheader.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag
if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok && bizErr != nil {
setter.SetBizStatusErr(bizErr)
}
if val, ok := strInfo[transmeta.HeaderConnectionReadyToReset]; ok {
if ei := rpcinfo.AsTaggable(ri.To()); ei != nil {
ei.SetTag(rpcinfo.ConnResetTag, val)
}
}
return ctx, nil
}

Expand Down Expand Up @@ -190,6 +195,9 @@ func (sh *serverTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa
strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra())
}
}
if val, ok := ri.To().Tag(rpcinfo.ConnResetTag); ok {
strInfo[transmeta.HeaderConnectionReadyToReset] = val
}

return ctx, nil
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/transmeta/ttheader_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ func TestTTHeaderServerReadMetainfo(t *testing.T) {

func TestTTHeaderServerWriteMetainfo(t *testing.T) {
ctx := context.Background()
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation("", ""),
ri := rpcinfo.NewRPCInfo(nil, rpcinfo.NewEndpointInfo("", "mock", nil, nil), rpcinfo.NewInvocation("", ""),
rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
msg := remote.NewMessage(nil, mocks.ServiceInfo(), ri, remote.Call, remote.Client)

Expand Down
136 changes: 71 additions & 65 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -979,80 +979,86 @@ func TestInvokeHandlerPanic(t *testing.T) {
}

func TestRegisterService(t *testing.T) {
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

svr.Run()
svr.Run()

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "server is running")
}
return true
})
svr.Stop()
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "server is running")
}
return true
})
svr.Stop()
}

svr = NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

test.PanicAt(t, func() {
_ = svr.RegisterService(nil, mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "svcInfo is nil")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(nil, mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "svcInfo is nil")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), nil)
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "handler is nil")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), nil)
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "handler is nil")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService())
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "Service[MockService] is already defined")
}
return true
})
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler(), WithFallbackService())
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "Service[MockService] is already defined")
}
return true
})

test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "multiple fallback services cannot be registered")
}
return true
})
svr.Stop()
test.PanicAt(t, func() {
_ = svr.RegisterService(mocks.Service2Info(), mocks.MyServiceHandler(), WithFallbackService())
}, func(err interface{}) bool {
if errMsg, ok := err.(string); ok {
return strings.Contains(errMsg, "multiple fallback services cannot be registered")
}
return true
})
svr.Stop()
}

svr = NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
{
svr := NewServer()
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})

_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
err := svr.Run()
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
svr.Stop()
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
err := svr.Run()
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
svr.Stop()
}
}

func TestRegisterServiceWithMiddleware(t *testing.T) {
Expand Down
Loading