diff --git a/internal/server/option.go b/internal/server/option.go index c9a4ac5d3b..35aa34dcec 100644 --- a/internal/server/option.go +++ b/internal/server/option.go @@ -98,7 +98,7 @@ type Options struct { Streaming stream.StreamingConfig - OnlyAcceptingHTTP2Traffic bool + RefuseTrafficWithoutServiceName bool } type Limit struct { diff --git a/pkg/remote/codec/header_codec_test.go b/pkg/remote/codec/header_codec_test.go index 864f55f177..64275ef37b 100644 --- a/pkg/remote/codec/header_codec_test.go +++ b/pkg/remote/codec/header_codec_test.go @@ -319,7 +319,7 @@ func initServerRecvMsg() remote.Message { mocks.MockErrorMethod: svcInfo, mocks.MockOnewayMethod: svcInfo, } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server) + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, mockSvrRPCInfo, remote.Call, remote.Server, false) return msg } diff --git a/pkg/remote/codec/util_test.go b/pkg/remote/codec/util_test.go index a7f75fdb0d..672edb5e74 100644 --- a/pkg/remote/codec/util_test.go +++ b/pkg/remote/codec/util_test.go @@ -41,7 +41,7 @@ func TestSetOrCheckMethodName(t *testing.T) { mocks.MockErrorMethod: svcInfo, mocks.MockOnewayMethod: svcInfo, } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) err := SetOrCheckMethodName("mock", msg) test.Assert(t, err == nil) ri = msg.RPCInfo() @@ -50,7 +50,7 @@ func TestSetOrCheckMethodName(t *testing.T) { test.Assert(t, ri.Invocation().MethodName() == "mock") test.Assert(t, ri.To().Method() == "mock") - msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server) + msg = remote.NewMessageWithNewer(svcInfo, map[string]*serviceinfo.ServiceInfo{}, ri, remote.Call, remote.Server, false) err = SetOrCheckMethodName("dummy", msg) test.Assert(t, err != nil) test.Assert(t, err.Error() == "unknown method dummy") diff --git a/pkg/remote/message.go b/pkg/remote/message.go index 5fdbbb39f8..f37eb248e5 100644 --- a/pkg/remote/message.go +++ b/pkg/remote/message.go @@ -114,7 +114,7 @@ func NewMessage(data interface{}, svcInfo *serviceinfo.ServiceInfo, ri rpcinfo.R } // NewMessageWithNewer creates a new Message and set data later. -func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole) Message { +func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap map[string]*serviceinfo.ServiceInfo, ri rpcinfo.RPCInfo, msgType MessageType, rpcRole RPCRole, refuseTrafficWithoutServiceName bool) Message { msg := messagePool.Get().(*message) msg.rpcInfo = ri msg.targetSvcInfo = targetSvcInfo @@ -122,6 +122,7 @@ func NewMessageWithNewer(targetSvcInfo *serviceinfo.ServiceInfo, svcSearchMap ma msg.msgType = msgType msg.rpcRole = rpcRole msg.transInfo = transInfoPool.Get().(*transInfo) + msg.refuseTrafficWithoutServiceName = refuseTrafficWithoutServiceName return msg } @@ -137,18 +138,19 @@ func newMessage() interface{} { } type message struct { - msgType MessageType - data interface{} - rpcInfo rpcinfo.RPCInfo - targetSvcInfo *serviceinfo.ServiceInfo - svcSearchMap map[string]*serviceinfo.ServiceInfo - rpcRole RPCRole - compressType CompressType - payloadSize int - transInfo TransInfo - tags map[string]interface{} - protocol ProtocolInfo - payloadCodec PayloadCodec + msgType MessageType + data interface{} + rpcInfo rpcinfo.RPCInfo + targetSvcInfo *serviceinfo.ServiceInfo + svcSearchMap map[string]*serviceinfo.ServiceInfo + rpcRole RPCRole + compressType CompressType + payloadSize int + transInfo TransInfo + tags map[string]interface{} + protocol ProtocolInfo + payloadCodec PayloadCodec + refuseTrafficWithoutServiceName bool } func (m *message) zero() { @@ -187,6 +189,9 @@ func (m *message) SetServiceInfo(svcName, methodName string) (*serviceinfo.Servi } return m.targetSvcInfo, nil } + if svcName == "" && m.refuseTrafficWithoutServiceName { + return nil, NewTransErrorWithMsg(NoServiceName, "no service name while the server has WithRefuseTrafficWithoutServiceName option enabled") + } var key string if svcName == serviceinfo.GenericService || svcName == "" { key = methodName diff --git a/pkg/remote/option.go b/pkg/remote/option.go index e7e09f40f5..afe0f567b9 100644 --- a/pkg/remote/option.go +++ b/pkg/remote/option.go @@ -113,7 +113,8 @@ type ServerOption struct { GRPCUnknownServiceHandler func(ctx context.Context, method string, stream streaming.Stream) error - OnlyAcceptingHTTP2Traffic bool + // RefuseTrafficWithoutServiceName is used for a server with multi services + RefuseTrafficWithoutServiceName bool Option diff --git a/pkg/remote/trans/default_server_handler.go b/pkg/remote/trans/default_server_handler.go index c63acf24e8..2e0f257f75 100644 --- a/pkg/remote/trans/default_server_handler.go +++ b/pkg/remote/trans/default_server_handler.go @@ -34,9 +34,6 @@ import ( // NewDefaultSvrTransHandler to provide default impl of svrTransHandler, it can be reused in netpoll, shm-ipc, framework-sdk extensions func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.ServerTransHandler, error) { - if opt.OnlyAcceptingHTTP2Traffic { - return nil, remote.NewTransErrorWithMsg(remote.InvalidProtocol, "only http2 traffic is accepted") - } svrHdlr := &svrTransHandler{ opt: opt, codec: opt.Codec, @@ -167,7 +164,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) }() ctx = t.startTracer(ctx, ri) ctx = t.startProfiler(ctx) - recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) recvMsg.SetPayloadCodec(t.opt.PayloadCodec) ctx, err = t.transPipe.Read(ctx, conn, recvMsg) if err != nil { diff --git a/pkg/remote/trans/default_server_handler_test.go b/pkg/remote/trans/default_server_handler_test.go index 6143f34ac8..eefb224f76 100644 --- a/pkg/remote/trans/default_server_handler_test.go +++ b/pkg/remote/trans/default_server_handler_test.go @@ -105,11 +105,6 @@ func TestDefaultSvrTransHandler(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, tagEncode == 1, tagEncode) test.Assert(t, tagDecode == 1, tagDecode) - - opt = &remote.ServerOption{OnlyAcceptingHTTP2Traffic: true} - _, err = NewDefaultSvrTransHandler(opt, ext) - test.Assert(t, err != nil) - test.Assert(t, err.Error() == "only http2 traffic is accepted") } func TestSvrTransHandlerBizError(t *testing.T) { diff --git a/pkg/remote/trans/netpoll/http_client_handler_test.go b/pkg/remote/trans/netpoll/http_client_handler_test.go index 04814e5b9c..8826a2036a 100644 --- a/pkg/remote/trans/netpoll/http_client_handler_test.go +++ b/pkg/remote/trans/netpoll/http_client_handler_test.go @@ -132,7 +132,7 @@ func TestHTTPOnMessage(t *testing.T) { } ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, method), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) sendMsg := remote.NewMessage(svcInfo.MethodInfo(method).NewResult(), svcInfo, ri, remote.Reply, remote.Server) // 2. test diff --git a/pkg/remote/trans/netpollmux/server_handler.go b/pkg/remote/trans/netpollmux/server_handler.go index d64ea36609..052b10f751 100644 --- a/pkg/remote/trans/netpollmux/server_handler.go +++ b/pkg/remote/trans/netpollmux/server_handler.go @@ -62,9 +62,6 @@ func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remo } func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { - if opt.OnlyAcceptingHTTP2Traffic { - return nil, remote.NewTransErrorWithMsg(remote.InvalidProtocol, "only http2 traffic is accepted") - } svrHdlr := &svrTransHandler{ opt: opt, codec: opt.Codec, @@ -238,7 +235,7 @@ func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, rea }() // read - recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server) + recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) bufReader := np.NewReaderByteBuffer(reader) err = t.readWithByteBuffer(ctx, bufReader, recvMsg) if err != nil { diff --git a/pkg/remote/trans/netpollmux/server_handler_test.go b/pkg/remote/trans/netpollmux/server_handler_test.go index 90bfd01394..6c1f36620a 100644 --- a/pkg/remote/trans/netpollmux/server_handler_test.go +++ b/pkg/remote/trans/netpollmux/server_handler_test.go @@ -104,10 +104,6 @@ func TestNewTransHandler(t *testing.T) { handler, err := NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{}) test.Assert(t, err == nil, err) test.Assert(t, handler != nil) - - _, err = NewSvrTransHandlerFactory().NewTransHandler(&remote.ServerOption{OnlyAcceptingHTTP2Traffic: true}) - test.Assert(t, err != nil, err) - test.Assert(t, err.Error() == "only http2 traffic is accepted") } // TestOnActive test ServerTransHandler OnActive diff --git a/pkg/remote/trans_errors.go b/pkg/remote/trans_errors.go index 2b226a49f9..6adcf992ce 100644 --- a/pkg/remote/trans_errors.go +++ b/pkg/remote/trans_errors.go @@ -38,6 +38,7 @@ const ( UnsupportedClientType = 10 // kitex's own type id from number 20 UnknownService = 20 + NoServiceName = 21 ) var defaultTransErrorMessage = map[int32]string{ diff --git a/server/option.go b/server/option.go index 4d7b8f323c..82e83edd3a 100644 --- a/server/option.go +++ b/server/option.go @@ -357,9 +357,11 @@ func WithContextBackup(enable, async bool) Option { }} } -// WithOnlyAcceptingHTTP2Traffic returns an Option that accepts only http2 traffic. -func WithOnlyAcceptingHTTP2Traffic(enable bool) Option { +// WithRefuseTrafficWithoutServiceName returns an Option that only accepts traffics with service name. +// This is used for a server with multi services and is one of the options to avoid a server startup error +// when having conflicting method names between services without specifying a fallback service for the method. +func WithRefuseTrafficWithoutServiceName(enable bool) Option { return Option{F: func(o *internal_server.Options, di *utils.Slice) { - o.OnlyAcceptingHTTP2Traffic = enable + o.RefuseTrafficWithoutServiceName = enable }} } diff --git a/server/option_test.go b/server/option_test.go index 17a977e757..c61158ea80 100644 --- a/server/option_test.go +++ b/server/option_test.go @@ -453,7 +453,7 @@ func TestWithProfilerMessageTagging(t *testing.T) { mocks.MockErrorMethod: svcInfo, mocks.MockOnewayMethod: svcInfo, } - msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + msg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) newCtx, tags := iSvr.opt.RemoteOpt.ProfilerMessageTagging(ctx, msg) test.Assert(t, len(tags) == 8) @@ -461,3 +461,17 @@ func TestWithProfilerMessageTagging(t *testing.T) { test.Assert(t, newCtx.Value("ctx1").(int) == 1) test.Assert(t, newCtx.Value("ctx2").(int) == 2) } + +func TestRefuseTrafficWithoutServiceNamOption(t *testing.T) { + svr := NewServer(WithRefuseTrafficWithoutServiceName(true)) + err := svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) + test.Assert(t, err == nil, err) + time.AfterFunc(100*time.Millisecond, func() { + err := svr.Stop() + test.Assert(t, err == nil, err) + }) + err = svr.Run() + test.Assert(t, err == nil, err) + iSvr := svr.(*server) + test.Assert(t, iSvr.opt.RefuseTrafficWithoutServiceName) +} diff --git a/server/register_option_test.go b/server/register_option_test.go new file mode 100644 index 0000000000..0e6e4d9587 --- /dev/null +++ b/server/register_option_test.go @@ -0,0 +1,30 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "testing" + + internal_server "github.com/cloudwego/kitex/internal/server" + "github.com/cloudwego/kitex/internal/test" +) + +func TestWithFallbackService(t *testing.T) { + opts := []RegisterOption{WithFallbackService()} + registerOpts := internal_server.NewRegisterOptions(opts) + test.Assert(t, registerOpts.IsFallbackService) +} diff --git a/server/server.go b/server/server.go index c1078098e6..939db2b4ae 100644 --- a/server/server.go +++ b/server/server.go @@ -355,7 +355,7 @@ func (s *server) initBasicRemoteOption() { remoteOpt := s.opt.RemoteOpt remoteOpt.TargetSvcInfo = s.targetSvcInfo remoteOpt.SvcSearchMap = s.svcs.getSvcInfoSearchMap() - remoteOpt.OnlyAcceptingHTTP2Traffic = s.opt.OnlyAcceptingHTTP2Traffic + remoteOpt.RefuseTrafficWithoutServiceName = s.opt.RefuseTrafficWithoutServiceName remoteOpt.InitOrResetRPCInfoFunc = s.initOrResetRPCInfoFunc() remoteOpt.TracerCtl = s.opt.TracerCtl remoteOpt.ReadWriteTimeout = s.opt.Configs.ReadWriteTimeout() @@ -448,7 +448,7 @@ func (s *server) check() error { s.targetSvcInfo = getDefaultSvcInfo(s.svcs) return nil } - return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.OnlyAcceptingHTTP2Traffic) + return checkFallbackServiceForConflictingMethods(s.svcs.conflictingMethodHasFallbackSvcMap, s.opt.RefuseTrafficWithoutServiceName) } func doAddBoundHandlerToHead(h remote.BoundHandler, opt *remote.ServerOption) { @@ -572,8 +572,8 @@ func getDefaultSvcInfo(svcs *services) *serviceinfo.ServiceInfo { return nil } -func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, onlyAcceptingHTTP2Traffic bool) error { - if onlyAcceptingHTTP2Traffic { +func checkFallbackServiceForConflictingMethods(conflictingMethodHasFallbackSvcMap map[string]bool, refuseTrafficWithoutServiceName bool) error { + if refuseTrafficWithoutServiceName { return nil } for name, hasFallbackSvc := range conflictingMethodHasFallbackSvcMap { diff --git a/server/server_test.go b/server/server_test.go index 742afc9ebe..612a678cab 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -655,7 +655,7 @@ func testInvokeHandlerWithSession(t *testing.T, fail bool, ad string) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -742,7 +742,7 @@ func TestInvokeHandlerExec(t *testing.T) { { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -805,7 +805,7 @@ func TestInvokeHandlerPanic(t *testing.T) { // mock server call ri := rpcinfo.NewRPCInfo(nil, nil, rpcinfo.NewInvocation(svcInfo.ServiceName, callMethod), nil, rpcinfo.NewRPCStats()) ctx := rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) - recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server) + recvMsg := remote.NewMessageWithNewer(svcInfo, svcSearchMap, ri, remote.Call, remote.Server, false) recvMsg.NewData(callMethod) sendMsg := remote.NewMessage(svcInfo.MethodInfo(callMethod).NewResult(), svcInfo, ri, remote.Reply, remote.Server) @@ -921,18 +921,6 @@ func TestRegisterService(t *testing.T) { test.Assert(t, err != nil) test.Assert(t, err.Error() == "method name [mock] is conflicted between services but no fallback service is specified") svr.Stop() - - svr = NewServer(WithOnlyAcceptingHTTP2Traffic(true)) - time.AfterFunc(time.Second, func() { - err := svr.Stop() - test.Assert(t, err == nil, err) - }) - _ = svr.RegisterService(mocks.ServiceInfo(), mocks.MyServiceHandler()) - _ = svr.RegisterService(mocks.Service3Info(), mocks.MyServiceHandler()) - err = svr.Run() - test.Assert(t, err != nil) - test.Assert(t, err.Error() == "only http2 traffic is accepted") - svr.Stop() } type noopMetahandler struct{}