Skip to content

Commit

Permalink
fix: change WithOnlyAcceptingHTTP2Traffic to WithRefuseTrafficWithout…
Browse files Browse the repository at this point in the history
…ServiceName
  • Loading branch information
Marina-Sakai committed Jan 22, 2024
1 parent 8344224 commit 5b4fb56
Show file tree
Hide file tree
Showing 16 changed files with 85 additions and 59 deletions.
2 changes: 1 addition & 1 deletion internal/server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ type Options struct {

Streaming stream.StreamingConfig

OnlyAcceptingHTTP2Traffic bool
RefuseTrafficWithoutServiceName bool
}

type Limit struct {
Expand Down
2 changes: 1 addition & 1 deletion pkg/remote/codec/header_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func initServerRecvMsg() remote.Message {
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server)
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server, false)
return msg
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/remote/codec/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestSetOrCheckMethodName(t *testing.T) {
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
err := SetOrCheckMethodName("mock", msg)
test.Assert(t, err == nil)
ri = msg.RPCInfo()
Expand All @@ -50,7 +50,7 @@ func TestSetOrCheckMethodName(t *testing.T) {
test.Assert(t, ri.Invocation().MethodName() == "mock")
test.Assert(t, ri.To().Method() == "mock")

msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server)
msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server, false)
err = SetOrCheckMethodName("dummy", msg)
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "unknown method dummy")
Expand Down
31 changes: 18 additions & 13 deletions pkg/remote/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,15 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R
}

// NewMessageWithNewer creates a new Message and set data later.
func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole) Message {
func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole, refuseTrafficWithoutServiceName bool) Message {
msg := messagePool.Get().(*message)
msg.rpcInfo = ri
msg.targetSvcInfo = targetSvcInfo
msg.svcSearchMap = svcSearchMap
msg.msgType = msgType
msg.rpcRole = rpcRole
msg.transInfo = transInfoPool.Get().(*transInfo)
msg.refuseTrafficWithoutServiceName = refuseTrafficWithoutServiceName
return msg
}

Expand All @@ -137,18 +138,19 @@ func newMessage() interface{} {
}

type message struct {
msgType MessageType
data interface{}
rpcInfo rpcinfo.RPCInfo
targetSvcInfo *serviceinfo.ServiceInfo
svcSearchMap map[string]*serviceinfo.ServiceInfo
rpcRole RPCRole
compressType CompressType
payloadSize int
transInfo TransInfo
tags map[string]interface{}
protocol ProtocolInfo
payloadCodec PayloadCodec
msgType MessageType
data interface{}
rpcInfo rpcinfo.RPCInfo
targetSvcInfo *serviceinfo.ServiceInfo
svcSearchMap map[string]*serviceinfo.ServiceInfo
rpcRole RPCRole
compressType CompressType
payloadSize int
transInfo TransInfo
tags map[string]interface{}
protocol ProtocolInfo
payloadCodec PayloadCodec
refuseTrafficWithoutServiceName bool
}

