Skip to content

Commit

Permalink
Add the support of JSON.ARRINDEX command (#1865)
Browse files Browse the repository at this point in the history
  • Loading branch information
skyitachi authored Nov 10, 2023
1 parent 1083142 commit 319de1e
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 1 deletion.
52 changes: 51 additions & 1 deletion src/commands/cmd_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,54 @@ class CommandJsonArrPop : public Commander {
int64_t index_ = -1;
};

class CommanderJsonArrIndex : public Commander {
public:
Status Parse(const std::vector<std::string> &args) override {
if (args.size() > 6) {
return {Status::RedisExecErr, errWrongNumOfArguments};
}
start_ = 0;
end_ = std::numeric_limits<ssize_t>::max();

if (args.size() > 4) {
start_ = GET_OR_RET(ParseInt<ssize_t>(args[4], 10));
}
if (args.size() > 5) {
end_ = GET_OR_RET(ParseInt<ssize_t>(args[5], 10));
}
return Status::OK();
}

Status Execute(Server *svr, Connection *conn, std::string *output) override {
redis::Json json(svr->storage, conn->GetNamespace());

std::vector<ssize_t> result;

auto s = json.ArrIndex(args_[1], args_[2], args_[3], start_, end_, &result);

if (s.IsNotFound()) {
*output = redis::NilString();
return Status::OK();
}

if (!s.ok()) return {Status::RedisExecErr, s.ToString()};

*output = redis::MultiLen(result.size());
for (const auto &found_index : result) {
if (found_index == NOT_ARRAY) {
*output += redis::NilString();
continue;
}
*output += redis::Integer(found_index);
}
return Status::OK();
}

private:
ssize_t start_;
ssize_t end_;
};

REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandJsonSet>("json.set", 4, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonGet>("json.get", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandJsonInfo>("json.info", 2, "read-only", 1, 1, 1),
Expand All @@ -329,5 +377,7 @@ REDIS_REGISTER_COMMANDS(MakeCmdAttr<CommandJsonSet>("json.set", 4, "write", 1, 1
MakeCmdAttr<CommandJsonToggle>("json.toggle", -2, "write", 1, 1, 1),
MakeCmdAttr<CommandJsonArrLen>("json.arrlen", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandJsonObjkeys>("json.objkeys", -2, "read-only", 1, 1, 1),
MakeCmdAttr<CommandJsonArrPop>("json.arrpop", -2, "write", 1, 1, 1), );
MakeCmdAttr<CommandJsonArrPop>("json.arrpop", -2, "write", 1, 1, 1),
MakeCmdAttr<CommanderJsonArrIndex>("json.arrindex", -4, "read-only", 1, 1, 1), );

} // namespace redis
45 changes: 45 additions & 0 deletions src/types/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@

#include "status.h"

constexpr ssize_t NOT_FOUND_INDEX = -1;
constexpr ssize_t NOT_ARRAY = -2;

struct JsonValue {
JsonValue() = default;
explicit JsonValue(jsoncons::basic_json<char> value) : value(std::move(value)) {}
Expand Down Expand Up @@ -173,6 +176,48 @@ struct JsonValue {
return result_count;
}

static std::pair<ssize_t, ssize_t> NormalizeArrIndices(ssize_t start, ssize_t end, ssize_t len) {
if (start < 0) {
start = std::max<ssize_t>(0, len + start);
} else {
start = std::min<ssize_t>(start, len - 1);
}
if (end == 0) {
end = len;
} else if (end < 0) {
end = std::max<ssize_t>(0, len + end);
}
end = std::min<ssize_t>(end, len);
return {start, end};
}

StatusOr<std::vector<ssize_t>> ArrIndex(std::string_view path, const jsoncons::json &needle, ssize_t start,
ssize_t end) const {
std::vector<ssize_t> result;
try {
jsoncons::jsonpath::json_query(value, path, [&](const std::string & /*path*/, const jsoncons::json &val) {
if (!val.is_array()) {
result.emplace_back(NOT_ARRAY);
return;
}
auto [pstart, pend] = NormalizeArrIndices(start, end, static_cast<ssize_t>(val.size()));
auto arr_begin = val.array_range().begin();
auto begin_it = arr_begin + pstart;

auto end_it = arr_begin + pend;
auto it = std::find(begin_it, end_it, needle);
if (it != end_it) {
result.emplace_back(it - arr_begin);
return;
}
result.emplace_back(NOT_FOUND_INDEX);
});
} catch (const jsoncons::jsonpath::jsonpath_error &e) {
return {Status::NotOK, e.what()};
}
return result;
}

StatusOr<std::vector<std::string>> Type(std::string_view path) const {
std::vector<std::string> types;
try {
Expand Down
20 changes: 20 additions & 0 deletions src/types/redis_json.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ rocksdb::Status Json::ArrAppend(const std::string &user_key, const std::string &
return write(ns_key, &metadata, value);
}

rocksdb::Status Json::ArrIndex(const std::string &user_key, const std::string &path, const std::string &needle,
ssize_t start, ssize_t end, std::vector<ssize_t> *result) {
auto ns_key = AppendNamespacePrefix(user_key);

auto needle_res = JsonValue::FromString(needle, storage_->GetConfig()->json_max_nesting_depth);
if (!needle_res) return rocksdb::Status::InvalidArgument(needle_res.Msg());
auto needle_value = *std::move(needle_res);

JsonMetadata metadata;
JsonValue value;
auto s = read(ns_key, &metadata, &value);
if (!s.ok()) return s;

auto index_res = value.ArrIndex(path, needle_value.value, start, end);
if (!index_res) return rocksdb::Status::InvalidArgument(index_res.Msg());
*result = *index_res;

return rocksdb::Status::OK();
}

rocksdb::Status Json::Type(const std::string &user_key, const std::string &path, std::vector<std::string> *results) {
auto ns_key = AppendNamespacePrefix(user_key);

Expand Down
2 changes: 2 additions & 0 deletions src/types/redis_json.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Json : public Database {
std::vector<std::optional<std::vector<std::string>>> &keys);
rocksdb::Status ArrPop(const std::string &user_key, const std::string &path, int64_t index,
std::vector<std::optional<JsonValue>> *results);
rocksdb::Status ArrIndex(const std::string &user_key, const std::string &path, const std::string &needle,
ssize_t start, ssize_t end, std::vector<ssize_t> *result);

private:
rocksdb::Status write(Slice ns_key, JsonMetadata *metadata, const JsonValue &json_val);
Expand Down
57 changes: 57 additions & 0 deletions tests/cppunit/types/json_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -398,3 +398,60 @@ TEST_F(RedisJsonTest, ArrPop) {
ASSERT_EQ(res[3]->Dump().GetValue(), "1");
res.clear();
}

TEST_F(RedisJsonTest, ArrIndex) {
std::vector<ssize_t> res;
int max_end = std::numeric_limits<int>::max();

ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 2, 1, 0]})").ok());
ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 0);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res.size(), 1);
ASSERT_EQ(res[0], 3);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "4", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 1, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 6);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", -1, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 6);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 6, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 6);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, -1, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "0", 5, 0, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 6);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", -2, 6, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "\"foo\"", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);

ASSERT_TRUE(json_->Set(key_, "$", R"({"arr":[0, 1, 2, 3, 4, 2, 1, 0]})").ok());

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 3);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 3, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 5);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "1", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 1);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "2", 1, 4, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], 2);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "6", 0, max_end, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);

