diff --git a/src/commands/cmd_json.cc b/src/commands/cmd_json.cc index 08bcf7abe85..44e5d2098f8 100644 --- a/src/commands/cmd_json.cc +++ b/src/commands/cmd_json.cc @@ -320,6 +320,64 @@ class CommandJsonArrPop : public Commander { int64_t index_ = -1; }; +class CommanderJsonArrIndex : public Commander { + public: + Status Parse(const std::vector &args) override { + if (args.size() < 4 || args.size() > 6) { + return {Status::RedisExecErr, errWrongNumOfArguments}; + } + start_ = 0; + end_ = std::numeric_limits::max(); + + if (args.size() > 4) { + auto parse_start = ParseInt(args[4], 10); + if (parse_start.IsOK()) { + start_ = parse_start.GetValue(); + } else { + return {Status::RedisParseErr, errValueNotInteger}; + } + } + if (args.size() > 5) { + auto parse_end = ParseInt(args[5], 10); + if (parse_end.IsOK()) { + end_ = parse_end.GetValue(); + } else { + return {Status::RedisParseErr, errValueNotInteger}; + } + } + return Status::OK(); + } + + Status Execute(Server *svr, Connection *conn, std::string *output) override { + redis::Json json(svr->storage, conn->GetNamespace()); + + std::vector result; + + auto s = json.ArrIndex(args_[1], args_[2], args_[3], start_, end_, &result); + + if (s.IsNotFound()) { + *output = redis::NilString(); + return Status::OK(); + } + + if (!s.ok()) return {Status::RedisExecErr, s.ToString()}; + + *output = redis::MultiLen(result.size()); + for (const auto &found_index : result) { + if (found_index == NOT_ARRAY) { + *output += redis::NilString(); + continue; + } + *output += redis::Integer(found_index); + } + return Status::OK(); + } + + private: + ssize_t start_; + ssize_t end_; +}; + REDIS_REGISTER_COMMANDS(MakeCmdAttr("json.set", 4, "write", 1, 1, 1), MakeCmdAttr("json.get", -2, "read-only", 1, 1, 1), MakeCmdAttr("json.info", 2, "read-only", 1, 1, 1), @@ -329,5 +387,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr("json.set", 4, "write", 1, 1 MakeCmdAttr("json.toggle", -2, "write", 1, 1, 1), MakeCmdAttr("json.arrlen", -2, "read-only", 1, 1, 1), MakeCmdAttr("json.objkeys", -2, "read-only", 1, 1, 1), - MakeCmdAttr("json.arrpop", -2, "write", 1, 1, 1), ); + MakeCmdAttr("json.arrpop", -2, "write", 1, 1, 1), + MakeCmdAttr("json.arrindex", -4, "read-only", 1, 1, 1), ); + } // namespace redis diff --git a/src/types/json.h b/src/types/json.h index 1b78ec063d1..4aa4786b57a 100644 --- a/src/types/json.h +++ b/src/types/json.h @@ -33,6 +33,9 @@ #include "status.h" +constexpr ssize_t NOT_FOUND_INDEX = -1; +constexpr ssize_t NOT_ARRAY = -2; + struct JsonValue { JsonValue() = default; explicit JsonValue(jsoncons::basic_json value) : value(std::move(value)) {} @@ -173,6 +176,48 @@ struct JsonValue { return result_count; } + static std::pair NormalizeArrIndices(ssize_t start, ssize_t end, ssize_t len) { + if (start < 0) { + start = std::max(0, len + start); + } else { + start = std::min(start, len - 1); + } + if (end == 0) { + end = len; + } else if (end < 0) { + end = std::max(0, len + end); + } + end = std::min(end, len); + return {start, end}; + } + + StatusOr> ArrIndex(std::string_view path, const jsoncons::json &needle, ssize_t start, + ssize_t end) const { + std::vector result; + try { + jsoncons::jsonpath::json_query(value, path, [&](const std::string & /*path*/, const jsoncons::json &val) { + if (!val.is_array()) { + result.emplace_back(NOT_ARRAY); + return; + } + auto [pstart, pend] = NormalizeArrIndices(start, end, static_cast(val.size())); + auto arr_begin = val.array_range().begin(); + auto begin_it = arr_begin + pstart; + + auto end_it = arr_begin + pend; + auto it = std::find(begin_it, end_it, needle); + if (it != end_it) { + result.emplace_back(it - arr_begin); + return; + } + result.emplace_back(NOT_FOUND_INDEX); + }); + } catch (const jsoncons::jsonpath::jsonpath_error &e) { + return {Status::NotOK, e.what()}; + } + return result; + } + StatusOr> Type(std::string_view path) const { std::vector types; try { diff --git a/src/types/redis_json.cc b/src/types/redis_json.cc index 94536605257..3de38a884ba 100644 --- a/src/types/redis_json.cc +++ b/src/types/redis_json.cc @@ -180,6 +180,26 @@ rocksdb::Status Json::ArrAppend(const std::string &user_key, const std::string & return write(ns_key, &metadata, value); } +rocksdb::Status Json::ArrIndex(const std::string &user_key, const std::string &path, const std::string &needle, + ssize_t start, ssize_t end, std::vector *result) { + auto ns_key = AppendNamespacePrefix(user_key); + + auto needle_res = JsonValue::FromString(needle, storage_->GetConfig()->json_max_nesting_depth); + if (!needle_res) return rocksdb::Status::InvalidArgument(needle_res.Msg()); + auto needle_value = *std::move(needle_res); + + JsonMetadata metadata; + JsonValue value; + auto s = read(ns_key, &metadata, &value); + if (!s.ok()) return s; + + auto index_res = value.ArrIndex(path, needle_value.value, start, end); + if (!index_res) return rocksdb::Status::InvalidArgument(index_res.Msg()); + *result = *index_res; + + return rocksdb::Status::OK(); +} + rocksdb::Status Json::Type(const std::string &user_key, const std::string &path, std::vector *results) { auto ns_key = AppendNamespacePrefix(user_key); diff --git a/src/types/redis_json.h b/src/types/redis_json.h index e1f5cde961f..94a054aee4f 100644 --- a/src/types/redis_json.h +++ b/src/types/redis_json.h @@ -48,6 +48,8 @@ class Json : public Database { std::vector>> &keys); rocksdb::Status ArrPop(const std::string &user_key, const std::string &path, int64_t index, std::vector> *results); + rocksdb::Status ArrIndex(const std::string &user_key, const std::string &path, const std::string &needle, + ssize_t start, ssize_t end, std::vector *result); private: rocksdb::Status write(Slice ns_key, JsonMetadata *metadata, const JsonValue &json_val); diff --git a/tests/cppunit/types/json_test.cc b/tests/cppunit/types/json_test.cc index c1c10e8dc2b..c371c6d6bce 100644 --- a/tests/cppunit/types/json_test.cc +++ b/tests/cppunit/types/json_test.cc @@ -398,3 +398,60 @@ TEST_F(RedisJsonTest, ArrPop) { ASSERT_EQ(res[3]->Dump().GetValue(), "1"); res.clear(); } + +TEST_F(RedisJsonTest, ArrIndex) { + std::vector res; + int max_end = std::numeric_limits::max(); + + ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 2, 1, 0]})").ok()); + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 0); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res.size(), 1); + ASSERT_EQ(res[0], 3); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "4", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 1, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 6); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", -1, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 6); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 6, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 6); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, -1, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, 0, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 6); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", -2, 6, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "\"foo\"", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); + + ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 4, 2, 1, 0]})").ok()); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 3); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 3, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 5); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "1", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 1); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 1, 4, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], 2); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "6", 0, max_end, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); + + ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, 2, &res).ok() && res.size() == 1); + ASSERT_EQ(res[0], -1); +} diff --git a/tests/gocase/unit/type/json/json_test.go b/tests/gocase/unit/type/json/json_test.go index 0d5f3fd9517..2e97b3b8e32 100644 --- a/tests/gocase/unit/type/json/json_test.go +++ b/tests/gocase/unit/type/json/json_test.go @@ -275,4 +275,42 @@ func TestJson(t *testing.T) { require.Equal(t, rdb.Do(ctx, "JSON.GET", "a").Val(), `[99,false,99]`) }) + t.Run("JSON.ARRINDEX basics", func(t *testing.T) { + arrIndexCmd := "JSON.ARRINDEX" + require.NoError(t, rdb.Do(ctx, "SET", "a", `1`).Err()) + require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", `1`).Err()) + require.NoError(t, rdb.Do(ctx, "DEL", "a").Err()) + + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", ` {"x":1, "y": {"x":1} } `).Err()) + require.Equal(t, []interface{}{}, rdb.Do(ctx, arrIndexCmd, "a", "$..k", `1`).Val()) + require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$").Err()) + require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", ` 1, 2, 3`).Err()) + + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2,3,2,1,0]}`).Err()) + require.Equal(t, []interface{}{int64(0)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`).Val()) + require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `4`).Val()) + require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 1).Val()) + require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, -1).Val()) + require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 6).Val()) + require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 4, -0).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 5, -1).Val()) + require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 5, 0).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, -2, 6).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `"foo"`).Val()) + + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2,3,4,2,1,0]}`).Err()) + + require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`).Val()) + require.Equal(t, []interface{}{int64(5)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, 3).Val()) + require.Equal(t, []interface{}{int64(1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `1`).Val()) + require.Equal(t, []interface{}{int64(2)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, 1, 4).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `6`).Val()) + require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`, 0, 2).Val()) + + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2]}`).Err()) + require.Equal(t, []interface{}{nil, nil, nil}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr.*", `1`).Val()) + require.NoError(t, rdb.Do(ctx, "JSON.SET", "a1", "$", `{"arr":[[1],[2],[3]]}`).Err()) + require.Equal(t, []interface{}{int64(0), int64(-1), int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a1", "$.arr.*", `1`).Val()) + }) }