func (m *message) zero() {
Expand Down Expand Up @@ -187,6 +189,9 @@ func (m *message) SetServiceInfo(svcName, methodName string) (*serviceinfo.Servi
}
return m.targetSvcInfo, nil
}
if svcName == "" && m.refuseTrafficWithoutServiceName {
return nil, NewTransErrorWithMsg(NoServiceName, "no service name while the server has WithRefuseTrafficWithoutServiceName option enabled")

Check warning on line 193 in pkg/remote/message.go

View check run for this annotation

Codecov / codecov/patch

pkg/remote/message.go#L192-L193

Added lines #L192 - L193 were not covered by tests
}
var key string
if svcName == serviceinfo.GenericService || svcName == "" {
key = methodName
Expand Down
3 changes: 2 additions & 1 deletion pkg/remote/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ type ServerOption struct {

GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error

OnlyAcceptingHTTP2Traffic bool
// RefuseTrafficWithoutServiceName is used for a server with multi services
RefuseTrafficWithoutServiceName bool

Option

Expand Down
5 changes: 1 addition & 4 deletions pkg/remote/trans/default_server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ import (

// NewDefaultSvrTransHandler to provide default impl of svrTransHandler, it can be reused in netpoll, shm-ipc, framework-sdk extensions
func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.ServerTransHandler, error) {
if opt.OnlyAcceptingHTTP2Traffic {
return nil, remote.NewTransErrorWithMsg(remote.InvalidProtocol, "only http2 traffic is accepted")
}
svrHdlr := &svrTransHandler{
opt: opt,
codec: opt.Codec,
Expand Down Expand Up @@ -167,7 +164,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error)
}()
ctx = t.startTracer(ctx, ri)
ctx = t.startProfiler(ctx)
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server)
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName)
recvMsg.SetPayloadCodec(t.opt.PayloadCodec)
ctx, err = t.transPipe.Read(ctx, conn, recvMsg)
if err != nil {
Expand Down
5 changes: 0 additions & 5 deletions pkg/remote/trans/default_server_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,6 @@ func TestDefaultSvrTransHandler(t *testing.T) {
test.Assert(t, err == nil, err)
test.Assert(t, tagEncode == 1, tagEncode)
test.Assert(t, tagDecode == 1, tagDecode)

opt = &remote.ServerOption{OnlyAcceptingHTTP2Traffic: true}
_, err = NewDefaultSvrTransHandler(opt, ext)
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "only http2 traffic is accepted")
}

func TestSvrTransHandlerBizError(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/remote/trans/netpoll/http_client_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func TestHTTPOnMessage(t *testing.T) {
}
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, method), nil, rpcinfo.NewRPCStats())
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
sendMsg := remote.NewMessage(svcInfo.MethodInfo(method).NewResult(), svcInfo, ri, remote.Reply, remote.Server)

// 2. test
Expand Down
5 changes: 1 addition & 4 deletions pkg/remote/trans/netpollmux/server_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo
}

func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
if opt.OnlyAcceptingHTTP2Traffic {
return nil, remote.NewTransErrorWithMsg(remote.InvalidProtocol, "only http2 traffic is accepted")
}
svrHdlr := &svrTransHandler{
opt: opt,
codec: opt.Codec,
Expand Down Expand Up @@ -238,7 +235,7 @@ func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, rea
}()

// read
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server)
recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName)
bufReader := np.NewReaderByteBuffer(reader)
err = t.readWithByteBuffer(ctx, bufReader, recvMsg)
if err != nil {
Expand Down
4 changes: 0 additions & 4 deletions pkg/remote/trans/netpollmux/server_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ func TestNewTransHandler(t *testing.T) {
handler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{})
test.Assert(t, err == nil, err)
test.Assert(t, handler != nil)

_, err = NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{OnlyAcceptingHTTP2Traffic: true})
test.Assert(t, err != nil, err)
test.Assert(t, err.Error() == "only http2 traffic is accepted")
}

// TestOnActive test ServerTransHandler OnActive
Expand Down
1 change: 1 addition & 0 deletions pkg/remote/trans_errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
UnsupportedClientType = 10
// kitex's own type id from number 20
UnknownService = 20
NoServiceName = 21
)

var defaultTransErrorMessage = map[int32]string{
Expand Down
8 changes: 5 additions & 3 deletions server/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,11 @@ func WithContextBackup(enable, async bool) Option {
}}
}

// WithOnlyAcceptingHTTP2Traffic returns an Option that accepts only http2 traffic.
func WithOnlyAcceptingHTTP2Traffic(enable bool) Option {
// WithRefuseTrafficWithoutServiceName returns an Option that only accepts traffics with service name.
// This is used for a server with multi services and is one of the options to avoid a server startup error
// when having conflicting method names between services without specifying a fallback service for the method.
func WithRefuseTrafficWithoutServiceName(enable bool) Option {
return Option{F: func(o *internal_server.Options, di *utils.Slice) {
o.OnlyAcceptingHTTP2Traffic = enable
o.RefuseTrafficWithoutServiceName = enable
}}
}
16 changes: 15 additions & 1 deletion server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,25 @@ func TestWithProfilerMessageTagging(t *testing.T) {
mocks.MockErrorMethod: svcInfo,
mocks.MockOnewayMethod: svcInfo,
}
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)

