diff --git a/pkg/restore/split_client.go b/pkg/restore/split_client.go index e78eb7511..70f2763e1 100755 --- a/pkg/restore/split_client.go +++ b/pkg/restore/split_client.go @@ -52,6 +52,8 @@ type SplitClient interface { // BatchSplitRegions splits a region from a batch of keys. // note: the keys should not be encoded BatchSplitRegions(ctx context.Context, regionInfo *RegionInfo, keys [][]byte) ([]*RegionInfo, error) + // BatchSplitRegionsWithOrigin splits a region from a batch of keys and return the original region and split new regions + BatchSplitRegionsWithOrigin(ctx context.Context, regionInfo *RegionInfo, keys [][]byte) (*RegionInfo, []*RegionInfo, error) // ScatterRegion scatters a specified region. ScatterRegion(ctx context.Context, regionInfo *RegionInfo) error // GetOperator gets the status of operator of the specified region. @@ -324,22 +326,20 @@ func (c *pdClient) sendSplitRegionRequest( return nil, errors.Trace(splitErrors) } -func (c *pdClient) BatchSplitRegions( +func (c *pdClient) BatchSplitRegionsWithOrigin( ctx context.Context, regionInfo *RegionInfo, keys [][]byte, -) ([]*RegionInfo, error) { +) (*RegionInfo, []*RegionInfo, error) { resp, err := c.sendSplitRegionRequest(ctx, regionInfo, keys) if err != nil { - return nil, errors.Trace(err) + return nil, nil, errors.Trace(err) } regions := resp.GetRegions() newRegionInfos := make([]*RegionInfo, 0, len(regions)) + var originRegion *RegionInfo for _, region := range regions { - // Skip the original region - if region.GetId() == regionInfo.Region.GetId() { - continue - } var leader *metapb.Peer + // Assume the leaders will be at the same store. if regionInfo.Leader != nil { for _, p := range region.GetPeers() { @@ -349,12 +349,27 @@ func (c *pdClient) BatchSplitRegions( } } } + // original region + if region.GetId() == regionInfo.Region.GetId() { + originRegion = &RegionInfo{ + Region: region, + Leader: leader, + } + continue + } newRegionInfos = append(newRegionInfos, &RegionInfo{ Region: region, Leader: leader, }) } - return newRegionInfos, nil + return originRegion, newRegionInfos, nil +} + +func (c *pdClient) BatchSplitRegions( + ctx context.Context, regionInfo *RegionInfo, keys [][]byte, +) ([]*RegionInfo, error) { + _, newRegions, err := c.BatchSplitRegionsWithOrigin(ctx, regionInfo, keys) + return newRegions, err } func (c *pdClient) ScatterRegion(ctx context.Context, regionInfo *RegionInfo) error { diff --git a/pkg/restore/split_test.go b/pkg/restore/split_test.go index 703d84900..9d357817a 100644 --- a/pkg/restore/split_test.go +++ b/pkg/restore/split_test.go @@ -119,12 +119,13 @@ func (c *testClient) SplitRegion( return newRegion, nil } -func (c *testClient) BatchSplitRegions( +func (c *testClient) BatchSplitRegionsWithOrigin( ctx context.Context, regionInfo *restore.RegionInfo, keys [][]byte, -) ([]*restore.RegionInfo, error) { +) (*restore.RegionInfo, []*restore.RegionInfo, error) { c.mu.Lock() defer c.mu.Unlock() newRegions := make([]*restore.RegionInfo, 0) + var region *restore.RegionInfo for _, key := range keys { var target *restore.RegionInfo splitKey := codec.EncodeBytes([]byte{}, key) @@ -148,9 +149,17 @@ func (c *testClient) BatchSplitRegions( c.nextRegionID++ target.Region.StartKey = splitKey c.regions[target.Region.Id] = target + region = target newRegions = append(newRegions, newRegion) } - return newRegions, nil + return region, newRegions, nil +} + +func (c *testClient) BatchSplitRegions( + ctx context.Context, regionInfo *restore.RegionInfo, keys [][]byte, +) ([]*restore.RegionInfo, error) { + _, newRegions, err := c.BatchSplitRegionsWithOrigin(ctx, regionInfo, keys) + return newRegions, err } func (c *testClient) ScatterRegion(ctx context.Context, regionInfo *restore.RegionInfo) error {