Skip to content

Commit

Permalink
Add unit test for the server module
Browse files Browse the repository at this point in the history
Signed-off-by: SimFG <bang.fu@zilliz.com>
  • Loading branch information
SimFG committed Oct 30, 2023
1 parent 3b4fd5a commit 9a36991
Show file tree
Hide file tree
Showing 7 changed files with 1,020 additions and 1,194 deletions.
13 changes: 12 additions & 1 deletion server/cdc_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

package server

import "github.com/zilliztech/milvus-cdc/server/model/request"
import (
"github.com/milvus-io/milvus/pkg/log"
"github.com/zilliztech/milvus-cdc/server/model/request"
)

type CDCService interface {
ReloadTask()
Expand All @@ -36,33 +39,41 @@ func NewBaseCDC() *BaseCDC {
}

func (b *BaseCDC) ReloadTask() {
log.Warn("ReloadTask is not implemented, please check it")
}

func (b *BaseCDC) Create(request *request.CreateRequest) (*request.CreateResponse, error) {
log.Warn("Create is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) Delete(request *request.DeleteRequest) (*request.DeleteResponse, error) {
log.Warn("Delete is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) Pause(request *request.PauseRequest) (*request.PauseResponse, error) {
log.Warn("Pause is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) Resume(request *request.ResumeRequest) (*request.ResumeResponse, error) {
log.Warn("Resume is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) Get(request *request.GetRequest) (*request.GetResponse, error) {
log.Warn("Get is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) GetPosition(req *request.GetPositionRequest) (*request.GetPositionResponse, error) {
log.Warn("GetPosition is not implemented, please check it")
return nil, nil
}

func (b *BaseCDC) List(request *request.ListRequest) (*request.ListResponse, error) {
log.Warn("List is not implemented, please check it")
return nil, nil
}

Expand Down
40 changes: 40 additions & 0 deletions server/cdc_api_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package server

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestDefaultCDCServer(t *testing.T) {
baseCDC := NewBaseCDC()
baseCDC.ReloadTask()
{
_, err := baseCDC.Create(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.Delete(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.Pause(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.Resume(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.Get(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.GetPosition(nil)
assert.NoError(t, err)
}
{
_, err := baseCDC.List(nil)
assert.NoError(t, err)
}
}
50 changes: 6 additions & 44 deletions server/cdc_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type ReplicateEntity struct {
type MetaCDC struct {
BaseCDC
metaStoreFactory serverapi.MetaStoreFactory
mqFactoryCreator cdcreader.FactoryCreator
rootPath string
config *CDCServerConfig

Expand Down Expand Up @@ -118,6 +119,7 @@ func NewMetaCDC(serverConfig *CDCServerConfig) *MetaCDC {
cdc := &MetaCDC{
metaStoreFactory: factory,
config: serverConfig,
mqFactoryCreator: cdcreader.NewDefaultFactoryCreator(),
}
cdc.collectionNames.data = make(map[string][]string)
cdc.collectionNames.excludeData = make(map[string][]string)
Expand All @@ -143,33 +145,6 @@ func (e *MetaCDC) ReloadTask() {
log.Panic("fail to get all task info", zap.Error(err))
}

// if reverse {
// var err error
// reverseTxn, commitFunc, err := e.metaStoreFactory.Txn(ctx)
// if err != nil {
// log.Panic("fail to new the reverse txn", zap.Error(err))
// }
// for _, taskInfo := range taskInfos {
// if taskInfo.MilvusConnectParam.Host == currentConfig.Host && taskInfo.MilvusConnectParam.Port == currentConfig.Port {
// taskInfo.MilvusConnectParam.Host = reverseConfig.Host
// taskInfo.MilvusConnectParam.Port = reverseConfig.Port
// taskInfo.MilvusConnectParam.Username = reverseConfig.Username
// taskInfo.MilvusConnectParam.Password = reverseConfig.Password
// taskInfo.MilvusConnectParam.EnableTLS = reverseConfig.EnableTLS
// if err = e.metaStoreFactory.GetTaskInfoMetaStore(ctx).Put(ctx, taskInfo, reverseTxn); err != nil {
// log.Panic("fail to put the task info to metastore when reversing", zap.Error(err))
// }
// // TODO need to use new target position in the future, not delete and receive the msg from the latest position
// if err = e.metaStoreFactory.GetTaskCollectionPositionMetaStore(ctx).Delete(ctx, &meta.TaskCollectionPosition{TaskID: taskInfo.TaskID}, reverseTxn); err != nil {
// log.Panic("fail to delete the task collection position to metastore when reversing", zap.Error(err))
// }
// }
// }
// if err = commitFunc(err); err != nil {
// log.Panic("fail to commit the reverse txn", zap.Error(err))
// }
// }

for _, taskInfo := range taskInfos {
milvusAddress := fmt.Sprintf("%s:%d", taskInfo.MilvusConnectParam.Host, taskInfo.MilvusConnectParam.Port)
newCollectionNames := lo.Map(taskInfo.CollectionInfos, func(t model.CollectionInfo, _ int) string {
Expand Down Expand Up @@ -234,9 +209,6 @@ func (e *MetaCDC) Create(req *request.CreateRequest) (resp *request.CreateRespon
existCollectionNames := e.collectionNames.data[milvusAddress]
excludeCollectionNames = make([]string, len(existCollectionNames))
copy(excludeCollectionNames, existCollectionNames)
// if !lo.Contains(excludeCollectionNames, util.RPCRequestCollectionName) {
// excludeCollectionNames = append(excludeCollectionNames, util.RPCRequestCollectionName)
// }
e.collectionNames.excludeData[milvusAddress] = excludeCollectionNames
}
e.collectionNames.data[milvusAddress] = append(e.collectionNames.data[milvusAddress], newCollectionNames...)
Expand Down Expand Up @@ -335,17 +307,6 @@ func (e *MetaCDC) validCreateRequest(req *request.CreateRequest) error {
return servererror.NewClientError("the cache size is less zero")
}

// if req.RPCChannelInfo.Name == "" {
// if err := e.checkCollectionInfos(req.CollectionInfos); err != nil {
// return err
// }
// } else {
// if len(req.CollectionInfos) > 0 {
// return servererror.NewClientError("the collection info should be empty when the rpc channel is not empty")
// }
// req.CollectionInfos = []model.CollectionInfo{{Name: util.RPCRequestCollectionName}}
// }

if err := e.checkCollectionInfos(req.CollectionInfos); err != nil {
return err
}
Expand Down Expand Up @@ -418,6 +379,7 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err
milvusAddress := fmt.Sprintf("%s:%d", milvusConnectParam.Host, milvusConnectParam.Port)
e.replicateEntityMap.RLock()
replicateEntity, ok := e.replicateEntityMap.data[milvusAddress]
log.Info("ok", zap.Any("ok", ok))
e.replicateEntityMap.RUnlock()

newReplicateEntity := func() (*ReplicateEntity, error) {
Expand All @@ -439,7 +401,7 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err
channelManager, err := cdcreader.NewReplicateChannelManager(config.MQConfig{
Pulsar: e.config.SourceConfig.Pulsar,
Kafka: e.config.SourceConfig.Kafka,
}, cdcreader.NewDefaultFactoryCreator(), milvusClient, bufferSize)
}, e.mqFactoryCreator, milvusClient, bufferSize)
if err != nil {
log.Warn("fail to create replicate channel manager", zap.Error(err))
return nil, servererror.NewClientError("fail to create replicate channel manager")
Expand Down Expand Up @@ -600,7 +562,7 @@ func (e *MetaCDC) startInternal(info *meta.TaskInfo, ignoreUpdateState bool) err
writeCallback.UpdateTaskCollectionPosition(TmpCollectionID, TmpCollectionName, channelName,
metaPosition, metaPosition, nil)
return true
}, cdcreader.NewDefaultFactoryCreator())
}, e.mqFactoryCreator)
readCtx, cancelReadFunc := context.WithCancel(context.Background())
e.replicateEntityMap.Lock()
// replicateEntity.readerObj = collectionReader
Expand Down Expand Up @@ -726,7 +688,7 @@ func (e *MetaCDC) GetPosition(req *request.GetPositionRequest) (*request.GetPosi
ctx := context.Background()
positions, err := e.metaStoreFactory.GetTaskCollectionPositionMetaStore(ctx).Get(ctx, &meta.TaskCollectionPosition{TaskID: req.TaskID}, nil)
if err != nil {
return nil, err
return nil, servererror.NewServerError(err)
}
resp := &request.GetPositionResponse{}
if len(positions) > 0 {
Expand Down
Loading

0 comments on commit 9a36991

Please sign in to comment.