From bd79cc2b95a24cb9ba9dacb7d9baf915e651c44d Mon Sep 17 00:00:00 2001 From: JmPotato Date: Tue, 31 May 2022 16:06:27 +0800 Subject: [PATCH 1/4] *: use require.New to reduce code (#5076) ref tikv/pd#4813 Use `require.New` to reduce code. Signed-off-by: JmPotato --- pkg/apiutil/apiutil_test.go | 18 +- pkg/assertutil/assertutil_test.go | 5 +- pkg/audit/audit_test.go | 28 +-- pkg/autoscaling/calculation_test.go | 33 ++-- pkg/autoscaling/prometheus_test.go | 32 +-- pkg/cache/cache_test.go | 275 +++++++++++++------------- pkg/codec/codec_test.go | 16 +- pkg/encryption/config_test.go | 18 +- pkg/encryption/crypter_test.go | 70 +++---- pkg/encryption/master_key_test.go | 66 ++++--- pkg/encryption/region_crypter_test.go | 74 +++---- pkg/errs/errs_test.go | 23 ++- pkg/etcdutil/etcdutil_test.go | 81 ++++---- pkg/grpcutil/grpcutil_test.go | 19 +- pkg/keyutil/util_test.go | 3 +- pkg/logutil/log_test.go | 22 ++- pkg/metricutil/metricutil_test.go | 3 +- 17 files changed, 422 insertions(+), 364 deletions(-) diff --git a/pkg/apiutil/apiutil_test.go b/pkg/apiutil/apiutil_test.go index 94cf96c3f26..8a79edc9784 100644 --- a/pkg/apiutil/apiutil_test.go +++ b/pkg/apiutil/apiutil_test.go @@ -25,6 +25,7 @@ import ( ) func TestJsonRespondErrorOk(t *testing.T) { + re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, }) @@ -33,15 +34,16 @@ func TestJsonRespondErrorOk(t *testing.T) { var input map[string]string output := map[string]string{"zone": "cn", "host": "local"} err := ReadJSONRespondError(rd, response, body, &input) - require.NoError(t, err) - require.Equal(t, output["zone"], input["zone"]) - require.Equal(t, output["host"], input["host"]) + re.NoError(err) + re.Equal(output["zone"], input["zone"]) + re.Equal(output["host"], input["host"]) result := response.Result() defer result.Body.Close() - require.Equal(t, 200, result.StatusCode) + re.Equal(200, result.StatusCode) } func TestJsonRespondErrorBadInput(t *testing.T) { + re := require.New(t) rd := render.New(render.Options{ IndentJSON: true, }) @@ -49,18 +51,18 @@ func TestJsonRespondErrorBadInput(t *testing.T) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\", \"host\":\"local\"}")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - require.EqualError(t, err, "json: cannot unmarshal object into Go value of type []string") + re.EqualError(err, "json: cannot unmarshal object into Go value of type []string") result := response.Result() defer result.Body.Close() - require.Equal(t, 400, result.StatusCode) + re.Equal(400, result.StatusCode) { body := io.NopCloser(bytes.NewBufferString("{\"zone\":\"cn\",")) var input []string err := ReadJSONRespondError(rd, response, body, &input) - require.EqualError(t, err, "unexpected end of JSON input") + re.EqualError(err, "unexpected end of JSON input") result := response.Result() defer result.Body.Close() - require.Equal(t, 400, result.StatusCode) + re.Equal(400, result.StatusCode) } } diff --git a/pkg/assertutil/assertutil_test.go b/pkg/assertutil/assertutil_test.go index 6cdfd591937..324e403f7b6 100644 --- a/pkg/assertutil/assertutil_test.go +++ b/pkg/assertutil/assertutil_test.go @@ -22,11 +22,12 @@ import ( ) func TestNilFail(t *testing.T) { + re := require.New(t) var failErr error checker := NewChecker(func() { failErr = errors.New("called assert func not exist") }) - require.Nil(t, checker.IsNil) + re.Nil(checker.IsNil) checker.AssertNil(nil) - require.NotNil(t, failErr) + re.NotNil(failErr) } diff --git a/pkg/audit/audit_test.go b/pkg/audit/audit_test.go index 66df5298b8b..2b33b62ca55 100644 --- a/pkg/audit/audit_test.go +++ b/pkg/audit/audit_test.go @@ -32,14 +32,16 @@ import ( ) func TestLabelMatcher(t *testing.T) { + re := require.New(t) matcher := &LabelMatcher{"testSuccess"} labels1 := &BackendLabels{Labels: []string{"testFail", "testSuccess"}} - require.True(t, matcher.Match(labels1)) + re.True(matcher.Match(labels1)) labels2 := &BackendLabels{Labels: []string{"testFail"}} - require.False(t, matcher.Match(labels2)) + re.False(matcher.Match(labels2)) } func TestPrometheusHistogramBackend(t *testing.T) { + re := require.New(t) serviceAuditHistogramTest := prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "pd", @@ -60,43 +62,43 @@ func TestPrometheusHistogramBackend(t *testing.T) { info.ServiceLabel = "test" info.Component = "user1" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.False(t, backend.ProcessHTTPRequest(req)) + re.False(backend.ProcessHTTPRequest(req)) endTime := time.Now().Unix() + 20 req = req.WithContext(requestutil.WithEndTime(req.Context(), endTime)) - require.True(t, backend.ProcessHTTPRequest(req)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) info.Component = "user2" req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) // For test, sleep time needs longer than the push interval time.Sleep(1 * time.Second) req, _ = http.NewRequest("GET", ts.URL, nil) resp, err := http.DefaultClient.Do(req) - require.NoError(t, err) + re.NoError(err) defer resp.Body.Close() content, _ := io.ReadAll(resp.Body) output := string(content) - require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2") - require.Contains(t, output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user1\",method=\"HTTP\",service=\"test\"} 2") + re.Contains(output, "pd_service_audit_handling_seconds_test_count{component=\"user2\",method=\"HTTP\",service=\"test\"} 1") } func TestLocalLogBackendUsingFile(t *testing.T) { + re := require.New(t) backend := NewLocalLogBackend(true) fname := initLog() defer os.Remove(fname) req, _ := http.NewRequest("GET", "http://127.0.0.1:2379/test?test=test", strings.NewReader("testBody")) - require.False(t, backend.ProcessHTTPRequest(req)) + re.False(backend.ProcessHTTPRequest(req)) info := requestutil.GetRequestInfo(req) req = req.WithContext(requestutil.WithRequestInfo(req.Context(), info)) - require.True(t, backend.ProcessHTTPRequest(req)) + re.True(backend.ProcessHTTPRequest(req)) b, _ := os.ReadFile(fname) output := strings.SplitN(string(b), "]", 4) - require.Equal( - t, + re.Equal( fmt.Sprintf(" [\"Audit Log\"] [service-info=\"{ServiceLabel:, Method:HTTP/1.1/GET:/test, Component:anonymous, IP:, "+ "StartTime:%s, URLParam:{\\\"test\\\":[\\\"test\\\"]}, BodyParam:testBody}\"]\n", time.Unix(info.StartTimeStamp, 0).String()), diff --git a/pkg/autoscaling/calculation_test.go b/pkg/autoscaling/calculation_test.go index 8334c96ecc5..f5ac3313ba4 100644 --- a/pkg/autoscaling/calculation_test.go +++ b/pkg/autoscaling/calculation_test.go @@ -30,6 +30,7 @@ import ( ) func TestGetScaledTiKVGroups(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() // case1 indicates the tikv cluster with not any group existed @@ -193,10 +194,10 @@ func TestGetScaledTiKVGroups(t *testing.T) { t.Log(testCase.name) plans, err := getScaledTiKVGroups(testCase.informer, testCase.healthyInstances) if testCase.expectedPlan == nil { - require.Len(t, plans, 0) - require.Equal(t, testCase.noError, err == nil) + re.Len(plans, 0) + re.Equal(testCase.noError, err == nil) } else { - require.True(t, reflect.DeepEqual(testCase.expectedPlan, plans)) + re.True(reflect.DeepEqual(testCase.expectedPlan, plans)) } } } @@ -213,6 +214,7 @@ func (q *mockQuerier) Query(options *QueryOptions) (QueryResult, error) { } func TestGetTotalCPUUseTime(t *testing.T) { + re := require.New(t) querier := &mockQuerier{} instances := []instance{ { @@ -230,10 +232,11 @@ func TestGetTotalCPUUseTime(t *testing.T) { } totalCPUUseTime, _ := getTotalCPUUseTime(querier, TiDB, instances, time.Now(), 0) expected := mockResultValue * float64(len(instances)) - require.True(t, math.Abs(expected-totalCPUUseTime) < 1e-6) + re.True(math.Abs(expected-totalCPUUseTime) < 1e-6) } func TestGetTotalCPUQuota(t *testing.T) { + re := require.New(t) querier := &mockQuerier{} instances := []instance{ { @@ -251,10 +254,11 @@ func TestGetTotalCPUQuota(t *testing.T) { } totalCPUQuota, _ := getTotalCPUQuota(querier, TiDB, instances, time.Now()) expected := uint64(mockResultValue * float64(len(instances)*milliCores)) - require.Equal(t, expected, totalCPUQuota) + re.Equal(expected, totalCPUQuota) } func TestScaleOutGroupLabel(t *testing.T) { + re := require.New(t) var jsonStr = []byte(` { "rules":[ @@ -288,14 +292,15 @@ func TestScaleOutGroupLabel(t *testing.T) { }`) strategy := &Strategy{} err := json.Unmarshal(jsonStr, strategy) - require.NoError(t, err) + re.NoError(err) plan := findBestGroupToScaleOut(strategy, nil, TiKV) - require.Equal(t, "hotRegion", plan.Labels["specialUse"]) + re.Equal("hotRegion", plan.Labels["specialUse"]) plan = findBestGroupToScaleOut(strategy, nil, TiDB) - require.Equal(t, "", plan.Labels["specialUse"]) + re.Equal("", plan.Labels["specialUse"]) } func TestStrategyChangeCount(t *testing.T) { + re := require.New(t) var count uint64 = 2 strategy := &Strategy{ Rules: []*Rule{ @@ -343,21 +348,21 @@ func TestStrategyChangeCount(t *testing.T) { // exist two scaled TiKVs and plan does not change due to the limit of resource count groups, err := getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) plans := calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(2), plans[0].Count) + re.Equal(uint64(2), plans[0].Count) // change the resource count to 3 and plan increates one more tikv groups, err = getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) *strategy.Resources[0].Count = 3 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(3), plans[0].Count) + re.Equal(uint64(3), plans[0].Count) // change the resource count to 1 and plan decreases to 1 tikv due to the limit of resource count groups, err = getScaledTiKVGroups(cluster, instances) - require.NoError(t, err) + re.NoError(err) *strategy.Resources[0].Count = 1 plans = calculateScaleOutPlan(strategy, TiKV, scaleOutQuota, groups) - require.Equal(t, uint64(1), plans[0].Count) + re.Equal(uint64(1), plans[0].Count) } diff --git a/pkg/autoscaling/prometheus_test.go b/pkg/autoscaling/prometheus_test.go index 2c541446d2b..2906645b180 100644 --- a/pkg/autoscaling/prometheus_test.go +++ b/pkg/autoscaling/prometheus_test.go @@ -181,6 +181,7 @@ func (c *normalClient) Do(_ context.Context, req *http.Request) (response *http. } func TestRetrieveCPUMetrics(t *testing.T) { + re := require.New(t) client := &normalClient{ mockData: make(map[string]*response), } @@ -191,15 +192,15 @@ func TestRetrieveCPUMetrics(t *testing.T) { for _, metric := range metrics { options := NewQueryOptions(component, metric, addresses[:len(addresses)-1], time.Now(), mockDuration) result, err := querier.Query(options) - require.NoError(t, err) + re.NoError(err) for i := 0; i < len(addresses)-1; i++ { value, ok := result[addresses[i]] - require.True(t, ok) - require.True(t, math.Abs(value-mockResultValue) < 1e-6) + re.True(ok) + re.True(math.Abs(value-mockResultValue) < 1e-6) } _, ok := result[addresses[len(addresses)-1]] - require.False(t, ok) + re.False(ok) } } } @@ -224,12 +225,13 @@ func (c *emptyResponseClient) Do(_ context.Context, req *http.Request) (r *http. } func TestEmptyResponse(t *testing.T) { + re := require.New(t) client := &emptyResponseClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } type errorHTTPStatusClient struct{} @@ -250,12 +252,13 @@ func (c *errorHTTPStatusClient) Do(_ context.Context, req *http.Request) (r *htt } func TestErrorHTTPStatus(t *testing.T) { + re := require.New(t) client := &errorHTTPStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } type errorPrometheusStatusClient struct{} @@ -274,15 +277,17 @@ func (c *errorPrometheusStatusClient) Do(_ context.Context, req *http.Request) ( } func TestErrorPrometheusStatus(t *testing.T) { + re := require.New(t) client := &errorPrometheusStatusClient{} querier := NewPrometheusQuerier(client) options := NewQueryOptions(TiDB, CPUUsage, podAddresses[TiDB], time.Now(), mockDuration) result, err := querier.Query(options) - require.Nil(t, result) - require.Error(t, err) + re.Nil(result) + re.Error(err) } func TestGetInstanceNameFromAddress(t *testing.T) { + re := require.New(t) testCases := []struct { address string expectedInstanceName string @@ -311,14 +316,15 @@ func TestGetInstanceNameFromAddress(t *testing.T) { for _, testCase := range testCases { instanceName, err := getInstanceNameFromAddress(testCase.address) if testCase.expectedInstanceName == "" { - require.Error(t, err) + re.Error(err) } else { - require.Equal(t, testCase.expectedInstanceName, instanceName) + re.Equal(testCase.expectedInstanceName, instanceName) } } } func TestGetDurationExpression(t *testing.T) { + re := require.New(t) testCases := []struct { duration time.Duration expectedExpression string @@ -343,6 +349,6 @@ func TestGetDurationExpression(t *testing.T) { for _, testCase := range testCases { expression := getDurationExpression(testCase.duration) - require.Equal(t, testCase.expectedExpression, expression) + re.Equal(testCase.expectedExpression, expression) } } diff --git a/pkg/cache/cache_test.go b/pkg/cache/cache_test.go index bd633fef525..bf1626450f7 100644 --- a/pkg/cache/cache_test.go +++ b/pkg/cache/cache_test.go @@ -25,36 +25,37 @@ import ( ) func TestExpireRegionCache(t *testing.T) { + re := require.New(t) ctx, cancel := context.WithCancel(context.Background()) defer cancel() cache := NewIDTTL(ctx, time.Second, 2*time.Second) // Test Pop cache.PutWithTTL(9, "9", 5*time.Second) cache.PutWithTTL(10, "10", 5*time.Second) - require.Equal(t, 2, cache.Len()) + re.Equal(2, cache.Len()) k, v, success := cache.pop() - require.True(t, success) - require.Equal(t, 1, cache.Len()) + re.True(success) + re.Equal(1, cache.Len()) k2, v2, success := cache.pop() - require.True(t, success) + re.True(success) // we can't ensure the order which the key/value pop from cache, so we save into a map kvMap := map[uint64]string{ 9: "9", 10: "10", } expV, ok := kvMap[k.(uint64)] - require.True(t, ok) - require.Equal(t, expV, v.(string)) + re.True(ok) + re.Equal(expV, v.(string)) expV, ok = kvMap[k2.(uint64)] - require.True(t, ok) - require.Equal(t, expV, v2.(string)) + re.True(ok) + re.Equal(expV, v2.(string)) cache.PutWithTTL(11, "11", 1*time.Second) time.Sleep(5 * time.Second) k, v, success = cache.pop() - require.False(t, success) - require.Nil(t, k) - require.Nil(t, v) + re.False(success) + re.Nil(k) + re.Nil(v) // Test Get cache.PutWithTTL(1, 1, 1*time.Second) @@ -62,50 +63,50 @@ func TestExpireRegionCache(t *testing.T) { cache.PutWithTTL(3, 3.0, 5*time.Second) value, ok := cache.Get(1) - require.True(t, ok) - require.Equal(t, 1, value) + re.True(ok) + re.Equal(1, value) value, ok = cache.Get(2) - require.True(t, ok) - require.Equal(t, "v2", value) + re.True(ok) + re.Equal("v2", value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{1, 2, 3})) time.Sleep(2 * time.Second) value, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, value) + re.False(ok) + re.Nil(value) value, ok = cache.Get(2) - require.True(t, ok) - require.Equal(t, "v2", value) + re.True(ok) + re.Equal("v2", value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 2, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) + re.Equal(2, cache.Len()) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{2, 3})) cache.Remove(2) value, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, value) + re.False(ok) + re.Nil(value) value, ok = cache.Get(3) - require.True(t, ok) - require.Equal(t, 3.0, value) + re.True(ok) + re.Equal(3.0, value) - require.Equal(t, 1, cache.Len()) - require.True(t, reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) + re.Equal(1, cache.Len()) + re.True(reflect.DeepEqual(sortIDs(cache.GetAllID()), []uint64{3})) } func sortIDs(ids []uint64) []uint64 { @@ -115,6 +116,7 @@ func sortIDs(ids []uint64) []uint64 { } func TestLRUCache(t *testing.T) { + re := require.New(t) cache := newLRU(3) cache.Put(1, "1") @@ -122,173 +124,175 @@ func TestLRUCache(t *testing.T) { cache.Put(3, "3") val, ok := cache.Get(3) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "3")) + re.True(ok) + re.True(reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "4")) + re.True(ok) + re.True(reflect.DeepEqual(val, "4")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Peek(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) - require.True(t, reflect.DeepEqual(elems[1].Value, "2")) - require.True(t, reflect.DeepEqual(elems[2].Value, "1")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "4")) + re.True(reflect.DeepEqual(elems[1].Value, "2")) + re.True(reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) val, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(4) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) } func TestFifoCache(t *testing.T) { + re := require.New(t) cache := NewFIFO(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "2")) - require.True(t, reflect.DeepEqual(elems[1].Value, "3")) - require.True(t, reflect.DeepEqual(elems[2].Value, "4")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "2")) + re.True(reflect.DeepEqual(elems[1].Value, "3")) + re.True(reflect.DeepEqual(elems[2].Value, "4")) elems = cache.FromElems(3) - require.Len(t, elems, 1) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) + re.Len(elems, 1) + re.True(reflect.DeepEqual(elems[0].Value, "4")) cache.Remove() cache.Remove() cache.Remove() - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) } func TestTwoQueueCache(t *testing.T) { + re := require.New(t) cache := newTwoQueue(3) cache.Put(1, "1") cache.Put(2, "2") cache.Put(3, "3") val, ok := cache.Get(3) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "3")) + re.True(ok) + re.True(reflect.DeepEqual(val, "3")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) cache.Put(4, "4") - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) val, ok = cache.Get(2) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "2")) + re.True(ok) + re.True(reflect.DeepEqual(val, "2")) val, ok = cache.Get(4) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "4")) + re.True(ok) + re.True(reflect.DeepEqual(val, "4")) - require.Equal(t, 3, cache.Len()) + re.Equal(3, cache.Len()) val, ok = cache.Peek(1) - require.True(t, ok) - require.True(t, reflect.DeepEqual(val, "1")) + re.True(ok) + re.True(reflect.DeepEqual(val, "1")) elems := cache.Elems() - require.Len(t, elems, 3) - require.True(t, reflect.DeepEqual(elems[0].Value, "4")) - require.True(t, reflect.DeepEqual(elems[1].Value, "2")) - require.True(t, reflect.DeepEqual(elems[2].Value, "1")) + re.Len(elems, 3) + re.True(reflect.DeepEqual(elems[0].Value, "4")) + re.True(reflect.DeepEqual(elems[1].Value, "2")) + re.True(reflect.DeepEqual(elems[2].Value, "1")) cache.Remove(1) cache.Remove(2) cache.Remove(4) - require.Equal(t, 0, cache.Len()) + re.Equal(0, cache.Len()) val, ok = cache.Get(1) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(2) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(3) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) val, ok = cache.Get(4) - require.False(t, ok) - require.Nil(t, val) + re.False(ok) + re.Nil(val) } var _ PriorityQueueItem = PriorityQueueItemTest(0) @@ -300,53 +304,54 @@ func (pq PriorityQueueItemTest) ID() uint64 { } func TestPriorityQueue(t *testing.T) { + re := require.New(t) testData := []PriorityQueueItemTest{0, 1, 2, 3, 4, 5} pq := NewPriorityQueue(0) - require.False(t, pq.Put(1, testData[1])) + re.False(pq.Put(1, testData[1])) // it will have priority-value pair as 1-1 2-2 3-3 pq = NewPriorityQueue(3) - require.True(t, pq.Put(1, testData[1])) - require.True(t, pq.Put(2, testData[2])) - require.True(t, pq.Put(3, testData[4])) - require.True(t, pq.Put(5, testData[4])) - require.False(t, pq.Put(5, testData[5])) - require.True(t, pq.Put(3, testData[3])) - require.True(t, pq.Put(3, testData[3])) - require.Nil(t, pq.Get(4)) - require.Equal(t, 3, pq.Len()) + re.True(pq.Put(1, testData[1])) + re.True(pq.Put(2, testData[2])) + re.True(pq.Put(3, testData[4])) + re.True(pq.Put(5, testData[4])) + re.False(pq.Put(5, testData[5])) + re.True(pq.Put(3, testData[3])) + re.True(pq.Put(3, testData[3])) + re.Nil(pq.Get(4)) + re.Equal(3, pq.Len()) // case1 test getAll, the highest element should be the first entries := pq.Elems() - require.Len(t, entries, 3) - require.Equal(t, 1, entries[0].Priority) - require.Equal(t, testData[1], entries[0].Value) - require.Equal(t, 2, entries[1].Priority) - require.Equal(t, testData[2], entries[1].Value) - require.Equal(t, 3, entries[2].Priority) - require.Equal(t, testData[3], entries[2].Value) + re.Len(entries, 3) + re.Equal(1, entries[0].Priority) + re.Equal(testData[1], entries[0].Value) + re.Equal(2, entries[1].Priority) + re.Equal(testData[2], entries[1].Value) + re.Equal(3, entries[2].Priority) + re.Equal(testData[3], entries[2].Value) // case2 test remove the high element, and the second element should be the first pq.Remove(uint64(1)) - require.Nil(t, pq.Get(1)) - require.Equal(t, 2, pq.Len()) + re.Nil(pq.Get(1)) + re.Equal(2, pq.Len()) entry := pq.Peek() - require.Equal(t, 2, entry.Priority) - require.Equal(t, testData[2], entry.Value) + re.Equal(2, entry.Priority) + re.Equal(testData[2], entry.Value) // case3 update 3's priority to highest pq.Put(-1, testData[3]) entry = pq.Peek() - require.Equal(t, -1, entry.Priority) - require.Equal(t, testData[3], entry.Value) + re.Equal(-1, entry.Priority) + re.Equal(testData[3], entry.Value) pq.Remove(entry.Value.ID()) - require.Equal(t, testData[2], pq.Peek().Value) - require.Equal(t, 1, pq.Len()) + re.Equal(testData[2], pq.Peek().Value) + re.Equal(1, pq.Len()) // case4 remove all element pq.Remove(uint64(2)) - require.Equal(t, 0, pq.Len()) - require.Len(t, pq.items, 0) - require.Nil(t, pq.Peek()) - require.Nil(t, pq.Tail()) + re.Equal(0, pq.Len()) + re.Len(pq.items, 0) + re.Nil(pq.Peek()) + re.Nil(pq.Tail()) } diff --git a/pkg/codec/codec_test.go b/pkg/codec/codec_test.go index cd73c1da0cc..50bf552a60d 100644 --- a/pkg/codec/codec_test.go +++ b/pkg/codec/codec_test.go @@ -21,27 +21,29 @@ import ( ) func TestDecodeBytes(t *testing.T) { + re := require.New(t) key := "abcdefghijklmnopqrstuvwxyz" for i := 0; i < len(key); i++ { _, k, err := DecodeBytes(EncodeBytes([]byte(key[:i]))) - require.NoError(t, err) - require.Equal(t, key[:i], string(k)) + re.NoError(err) + re.Equal(key[:i], string(k)) } } func TestTableID(t *testing.T) { + re := require.New(t) key := EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0xff), key.TableID()) + re.Equal(int64(0xff), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\x00\xff_i\x01\x02")) - require.Equal(t, int64(0xff), key.TableID()) + re.Equal(int64(0xff), key.TableID()) key = []byte("t\x80\x00\x00\x00\x00\x00\x00\xff") - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) key = EncodeBytes([]byte("T\x00\x00\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) key = EncodeBytes([]byte("t\x80\x00\x00\x00\x00\x00\xff")) - require.Equal(t, int64(0), key.TableID()) + re.Equal(int64(0), key.TableID()) } diff --git a/pkg/encryption/config_test.go b/pkg/encryption/config_test.go index 04e9d417686..1e3231b0903 100644 --- a/pkg/encryption/config_test.go +++ b/pkg/encryption/config_test.go @@ -23,26 +23,30 @@ import ( ) func TestAdjustDefaultValue(t *testing.T) { + re := require.New(t) config := &Config{} err := config.Adjust() - require.NoError(t, err) - require.Equal(t, methodPlaintext, config.DataEncryptionMethod) + re.NoError(err) + re.Equal(methodPlaintext, config.DataEncryptionMethod) defaultRotationPeriod, _ := time.ParseDuration(defaultDataKeyRotationPeriod) - require.Equal(t, defaultRotationPeriod, config.DataKeyRotationPeriod.Duration) - require.Equal(t, masterKeyTypePlaintext, config.MasterKey.Type) + re.Equal(defaultRotationPeriod, config.DataKeyRotationPeriod.Duration) + re.Equal(masterKeyTypePlaintext, config.MasterKey.Type) } func TestAdjustInvalidDataEncryptionMethod(t *testing.T) { + re := require.New(t) config := &Config{DataEncryptionMethod: "unknown"} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } func TestAdjustNegativeRotationDuration(t *testing.T) { + re := require.New(t) config := &Config{DataKeyRotationPeriod: typeutil.NewDuration(time.Duration(int64(-1)))} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } func TestAdjustInvalidMasterKeyType(t *testing.T) { + re := require.New(t) config := &Config{MasterKey: MasterKeyConfig{Type: "unknown"}} - require.NotNil(t, config.Adjust()) + re.NotNil(config.Adjust()) } diff --git a/pkg/encryption/crypter_test.go b/pkg/encryption/crypter_test.go index 716d15ecdcb..c29ed6a8725 100644 --- a/pkg/encryption/crypter_test.go +++ b/pkg/encryption/crypter_test.go @@ -24,77 +24,81 @@ import ( ) func TestEncryptionMethodSupported(t *testing.T) { - require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) - require.NotNil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) - require.Nil(t, CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) + re := require.New(t) + re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_PLAINTEXT)) + re.NotNil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_UNKNOWN)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES128_CTR)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES192_CTR)) + re.Nil(CheckEncryptionMethodSupported(encryptionpb.EncryptionMethod_AES256_CTR)) } func TestKeyLength(t *testing.T) { + re := require.New(t) _, err := KeyLength(encryptionpb.EncryptionMethod_PLAINTEXT) - require.NotNil(t, err) + re.NotNil(err) _, err = KeyLength(encryptionpb.EncryptionMethod_UNKNOWN) - require.NotNil(t, err) + re.NotNil(err) length, err := KeyLength(encryptionpb.EncryptionMethod_AES128_CTR) - require.NoError(t, err) - require.Equal(t, 16, length) + re.NoError(err) + re.Equal(16, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES192_CTR) - require.NoError(t, err) - require.Equal(t, 24, length) + re.NoError(err) + re.Equal(24, length) length, err = KeyLength(encryptionpb.EncryptionMethod_AES256_CTR) - require.NoError(t, err) - require.Equal(t, 32, length) + re.NoError(err) + re.Equal(32, length) } func TestNewIv(t *testing.T) { + re := require.New(t) ivCtr, err := NewIvCTR() - require.NoError(t, err) - require.Len(t, []byte(ivCtr), ivLengthCTR) + re.NoError(err) + re.Len([]byte(ivCtr), ivLengthCTR) ivGcm, err := NewIvGCM() - require.NoError(t, err) - require.Len(t, []byte(ivGcm), ivLengthGCM) + re.NoError(err) + re.Len([]byte(ivGcm), ivLengthGCM) } func TestNewDataKey(t *testing.T) { + re := require.New(t) for _, method := range []encryptionpb.EncryptionMethod{ encryptionpb.EncryptionMethod_AES128_CTR, encryptionpb.EncryptionMethod_AES192_CTR, encryptionpb.EncryptionMethod_AES256_CTR, } { _, key, err := NewDataKey(method, uint64(123)) - require.NoError(t, err) + re.NoError(err) length, err := KeyLength(method) - require.NoError(t, err) - require.Len(t, key.Key, length) - require.Equal(t, method, key.Method) - require.False(t, key.WasExposed) - require.Equal(t, uint64(123), key.CreationTime) + re.NoError(err) + re.Len(key.Key, length) + re.Equal(method, key.Method) + re.False(key.WasExposed) + re.Equal(uint64(123), key.CreationTime) } } func TestAesGcmCrypter(t *testing.T) { + re := require.New(t) key, err := hex.DecodeString("ed568fbd8c8018ed2d042a4e5d38d6341486922d401d2022fb81e47c900d3f07") - require.NoError(t, err) + re.NoError(err) plaintext, err := hex.DecodeString( "5c873a18af5e7c7c368cb2635e5a15c7f87282085f4b991e84b78c5967e946d4") - require.NoError(t, err) + re.NoError(err) // encrypt ivBytes, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - require.NoError(t, err) + re.NoError(err) iv := IvGCM(ivBytes) ciphertext, err := aesGcmEncryptImpl(key, plaintext, iv) - require.NoError(t, err) - require.Len(t, []byte(iv), ivLengthGCM) - require.Equal( - t, + re.NoError(err) + re.Len([]byte(iv), ivLengthGCM) + re.Equal( "bbb9b49546350880cf55d4e4eaccc831c506a4aeae7f6cda9c821d4cb8cfc269dcdaecb09592ef25d7a33b40d3f02208", hex.EncodeToString(ciphertext), ) // decrypt plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - require.NoError(t, err) - require.True(t, bytes.Equal(plaintext2, plaintext)) + re.NoError(err) + re.True(bytes.Equal(plaintext2, plaintext)) // Modify ciphertext to test authentication failure. We modify the beginning of the ciphertext, // which is the real ciphertext part, not the tag. fakeCiphertext := make([]byte, len(ciphertext)) @@ -102,5 +106,5 @@ func TestAesGcmCrypter(t *testing.T) { // ignore overflow fakeCiphertext[0] = ciphertext[0] + 1 _, err = AesGcmDecrypt(key, fakeCiphertext, iv) - require.NotNil(t, err) + re.NotNil(err) } diff --git a/pkg/encryption/master_key_test.go b/pkg/encryption/master_key_test.go index 0fc1d376ca7..990d6322c3e 100644 --- a/pkg/encryption/master_key_test.go +++ b/pkg/encryption/master_key_test.go @@ -24,59 +24,63 @@ import ( ) func TestPlaintextMasterKey(t *testing.T) { + re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_Plaintext{ Plaintext: &encryptionpb.MasterKeyPlaintext{}, }, } masterKey, err := NewMasterKey(config, nil) - require.NoError(t, err) - require.NotNil(t, masterKey) - require.Len(t, masterKey.key, 0) + re.NoError(err) + re.NotNil(masterKey) + re.Len(masterKey.key, 0) plaintext := "this is a plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - require.NoError(t, err) - require.Len(t, iv, 0) - require.Equal(t, plaintext, string(ciphertext)) + re.NoError(err) + re.Len(iv, 0) + re.Equal(plaintext, string(ciphertext)) plaintext2, err := masterKey.Decrypt(ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) - require.True(t, masterKey.IsPlaintext()) + re.True(masterKey.IsPlaintext()) } func TestEncrypt(t *testing.T) { + re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - require.NoError(t, err) + re.NoError(err) masterKey := &MasterKey{key: key} plaintext := "this-is-a-plaintext" ciphertext, iv, err := masterKey.Encrypt([]byte(plaintext)) - require.NoError(t, err) - require.Len(t, iv, ivLengthGCM) + re.NoError(err) + re.Len(iv, ivLengthGCM) plaintext2, err := AesGcmDecrypt(key, ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) } func TestDecrypt(t *testing.T) { + re := require.New(t) keyHex := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" key, err := hex.DecodeString(keyHex) - require.NoError(t, err) + re.NoError(err) plaintext := "this-is-a-plaintext" iv, err := hex.DecodeString("ba432b70336c40c39ba14c1b") - require.NoError(t, err) + re.NoError(err) ciphertext, err := aesGcmEncryptImpl(key, []byte(plaintext), iv) - require.NoError(t, err) + re.NoError(err) masterKey := &MasterKey{key: key} plaintext2, err := masterKey.Decrypt(ciphertext, iv) - require.NoError(t, err) - require.Equal(t, plaintext, string(plaintext2)) + re.NoError(err) + re.Equal(plaintext, string(plaintext2)) } func TestNewFileMasterKeyMissingPath(t *testing.T) { + re := require.New(t) config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ File: &encryptionpb.MasterKeyFile{ @@ -85,12 +89,13 @@ func TestNewFileMasterKeyMissingPath(t *testing.T) { }, } _, err := NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyMissingFile(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" config := &encryptionpb.MasterKey{ Backend: &encryptionpb.MasterKey_File{ @@ -100,12 +105,13 @@ func TestNewFileMasterKeyMissingFile(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyNotHexString(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte("not-a-hex-string"), 0600) config := &encryptionpb.MasterKey{ @@ -116,12 +122,13 @@ func TestNewFileMasterKeyNotHexString(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKeyLengthMismatch(t *testing.T) { + re := require.New(t) dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte("2f07ec61e5a50284f47f2b402a962ec6"), 0600) config := &encryptionpb.MasterKey{ @@ -132,13 +139,14 @@ func TestNewFileMasterKeyLengthMismatch(t *testing.T) { }, } _, err = NewMasterKey(config, nil) - require.Error(t, err) + re.Error(err) } func TestNewFileMasterKey(t *testing.T) { + re := require.New(t) key := "2f07ec61e5a50284f47f2b402a962ec672e500b26cb3aa568bb1531300c74806" dir, err := os.MkdirTemp("", "test_key_files") - require.NoError(t, err) + re.NoError(err) path := dir + "/key" os.WriteFile(path, []byte(key), 0600) config := &encryptionpb.MasterKey{ @@ -149,6 +157,6 @@ func TestNewFileMasterKey(t *testing.T) { }, } masterKey, err := NewMasterKey(config, nil) - require.NoError(t, err) - require.Equal(t, key, hex.EncodeToString(masterKey.key)) + re.NoError(err) + re.Equal(key, hex.EncodeToString(masterKey.key)) } diff --git a/pkg/encryption/region_crypter_test.go b/pkg/encryption/region_crypter_test.go index 06398ebc7ff..b1ca558063c 100644 --- a/pkg/encryption/region_crypter_test.go +++ b/pkg/encryption/region_crypter_test.go @@ -70,15 +70,17 @@ func (m *testKeyManager) GetKey(keyID uint64) (*encryptionpb.DataKey, error) { } func TestNilRegion(t *testing.T) { + re := require.New(t) m := newTestKeyManager() region, err := EncryptRegion(nil, m) - require.Error(t, err) - require.Nil(t, region) + re.Error(err) + re.Nil(region) err = DecryptRegion(nil, m) - require.Error(t, err) + re.Error(err) } func TestEncryptRegionWithoutKeyManager(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -86,14 +88,15 @@ func TestEncryptRegionWithoutKeyManager(t *testing.T) { EncryptionMeta: nil, } region, err := EncryptRegion(region, nil) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -103,14 +106,15 @@ func TestEncryptRegionWhileEncryptionDisabled(t *testing.T) { m := newTestKeyManager() m.EncryptionEnabled = false region, err := EncryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestEncryptRegion(t *testing.T) { + re := require.New(t) startKey := []byte("abc") endKey := []byte("xyz") region := &metapb.Region{ @@ -123,27 +127,28 @@ func TestEncryptRegion(t *testing.T) { copy(region.EndKey, endKey) m := newTestKeyManager() outRegion, err := EncryptRegion(region, m) - require.NoError(t, err) - require.NotEqual(t, region, outRegion) + re.NoError(err) + re.NotEqual(region, outRegion) // check region is encrypted - require.NotNil(t, outRegion.EncryptionMeta) - require.Equal(t, uint64(2), outRegion.EncryptionMeta.KeyId) - require.Len(t, outRegion.EncryptionMeta.Iv, ivLengthCTR) + re.NotNil(outRegion.EncryptionMeta) + re.Equal(uint64(2), outRegion.EncryptionMeta.KeyId) + re.Len(outRegion.EncryptionMeta.Iv, ivLengthCTR) // Check encrypted content _, currentKey, err := m.GetCurrentKey() - require.NoError(t, err) + re.NoError(err) block, err := aes.NewCipher(currentKey.Key) - require.NoError(t, err) + re.NoError(err) stream := cipher.NewCTR(block, outRegion.EncryptionMeta.Iv) ciphertextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(ciphertextStartKey, startKey) - require.Equal(t, string(ciphertextStartKey), string(outRegion.StartKey)) + re.Equal(string(ciphertextStartKey), string(outRegion.StartKey)) ciphertextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(ciphertextEndKey, endKey) - require.Equal(t, string(ciphertextEndKey), string(outRegion.EndKey)) + re.Equal(string(ciphertextEndKey), string(outRegion.EndKey)) } func TestDecryptRegionNotEncrypted(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -152,14 +157,15 @@ func TestDecryptRegionNotEncrypted(t *testing.T) { } m := newTestKeyManager() err := DecryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check the region isn't changed - require.Equal(t, "abc", string(region.StartKey)) - require.Equal(t, "xyz", string(region.EndKey)) - require.Nil(t, region.EncryptionMeta) + re.Equal("abc", string(region.StartKey)) + re.Equal("xyz", string(region.EndKey)) + re.Nil(region.EncryptionMeta) } func TestDecryptRegionWithoutKeyManager(t *testing.T) { + re := require.New(t) region := &metapb.Region{ Id: 10, StartKey: []byte("abc"), @@ -170,14 +176,15 @@ func TestDecryptRegionWithoutKeyManager(t *testing.T) { }, } err := DecryptRegion(region, nil) - require.Error(t, err) + re.Error(err) } func TestDecryptRegionWhileKeyMissing(t *testing.T) { + re := require.New(t) keyID := uint64(3) m := newTestKeyManager() _, err := m.GetKey(3) - require.Error(t, err) + re.Error(err) region := &metapb.Region{ Id: 10, @@ -189,10 +196,11 @@ func TestDecryptRegionWhileKeyMissing(t *testing.T) { }, } err = DecryptRegion(region, m) - require.Error(t, err) + re.Error(err) } func TestDecryptRegion(t *testing.T) { + re := require.New(t) keyID := uint64(1) startKey := []byte("abc") endKey := []byte("xyz") @@ -211,19 +219,19 @@ func TestDecryptRegion(t *testing.T) { copy(region.EncryptionMeta.Iv, iv) m := newTestKeyManager() err := DecryptRegion(region, m) - require.NoError(t, err) + re.NoError(err) // check region is decrypted - require.Nil(t, region.EncryptionMeta) + re.Nil(region.EncryptionMeta) // Check decrypted content key, err := m.GetKey(keyID) - require.NoError(t, err) + re.NoError(err) block, err := aes.NewCipher(key.Key) - require.NoError(t, err) + re.NoError(err) stream := cipher.NewCTR(block, iv) plaintextStartKey := make([]byte, len(startKey)) stream.XORKeyStream(plaintextStartKey, startKey) - require.Equal(t, string(plaintextStartKey), string(region.StartKey)) + re.Equal(string(plaintextStartKey), string(region.StartKey)) plaintextEndKey := make([]byte, len(endKey)) stream.XORKeyStream(plaintextEndKey, endKey) - require.Equal(t, string(plaintextEndKey), string(region.EndKey)) + re.Equal(string(plaintextEndKey), string(region.EndKey)) } diff --git a/pkg/errs/errs_test.go b/pkg/errs/errs_test.go index 65ebb6460d0..74e55257d70 100644 --- a/pkg/errs/errs_test.go +++ b/pkg/errs/errs_test.go @@ -72,44 +72,46 @@ func newZapTestLogger(cfg *log.Config, opts ...zap.Option) verifyLogger { } func TestError(t *testing.T) { + re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) rfc := `[error="[PD:member:ErrEtcdLeaderNotFound]etcd leader not found` log.Error("test", zap.Error(ErrEtcdLeaderNotFound.FastGenByArgs())) - require.Contains(t, lg.Message(), rfc) + re.Contains(lg.Message(), rfc) err := errors.New("test error") log.Error("test", ZapError(ErrEtcdLeaderNotFound, err)) rfc = `[error="[PD:member:ErrEtcdLeaderNotFound]test error` - require.Contains(t, lg.Message(), rfc) + re.Contains(lg.Message(), rfc) } func TestErrorEqual(t *testing.T) { + re := require.New(t) err1 := ErrSchedulerNotFound.FastGenByArgs() err2 := ErrSchedulerNotFound.FastGenByArgs() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err1 = ErrSchedulerNotFound.FastGenByArgs() err2 = ErrSchedulerNotFound.Wrap(err).FastGenWithCause() - require.False(t, errors.ErrorEqual(err1, err2)) + re.False(errors.ErrorEqual(err1, err2)) err3 := errors.New("test") err4 := errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - require.True(t, errors.ErrorEqual(err1, err2)) + re.True(errors.ErrorEqual(err1, err2)) err3 = errors.New("test1") err4 = errors.New("test") err1 = ErrSchedulerNotFound.Wrap(err3).FastGenWithCause() err2 = ErrSchedulerNotFound.Wrap(err4).FastGenWithCause() - require.False(t, errors.ErrorEqual(err1, err2)) + re.False(errors.ErrorEqual(err1, err2)) } func TestZapError(t *testing.T) { @@ -121,6 +123,7 @@ func TestZapError(t *testing.T) { } func TestErrorWithStack(t *testing.T) { + re := require.New(t) conf := &log.Config{Level: "debug", File: log.FileLogConfig{}, DisableTimestamp: true} lg := newZapTestLogger(conf) log.ReplaceGlobals(lg.Logger, nil) @@ -133,8 +136,8 @@ func TestErrorWithStack(t *testing.T) { // This test is based on line number and the first log is in line 141, the second is in line 142. // So they have the same length stack. Move this test to another place need to change the corresponding length. idx1 := strings.Index(m1, "[stack=") - require.GreaterOrEqual(t, idx1, -1) + re.GreaterOrEqual(idx1, -1) idx2 := strings.Index(m2, "[stack=") - require.GreaterOrEqual(t, idx2, -1) - require.Equal(t, len(m1[idx1:]), len(m2[idx2:])) + re.GreaterOrEqual(idx2, -1) + re.Equal(len(m1[idx1:]), len(m2[idx2:])) } diff --git a/pkg/etcdutil/etcdutil_test.go b/pkg/etcdutil/etcdutil_test.go index 7bc73f12cbe..7731a319a94 100644 --- a/pkg/etcdutil/etcdutil_test.go +++ b/pkg/etcdutil/etcdutil_test.go @@ -28,28 +28,29 @@ import ( ) func TestMemberHelpers(t *testing.T) { + re := require.New(t) cfg1 := NewTestSingleConfig() etcd1, err := embed.StartEtcd(cfg1) defer func() { etcd1.Close() CleanConfig(cfg1) }() - require.NoError(t, err) + re.NoError(err) ep1 := cfg1.LCUrls[0].String() client1, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep1}, }) - require.NoError(t, err) + re.NoError(err) <-etcd1.Server.ReadyNotify() // Test ListEtcdMembers listResp1, err := ListEtcdMembers(client1) - require.NoError(t, err) - require.Len(t, listResp1.Members, 1) + re.NoError(err) + re.Len(listResp1.Members, 1) // types.ID is an alias of uint64. - require.Equal(t, uint64(etcd1.Server.ID()), listResp1.Members[0].ID) + re.Equal(uint64(etcd1.Server.ID()), listResp1.Members[0].ID) // Test AddEtcdMember // Make a new etcd config. @@ -61,28 +62,28 @@ func TestMemberHelpers(t *testing.T) { // Add it to the cluster above. peerURL := cfg2.LPUrls[0].String() addResp, err := AddEtcdMember(client1, []string{peerURL}) - require.NoError(t, err) + re.NoError(err) etcd2, err := embed.StartEtcd(cfg2) defer func() { etcd2.Close() CleanConfig(cfg2) }() - require.NoError(t, err) - require.Equal(t, uint64(etcd2.Server.ID()), addResp.Member.ID) + re.NoError(err) + re.Equal(uint64(etcd2.Server.ID()), addResp.Member.ID) ep2 := cfg2.LCUrls[0].String() client2, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep2}, }) - require.NoError(t, err) + re.NoError(err) <-etcd2.Server.ReadyNotify() - require.NoError(t, err) + re.NoError(err) listResp2, err := ListEtcdMembers(client2) - require.NoError(t, err) - require.Len(t, listResp2.Members, 2) + re.NoError(err) + re.Len(listResp2.Members, 2) for _, m := range listResp2.Members { switch m.ID { case uint64(etcd1.Server.ID()): @@ -94,34 +95,35 @@ func TestMemberHelpers(t *testing.T) { // Test CheckClusterID urlsMap, err := types.NewURLsMap(cfg2.InitialCluster) - require.NoError(t, err) + re.NoError(err) err = CheckClusterID(etcd1.Server.Cluster().ID(), urlsMap, &tls.Config{MinVersion: tls.VersionTLS12}) - require.NoError(t, err) + re.NoError(err) // Test RemoveEtcdMember _, err = RemoveEtcdMember(client1, uint64(etcd2.Server.ID())) - require.NoError(t, err) + re.NoError(err) listResp3, err := ListEtcdMembers(client1) - require.NoError(t, err) - require.Len(t, listResp3.Members, 1) - require.Equal(t, uint64(etcd1.Server.ID()), listResp3.Members[0].ID) + re.NoError(err) + re.Len(listResp3.Members, 1) + re.Equal(uint64(etcd1.Server.ID()), listResp3.Members[0].ID) } func TestEtcdKVGet(t *testing.T) { + re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - require.NoError(t, err) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - require.NoError(t, err) + re.NoError(err) <-etcd.Server.ReadyNotify() @@ -131,69 +133,70 @@ func TestEtcdKVGet(t *testing.T) { kv := clientv3.NewKV(client) for i := range keys { _, err = kv.Put(context.TODO(), keys[i], vals[i]) - require.NoError(t, err) + re.NoError(err) } // Test simple point get resp, err := EtcdKVGet(client, "test/key1") - require.NoError(t, err) - require.Equal(t, "val1", string(resp.Kvs[0].Value)) + re.NoError(err) + re.Equal("val1", string(resp.Kvs[0].Value)) // Test range get withRange := clientv3.WithRange("test/zzzz") withLimit := clientv3.WithLimit(3) resp, err = EtcdKVGet(client, "test/", withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - require.NoError(t, err) - require.Len(t, resp.Kvs, 3) + re.NoError(err) + re.Len(resp.Kvs, 3) for i := range resp.Kvs { - require.Equal(t, keys[i], string(resp.Kvs[i].Key)) - require.Equal(t, vals[i], string(resp.Kvs[i].Value)) + re.Equal(keys[i], string(resp.Kvs[i].Key)) + re.Equal(vals[i], string(resp.Kvs[i].Value)) } lastKey := string(resp.Kvs[len(resp.Kvs)-1].Key) next := clientv3.GetPrefixRangeEnd(lastKey) resp, err = EtcdKVGet(client, next, withRange, withLimit, clientv3.WithSort(clientv3.SortByKey, clientv3.SortAscend)) - require.NoError(t, err) - require.Len(t, resp.Kvs, 2) + re.NoError(err) + re.Len(resp.Kvs, 2) } func TestEtcdKVPutWithTTL(t *testing.T) { + re := require.New(t) cfg := NewTestSingleConfig() etcd, err := embed.StartEtcd(cfg) defer func() { etcd.Close() CleanConfig(cfg) }() - require.NoError(t, err) + re.NoError(err) ep := cfg.LCUrls[0].String() client, err := clientv3.New(clientv3.Config{ Endpoints: []string{ep}, }) - require.NoError(t, err) + re.NoError(err) <-etcd.Server.ReadyNotify() _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl1", "val1", 2) - require.NoError(t, err) + re.NoError(err) _, err = EtcdKVPutWithTTL(context.TODO(), client, "test/ttl2", "val2", 4) - require.NoError(t, err) + re.NoError(err) time.Sleep(3 * time.Second) // test/ttl1 is outdated resp, err := EtcdKVGet(client, "test/ttl1") - require.NoError(t, err) - require.Equal(t, int64(0), resp.Count) + re.NoError(err) + re.Equal(int64(0), resp.Count) // but test/ttl2 is not resp, err = EtcdKVGet(client, "test/ttl2") - require.NoError(t, err) - require.Equal(t, "val2", string(resp.Kvs[0].Value)) + re.NoError(err) + re.Equal("val2", string(resp.Kvs[0].Value)) time.Sleep(2 * time.Second) // test/ttl2 is also outdated resp, err = EtcdKVGet(client, "test/ttl2") - require.NoError(t, err) - require.Equal(t, int64(0), resp.Count) + re.NoError(err) + re.Equal(int64(0), resp.Count) } diff --git a/pkg/grpcutil/grpcutil_test.go b/pkg/grpcutil/grpcutil_test.go index d1b9d3a8830..44eee64b85e 100644 --- a/pkg/grpcutil/grpcutil_test.go +++ b/pkg/grpcutil/grpcutil_test.go @@ -9,18 +9,19 @@ import ( "github.com/tikv/pd/pkg/errs" ) -func loadTLSContent(t *testing.T, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { +func loadTLSContent(re *require.Assertions, caPath, certPath, keyPath string) (caData, certData, keyData []byte) { var err error caData, err = os.ReadFile(caPath) - require.NoError(t, err) + re.NoError(err) certData, err = os.ReadFile(certPath) - require.NoError(t, err) + re.NoError(err) keyData, err = os.ReadFile(keyPath) - require.NoError(t, err) + re.NoError(err) return } func TestToTLSConfig(t *testing.T) { + re := require.New(t) tlsConfig := TLSConfig{ KeyPath: "../../tests/client/cert/pd-server-key.pem", CertPath: "../../tests/client/cert/pd-server.pem", @@ -28,24 +29,24 @@ func TestToTLSConfig(t *testing.T) { } // test without bytes _, err := tlsConfig.ToTLSConfig() - require.NoError(t, err) + re.NoError(err) // test with bytes - caData, certData, keyData := loadTLSContent(t, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) + caData, certData, keyData := loadTLSContent(re, tlsConfig.CAPath, tlsConfig.CertPath, tlsConfig.KeyPath) tlsConfig.SSLCABytes = caData tlsConfig.SSLCertBytes = certData tlsConfig.SSLKEYBytes = keyData _, err = tlsConfig.ToTLSConfig() - require.NoError(t, err) + re.NoError(err) // test wrong cert bytes tlsConfig.SSLCertBytes = []byte("invalid cert") _, err = tlsConfig.ToTLSConfig() - require.True(t, errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) + re.True(errors.ErrorEqual(err, errs.ErrCryptoX509KeyPair)) // test wrong ca bytes tlsConfig.SSLCertBytes = certData tlsConfig.SSLCABytes = []byte("invalid ca") _, err = tlsConfig.ToTLSConfig() - require.True(t, errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) + re.True(errors.ErrorEqual(err, errs.ErrCryptoAppendCertsFromPEM)) } diff --git a/pkg/keyutil/util_test.go b/pkg/keyutil/util_test.go index 6603c61b131..f69463c5060 100644 --- a/pkg/keyutil/util_test.go +++ b/pkg/keyutil/util_test.go @@ -21,8 +21,9 @@ import ( ) func TestKeyUtil(t *testing.T) { + re := require.New(t) startKey := []byte("a") endKey := []byte("b") key := BuildKeyRangeKey(startKey, endKey) - require.Equal(t, "61-62", key) + re.Equal("61-62", key) } diff --git a/pkg/logutil/log_test.go b/pkg/logutil/log_test.go index 42a9126ea33..270a8e5b0ba 100644 --- a/pkg/logutil/log_test.go +++ b/pkg/logutil/log_test.go @@ -24,16 +24,18 @@ import ( ) func TestStringToZapLogLevel(t *testing.T) { - require.Equal(t, zapcore.FatalLevel, StringToZapLogLevel("fatal")) - require.Equal(t, zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) - require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warn")) - require.Equal(t, zapcore.WarnLevel, StringToZapLogLevel("warning")) - require.Equal(t, zapcore.DebugLevel, StringToZapLogLevel("debug")) - require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("info")) - require.Equal(t, zapcore.InfoLevel, StringToZapLogLevel("whatever")) + re := require.New(t) + re.Equal(zapcore.FatalLevel, StringToZapLogLevel("fatal")) + re.Equal(zapcore.ErrorLevel, StringToZapLogLevel("ERROR")) + re.Equal(zapcore.WarnLevel, StringToZapLogLevel("warn")) + re.Equal(zapcore.WarnLevel, StringToZapLogLevel("warning")) + re.Equal(zapcore.DebugLevel, StringToZapLogLevel("debug")) + re.Equal(zapcore.InfoLevel, StringToZapLogLevel("info")) + re.Equal(zapcore.InfoLevel, StringToZapLogLevel("whatever")) } func TestRedactLog(t *testing.T) { + re := require.New(t) testCases := []struct { name string arg interface{} @@ -71,11 +73,11 @@ func TestRedactLog(t *testing.T) { SetRedactLog(testCase.enableRedactLog) switch r := testCase.arg.(type) { case []byte: - require.True(t, reflect.DeepEqual(testCase.expect, RedactBytes(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactBytes(r))) case string: - require.True(t, reflect.DeepEqual(testCase.expect, RedactString(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactString(r))) case fmt.Stringer: - require.True(t, reflect.DeepEqual(testCase.expect, RedactStringer(r))) + re.True(reflect.DeepEqual(testCase.expect, RedactStringer(r))) default: panic("unmatched case") } diff --git a/pkg/metricutil/metricutil_test.go b/pkg/metricutil/metricutil_test.go index 512732c7f7e..a72eb7ee5f5 100644 --- a/pkg/metricutil/metricutil_test.go +++ b/pkg/metricutil/metricutil_test.go @@ -23,6 +23,7 @@ import ( ) func TestCamelCaseToSnakeCase(t *testing.T) { + re := require.New(t) inputs := []struct { name string newName string @@ -50,7 +51,7 @@ func TestCamelCaseToSnakeCase(t *testing.T) { } for _, input := range inputs { - require.Equal(t, input.newName, camelCaseToSnakeCase(input.name)) + re.Equal(input.newName, camelCaseToSnakeCase(input.name)) } } From a7ac85daa078f913610e5bd1ac101f824d084608 Mon Sep 17 00:00:00 2001 From: buffer <1045931706@qq.com> Date: Tue, 31 May 2022 16:22:27 +0800 Subject: [PATCH 2/4] config: fix the bug that the type of bucket size is not right. (#5074) close tikv/pd#5073 Signed-off-by: bufferflies <1045931706@qq.com> Co-authored-by: Ti Chi Robot --- server/config/store_config.go | 11 ++++++++--- server/config/store_config_test.go | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/server/config/store_config.go b/server/config/store_config.go index 6e8ba7c22f7..27fc456dd08 100644 --- a/server/config/store_config.go +++ b/server/config/store_config.go @@ -35,6 +35,8 @@ var ( defaultRegionMaxSize = uint64(144) // default region split size is 96MB defaultRegionSplitSize = uint64(96) + // default bucket size is 96MB + defaultBucketSize = uint64(96) // default region max key is 144000 defaultRegionMaxKey = uint64(1440000) // default region split key is 960000 @@ -58,7 +60,7 @@ type Coprocessor struct { RegionMaxKeys int `json:"region-max-keys"` RegionSplitKeys int `json:"region-split-keys"` EnableRegionBucket bool `json:"enable-region-bucket"` - RegionBucketSize int `json:"region-bucket-size"` + RegionBucketSize string `json:"region-bucket-size"` } // String implements fmt.Stringer interface. @@ -111,11 +113,14 @@ func (c *StoreConfig) IsEnableRegionBucket() bool { } // GetRegionBucketSize returns region bucket size if enable region buckets. -func (c *StoreConfig) GetRegionBucketSize() int { +func (c *StoreConfig) GetRegionBucketSize() uint64 { if c == nil || !c.Coprocessor.EnableRegionBucket { return 0 } - return c.Coprocessor.RegionBucketSize + if len(c.Coprocessor.RegionBucketSize) == 0 { + return defaultBucketSize + } + return typeutil.ParseMBFromText(c.Coprocessor.RegionBucketSize, defaultBucketSize) } // CheckRegionSize return error if the smallest region's size is less than mergeSize diff --git a/server/config/store_config_test.go b/server/config/store_config_test.go index 106d8b7bf4e..478e1ebb3d7 100644 --- a/server/config/store_config_test.go +++ b/server/config/store_config_test.go @@ -77,6 +77,30 @@ func (t *testTiKVConfigSuite) TestUpdateConfig(c *C) { c.Assert(manager.source.(*TiKVConfigSource).schema, Equals, "http") } +func (t *testTiKVConfigSuite) TestParseConfig(c *C) { + body := ` +{ +"coprocessor":{ +"split-region-on-table":false, +"batch-split-limit":10, +"region-max-size":"384MiB", +"region-split-size":"256MiB", +"region-max-keys":3840000, +"region-split-keys":2560000, +"consistency-check-method":"mvcc", +"enable-region-bucket":true, +"region-bucket-size":"96MiB", +"region-size-threshold-for-approximate":"384MiB", +"region-bucket-merge-size-ratio":0.33 +} +} +` + + var config StoreConfig + c.Assert(json.Unmarshal([]byte(body), &config), IsNil) + c.Assert(config.GetRegionBucketSize(), Equals, uint64(96)) +} + func (t *testTiKVConfigSuite) TestMergeCheck(c *C) { testdata := []struct { size uint64 From b92303c6a0395e0a231d4b00e89dd9a105e196b4 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Tue, 31 May 2022 18:04:27 +0800 Subject: [PATCH 3/4] workflow: change timeout-minutes of `statics` to 8 (#5079) close tikv/pd#5078 change timeout-minutes of `statics` to 8 Signed-off-by: Cabinfever_B --- .github/workflows/check.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/check.yaml b/.github/workflows/check.yaml index 4fbed3641e5..47ef287d73f 100644 --- a/.github/workflows/check.yaml +++ b/.github/workflows/check.yaml @@ -6,6 +6,7 @@ concurrency: jobs: statics: runs-on: ubuntu-latest + timeout-minutes: 8 steps: - uses: actions/setup-go@v2 with: From 52dd58715804faa5e6ee88bf11183b61b5273570 Mon Sep 17 00:00:00 2001 From: Yongbo Jiang Date: Tue, 31 May 2022 18:16:28 +0800 Subject: [PATCH 4/4] *: Add Limiter Config (#4839) ref tikv/pd#4666 Add Rate Limiter Config for server Signed-off-by: Cabinfever_B Co-authored-by: Ryan Leung Co-authored-by: Ti Chi Robot --- pkg/ratelimit/limiter.go | 27 ++++++++--- pkg/ratelimit/limiter_test.go | 45 +++++++++++------- pkg/ratelimit/option.go | 88 +++++++++++++++++++++++++++++------ 3 files changed, 124 insertions(+), 36 deletions(-) diff --git a/pkg/ratelimit/limiter.go b/pkg/ratelimit/limiter.go index 43f01cea41b..4bf930ed6c5 100644 --- a/pkg/ratelimit/limiter.go +++ b/pkg/ratelimit/limiter.go @@ -20,6 +20,15 @@ import ( "golang.org/x/time/rate" ) +// DimensionConfig is the limit dimension config of one label +type DimensionConfig struct { + // qps conifg + QPS float64 + QPSBurst int + // concurrency config + ConcurrencyLimit uint64 +} + // Limiter is a controller for the request rate. type Limiter struct { qpsLimiter sync.Map @@ -30,7 +39,9 @@ type Limiter struct { // NewLimiter returns a global limiter which can be updated in the later. func NewLimiter() *Limiter { - return &Limiter{labelAllowList: make(map[string]struct{})} + return &Limiter{ + labelAllowList: make(map[string]struct{}), + } } // Allow is used to check whether it has enough token. @@ -65,10 +76,12 @@ func (l *Limiter) Release(label string) { } // Update is used to update Ratelimiter with Options -func (l *Limiter) Update(label string, opts ...Option) { +func (l *Limiter) Update(label string, opts ...Option) UpdateStatus { + var status UpdateStatus for _, opt := range opts { - opt(label, l) + status |= opt(label, l) } + return status } // GetQPSLimiterStatus returns the status of a given label's QPS limiter. @@ -80,8 +93,8 @@ func (l *Limiter) GetQPSLimiterStatus(label string) (limit rate.Limit, burst int return 0, 0 } -// DeleteQPSLimiter deletes QPS limiter of given label -func (l *Limiter) DeleteQPSLimiter(label string) { +// QPSUnlimit deletes QPS limiter of the given label +func (l *Limiter) QPSUnlimit(label string) { l.qpsLimiter.Delete(label) } @@ -94,8 +107,8 @@ func (l *Limiter) GetConcurrencyLimiterStatus(label string) (limit uint64, curre return 0, 0 } -// DeleteConcurrencyLimiter deletes concurrency limiter of given label -func (l *Limiter) DeleteConcurrencyLimiter(label string) { +// ConcurrencyUnlimit deletes concurrency limiter of the given label +func (l *Limiter) ConcurrencyUnlimit(label string) { l.concurrencyLimiter.Delete(label) } diff --git a/pkg/ratelimit/limiter_test.go b/pkg/ratelimit/limiter_test.go index bd095543a05..cf75d76152a 100644 --- a/pkg/ratelimit/limiter_test.go +++ b/pkg/ratelimit/limiter_test.go @@ -34,9 +34,8 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + c.Assert(status&ConcurrencyChanged != 0, IsTrue) var lock sync.Mutex successCount, failedCount := 0, 0 var wg sync.WaitGroup @@ -57,7 +56,11 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { c.Assert(limit, Equals, uint64(10)) c.Assert(current, Equals, uint64(0)) - limiter.Update(label, UpdateConcurrencyLimiter(5)) + status = limiter.Update(label, UpdateConcurrencyLimiter(10)) + c.Assert(status&ConcurrencyNoChange != 0, IsTrue) + + status = limiter.Update(label, UpdateConcurrencyLimiter(5)) + c.Assert(status&ConcurrencyChanged != 0, IsTrue) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -71,7 +74,8 @@ func (s *testRatelimiterSuite) TestUpdateConcurrencyLimiter(c *C) { limiter.Release(label) } - limiter.DeleteConcurrencyLimiter(label) + status = limiter.Update(label, UpdateConcurrencyLimiter(0)) + c.Assert(status&ConcurrencyDeleted != 0, IsTrue) failedCount = 0 successCount = 0 for i := 0; i < 15; i++ { @@ -99,7 +103,8 @@ func (s *testRatelimiterSuite) TestBlockList(c *C) { } c.Assert(limiter.IsInAllowList(label), Equals, true) - UpdateQPSLimiter(rate.Every(time.Second), 1)(label, limiter) + status := UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)(label, limiter) + c.Assert(status&InAllowList != 0, Equals, true) for i := 0; i < 10; i++ { c.Assert(limiter.Allow(label), Equals, true) } @@ -107,13 +112,12 @@ func (s *testRatelimiterSuite) TestBlockList(c *C) { func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(time.Second), 1)} + opts := []Option{UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)} limiter := NewLimiter() label := "test" - for _, opt := range opts { - opt(label, limiter) - } + status := limiter.Update(label, opts...) + c.Assert(status&QPSChanged != 0, IsTrue) var lock sync.Mutex successCount, failedCount := 0, 0 @@ -130,7 +134,11 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { c.Assert(limit, Equals, rate.Limit(1)) c.Assert(burst, Equals, 1) - limiter.Update(label, UpdateQPSLimiter(5, 5)) + status = limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(time.Second)), 1)) + c.Assert(status&QPSNoChange != 0, IsTrue) + + status = limiter.Update(label, UpdateQPSLimiter(5, 5)) + c.Assert(status&QPSChanged != 0, IsTrue) limit, burst = limiter.GetQPSLimiterStatus(label) c.Assert(limit, Equals, rate.Limit(5)) c.Assert(burst, Equals, 5) @@ -144,7 +152,9 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { } } time.Sleep(time.Second) - limiter.DeleteQPSLimiter(label) + + status = limiter.Update(label, UpdateQPSLimiter(0, 0)) + c.Assert(status&QPSDeleted != 0, IsTrue) for i := 0; i < 10; i++ { c.Assert(limiter.Allow(label), Equals, true) } @@ -155,7 +165,7 @@ func (s *testRatelimiterSuite) TestUpdateQPSLimiter(c *C) { func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(rate.Every(3*time.Second), 100)} + opts := []Option{UpdateQPSLimiter(float64(rate.Every(3*time.Second)), 100)} limiter := NewLimiter() label := "test" @@ -184,9 +194,12 @@ func (s *testRatelimiterSuite) TestQPSLimiter(c *C) { func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { c.Parallel() - opts := []Option{UpdateQPSLimiter(100, 100), - UpdateConcurrencyLimiter(100), + cfg := &DimensionConfig{ + QPS: 100, + QPSBurst: 100, + ConcurrencyLimit: 100, } + opts := []Option{UpdateDimensionConfig(cfg)} limiter := NewLimiter() label := "test" @@ -217,7 +230,7 @@ func (s *testRatelimiterSuite) TestTwoLimiters(c *C) { for i := 0; i < 100; i++ { limiter.Release(label) } - limiter.Update(label, UpdateQPSLimiter(rate.Every(10*time.Second), 1)) + limiter.Update(label, UpdateQPSLimiter(float64(rate.Every(10*time.Second)), 1)) wg.Add(100) for i := 0; i < 100; i++ { go CountRateLimiterHandleResult(limiter, label, &successCount, &failedCount, &lock, &wg) diff --git a/pkg/ratelimit/option.go b/pkg/ratelimit/option.go index af98eddb827..53afb9926d4 100644 --- a/pkg/ratelimit/option.go +++ b/pkg/ratelimit/option.go @@ -16,39 +16,101 @@ package ratelimit import "golang.org/x/time/rate" +// UpdateStatus is flags for updating limiter config. +type UpdateStatus uint32 + +// Flags for limiter. +const ( + eps float64 = 1e-8 + // QPSNoChange shows that limiter's config isn't changed. + QPSNoChange UpdateStatus = 1 << iota + // QPSChanged shows that limiter's config is changed and not deleted. + QPSChanged + // QPSDeleted shows that limiter's config is deleted. + QPSDeleted + // ConcurrencyNoChange shows that limiter's config isn't changed. + ConcurrencyNoChange + // ConcurrencyChanged shows that limiter's config is changed and not deleted. + ConcurrencyChanged + // ConcurrencyDeleted shows that limiter's config is deleted. + ConcurrencyDeleted + // InAllowList shows that limiter's config isn't changed because it is in in allow list. + InAllowList +) + // Option is used to create a limiter with the optional settings. // these setting is used to add a kind of limiter for a service -type Option func(string, *Limiter) +type Option func(string, *Limiter) UpdateStatus // AddLabelAllowList adds a label into allow list. // It means the given label will not be limited func AddLabelAllowList() Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { l.labelAllowList[label] = struct{}{} + return 0 + } +} + +func updateConcurrencyConfig(l *Limiter, label string, limit uint64) UpdateStatus { + oldConcurrencyLimit, _ := l.GetConcurrencyLimiterStatus(label) + if oldConcurrencyLimit == limit { + return ConcurrencyNoChange + } + if limit < 1 { + l.ConcurrencyUnlimit(label) + return ConcurrencyDeleted + } + if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { + limiter.(*concurrencyLimiter).setLimit(limit) + } + return ConcurrencyChanged +} + +func updateQPSConfig(l *Limiter, label string, limit float64, burst int) UpdateStatus { + oldQPSLimit, oldBurst := l.GetQPSLimiterStatus(label) + + if (float64(oldQPSLimit)-limit < eps && float64(oldQPSLimit)-limit > -eps) && oldBurst == burst { + return QPSNoChange + } + if limit <= eps || burst < 1 { + l.QPSUnlimit(label) + return QPSDeleted } + if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(limit, burst)); exist { + limiter.(*RateLimiter).SetLimit(rate.Limit(limit)) + limiter.(*RateLimiter).SetBurst(burst) + } + return QPSChanged } // UpdateConcurrencyLimiter creates a concurrency limiter for a given label if it doesn't exist. func UpdateConcurrencyLimiter(limit uint64) Option { - return func(label string, l *Limiter) { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return - } - if limiter, exist := l.concurrencyLimiter.LoadOrStore(label, newConcurrencyLimiter(limit)); exist { - limiter.(*concurrencyLimiter).setLimit(limit) + return InAllowList } + return updateConcurrencyConfig(l, label, limit) } } // UpdateQPSLimiter creates a QPS limiter for a given label if it doesn't exist. -func UpdateQPSLimiter(limit rate.Limit, burst int) Option { - return func(label string, l *Limiter) { +func UpdateQPSLimiter(limit float64, burst int) Option { + return func(label string, l *Limiter) UpdateStatus { if _, allow := l.labelAllowList[label]; allow { - return + return InAllowList } - if limiter, exist := l.qpsLimiter.LoadOrStore(label, NewRateLimiter(float64(limit), burst)); exist { - limiter.(*RateLimiter).SetLimit(limit) - limiter.(*RateLimiter).SetBurst(burst) + return updateQPSConfig(l, label, limit, burst) + } +} + +// UpdateDimensionConfig creates QPS limiter and concurrency limiter for a given label by config if it doesn't exist. +func UpdateDimensionConfig(cfg *DimensionConfig) Option { + return func(label string, l *Limiter) UpdateStatus { + if _, allow := l.labelAllowList[label]; allow { + return InAllowList } + status := updateQPSConfig(l, label, cfg.QPS, cfg.QPSBurst) + status |= updateConcurrencyConfig(l, label, cfg.ConcurrencyLimit) + return status } }