diff --git a/pkg/core/region_tree.go b/pkg/core/region_tree.go index d4ef4a880fc..cf146c05f90 100644 --- a/pkg/core/region_tree.go +++ b/pkg/core/region_tree.go @@ -272,7 +272,7 @@ func (t *regionTree) find(item *regionItem) *regionItem { // until f return false func (t *regionTree) scanRange(startKey []byte, f func(*RegionInfo) bool) { region := &RegionInfo{meta: &metapb.Region{StartKey: startKey}} - // find if there is a region with key range [s, d), s < startKey < d + // find if there is a region with key range [s, d), s <= startKey < d fn := func(item *regionItem) bool { r := item return f(r.RegionInfo) diff --git a/pkg/utils/apiutil/apiutil.go b/pkg/utils/apiutil/apiutil.go index d0745ada271..20465d8376c 100644 --- a/pkg/utils/apiutil/apiutil.go +++ b/pkg/utils/apiutil/apiutil.go @@ -336,6 +336,30 @@ func ParseKey(name string, input map[string]any) ([]byte, string, error) { return returned, rawKey, nil } +// ParseHexKeys decodes hexadecimal src into DecodedLen(len(src)) bytes if the format is "hex". +// +// ParseHexKeys expects that each key contains only +// hexadecimal characters and each key has even length. +// If existing one key is malformed, ParseHexKeys returns +// the original bytes. +func ParseHexKeys(format string, keys [][]byte) (decodedBytes [][]byte, err error) { + if format != "hex" { + return keys, nil + } + + for _, key := range keys { + // We can use the source slice itself as the destination + // because the decode loop increments by one and then the 'seen' byte is not used anymore. + // Reference to hex.DecodeString() + n, err := hex.Decode(key, key) + if err != nil { + return keys, err + } + decodedBytes = append(decodedBytes, key[:n]) + } + return decodedBytes, nil +} + // ReadJSON reads a JSON data from r and then closes it. // An error due to invalid json will be returned as a JSONError func ReadJSON(r io.ReadCloser, data any) error { diff --git a/pkg/utils/apiutil/apiutil_test.go b/pkg/utils/apiutil/apiutil_test.go index aee21621dd2..3e8a998d5fd 100644 --- a/pkg/utils/apiutil/apiutil_test.go +++ b/pkg/utils/apiutil/apiutil_test.go @@ -204,3 +204,39 @@ func TestGetIPPortFromHTTPRequest(t *testing.T) { re.Equal(testCase.port, port, "case %d", idx) } } + +func TestParseHexKeys(t *testing.T) { + re := require.New(t) + // Test for hex format + hexBytes := [][]byte{[]byte(""), []byte("67"), []byte("0001020304050607"), []byte("08090a0b0c0d0e0f"), []byte("f0f1f2f3f4f5f6f7")} + parseKeys, err := ParseHexKeys("hex", hexBytes) + re.NoError(err) + expectedBytes := [][]byte{[]byte(""), []byte("g"), []byte("\x00\x01\x02\x03\x04\x05\x06\x07"), []byte("\x08\t\n\x0b\x0c\r\x0e\x0f"), []byte("\xf0\xf1\xf2\xf3\xf4\xf5\xf6\xf7")} + re.Equal(expectedBytes, parseKeys) + // Test for other format NOT hex + hexBytes = [][]byte{[]byte("hello")} + parseKeys, err = ParseHexKeys("other", hexBytes) + re.NoError(err) + re.Len(parseKeys, 1) + re.Equal([]byte("hello"), parseKeys[0]) + // Test for wrong key + hexBytes = [][]byte{[]byte("world")} + parseKeys, err = ParseHexKeys("hex", hexBytes) + re.Error(err) + re.Len(parseKeys, 1) + re.Equal([]byte("world"), parseKeys[0]) + // Test for the first key is not valid, but the second key is valid + hexBytes = [][]byte{[]byte("world"), []byte("0001020304050607")} + parseKeys, err = ParseHexKeys("hex", hexBytes) + re.Error(err) + re.Len(parseKeys, 2) + re.Equal([]byte("world"), parseKeys[0]) + re.NotEqual([]byte("\x00\x01\x02\x03\x04\x05\x06\x07"), parseKeys[1]) + // Test for the first key is valid, but the second key is not valid + hexBytes = [][]byte{[]byte("0001020304050607"), []byte("world")} + parseKeys, err = ParseHexKeys("hex", hexBytes) + re.Error(err) + re.Len(parseKeys, 2) + re.NotEqual([]byte("\x00\x01\x02\x03\x04\x05\x06\x07"), parseKeys[0]) + re.Equal([]byte("world"), parseKeys[1]) +} diff --git a/server/api/region.go b/server/api/region.go index 974b5e4fa12..c6bc3d9e699 100644 --- a/server/api/region.go +++ b/server/api/region.go @@ -16,7 +16,6 @@ package api import ( "container/heap" - "encoding/hex" "fmt" "net/http" "net/url" @@ -86,24 +85,20 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) { func (h *regionHandler) GetRegion(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) vars := mux.Vars(r) - key := vars["key"] - key, err := url.QueryUnescape(key) + key, err := url.QueryUnescape(vars["key"]) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } // decode hex if query has params with hex format - formatStr := r.URL.Query().Get("format") - if formatStr == "hex" { - keyBytes, err := hex.DecodeString(key) - if err != nil { - h.rd.JSON(w, http.StatusBadRequest, err.Error()) - return - } - key = string(keyBytes) + paramsByte := [][]byte{[]byte(key)} + paramsByte, err = apiutil.ParseHexKeys(r.URL.Query().Get("format"), paramsByte) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return } - regionInfo := rc.GetRegionByKey([]byte(key)) + regionInfo := rc.GetRegionByKey(paramsByte[0]) b, err := response.MarshalRegionInfoJSON(r.Context(), regionInfo) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) @@ -174,15 +169,21 @@ func (h *regionsHandler) GetRegions(w http.ResponseWriter, r *http.Request) { // @Router /regions/key [get] func (h *regionsHandler) ScanRegions(w http.ResponseWriter, r *http.Request) { rc := getCluster(r) - startKey := r.URL.Query().Get("key") - endKey := r.URL.Query().Get("end_key") - limit, err := h.AdjustLimit(r.URL.Query().Get("limit")) + query := r.URL.Query() + paramsByte := [][]byte{[]byte(query.Get("key")), []byte(query.Get("end_key"))} + paramsByte, err := apiutil.ParseHexKeys(query.Get("format"), paramsByte) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + + limit, err := h.AdjustLimit(query.Get("limit")) if err != nil { h.rd.JSON(w, http.StatusBadRequest, err.Error()) return } - regions := rc.ScanRegions([]byte(startKey), []byte(endKey), limit) + regions := rc.ScanRegions(paramsByte[0], paramsByte[1], limit) b, err := response.MarshalRegionsInfoJSON(r.Context(), regions) if err != nil { h.rd.JSON(w, http.StatusInternalServerError, err.Error()) diff --git a/server/api/region_test.go b/server/api/region_test.go index 0e5dcd97678..88632232175 100644 --- a/server/api/region_test.go +++ b/server/api/region_test.go @@ -468,6 +468,41 @@ func (suite *getRegionTestSuite) TestScanRegionByKeys() { for i, v := range regionIDs { re.Equal(regions.Regions[i].ID, v) } + url = fmt.Sprintf("%s/regions/key?key=%s&format=hex", suite.urlPrefix, hex.EncodeToString([]byte("b"))) + regionIDs = []uint64{3, 4, 5, 99} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s&format=hex", + suite.urlPrefix, hex.EncodeToString([]byte("b")), hex.EncodeToString([]byte("g"))) + regionIDs = []uint64{3, 4} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } + url = fmt.Sprintf("%s/regions/key?key=%s&end_key=%s&format=hex", + suite.urlPrefix, hex.EncodeToString([]byte("b")), hex.EncodeToString([]byte{0xFF, 0xFF, 0xCC})) + regionIDs = []uint64{3, 4, 5, 99} + regions = &response.RegionsInfo{} + err = tu.ReadGetJSON(re, testDialClient, url, regions) + re.NoError(err) + re.Len(regionIDs, regions.Count) + for i, v := range regionIDs { + re.Equal(regions.Regions[i].ID, v) + } + // test invalid key + url = fmt.Sprintf("%s/regions/key?key=%s&format=hex", suite.urlPrefix, "invalid") + err = tu.CheckGetJSON(testDialClient, url, nil, + tu.Status(re, http.StatusBadRequest), + tu.StringEqual(re, "\"encoding/hex: invalid byte: U+0069 'i'\"\n")) + re.NoError(err) } // Start a new test suite to prevent from being interfered by other tests.