ASSERT_TRUE(json_->ArrIndex(key_, "$.arr", "3", 0, 2, &res).ok() && res.size() == 1);
ASSERT_EQ(res[0], -1);
}
38 changes: 38 additions & 0 deletions tests/gocase/unit/type/json/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,4 +275,42 @@ func TestJson(t *testing.T) {
require.Equal(t, rdb.Do(ctx, "JSON.GET", "a").Val(), `[99,false,99]`)
})

t.Run("JSON.ARRINDEX basics", func(t *testing.T) {
arrIndexCmd := "JSON.ARRINDEX"
require.NoError(t, rdb.Do(ctx, "SET", "a", `1`).Err())
require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", `1`).Err())
require.NoError(t, rdb.Do(ctx, "DEL", "a").Err())

require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", ` {"x":1, "y": {"x":1} } `).Err())
require.Equal(t, []interface{}{}, rdb.Do(ctx, arrIndexCmd, "a", "$..k", `1`).Val())
require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$").Err())
require.Error(t, rdb.Do(ctx, arrIndexCmd, "a", "$", ` 1, 2, 3`).Err())

require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2,3,2,1,0]}`).Err())
require.Equal(t, []interface{}{int64(0)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`).Val())
require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `4`).Val())
require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 1).Val())
require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, -1).Val())
require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 6).Val())
require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 4, -0).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 5, -1).Val())
require.Equal(t, []interface{}{int64(6)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `0`, 5, 0).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, -2, 6).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `"foo"`).Val())

require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2,3,4,2,1,0]}`).Err())

require.Equal(t, []interface{}{int64(3)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`).Val())
require.Equal(t, []interface{}{int64(5)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, 3).Val())
require.Equal(t, []interface{}{int64(1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `1`).Val())
require.Equal(t, []interface{}{int64(2)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `2`, 1, 4).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `6`).Val())
require.Equal(t, []interface{}{int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr", `3`, 0, 2).Val())

require.NoError(t, rdb.Do(ctx, "JSON.SET", "a", "$", `{"arr":[0,1,2]}`).Err())
require.Equal(t, []interface{}{nil, nil, nil}, rdb.Do(ctx, arrIndexCmd, "a", "$.arr.*", `1`).Val())
require.NoError(t, rdb.Do(ctx, "JSON.SET", "a1", "$", `{"arr":[[1],[2],[3]]}`).Err())
require.Equal(t, []interface{}{int64(0), int64(-1), int64(-1)}, rdb.Do(ctx, arrIndexCmd, "a1", "$.arr.*", `1`).Val())
})
}

0 comments on commit 319de1e

Please sign in to comment.