Skip to content

Commit

Permalink
server: refactor the independent service check (#8738)
Browse files Browse the repository at this point in the history
ref #8477

Signed-off-by: Ryan Leung <rleungx@gmail.com>

Co-authored-by: ti-chi-bot[bot] <108142056+ti-chi-bot[bot]@users.noreply.github.com>
  • Loading branch information
rleungx and ti-chi-bot[bot] authored Oct 25, 2024
1 parent 24343e4 commit 988c9a3
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 67 deletions.
2 changes: 1 addition & 1 deletion pkg/utils/apiutil/serverapi/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func (h *redirector) matchMicroServiceRedirectRules(r *http.Request) (bool, stri
for _, rule := range h.microserviceRedirectRules {
// Now we only support checking the scheduling service whether it is independent
if rule.targetServiceName == constant.SchedulingServiceName {
if !h.s.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) {
if !h.s.IsServiceIndependent(constant.SchedulingServiceName) {
continue
}
}
Expand Down
6 changes: 3 additions & 3 deletions server/api/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func newConfHandler(svr *server.Server, rd *render.Render) *confHandler {
// @Router /config [get]
func (h *confHandler) GetConfig(w http.ResponseWriter, r *http.Request) {
cfg := h.svr.GetConfig()
if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) &&
if h.svr.IsServiceIndependent(constant.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
schedulingServerConfig, err := h.getSchedulingServerConfig()
if err != nil {
Expand Down Expand Up @@ -336,7 +336,7 @@ func getConfigMap(cfg map[string]any, key []string, value any) map[string]any {
// @Success 200 {object} sc.ScheduleConfig
// @Router /config/schedule [get]
func (h *confHandler) GetScheduleConfig(w http.ResponseWriter, r *http.Request) {
if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) &&
if h.svr.IsServiceIndependent(constant.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
cfg, err := h.getSchedulingServerConfig()
if err != nil {
Expand Down Expand Up @@ -409,7 +409,7 @@ func (h *confHandler) SetScheduleConfig(w http.ResponseWriter, r *http.Request)
// @Success 200 {object} sc.ReplicationConfig
// @Router /config/replicate [get]
func (h *confHandler) GetReplicationConfig(w http.ResponseWriter, r *http.Request) {
if h.svr.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName) &&
if h.svr.IsServiceIndependent(constant.SchedulingServiceName) &&
r.Header.Get(apiutil.XForbiddenForwardToMicroServiceHeader) != "true" {
cfg, err := h.getSchedulingServerConfig()
if err != nil {
Expand Down
9 changes: 0 additions & 9 deletions server/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -2510,25 +2510,16 @@ func IsClientURL(addr string, etcdClient *clientv3.Client) bool {

// IsServiceIndependent returns whether the service is independent.
func (c *RaftCluster) IsServiceIndependent(name string) bool {
if c == nil {
return false
}
_, exist := c.independentServices.Load(name)
return exist
}

// SetServiceIndependent sets the service to be independent.
func (c *RaftCluster) SetServiceIndependent(name string) {
if c == nil {
return
}
c.independentServices.Store(name, struct{}{})
}

// UnsetServiceIndependent unsets the service to be independent.
func (c *RaftCluster) UnsetServiceIndependent(name string) {
if c == nil {
return
}
c.independentServices.Delete(name)
}
112 changes: 64 additions & 48 deletions server/forward.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,66 +133,82 @@ func (s *GrpcServer) forwardTSO(stream pdpb.PD_TsoServer) error {
err = errs.ErrGenerateTimestamp.FastGenByArgs("tso count should be positive")
return status.Error(codes.Unknown, err.Error())
}

forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
tsoStreamErr = errors.WithStack(ErrNotFoundTSOAddr)
forwardCtx, cancelForward, forwardStream, lastForwardedHost, tsoStreamErr, err = s.handleTSOForwarding(forwardCtx, forwardStream, stream, server, request, tsDeadlineCh, lastForwardedHost, cancelForward)
if tsoStreamErr != nil {
return tsoStreamErr
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancelForward != nil {
cancelForward()
}
if err != nil {
return err
}
}
}

clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
if err != nil {
tsoStreamErr = errors.WithStack(err)
return tsoStreamErr
}
forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn)
if err != nil {
tsoStreamErr = errors.WithStack(err)
return tsoStreamErr
}
lastForwardedHost = forwardedHost
func (s *GrpcServer) handleTSOForwarding(forwardCtx context.Context, forwardStream tsopb.TSO_TsoClient, stream pdpb.PD_TsoServer, server *tsoServer,
request *pdpb.TsoRequest, tsDeadlineCh chan<- *tsoutil.TSDeadline, lastForwardedHost string, cancelForward context.CancelFunc) (
context.Context,
context.CancelFunc,
tsopb.TSO_TsoClient,
string,
error, // tso stream error
error, // send error
) {
forwardedHost, ok := s.GetServicePrimaryAddr(stream.Context(), constant.TSOServiceName)
if !ok || len(forwardedHost) == 0 {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(ErrNotFoundTSOAddr), nil
}
if forwardStream == nil || lastForwardedHost != forwardedHost {
if cancelForward != nil {
cancelForward()
}

tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh)
clientConn, err := s.getDelegateClient(s.ctx, forwardedHost)
if err != nil {
tsoStreamErr = errors.WithStack(err)
return tsoStreamErr
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
}
forwardStream, forwardCtx, cancelForward, err = createTSOForwardStream(stream.Context(), clientConn)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
}
lastForwardedHost = forwardedHost
}

// The error types defined for tsopb and pdpb are different, so we need to convert them.
var pdpbErr *pdpb.Error
tsopbErr := tsopbResp.GetHeader().GetError()
if tsopbErr != nil {
if tsopbErr.Type == tsopb.ErrorType_OK {
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_OK,
Message: tsopbErr.GetMessage(),
}
} else {
// TODO: specify FORWARD FAILURE error type instead of UNKNOWN.
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_UNKNOWN,
Message: tsopbErr.GetMessage(),
}
tsopbResp, err := s.forwardTSORequestWithDeadLine(forwardCtx, cancelForward, forwardStream, request, tsDeadlineCh)
if err != nil {
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, errors.WithStack(err), nil
}

// The error types defined for tsopb and pdpb are different, so we need to convert them.
var pdpbErr *pdpb.Error
tsopbErr := tsopbResp.GetHeader().GetError()
if tsopbErr != nil {
if tsopbErr.Type == tsopb.ErrorType_OK {
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_OK,
Message: tsopbErr.GetMessage(),
}
} else {
// TODO: specify FORWARD FAILURE error type instead of UNKNOWN.
pdpbErr = &pdpb.Error{
Type: pdpb.ErrorType_UNKNOWN,
Message: tsopbErr.GetMessage(),
}
}
}

response := &pdpb.TsoResponse{
Header: &pdpb.ResponseHeader{
ClusterId: tsopbResp.GetHeader().GetClusterId(),
Error: pdpbErr,
},
Count: tsopbResp.GetCount(),
Timestamp: tsopbResp.GetTimestamp(),
}
if err := server.send(response); err != nil {
return errors.WithStack(err)
}
response := &pdpb.TsoResponse{
Header: &pdpb.ResponseHeader{
ClusterId: tsopbResp.GetHeader().GetClusterId(),
Error: pdpbErr,
},
Count: tsopbResp.GetCount(),
Timestamp: tsopbResp.GetTimestamp(),
}
if server != nil {
err = server.send(response)
} else {
err = stream.Send(response)
}
return forwardCtx, cancelForward, forwardStream, lastForwardedHost, nil, errors.WithStack(err)
}

func (s *GrpcServer) forwardTSORequestWithDeadLine(
Expand Down
8 changes: 4 additions & 4 deletions server/grpc_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func (s *GrpcServer) GetClusterInfo(context.Context, *pdpb.GetClusterInfoRequest

var tsoServiceAddrs []string
svcModes := make([]pdpb.ServiceMode, 0)
if s.IsAPIServiceMode() {
if s.IsServiceIndependent(constant.TSOServiceName) {
svcModes = append(svcModes, pdpb.ServiceMode_API_SVC_MODE)
tsoServiceAddrs = s.keyspaceGroupManager.GetTSOServiceAddrs()
} else {
Expand Down Expand Up @@ -318,7 +318,7 @@ func (s *GrpcServer) GetMinTS(
minTS *pdpb.Timestamp
err error
)
if s.IsAPIServiceMode() {
if s.IsServiceIndependent(constant.TSOServiceName) {
minTS, err = s.GetMinTSFromTSOService(tso.GlobalDCLocation)
} else {
start := time.Now()
Expand Down Expand Up @@ -486,7 +486,7 @@ func (s *GrpcServer) GetMembers(context.Context, *pdpb.GetMembersRequest) (*pdpb
}

tsoAllocatorLeaders := make(map[string]*pdpb.Member)
if !s.IsAPIServiceMode() {
if !s.IsServiceIndependent(constant.TSOServiceName) {
tsoAllocatorManager := s.GetTSOAllocatorManager()
tsoAllocatorLeaders, err = tsoAllocatorManager.GetLocalAllocatorLeaders()
}
Expand Down Expand Up @@ -524,7 +524,7 @@ func (s *GrpcServer) Tso(stream pdpb.PD_TsoServer) error {
return err
}
}
if s.IsAPIServiceMode() {
if s.IsServiceIndependent(constant.TSOServiceName) {
return s.forwardTSO(stream)
}

Expand Down
12 changes: 12 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,18 @@ func (s *Server) GetRaftCluster() *cluster.RaftCluster {
return s.cluster
}

// IsServiceIndependent returns whether the service is independent.
func (s *Server) IsServiceIndependent(name string) bool {
if s.mode == APIServiceMode && !s.IsClosed() {
// TODO: remove it after we support tso discovery
if name == constant.TSOServiceName {
return true
}
return s.cluster.IsServiceIndependent(name)
}
return false
}

// DirectlyGetRaftCluster returns raft cluster directly.
// Only used for test.
func (s *Server) DirectlyGetRaftCluster() *cluster.RaftCluster {
Expand Down
2 changes: 1 addition & 1 deletion tests/integrations/mcs/scheduling/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (suite *apiTestSuite) checkAPIForward(cluster *tests.TestCluster) {
var respSlice []string
var resp map[string]any
testutil.Eventually(re, func() bool {
return leader.GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName)
return leader.IsServiceIndependent(constant.SchedulingServiceName)
})

// Test operators
Expand Down
2 changes: 1 addition & 1 deletion tests/testutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ func (s *SchedulingTestEnvironment) startCluster(m SchedulerMode) {
cluster.SetSchedulingCluster(tc)
time.Sleep(200 * time.Millisecond) // wait for scheduling cluster to update member
testutil.Eventually(re, func() bool {
return cluster.GetLeaderServer().GetServer().GetRaftCluster().IsServiceIndependent(constant.SchedulingServiceName)
return cluster.GetLeaderServer().GetServer().IsServiceIndependent(constant.SchedulingServiceName)
})
s.clusters[APIMode] = cluster
}
Expand Down

0 comments on commit 988c9a3

Please sign in to comment.