Skip to content

Commit

Permalink
feat(search): add vector type to kqir::Value (apache#2371)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored Jun 19, 2024
1 parent 47e6705 commit 7d48490
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 38 deletions.
105 changes: 69 additions & 36 deletions src/search/indexer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,73 @@ StatusOr<FieldValueRetriever> FieldValueRetriever::Create(IndexOnDataType type,
}
}

// placeholders, remove them after vector indexing is implemented
static bool IsVectorType(const redis::IndexFieldMetadata *) { return false; }
static size_t GetVectorDim(const redis::IndexFieldMetadata *) { return 1; }

StatusOr<kqir::Value> FieldValueRetriever::ParseFromJson(const jsoncons::json &val,
const redis::IndexFieldMetadata *type) {
if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
if (!val.is_number() || val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"};
return kqir::MakeValue<kqir::Numeric>(val.as_double());
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
if (val.is_string()) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(val.as_string(), delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (val.is_array()) {
std::vector<std::string> strs;
for (size_t i = 0; i < val.size(); ++i) {
if (!val[i].is_string())
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
strs.push_back(val[i].as_string());
}
return kqir::MakeValue<kqir::StringArray>(strs);
} else {
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
}
} else if (IsVectorType(type)) {
size_t dim = GetVectorDim(type);
if (!val.is_array()) return {Status::NotOK, "json value should be array of numbers for vector fields"};
if (dim != val.size()) return {Status::NotOK, "the size of the json array is not equal to the dim of the vector"};
std::vector<double> nums;
for (size_t i = 0; i < dim; ++i) {
if (!val[i].is_number() || val[i].is_string())
return {Status::NotOK, "json value should be array of numbers for vector fields"};
nums.push_back(val[i].as_double());
}
return kqir::MakeValue<kqir::NumericArray>(nums);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}
}

StatusOr<kqir::Value> FieldValueRetriever::ParseFromHash(const std::string &value,
const redis::IndexFieldMetadata *type) {
if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
auto num = GET_OR_RET(ParseFloat(value));
return kqir::MakeValue<kqir::Numeric>(num);
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(value, delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (IsVectorType(type)) {
const size_t dim = GetVectorDim(type);
if (value.size() != dim * sizeof(double)) {
return {Status::NotOK, "field value is too short or too long to be parsed as a vector"};
}
std::vector<double> vec;
for (size_t i = 0; i < dim; ++i) {
// TODO: care about endian later
// TODO: currently only support 64bit floating point
vec.push_back(*(reinterpret_cast<const double *>(value.data()) + i));
}
return kqir::MakeValue<kqir::NumericArray>(vec);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}
}

StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, const redis::IndexFieldMetadata *type) {
if (std::holds_alternative<HashData>(db)) {
auto &[hash, metadata, key] = std::get<HashData>(db);
Expand All @@ -71,17 +138,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons
if (s.IsNotFound()) return {Status::NotFound, s.ToString()};
if (!s.ok()) return {Status::NotOK, s.ToString()};

if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
auto num = GET_OR_RET(ParseFloat(value));
return kqir::MakeValue<kqir::Numeric>(num);
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(value, delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}

return ParseFromHash(value, type);
} else if (std::holds_alternative<JsonData>(db)) {
auto &value = std::get<JsonData>(db);

Expand All @@ -91,31 +148,7 @@ StatusOr<kqir::Value> FieldValueRetriever::Retrieve(std::string_view field, cons
return {Status::NotFound, "json value specified by the field (json path) should exist and be unique"};
auto val = s->value[0];

if (auto numeric [[maybe_unused]] = dynamic_cast<const redis::NumericFieldMetadata *>(type)) {
if (val.is_string()) return {Status::NotOK, "json value cannot be string for numeric fields"};
return kqir::MakeValue<kqir::Numeric>(val.as_double());
} else if (auto tag = dynamic_cast<const redis::TagFieldMetadata *>(type)) {
if (val.is_string()) {
const char delim[] = {tag->separator, '\0'};
auto vec = util::Split(val.as_string(), delim);
return kqir::MakeValue<kqir::StringArray>(vec);
} else if (val.is_array()) {
std::vector<std::string> strs;
for (size_t i = 0; i < val.size(); ++i) {
if (!val[i].is_string())
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
strs.push_back(val[i].as_string());
}
return kqir::MakeValue<kqir::StringArray>(strs);
} else {
return {Status::NotOK, "json value should be string or array of strings for tag fields"};
}
} else {
return {Status::NotOK, "unknown field type to retrieve"};
}

return Status::OK();

return ParseFromJson(val, type);
} else {
return {Status::NotOK, "unknown redis data type to retrieve"};
}
Expand Down
3 changes: 3 additions & 0 deletions src/search/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct FieldValueRetriever {
explicit FieldValueRetriever(JsonValue json) : db(std::in_place_type<JsonData>, std::move(json)) {}

StatusOr<kqir::Value> Retrieve(std::string_view field, const redis::IndexFieldMetadata *type);

static StatusOr<kqir::Value> ParseFromJson(const jsoncons::json &value, const redis::IndexFieldMetadata *type);
static StatusOr<kqir::Value> ParseFromHash(const std::string &value, const redis::IndexFieldMetadata *type);
};

struct IndexUpdater {
Expand Down
9 changes: 7 additions & 2 deletions src/search/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ using String = std::string; // e.g. a single tag
using NumericArray = std::vector<Numeric>; // used for vector fields
using StringArray = std::vector<String>; // used for tag fields, e.g. a list for tags

struct Value : std::variant<Null, Numeric, StringArray> {
using Base = std::variant<Null, Numeric, StringArray>;
struct Value : std::variant<Null, Numeric, StringArray, NumericArray> {
using Base = std::variant<Null, Numeric, StringArray, NumericArray>;

using Base::Base;

Expand Down Expand Up @@ -72,6 +72,9 @@ struct Value : std::variant<Null, Numeric, StringArray> {
} else if (Is<StringArray>()) {
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, sep);
} else if (Is<NumericArray>()) {
return util::StringJoin(
Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); }, sep);
}

__builtin_unreachable();
Expand All @@ -87,6 +90,8 @@ struct Value : std::variant<Null, Numeric, StringArray> {
char sep = tag ? tag->separator : ',';
return util::StringJoin(
Get<StringArray>(), [](const auto &v) -> decltype(auto) { return v; }, std::string(1, sep));
} else if (Is<NumericArray>()) {
return util::StringJoin(Get<NumericArray>(), [](const auto &v) -> decltype(auto) { return std::to_string(v); });
}

__builtin_unreachable();
Expand Down

0 comments on commit 7d48490

Please sign in to comment.