From b1734834a52d0b679b838100ce00dd31efa975ef Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 6 Jun 2024 13:02:56 +0800 Subject: [PATCH 1/6] add hex Signed-off-by: husharp --- pkg/core/region_tree.go | 2 +- server/api/region.go | 19 +++++++++++++++++++ server/api/region_test.go | 29 +++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index d4ef4a880fc..cf146c05f90 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -272,7 +272,7 @@ func (t *regionTree) find(item *regionItem) *regionItem { // until f return false func (t *regionTree) scanRange(startKey []byte, f func(*RegionInfo) bool) { region := &RegionInfo{meta: &metapb.Region{StartKey: startKey}} - // find if there is a region with key range [s, d), s < startKey < d + // find if there is a region with key range [s, d), s <= startKey < d fn := func(item *regionItem) bool { r := item return f(r.RegionInfo) diff --git a/server/api/region.go b/server/api/region.go index 974b5e4fa12..41c3c27b471 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -176,6 +176,25 @@ func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) startKey := r.URL.Query().Get("key") endKey := r.URL.Query().Get("end_key") + + // decode hex if query has params with hex format + formatStr := r.URL.Query().Get("format") + if formatStr == "hex" { + keyBytes, err := hex.DecodeString(startKey) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + startKey = string(keyBytes) + + keyBytes, err = hex.DecodeString(endKey) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + endKey = string(keyBytes) + } + limit, err := h.AdjustLimit(r.URL.Query().Get("limit")) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) diff --git a/server/api/region_test.go b/server/api/region_test.go index 0e5dcd97678..97f9bfb5478 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -468,6 +468,35 @@ func (suite *getRegionTestSuite) TestScanRegionByKeys() { for i, v := range regionIDs { re.Equal(regions.Regions[i].ID, v) } + url = fmt.Sprintf("%s/regions/key?key=%s&format=hex", suite.urlPrefix, hex.EncodeToString([]byte("b"))) + regionIDs = []uint64{3, 4, 5, 99} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s&format=hex", + suite.urlPrefix, hex.EncodeToString([]byte("b")), hex.EncodeToString([]byte("g"))) + regionIDs = []uint64{3, 4} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s&format=hex", + suite.urlPrefix, hex.EncodeToString([]byte("b")), hex.EncodeToString([]byte{0xFF, 0xFF, 0xCC})) + regionIDs = []uint64{3, 4, 5, 99} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } } // Start a new test suite to prevent from being interfered by other tests. From 616a2459e52c6583cc93e78c8dde2e1b547dd4ee Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 6 Jun 2024 14:02:28 +0800 Subject: [PATCH 2/6] extract common func Signed-off-by: husharp --- pkg/utils/apiutil/apiutil.go | 16 ++++++++++++++++ server/api/region.go | 32 ++++++++------------------------ 2 files changed, 24 insertions(+), 24 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index d0745ada271..b98215bb3e6 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -336,6 +336,22 @@ func ParseKey(name string, input map[string]any) ([]byte, string, error) { return returned, rawKey, nil } +// ParseHexKeys decodes hexadecimal keys to bytes if the format is "hex". +func ParseHexKeys(format string, keys ...*string) error { + if format != "hex" { + return nil + } + + for _, key := range keys { + keyBytes, err := hex.DecodeString(*key) + if err != nil { + return err + } + *key = string(keyBytes) + } + return nil +} + // ReadJSON reads a JSON data from r and then closes it. // An error due to invalid json will be returned as a JSONError func ReadJSON(r io.ReadCloser, data any) error { diff --git a/server/api/region.go b/server/api/region.go index 41c3c27b471..e331987c7a4 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -16,7 +16,6 @@ package api import ( "container/heap" - "encoding/hex" "fmt" "net/http" "net/url" @@ -93,14 +92,10 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { return } // decode hex if query has params with hex format - formatStr := r.URL.Query().Get("format") - if formatStr == "hex" { - keyBytes, err := hex.DecodeString(key) - if err != nil { - h.rd.JSON(w, http.StatusBadRequest, err.Error()) - return - } - key = string(keyBytes) + err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), &key) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return } regionInfo := rc.GetRegionByKey([]byte(key)) @@ -178,21 +173,10 @@ func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { endKey := r.URL.Query().Get("end_key") // decode hex if query has params with hex format - formatStr := r.URL.Query().Get("format") - if formatStr == "hex" { - keyBytes, err := hex.DecodeString(startKey) - if err != nil { - h.rd.JSON(w, http.StatusBadRequest, err.Error()) - return - } - startKey = string(keyBytes) - - keyBytes, err = hex.DecodeString(endKey) - if err != nil { - h.rd.JSON(w, http.StatusBadRequest, err.Error()) - return - } - endKey = string(keyBytes) + err := apiutil.ParseHexKeys(r.URL.Query().Get("format"), &startKey, &endKey) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return } limit, err := h.AdjustLimit(r.URL.Query().Get("limit")) From 509c5fbc896c34c8a0c4ba38f52418c8f9ab86f9 Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 6 Jun 2024 18:02:10 +0800 Subject: [PATCH 3/6] address comment Signed-off-by: husharp --- pkg/utils/apiutil/apiutil.go | 12 ++++++------ pkg/utils/apiutil/apiutil_test.go | 18 ++++++++++++++++++ server/api/region.go | 14 +++++++------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index b98215bb3e6..664da056c13 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -337,19 +337,19 @@ func ParseKey(name string, input map[string]any) ([]byte, string, error) { } // ParseHexKeys decodes hexadecimal keys to bytes if the format is "hex". -func ParseHexKeys(format string, keys ...*string) error { +func ParseHexKeys(format string, keys []string) (hexStrings []string, err error) { if format != "hex" { - return nil + return keys, nil } for _, key := range keys { - keyBytes, err := hex.DecodeString(*key) + keyBytes, err := hex.DecodeString(key) if err != nil { - return err + return keys, err } - *key = string(keyBytes) + hexStrings = append(hexStrings, string(keyBytes)) } - return nil + return hexStrings, nil } // ReadJSON reads a JSON data from r and then closes it. diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index aee21621dd2..ec82ab53af9 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -204,3 +204,21 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { re.Equal(testCase.port, port, "case %d", idx) } } + +func TestParseHexKeys(t *testing.T) { + re := require.New(t) + hexKeys := []string{"", "67", "0001020304050607", "08090a0b0c0d0e0f", "f0f1f2f3f4f5f6f7"} + parseKeys, err := ParseHexKeys("hex", hexKeys) + re.NoError(err) + re.Equal(parseKeys, []string{"", "g", "\x00\x01\x02\x03\x04\x05\x06\x07", "\x08\t\n\x0b\x0c\r\x0e\x0f", "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7"}) + // Test for other format + hexKeys = []string{"hello"} + parseKeys, err = ParseHexKeys("other", hexKeys) + re.NoError(err) + re.Equal(parseKeys, []string{"hello"}) + // Test for wrong keys + hexKeys = []string{"world"} + parseKeys, err = ParseHexKeys("hex", hexKeys) + re.Error(err) + re.Equal(parseKeys, []string{"world"}) +} diff --git a/server/api/region.go b/server/api/region.go index e331987c7a4..93eeb0c0320 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -92,13 +92,14 @@ func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { return } // decode hex if query has params with hex format - err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), &key) + params := []string{key} + params, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), params) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - regionInfo := rc.GetRegionByKey([]byte(key)) + regionInfo := rc.GetRegionByKey([]byte(params[0])) b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) @@ -169,11 +170,10 @@ func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { // @Router /regions/key [get] func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) - startKey := r.URL.Query().Get("key") - endKey := r.URL.Query().Get("end_key") - + var err error // decode hex if query has params with hex format - err := apiutil.ParseHexKeys(r.URL.Query().Get("format"), &startKey, &endKey) + params := []string{r.URL.Query().Get("key"), r.URL.Query().Get("end_key")} + params, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), params) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return @@ -185,7 +185,7 @@ func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { return } - regions := rc.ScanRegions([]byte(startKey), []byte(endKey), limit) + regions := rc.ScanRegions([]byte(params[0]), []byte(params[1]), limit) b, err := response.MarshalRegionsInfoJSON(r.Context(), regions) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) From ea750a21cec0b018046e8b17e99abc77d77455ac Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 6 Jun 2024 18:07:12 +0800 Subject: [PATCH 4/6] make static happy Signed-off-by: husharp --- pkg/utils/apiutil/apiutil_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index ec82ab53af9..b8885b5ebfb 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -210,15 +210,15 @@ func TestParseHexKeys(t *testing.T) { hexKeys := []string{"", "67", "0001020304050607", "08090a0b0c0d0e0f", "f0f1f2f3f4f5f6f7"} parseKeys, err := ParseHexKeys("hex", hexKeys) re.NoError(err) - re.Equal(parseKeys, []string{"", "g", "\x00\x01\x02\x03\x04\x05\x06\x07", "\x08\t\n\x0b\x0c\r\x0e\x0f", "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7"}) + re.Equal([]string{"", "g", "\x00\x01\x02\x03\x04\x05\x06\x07", "\x08\t\n\x0b\x0c\r\x0e\x0f", "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7"}, parseKeys) // Test for other format hexKeys = []string{"hello"} parseKeys, err = ParseHexKeys("other", hexKeys) re.NoError(err) - re.Equal(parseKeys, []string{"hello"}) + re.Equal([]string{"hello"}, parseKeys) // Test for wrong keys hexKeys = []string{"world"} parseKeys, err = ParseHexKeys("hex", hexKeys) re.Error(err) - re.Equal(parseKeys, []string{"world"}) + re.Equal([]string{"world"}, parseKeys) } From 3f515a6d39dad824fb9499492988c6dd7af65dd4 Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 13 Jun 2024 14:07:43 +0800 Subject: [PATCH 5/6] change function to bytes and add more comment Signed-off-by: husharp --- pkg/utils/apiutil/apiutil.go | 18 ++++++++++---- pkg/utils/apiutil/apiutil_test.go | 40 ++++++++++++++++++++++--------- server/api/region.go | 20 +++++++--------- 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index 664da056c13..20465d8376c 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -336,20 +336,28 @@ func ParseKey(name string, input map[string]any) ([]byte, string, error) { return returned, rawKey, nil } -// ParseHexKeys decodes hexadecimal keys to bytes if the format is "hex". -func ParseHexKeys(format string, keys []string) (hexStrings []string, err error) { +// ParseHexKeys decodes hexadecimal src into DecodedLen(len(src)) bytes if the format is "hex". +// +// ParseHexKeys expects that each key contains only +// hexadecimal characters and each key has even length. +// If existing one key is malformed, ParseHexKeys returns +// the original bytes. +func ParseHexKeys(format string, keys [][]byte) (decodedBytes [][]byte, err error) { if format != "hex" { return keys, nil } for _, key := range keys { - keyBytes, err := hex.DecodeString(key) + // We can use the source slice itself as the destination + // because the decode loop increments by one and then the 'seen' byte is not used anymore. + // Reference to hex.DecodeString() + n, err := hex.Decode(key, key) if err != nil { return keys, err } - hexStrings = append(hexStrings, string(keyBytes)) + decodedBytes = append(decodedBytes, key[:n]) } - return hexStrings, nil + return decodedBytes, nil } // ReadJSON reads a JSON data from r and then closes it. diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index b8885b5ebfb..3e8a998d5fd 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -207,18 +207,36 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { func TestParseHexKeys(t *testing.T) { re := require.New(t) - hexKeys := []string{"", "67", "0001020304050607", "08090a0b0c0d0e0f", "f0f1f2f3f4f5f6f7"} - parseKeys, err := ParseHexKeys("hex", hexKeys) + // Test for hex format + hexBytes := [][]byte{[]byte(""), []byte("67"), []byte("0001020304050607"), []byte("08090a0b0c0d0e0f"), []byte("f0f1f2f3f4f5f6f7")} + parseKeys, err := ParseHexKeys("hex", hexBytes) re.NoError(err) - re.Equal([]string{"", "g", "\x00\x01\x02\x03\x04\x05\x06\x07", "\x08\t\n\x0b\x0c\r\x0e\x0f", "\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7"}, parseKeys) - // Test for other format - hexKeys = []string{"hello"} - parseKeys, err = ParseHexKeys("other", hexKeys) + expectedBytes := [][]byte{[]byte(""), []byte("g"), []byte("\x00\x01\x02\x03\x04\x05\x06\x07"), []byte("\x08\t\n\x0b\x0c\r\x0e\x0f"), []byte("\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7")} + re.Equal(expectedBytes, parseKeys) + // Test for other format NOT hex + hexBytes = [][]byte{[]byte("hello")} + parseKeys, err = ParseHexKeys("other", hexBytes) re.NoError(err) - re.Equal([]string{"hello"}, parseKeys) - // Test for wrong keys - hexKeys = []string{"world"} - parseKeys, err = ParseHexKeys("hex", hexKeys) + re.Len(parseKeys, 1) + re.Equal([]byte("hello"), parseKeys[0]) + // Test for wrong key + hexBytes = [][]byte{[]byte("world")} + parseKeys, err = ParseHexKeys("hex", hexBytes) re.Error(err) - re.Equal([]string{"world"}, parseKeys) + re.Len(parseKeys, 1) + re.Equal([]byte("world"), parseKeys[0]) + // Test for the first key is not valid, but the second key is valid + hexBytes = [][]byte{[]byte("world"), []byte("0001020304050607")} + parseKeys, err = ParseHexKeys("hex", hexBytes) + re.Error(err) + re.Len(parseKeys, 2) + re.Equal([]byte("world"), parseKeys[0]) + re.NotEqual([]byte("\x00\x01\x02\x03\x04\x05\x06\x07"), parseKeys[1]) + // Test for the first key is valid, but the second key is not valid + hexBytes = [][]byte{[]byte("0001020304050607"), []byte("world")} + parseKeys, err = ParseHexKeys("hex", hexBytes) + re.Error(err) + re.Len(parseKeys, 2) + re.NotEqual([]byte("\x00\x01\x02\x03\x04\x05\x06\x07"), parseKeys[0]) + re.Equal([]byte("world"), parseKeys[1]) } diff --git a/server/api/region.go b/server/api/region.go index 93eeb0c0320..c6bc3d9e699 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -85,21 +85,20 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) - key := vars["key"] - key, err := url.QueryUnescape(key) + key, err := url.QueryUnescape(vars["key"]) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } // decode hex if query has params with hex format - params := []string{key} - params, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), params) + paramsByte := [][]byte{[]byte(key)} + paramsByte, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), paramsByte) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - regionInfo := rc.GetRegionByKey([]byte(params[0])) + regionInfo := rc.GetRegionByKey(paramsByte[0]) b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) @@ -170,22 +169,21 @@ func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { // @Router /regions/key [get] func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) - var err error - // decode hex if query has params with hex format - params := []string{r.URL.Query().Get("key"), r.URL.Query().Get("end_key")} - params, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), params) + query := r.URL.Query() + paramsByte := [][]byte{[]byte(query.Get("key")), []byte(query.Get("end_key"))} + paramsByte, err := apiutil.ParseHexKeys(query.Get("format"), paramsByte) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - limit, err := h.AdjustLimit(r.URL.Query().Get("limit")) + limit, err := h.AdjustLimit(query.Get("limit")) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - regions := rc.ScanRegions([]byte(params[0]), []byte(params[1]), limit) + regions := rc.ScanRegions(paramsByte[0], paramsByte[1], limit) b, err := response.MarshalRegionsInfoJSON(r.Context(), regions) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) From 9377fd5c33ba9e1cb2ddc519e7bf6d883c5f52e7 Mon Sep 17 00:00:00 2001 From: husharp Date: Thu, 13 Jun 2024 14:25:38 +0800 Subject: [PATCH 6/6] add invalid key test Signed-off-by: husharp --- server/api/region_test.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/server/api/region_test.go b/server/api/region_test.go index 97f9bfb5478..88632232175 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -497,6 +497,12 @@ func (suite *getRegionTestSuite) TestScanRegionByKeys() { for i, v := range regionIDs { re.Equal(regions.Regions[i].ID, v) } + // test invalid key + url = fmt.Sprintf("%s/regions/key?key=%s&format=hex", suite.urlPrefix, "invalid") + err = tu.CheckGetJSON(testDialClient, url, nil, + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"encoding/hex: invalid byte: U+0069 'i'\"\n")) + re.NoError(err) } // Start a new test suite to prevent from being interfered by other tests.