newCtx, tags := iSvr.opt.RemoteOpt.ProfilerMessageTagging(ctx, msg)
test.Assert(t, len(tags) == 8)
test.DeepEqual(t, tags, []string{"b", "2", "c", "2", "a", "1", "b", "1"})
test.Assert(t, newCtx.Value("ctx1").(int) == 1)
test.Assert(t, newCtx.Value("ctx2").(int) == 2)
}

func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) {
svr := NewServer(WithRefuseTrafficWithoutServiceName(true))
err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
test.Assert(t, err == nil, err)
time.AfterFunc(100*time.Millisecond, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
err = svr.Run()
test.Assert(t, err == nil, err)
iSvr := svr.(*server)
test.Assert(t, iSvr.opt.RefuseTrafficWithoutServiceName)
}
30 changes: 30 additions & 0 deletions server/register_option_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright 2024 CloudWeGo Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package server

import (
"testing"

internal_server "github.com/cloudwego/kitex/internal/server"
"github.com/cloudwego/kitex/internal/test"
)

func TestWithFallbackService(t *testing.T) {
opts := []RegisterOption{WithFallbackService()}
registerOpts := internal_server.NewRegisterOptions(opts)
test.Assert(t, registerOpts.IsFallbackService)
}
8 changes: 4 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func (s *server) initBasicRemoteOption() {
remoteOpt := s.opt.RemoteOpt
remoteOpt.TargetSvcInfo = s.targetSvcInfo
remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap()
remoteOpt.OnlyAcceptingHTTP2Traffic = s.opt.OnlyAcceptingHTTP2Traffic
remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName
remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc()
remoteOpt.TracerCtl = s.opt.TracerCtl
remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout()
Expand Down Expand Up @@ -448,7 +448,7 @@ func (s *server) check() error {
s.targetSvcInfo = getDefaultSvcInfo(s.svcs)
return nil
}
return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.OnlyAcceptingHTTP2Traffic)
return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName)
}

func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) {
Expand Down Expand Up @@ -572,8 +572,8 @@ func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo {
return nil

Check warning on line 572 in server/server.go

View check run for this annotation

Codecov / codecov/patch

server/server.go#L572

Added line #L572 was not covered by tests
}

func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, onlyAcceptingHTTP2Traffic bool) error {
if onlyAcceptingHTTP2Traffic {
func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error {
if refuseTrafficWithoutServiceName {
return nil

Check warning on line 577 in server/server.go

View check run for this annotation

Codecov / codecov/patch

server/server.go#L577

Added line #L577 was not covered by tests
}
for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap {
Expand Down
18 changes: 3 additions & 15 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) {
{ // mock server call
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
recvMsg.NewData(callMethod)
sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)

Expand Down Expand Up @@ -742,7 +742,7 @@ func TestInvokeHandlerExec(t *testing.T) {
{ // mock server call
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
recvMsg.NewData(callMethod)
sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)

Expand Down Expand Up @@ -805,7 +805,7 @@ func TestInvokeHandlerPanic(t *testing.T) {
// mock server call
ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats())
ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server)
recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false)
recvMsg.NewData(callMethod)
sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server)

Expand Down Expand Up @@ -921,18 +921,6 @@ func TestRegisterService(t *testing.T) {
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified")
svr.Stop()

svr = NewServer(WithOnlyAcceptingHTTP2Traffic(true))
time.AfterFunc(time.Second, func() {
err := svr.Stop()
test.Assert(t, err == nil, err)
})
_ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler())
_ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler())
err = svr.Run()
test.Assert(t, err != nil)
test.Assert(t, err.Error() == "only http2 traffic is accepted")
svr.Stop()
}

type noopMetahandler struct{}
Expand Down

0 comments on commit 5b4fb56

Please sign in to comment.