diff --git a/client/http/api.go b/client/http/api.go index 2fae562dd20..1826e2231ee 100644 --- a/client/http/api.go +++ b/client/http/api.go @@ -32,7 +32,7 @@ const ( regionsByKey = "/pd/api/v1/regions/key" RegionsByStoreIDPrefix = "/pd/api/v1/regions/store" EmptyRegions = "/pd/api/v1/regions/check/empty-region" - accelerateSchedule = "/pd/api/v1/regions/accelerate-schedule" + AccelerateSchedule = "/pd/api/v1/regions/accelerate-schedule" store = "/pd/api/v1/store" Stores = "/pd/api/v1/stores" StatsRegion = "/pd/api/v1/stats/region" @@ -45,7 +45,10 @@ const ( PlacementRule = "/pd/api/v1/config/rule" PlacementRules = "/pd/api/v1/config/rules" placementRulesByGroup = "/pd/api/v1/config/rules/group" + PlacementRuleBundle = "/pd/api/v1/config/placement-rule" RegionLabelRule = "/pd/api/v1/config/region-label/rule" + RegionLabelRules = "/pd/api/v1/config/region-label/rules" + RegionLabelRulesByIDs = "/pd/api/v1/config/region-label/rules/ids" // Scheduler Schedulers = "/pd/api/v1/schedulers" scatterRangeScheduler = "/pd/api/v1/schedulers/scatter-range-" @@ -123,6 +126,16 @@ func PlacementRuleByGroupAndID(group, id string) string { return fmt.Sprintf("%s/%s/%s", PlacementRule, group, id) } +// PlacementRuleBundleByGroup returns the path of PD HTTP API to get placement rule bundle by group. +func PlacementRuleBundleByGroup(group string) string { + return fmt.Sprintf("%s/%s", PlacementRuleBundle, group) +} + +// PlacementRuleBundleWithPartialParameter returns the path of PD HTTP API to get placement rule bundle with partial parameter. +func PlacementRuleBundleWithPartialParameter(partial bool) string { + return fmt.Sprintf("%s?partial=%t", PlacementRuleBundle, partial) +} + // SchedulerByName returns the scheduler API with the given scheduler name. func SchedulerByName(name string) string { return fmt.Sprintf("%s/%s", Schedulers, name) diff --git a/client/http/client.go b/client/http/client.go index 6fa2dd8cdfd..23a019d0611 100644 --- a/client/http/client.go +++ b/client/http/client.go @@ -42,6 +42,7 @@ const ( // Client is a PD (Placement Driver) HTTP client. type Client interface { + /* Meta-related interfaces */ GetRegionByID(context.Context, uint64) (*RegionInfo, error) GetRegionByKey(context.Context, []byte) (*RegionInfo, error) GetRegions(context.Context) (*RegionsInfo, error) @@ -51,11 +52,24 @@ type Client interface { GetHotWriteRegions(context.Context) (*StoreHotPeersInfos, error) GetRegionStatusByKeyRange(context.Context, []byte, []byte) (*RegionStats, error) GetStores(context.Context) (*StoresInfo, error) + /* Rule-related interfaces */ + GetAllPlacementRuleBundles(context.Context) ([]*GroupBundle, error) + GetPlacementRuleBundleByGroup(context.Context, string) (*GroupBundle, error) GetPlacementRulesByGroup(context.Context, string) ([]*Rule, error) SetPlacementRule(context.Context, *Rule) error + SetPlacementRuleBundles(context.Context, []*GroupBundle, bool) error DeletePlacementRule(context.Context, string, string) error - GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) + GetAllRegionLabelRules(context.Context) ([]*LabelRule, error) + GetRegionLabelRulesByIDs(context.Context, []string) ([]*LabelRule, error) + SetRegionLabelRule(context.Context, *LabelRule) error + PatchRegionLabelRules(context.Context, *LabelRulePatch) error + /* Scheduling-related interfaces */ AccelerateSchedule(context.Context, []byte, []byte) error + /* Other interfaces */ + GetMinResolvedTSByStoresIDs(context.Context, []uint64) (uint64, map[uint64]uint64, error) + + /* Client-related methods */ + WithRespHandler(func(resp *http.Response) error) Client Close() } @@ -66,6 +80,8 @@ type client struct { tlsConf *tls.Config cli *http.Client + respHandler func(resp *http.Response) error + requestCounter *prometheus.CounterVec executionDuration *prometheus.HistogramVec } @@ -143,6 +159,13 @@ func (c *client) Close() { log.Info("[pd] http client closed") } +// WithRespHandler sets the client with the given HTTP response handler. +// This allows the caller to customize how the response is handled, including error handling logic. +func (c *client) WithRespHandler(handler func(resp *http.Response) error) Client { + c.respHandler = handler + return c +} + func (c *client) reqCounter(name, status string) { if c.requestCounter == nil { return @@ -204,6 +227,12 @@ func (c *client) request( } c.execDuration(name, time.Since(start)) c.reqCounter(name, resp.Status) + + // Give away the response handling to the caller if the handler is set. + if c.respHandler != nil { + return c.respHandler(resp) + } + defer func() { err = resp.Body.Close() if err != nil { @@ -345,6 +374,30 @@ func (c *client) GetStores(ctx context.Context) (*StoresInfo, error) { return &stores, nil } +// GetAllPlacementRuleBundles gets all placement rules bundles. +func (c *client) GetAllPlacementRuleBundles(ctx context.Context) ([]*GroupBundle, error) { + var bundles []*GroupBundle + err := c.requestWithRetry(ctx, + "GetPlacementRuleBundle", PlacementRuleBundle, + http.MethodGet, nil, &bundles) + if err != nil { + return nil, err + } + return bundles, nil +} + +// GetPlacementRuleBundleByGroup gets the placement rules bundle by group. +func (c *client) GetPlacementRuleBundleByGroup(ctx context.Context, group string) (*GroupBundle, error) { + var bundle GroupBundle + err := c.requestWithRetry(ctx, + "GetPlacementRuleBundleByGroup", PlacementRuleBundleByGroup(group), + http.MethodGet, nil, &bundle) + if err != nil { + return nil, err + } + return &bundle, nil +} + // GetPlacementRulesByGroup gets the placement rules by group. func (c *client) GetPlacementRulesByGroup(ctx context.Context, group string) ([]*Rule, error) { var rules []*Rule @@ -368,6 +421,18 @@ func (c *client) SetPlacementRule(ctx context.Context, rule *Rule) error { http.MethodPost, bytes.NewBuffer(ruleJSON), nil) } +// SetPlacementRuleBundles sets the placement rule bundles. +// If `partial` is false, all old configurations will be over-written and dropped. +func (c *client) SetPlacementRuleBundles(ctx context.Context, bundles []*GroupBundle, partial bool) error { + bundlesJSON, err := json.Marshal(bundles) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetPlacementRuleBundles", PlacementRuleBundleWithPartialParameter(partial), + http.MethodPost, bytes.NewBuffer(bundlesJSON), nil) +} + // DeletePlacementRule deletes the placement rule. func (c *client) DeletePlacementRule(ctx context.Context, group, id string) error { return c.requestWithRetry(ctx, @@ -375,6 +440,71 @@ func (c *client) DeletePlacementRule(ctx context.Context, group, id string) erro http.MethodDelete, nil, nil) } +// GetAllRegionLabelRules gets all region label rules. +func (c *client) GetAllRegionLabelRules(ctx context.Context) ([]*LabelRule, error) { + var labelRules []*LabelRule + err := c.requestWithRetry(ctx, + "GetAllRegionLabelRules", RegionLabelRules, + http.MethodGet, nil, &labelRules) + if err != nil { + return nil, err + } + return labelRules, nil +} + +// GetRegionLabelRulesByIDs gets the region label rules by IDs. +func (c *client) GetRegionLabelRulesByIDs(ctx context.Context, ruleIDs []string) ([]*LabelRule, error) { + idsJSON, err := json.Marshal(ruleIDs) + if err != nil { + return nil, errors.Trace(err) + } + var labelRules []*LabelRule + err = c.requestWithRetry(ctx, + "GetRegionLabelRulesByIDs", RegionLabelRulesByIDs, + http.MethodGet, bytes.NewBuffer(idsJSON), &labelRules) + if err != nil { + return nil, err + } + return labelRules, nil +} + +// SetRegionLabelRule sets the region label rule. +func (c *client) SetRegionLabelRule(ctx context.Context, labelRule *LabelRule) error { + labelRuleJSON, err := json.Marshal(labelRule) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "SetRegionLabelRule", RegionLabelRule, + http.MethodPost, bytes.NewBuffer(labelRuleJSON), nil) +} + +// PatchRegionLabelRules patches the region label rules. +func (c *client) PatchRegionLabelRules(ctx context.Context, labelRulePatch *LabelRulePatch) error { + labelRulePatchJSON, err := json.Marshal(labelRulePatch) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "PatchRegionLabelRules", RegionLabelRules, + http.MethodPatch, bytes.NewBuffer(labelRulePatchJSON), nil) +} + +// AccelerateSchedule accelerates the scheduling of the regions within the given key range. +func (c *client) AccelerateSchedule(ctx context.Context, startKey, endKey []byte) error { + input := map[string]string{ + "start_key": url.QueryEscape(string(startKey)), + "end_key": url.QueryEscape(string(endKey)), + } + inputJSON, err := json.Marshal(input) + if err != nil { + return errors.Trace(err) + } + return c.requestWithRetry(ctx, + "AccelerateSchedule", AccelerateSchedule, + http.MethodPost, bytes.NewBuffer(inputJSON), nil) +} + // GetMinResolvedTSByStoresIDs get min-resolved-ts by stores IDs. func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uint64) (uint64, map[uint64]uint64, error) { uri := MinResolvedTSPrefix @@ -406,18 +536,3 @@ func (c *client) GetMinResolvedTSByStoresIDs(ctx context.Context, storeIDs []uin } return resp.MinResolvedTS, resp.StoresMinResolvedTS, nil } - -// AccelerateSchedule accelerates the scheduling of the regions within the given key range. -func (c *client) AccelerateSchedule(ctx context.Context, startKey, endKey []byte) error { - input := map[string]string{ - "start_key": url.QueryEscape(string(startKey)), - "end_key": url.QueryEscape(string(endKey)), - } - inputJSON, err := json.Marshal(input) - if err != nil { - return errors.Trace(err) - } - return c.requestWithRetry(ctx, - "AccelerateSchedule", accelerateSchedule, - http.MethodPost, bytes.NewBuffer(inputJSON), nil) -} diff --git a/client/http/types.go b/client/http/types.go index c6bb0256c14..f948286c2b5 100644 --- a/client/http/types.go +++ b/client/http/types.go @@ -246,3 +246,34 @@ type Rule struct { Version uint64 `json:"version,omitempty"` // only set at runtime, add 1 each time rules updated, begin from 0. CreateTimestamp uint64 `json:"create_timestamp,omitempty"` // only set at runtime, recorded rule create timestamp } + +// GroupBundle represents a rule group and all rules belong to the group. +type GroupBundle struct { + ID string `json:"group_id"` + Index int `json:"group_index"` + Override bool `json:"group_override"` + Rules []*Rule `json:"rules"` +} + +// RegionLabel is the label of a region. +type RegionLabel struct { + Key string `json:"key"` + Value string `json:"value"` + TTL string `json:"ttl,omitempty"` + StartAt string `json:"start_at,omitempty"` +} + +// LabelRule is the rule to assign labels to a region. +type LabelRule struct { + ID string `json:"id"` + Index int `json:"index"` + Labels []RegionLabel `json:"labels"` + RuleType string `json:"rule_type"` + Data interface{} `json:"data"` +} + +// LabelRulePatch is the patch to update the label rules. +type LabelRulePatch struct { + SetRules []*LabelRule `json:"sets"` + DeleteRules []string `json:"deletes"` +} diff --git a/tests/integrations/client/http_client_test.go b/tests/integrations/client/http_client_test.go index d2c88d01f09..213aa57de46 100644 --- a/tests/integrations/client/http_client_test.go +++ b/tests/integrations/client/http_client_test.go @@ -17,10 +17,12 @@ package client_test import ( "context" "math" + "sort" "testing" "github.com/stretchr/testify/suite" pd "github.com/tikv/pd/client/http" + "github.com/tikv/pd/pkg/schedule/labeler" "github.com/tikv/pd/pkg/schedule/placement" "github.com/tikv/pd/tests" ) @@ -89,6 +91,13 @@ func (suite *httpClientTestSuite) TestGetMinResolvedTSByStoresIDs() { func (suite *httpClientTestSuite) TestRule() { re := suite.Require() + bundles, err := suite.client.GetAllPlacementRuleBundles(suite.ctx) + re.NoError(err) + re.Len(bundles, 1) + re.Equal(bundles[0].ID, placement.DefaultGroupID) + bundle, err := suite.client.GetPlacementRuleBundleByGroup(suite.ctx, placement.DefaultGroupID) + re.NoError(err) + re.Equal(bundles[0], bundle) rules, err := suite.client.GetPlacementRulesByGroup(suite.ctx, placement.DefaultGroupID) re.NoError(err) re.Len(rules, 1) @@ -96,19 +105,22 @@ func (suite *httpClientTestSuite) TestRule() { re.Equal(placement.DefaultRuleID, rules[0].ID) re.Equal(pd.Voter, rules[0].Role) re.Equal(3, rules[0].Count) - err = suite.client.SetPlacementRule(suite.ctx, &pd.Rule{ + // Should be the same as the rules in the bundle. + re.Equal(bundle.Rules, rules) + testRule := &pd.Rule{ GroupID: placement.DefaultGroupID, ID: "test", - Role: pd.Learner, + Role: pd.Voter, Count: 3, - }) + } + err = suite.client.SetPlacementRule(suite.ctx, testRule) re.NoError(err) rules, err = suite.client.GetPlacementRulesByGroup(suite.ctx, placement.DefaultGroupID) re.NoError(err) re.Len(rules, 2) re.Equal(placement.DefaultGroupID, rules[1].GroupID) re.Equal("test", rules[1].ID) - re.Equal(pd.Learner, rules[1].Role) + re.Equal(pd.Voter, rules[1].Role) re.Equal(3, rules[1].Count) err = suite.client.DeletePlacementRule(suite.ctx, placement.DefaultGroupID, "test") re.NoError(err) @@ -117,6 +129,75 @@ func (suite *httpClientTestSuite) TestRule() { re.Len(rules, 1) re.Equal(placement.DefaultGroupID, rules[0].GroupID) re.Equal(placement.DefaultRuleID, rules[0].ID) + err = suite.client.SetPlacementRuleBundles(suite.ctx, []*pd.GroupBundle{ + { + ID: placement.DefaultGroupID, + Rules: []*pd.Rule{testRule}, + }, + }, true) + re.NoError(err) + bundles, err = suite.client.GetAllPlacementRuleBundles(suite.ctx) + re.NoError(err) + re.Len(bundles, 1) + re.Equal(placement.DefaultGroupID, bundles[0].ID) + re.Len(bundles[0].Rules, 1) + // Make sure the create timestamp is not zero to pass the later assertion. + testRule.CreateTimestamp = bundles[0].Rules[0].CreateTimestamp + re.Equal(testRule, bundles[0].Rules[0]) +} + +func (suite *httpClientTestSuite) TestRegionLabel() { + re := suite.Require() + labelRules, err := suite.client.GetAllRegionLabelRules(suite.ctx) + re.NoError(err) + re.Len(labelRules, 1) + re.Equal("keyspaces/0", labelRules[0].ID) + // Set a new region label rule. + labelRule := &pd.LabelRule{ + ID: "rule1", + Labels: []pd.RegionLabel{{Key: "k1", Value: "v1"}}, + RuleType: "key-range", + Data: labeler.MakeKeyRanges("1234", "5678"), + } + err = suite.client.SetRegionLabelRule(suite.ctx, labelRule) + re.NoError(err) + labelRules, err = suite.client.GetAllRegionLabelRules(suite.ctx) + re.NoError(err) + re.Len(labelRules, 2) + sort.Slice(labelRules, func(i, j int) bool { + return labelRules[i].ID < labelRules[j].ID + }) + re.Equal(labelRule.ID, labelRules[1].ID) + re.Equal(labelRule.Labels, labelRules[1].Labels) + re.Equal(labelRule.RuleType, labelRules[1].RuleType) + // Patch the region label rule. + labelRule = &pd.LabelRule{ + ID: "rule2", + Labels: []pd.RegionLabel{{Key: "k2", Value: "v2"}}, + RuleType: "key-range", + Data: labeler.MakeKeyRanges("ab12", "cd12"), + } + patch := &pd.LabelRulePatch{ + SetRules: []*pd.LabelRule{labelRule}, + DeleteRules: []string{"rule1"}, + } + err = suite.client.PatchRegionLabelRules(suite.ctx, patch) + re.NoError(err) + allLabelRules, err := suite.client.GetAllRegionLabelRules(suite.ctx) + re.NoError(err) + re.Len(labelRules, 2) + sort.Slice(allLabelRules, func(i, j int) bool { + return allLabelRules[i].ID < allLabelRules[j].ID + }) + re.Equal(labelRule.ID, allLabelRules[1].ID) + re.Equal(labelRule.Labels, allLabelRules[1].Labels) + re.Equal(labelRule.RuleType, allLabelRules[1].RuleType) + labelRules, err = suite.client.GetRegionLabelRulesByIDs(suite.ctx, []string{"keyspaces/0", "rule2"}) + re.NoError(err) + sort.Slice(labelRules, func(i, j int) bool { + return labelRules[i].ID < labelRules[j].ID + }) + re.Equal(allLabelRules, labelRules) } func (suite *httpClientTestSuite) TestAccelerateSchedule() {