diff --git a/command.go b/command.go index 5fa347f43..56b225721 100644 --- a/command.go +++ b/command.go @@ -5620,3 +5620,59 @@ func (cmd *MonitorCmd) Stop() { defer cmd.mu.Unlock() cmd.status = monitorStatusStop } + +type VectorScoreSliceCmd struct { + baseCmd + + val []VectorScore +} + +var _ Cmder = (*VectorScoreSliceCmd)(nil) + +func NewVectorInfoSliceCmd(ctx context.Context, args ...any) *VectorScoreSliceCmd { + return &VectorScoreSliceCmd{ + baseCmd: baseCmd{ + ctx: ctx, + args: args, + }, + } +} + +func (cmd *VectorScoreSliceCmd) SetVal(val []VectorScore) { + cmd.val = val +} + +func (cmd *VectorScoreSliceCmd) Val() []VectorScore { + return cmd.val +} + +func (cmd *VectorScoreSliceCmd) Result() ([]VectorScore, error) { + return cmd.val, cmd.err +} + +func (cmd *VectorScoreSliceCmd) String() string { + return cmdString(cmd, cmd.val) +} + +func (cmd *VectorScoreSliceCmd) readReply(rd *proto.Reader) error { + n, err := rd.ReadMapLen() + if err != nil { + return err + } + + cmd.val = make([]VectorScore, n) + for i := 0; i < n; i++ { + name, err := rd.ReadString() + if err != nil { + return err + } + cmd.val[i].Name = name + + score, err := rd.ReadFloat() + if err != nil { + return err + } + cmd.val[i].Score = score + } + return nil +} diff --git a/commands.go b/commands.go index 271323242..c0358001d 100644 --- a/commands.go +++ b/commands.go @@ -234,6 +234,7 @@ type Cmdable interface { StreamCmdable TimeseriesCmdable JSONCmdable + VectorSetCmdable } type StatefulCmdable interface { diff --git a/unit_test.go b/unit_test.go new file mode 100644 index 000000000..e4d0e7b57 --- /dev/null +++ b/unit_test.go @@ -0,0 +1,26 @@ +package redis + +import ( + "context" +) + +// mockCmdable is a mock implementation of cmdable that records the last command. +// This is used for unit testing command construction without requiring a Redis server. +type mockCmdable struct { + lastCmd Cmder + returnErr error +} + +func (m *mockCmdable) call(ctx context.Context, cmd Cmder) error { + m.lastCmd = cmd + if m.returnErr != nil { + cmd.SetErr(m.returnErr) + } + return m.returnErr +} + +func (m *mockCmdable) asCmdable() cmdable { + return func(ctx context.Context, cmd Cmder) error { + return m.call(ctx, cmd) + } +} diff --git a/vectorset_commands.go b/vectorset_commands.go new file mode 100644 index 000000000..2bd9e2216 --- /dev/null +++ b/vectorset_commands.go @@ -0,0 +1,348 @@ +package redis + +import ( + "context" + "encoding/json" + "strconv" +) + +// note: the APIs is experimental and may be subject to change. +type VectorSetCmdable interface { + VAdd(ctx context.Context, key, element string, val Vector) *BoolCmd + VAddWithArgs(ctx context.Context, key, element string, val Vector, addArgs *VAddArgs) *BoolCmd + VCard(ctx context.Context, key string) *IntCmd + VDim(ctx context.Context, key string) *IntCmd + VEmb(ctx context.Context, key, element string, raw bool) *SliceCmd + VGetAttr(ctx context.Context, key, element string) *StringCmd + VInfo(ctx context.Context, key string) *MapStringInterfaceCmd + VLinks(ctx context.Context, key, element string) *StringSliceCmd + VLinksWithScores(ctx context.Context, key, element string) *VectorScoreSliceCmd + VRandMember(ctx context.Context, key string) *StringCmd + VRandMemberCount(ctx context.Context, key string, count int) *StringSliceCmd + VRem(ctx context.Context, key, element string) *BoolCmd + VSetAttr(ctx context.Context, key, element string, attr interface{}) *BoolCmd + VClearAttributes(ctx context.Context, key, element string) *BoolCmd + VSim(ctx context.Context, key string, val Vector) *StringSliceCmd + VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd + VSimWithArgs(ctx context.Context, key string, val Vector, args *VSimArgs) *StringSliceCmd + VSimWithArgsWithScores(ctx context.Context, key string, val Vector, args *VSimArgs) *VectorScoreSliceCmd +} + +type Vector interface { + Value() []any +} + +const ( + vectorFormatFP32 string = "FP32" + vectorFormatValues string = "Values" +) + +type VectorFP32 struct { + Val []byte +} + +func (v *VectorFP32) Value() []any { + return []any{vectorFormatFP32, v.Val} +} + +var _ Vector = (*VectorFP32)(nil) + +type VectorValues struct { + Val []float64 +} + +func (v *VectorValues) Value() []any { + res := make([]any, 2+len(v.Val)) + res[0] = vectorFormatValues + res[1] = len(v.Val) + for i, v := range v.Val { + res[2+i] = v + } + return res +} + +var _ Vector = (*VectorValues)(nil) + +type VectorRef struct { + Name string // the name of the referent vector +} + +func (v *VectorRef) Value() []any { + return []any{"ele", v.Name} +} + +var _ Vector = (*VectorRef)(nil) + +type VectorScore struct { + Name string + Score float64 +} + +// `VADD key (FP32 | VALUES num) vector element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VAdd(ctx context.Context, key, element string, val Vector) *BoolCmd { + return c.VAddWithArgs(ctx, key, element, val, &VAddArgs{}) +} + +type VAddArgs struct { + // the REDUCE option must be passed immediately after the key + Reduce int64 + Cas bool + + // The NoQuant, Q8 and Bin options are mutually exclusive. + NoQuant bool + Q8 bool + Bin bool + + EF int64 + SetAttr string + M int64 +} + +func (v VAddArgs) reduce() int64 { + return v.Reduce +} + +func (v VAddArgs) appendArgs(args []any) []any { + if v.Cas { + args = append(args, "cas") + } + + if v.NoQuant { + args = append(args, "noquant") + } else if v.Q8 { + args = append(args, "q8") + } else if v.Bin { + args = append(args, "bin") + } + + if v.EF > 0 { + args = append(args, "ef", strconv.FormatInt(v.EF, 10)) + } + if len(v.SetAttr) > 0 { + args = append(args, "setattr", v.SetAttr) + } + if v.M > 0 { + args = append(args, "m", strconv.FormatInt(v.M, 10)) + } + return args +} + +// `VADD key [REDUCE dim] (FP32 | VALUES num) vector element [CAS] [NOQUANT | Q8 | BIN] [EF build-exploration-factor] [SETATTR attributes] [M numlinks]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VAddWithArgs(ctx context.Context, key, element string, val Vector, addArgs *VAddArgs) *BoolCmd { + if addArgs == nil { + addArgs = &VAddArgs{} + } + args := []any{"vadd", key} + if addArgs.reduce() > 0 { + args = append(args, "reduce", addArgs.reduce()) + } + args = append(args, val.Value()...) + args = append(args, element) + args = addArgs.appendArgs(args) + cmd := NewBoolCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VCARD key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VCard(ctx context.Context, key string) *IntCmd { + cmd := NewIntCmd(ctx, "vcard", key) + _ = c(ctx, cmd) + return cmd +} + +// `VDIM key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VDim(ctx context.Context, key string) *IntCmd { + cmd := NewIntCmd(ctx, "vdim", key) + _ = c(ctx, cmd) + return cmd +} + +// `VEMB key element [RAW]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VEmb(ctx context.Context, key, element string, raw bool) *SliceCmd { + args := []any{"vemb", key, element} + if raw { + args = append(args, "raw") + } + cmd := NewSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VGETATTR key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VGetAttr(ctx context.Context, key, element string) *StringCmd { + cmd := NewStringCmd(ctx, "vgetattr", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VINFO key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VInfo(ctx context.Context, key string) *MapStringInterfaceCmd { + cmd := NewMapStringInterfaceCmd(ctx, "vinfo", key) + _ = c(ctx, cmd) + return cmd +} + +// `VLINKS key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VLinks(ctx context.Context, key, element string) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "vlinks", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VLINKS key element WITHSCORES` +// note: the API is experimental and may be subject to change. +func (c cmdable) VLinksWithScores(ctx context.Context, key, element string) *VectorScoreSliceCmd { + cmd := NewVectorInfoSliceCmd(ctx, "vlinks", key, element, "withscores") + _ = c(ctx, cmd) + return cmd +} + +// `VRANDMEMBER key` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRandMember(ctx context.Context, key string) *StringCmd { + cmd := NewStringCmd(ctx, "vrandmember", key) + _ = c(ctx, cmd) + return cmd +} + +// `VRANDMEMBER key [count]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRandMemberCount(ctx context.Context, key string, count int) *StringSliceCmd { + cmd := NewStringSliceCmd(ctx, "vrandmember", key, count) + _ = c(ctx, cmd) + return cmd +} + +// `VREM key element` +// note: the API is experimental and may be subject to change. +func (c cmdable) VRem(ctx context.Context, key, element string) *BoolCmd { + cmd := NewBoolCmd(ctx, "vrem", key, element) + _ = c(ctx, cmd) + return cmd +} + +// `VSETATTR key element "{ JSON obj }"` +// The `attr` must be something that can be marshaled to JSON (using encoding/JSON) unless +// the argument is a string or []byte when we assume that it can be passed directly as JSON. +// +// note: the API is experimental and may be subject to change. +func (c cmdable) VSetAttr(ctx context.Context, key, element string, attr interface{}) *BoolCmd { + var attrStr string + var err error + switch v := attr.(type) { + case string: + attrStr = v + case []byte: + attrStr = string(v) + default: + var bytes []byte + bytes, err = json.Marshal(v) + if err != nil { + // If marshalling fails, create the command and set the error; this command won't be executed. + cmd := NewBoolCmd(ctx, "vsetattr", key, element, "") + cmd.SetErr(err) + return cmd + } + attrStr = string(bytes) + } + cmd := NewBoolCmd(ctx, "vsetattr", key, element, attrStr) + _ = c(ctx, cmd) + return cmd +} + +// `VClearAttributes` clear attributes on a vector set element. +// The implementation of `VClearAttributes` is execute command `VSETATTR key element ""`. +// note: the API is experimental and may be subject to change. +func (c cmdable) VClearAttributes(ctx context.Context, key, element string) *BoolCmd { + cmd := NewBoolCmd(ctx, "vsetattr", key, element, "") + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element)` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSim(ctx context.Context, key string, val Vector) *StringSliceCmd { + return c.VSimWithArgs(ctx, key, val, &VSimArgs{}) +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) WITHSCORES` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithScores(ctx context.Context, key string, val Vector) *VectorScoreSliceCmd { + return c.VSimWithArgsWithScores(ctx, key, val, &VSimArgs{}) +} + +type VSimArgs struct { + Count int64 + EF int64 + Filter string + FilterEF int64 + Truth bool + NoThread bool + // The `VSim` command in Redis has the option, by the doc in Redis.io don't have. + // Epsilon float64 +} + +func (v VSimArgs) appendArgs(args []any) []any { + if v.Count > 0 { + args = append(args, "count", v.Count) + } + if v.EF > 0 { + args = append(args, "ef", v.EF) + } + if len(v.Filter) > 0 { + args = append(args, "filter", v.Filter) + } + if v.FilterEF > 0 { + args = append(args, "filter-ef", v.FilterEF) + } + if v.Truth { + args = append(args, "truth") + } + if v.NoThread { + args = append(args, "nothread") + } + // if v.Epsilon > 0 { + // args = append(args, "Epsilon", v.Epsilon) + // } + return args +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [COUNT num] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgs(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *StringSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = simArgs.appendArgs(args) + cmd := NewStringSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} + +// `VSIM key (ELE | FP32 | VALUES num) (vector | element) [WITHSCORES] [COUNT num] +// [EF search-exploration-factor] [FILTER expression] [FILTER-EF max-filtering-effort] [TRUTH] [NOTHREAD]` +// note: the API is experimental and may be subject to change. +func (c cmdable) VSimWithArgsWithScores(ctx context.Context, key string, val Vector, simArgs *VSimArgs) *VectorScoreSliceCmd { + if simArgs == nil { + simArgs = &VSimArgs{} + } + args := []any{"vsim", key} + args = append(args, val.Value()...) + args = append(args, "withscores") + args = simArgs.appendArgs(args) + cmd := NewVectorInfoSliceCmd(ctx, args...) + _ = c(ctx, cmd) + return cmd +} diff --git a/vectorset_commands_integration_test.go b/vectorset_commands_integration_test.go new file mode 100644 index 000000000..147fb84c5 --- /dev/null +++ b/vectorset_commands_integration_test.go @@ -0,0 +1,326 @@ +package redis_test + +import ( + "context" + "fmt" + "math/rand" + "time" + + . "github.com/bsm/ginkgo/v2" + . "github.com/bsm/gomega" + "github.com/redis/go-redis/v9" + "github.com/redis/go-redis/v9/internal/proto" +) + +func expectNil(err error) { + Expect(err).NotTo(HaveOccurred()) +} + +func expectTrue(t bool) { + expectEqual(t, true) +} + +func expectEqual[T any, U any](a T, b U) { + Expect(a).To(BeEquivalentTo(b)) +} + +func generateRandomVector(dim int) redis.VectorValues { + rand.Seed(time.Now().UnixNano()) + v := make([]float64, dim) + for i := range v { + v[i] = float64(rand.Intn(1000)) + rand.Float64() + } + return redis.VectorValues{Val: v} +} + +var _ = Describe("Redis VectorSet commands", Label("vectorset"), func() { + ctx := context.TODO() + + setupRedisClient := func(protocolVersion int) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: "localhost:6379", + DB: 0, + Protocol: protocolVersion, + UnstableResp3: true, + }) + } + + protocols := []int{2, 3} + for _, protocol := range protocols { + protocol := protocol + + Context(fmt.Sprintf("with protocol version %d", protocol), func() { + var client *redis.Client + + BeforeEach(func() { + client = setupRedisClient(protocol) + Expect(client.FlushAll(ctx).Err()).NotTo(HaveOccurred()) + }) + + AfterEach(func() { + if client != nil { + client.FlushDB(ctx) + client.Close() + } + }) + + It("basic", func() { + SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet") + vecName := "basic" + val := &redis.VectorValues{ + Val: []float64{1.5, 2.4, 3.3, 4.2}, + } + ok, err := client.VAdd(ctx, vecName, "k1", val).Result() + expectNil(err) + expectTrue(ok) + + fp32 := "\x8f\xc2\xf9\x3e\xcb\xbe\xe9\xbe\xb0\x1e\xca\x3f\x5e\x06\x9e\x3f" + val2 := &redis.VectorFP32{ + Val: []byte(fp32), + } + ok, err = client.VAdd(ctx, vecName, "k2", val2).Result() + expectNil(err) + expectTrue(ok) + + dim, err := client.VDim(ctx, vecName).Result() + expectNil(err) + expectEqual(dim, 4) + + count, err := client.VCard(ctx, vecName).Result() + expectNil(err) + expectEqual(count, 2) + + ok, err = client.VRem(ctx, vecName, "k1").Result() + expectNil(err) + expectTrue(ok) + + count, err = client.VCard(ctx, vecName).Result() + expectNil(err) + expectEqual(count, 1) + }) + + It("basic similarity", func() { + SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet") + vecName := "basic_similarity" + + ok, err := client.VAdd(ctx, vecName, "k1", &redis.VectorValues{ + Val: []float64{1, 0, 0, 0}, + }).Result() + expectNil(err) + expectTrue(ok) + ok, err = client.VAdd(ctx, vecName, "k2", &redis.VectorValues{ + Val: []float64{0.99, 0.01, 0, 0}, + }).Result() + expectNil(err) + expectTrue(ok) + ok, err = client.VAdd(ctx, vecName, "k3", &redis.VectorValues{ + Val: []float64{0.1, 1, -1, 0.5}, + }).Result() + expectNil(err) + expectTrue(ok) + + sim, err := client.VSimWithScores(ctx, vecName, &redis.VectorValues{ + Val: []float64{1, 0, 0, 0}, + }).Result() + expectNil(err) + expectEqual(len(sim), 3) + simMap := make(map[string]float64) + for _, vi := range sim { + simMap[vi.Name] = vi.Score + } + expectTrue(simMap["k1"] > 0.99) + expectTrue(simMap["k2"] > 0.99) + expectTrue(simMap["k3"] < 0.8) + }) + + It("dimension operation", func() { + SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet") + vecName := "dimension_op" + originalDim := 100 + reducedDim := 50 + + v1 := generateRandomVector(originalDim) + ok, err := client.VAddWithArgs(ctx, vecName, "k1", &v1, &redis.VAddArgs{ + Reduce: int64(reducedDim), + }).Result() + expectNil(err) + expectTrue(ok) + + info, err := client.VInfo(ctx, vecName).Result() + expectNil(err) + dim := info["vector-dim"].(int64) + oriDim := info["projection-input-dim"].(int64) + expectEqual(dim, reducedDim) + expectEqual(oriDim, originalDim) + + wrongDim := 80 + wrongV := generateRandomVector(wrongDim) + _, err = client.VAddWithArgs(ctx, vecName, "kw", &wrongV, &redis.VAddArgs{ + Reduce: int64(reducedDim), + }).Result() + expectTrue(err != nil) + + v2 := generateRandomVector(originalDim) + ok, err = client.VAddWithArgs(ctx, vecName, "k2", &v2, &redis.VAddArgs{ + Reduce: int64(reducedDim), + }).Result() + expectNil(err) + expectTrue(ok) + }) + + It("remove", func() { + SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet") + vecName := "remove" + v1 := generateRandomVector(5) + ok, err := client.VAdd(ctx, vecName, "k1", &v1).Result() + expectNil(err) + expectTrue(ok) + + exist, err := client.Exists(ctx, vecName).Result() + expectNil(err) + expectEqual(exist, 1) + + ok, err = client.VRem(ctx, vecName, "k1").Result() + expectNil(err) + expectTrue(ok) + + exist, err = client.Exists(ctx, vecName).Result() + expectNil(err) + expectEqual(exist, 0) + }) + + It("all operations", func() { + SkipBeforeRedisVersion(8.0, "Redis 8.0 introduces support for VectorSet") + vecName := "commands" + vals := []struct { + name string + v redis.VectorValues + attr string + }{ + { + name: "k0", + v: redis.VectorValues{Val: []float64{1, 0, 0, 0}}, + attr: `{"age": 25, "name": "Alice", "active": true, "scores": [85, 90, 95], "city": "New York"}`, + }, + { + name: "k1", + v: redis.VectorValues{Val: []float64{0, 1, 0, 0}}, + attr: `{"age": 30, "name": "Bob", "active": false, "scores": [70, 75, 80], "city": "Boston"}`, + }, + { + name: "k2", + v: redis.VectorValues{Val: []float64{0, 0, 1, 0}}, + attr: `{"age": 35, "name": "Charlie", "scores": [60, 65, 70], "city": "Seattle"}`, + }, + { + name: "k3", + v: redis.VectorValues{Val: []float64{0, 0, 0, 1}}, + }, + { + name: "k4", + v: redis.VectorValues{Val: []float64{0.5, 0.5, 0, 0}}, + attr: `invalid json`, + }, + } + + // If the key doesn't exist, return null error + _, err := client.VRandMember(ctx, vecName).Result() + expectEqual(err.Error(), proto.Nil.Error()) + + // If the key doesn't exist, return an empty array + res, err := client.VRandMemberCount(ctx, vecName, 3).Result() + expectNil(err) + expectEqual(len(res), 0) + + for _, v := range vals { + ok, err := client.VAdd(ctx, vecName, v.name, &v.v).Result() + expectNil(err) + expectTrue(ok) + if len(v.attr) > 0 { + ok, err = client.VSetAttr(ctx, vecName, v.name, v.attr).Result() + expectNil(err) + expectTrue(ok) + } + } + + // VGetAttr + attr, err := client.VGetAttr(ctx, vecName, vals[1].name).Result() + expectNil(err) + expectEqual(attr, vals[1].attr) + + // VRandMember + _, err = client.VRandMember(ctx, vecName).Result() + expectNil(err) + + res, err = client.VRandMemberCount(ctx, vecName, 3).Result() + expectNil(err) + expectEqual(len(res), 3) + + res, err = client.VRandMemberCount(ctx, vecName, 10).Result() + expectNil(err) + expectEqual(len(res), len(vals)) + + // test equality + sim, err := client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.age == 25`, + }).Result() + expectNil(err) + expectEqual(len(sim), 1) + expectEqual(sim[0], vals[0].name) + + // test greater than + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.age > 25`, + }).Result() + expectNil(err) + expectEqual(len(sim), 2) + + // test less than or equal + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.age <= 30`, + }).Result() + expectNil(err) + expectEqual(len(sim), 2) + + // test string equality + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.name == "Alice"`, + }).Result() + expectNil(err) + expectEqual(len(sim), 1) + expectEqual(sim[0], vals[0].name) + + // test string inequality + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.name != "Alice"`, + }).Result() + expectNil(err) + expectEqual(len(sim), 2) + + // test bool + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.active`, + }).Result() + expectNil(err) + expectEqual(len(sim), 1) + expectEqual(sim[0], vals[0].name) + + // test logical add + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.age > 20 and .age < 30`, + }).Result() + expectNil(err) + expectEqual(len(sim), 1) + expectEqual(sim[0], vals[0].name) + + // test logical or + sim, err = client.VSimWithArgs(ctx, vecName, &vals[0].v, &redis.VSimArgs{ + Filter: `.age < 30 or .age > 35`, + }).Result() + expectNil(err) + expectEqual(len(sim), 1) + expectEqual(sim[0], vals[0].name) + }) + }) + } +}) diff --git a/vectorset_commands_test.go b/vectorset_commands_test.go new file mode 100644 index 000000000..9dbc8a78f --- /dev/null +++ b/vectorset_commands_test.go @@ -0,0 +1,542 @@ +package redis + +import ( + "context" + "encoding/json" + "reflect" + "testing" +) + +func TestVectorFP32_Value(t *testing.T) { + v := &VectorFP32{Val: []byte{1, 2, 3}} + got := v.Value() + want := []any{"FP32", []byte{1, 2, 3}} + if !reflect.DeepEqual(got, want) { + t.Errorf("VectorFP32.Value() = %v, want %v", got, want) + } +} + +func TestVectorValues_Value(t *testing.T) { + v := &VectorValues{Val: []float64{1.1, 2.2}} + got := v.Value() + want := []any{"Values", 2, 1.1, 2.2} + if !reflect.DeepEqual(got, want) { + t.Errorf("VectorValues.Value() = %v, want %v", got, want) + } +} + +func TestVectorRef_Value(t *testing.T) { + v := &VectorRef{Name: "foo"} + got := v.Value() + want := []any{"ele", "foo"} + if !reflect.DeepEqual(got, want) { + t.Errorf("VectorRef.Value() = %v, want %v", got, want) + } +} + +func TestVAdd(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VAdd(context.Background(), "k", "e", vec) + cmd, ok := m.lastCmd.(*BoolCmd) + if !ok { + t.Fatalf("expected BoolCmd, got %T", m.lastCmd) + } + if cmd.args[0] != "vadd" || cmd.args[1] != "k" || cmd.args[len(cmd.args)-1] != "e" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVAddWithArgs_AllOptions(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VAddArgs{Reduce: 3, Cas: true, NoQuant: true, EF: 5, SetAttr: "attr", M: 2} + c.VAddWithArgs(context.Background(), "k", "e", vec, args) + cmd := m.lastCmd.(*BoolCmd) + found := map[string]bool{} + for _, a := range cmd.args { + if s, ok := a.(string); ok { + found[s] = true + } + } + for _, want := range []string{"reduce", "cas", "noquant", "ef", "setattr", "m"} { + if !found[want] { + t.Errorf("missing arg: %s", want) + } + } +} + +func TestVCard(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VCard(context.Background(), "k") + cmd := m.lastCmd.(*IntCmd) + if cmd.args[0] != "vcard" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVDim(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VDim(context.Background(), "k") + cmd := m.lastCmd.(*IntCmd) + if cmd.args[0] != "vdim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVEmb(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VEmb(context.Background(), "k", "e", true) + cmd := m.lastCmd.(*SliceCmd) + if cmd.args[0] != "vemb" || cmd.args[1] != "k" || cmd.args[2] != "e" || cmd.args[3] != "raw" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVGetAttr(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VGetAttr(context.Background(), "k", "e") + cmd := m.lastCmd.(*StringCmd) + if cmd.args[0] != "vgetattr" || cmd.args[1] != "k" || cmd.args[2] != "e" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVInfo(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VInfo(context.Background(), "k") + cmd := m.lastCmd.(*MapStringInterfaceCmd) + if cmd.args[0] != "vinfo" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVLinks(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VLinks(context.Background(), "k", "e") + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vlinks" || cmd.args[1] != "k" || cmd.args[2] != "e" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVLinksWithScores(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VLinksWithScores(context.Background(), "k", "e") + cmd := m.lastCmd.(*VectorScoreSliceCmd) + if cmd.args[0] != "vlinks" || cmd.args[1] != "k" || cmd.args[2] != "e" || cmd.args[3] != "withscores" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVRandMember(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VRandMember(context.Background(), "k") + cmd := m.lastCmd.(*StringCmd) + if cmd.args[0] != "vrandmember" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVRandMemberCount(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VRandMemberCount(context.Background(), "k", 5) + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vrandmember" || cmd.args[1] != "k" || cmd.args[2] != 5 { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVRem(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VRem(context.Background(), "k", "e") + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vrem" || cmd.args[1] != "k" || cmd.args[2] != "e" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSetAttr_String(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VSetAttr(context.Background(), "k", "e", "foo") + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vsetattr" || cmd.args[1] != "k" || cmd.args[2] != "e" || cmd.args[3] != "foo" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSetAttr_Bytes(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VSetAttr(context.Background(), "k", "e", []byte("bar")) + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[3] != "bar" { + t.Errorf("expected 'bar', got %v", cmd.args[3]) + } +} + +func TestVSetAttr_MarshalStruct(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + val := struct{ A int }{A: 1} + c.VSetAttr(context.Background(), "k", "e", val) + cmd := m.lastCmd.(*BoolCmd) + want, _ := json.Marshal(val) + if cmd.args[3] != string(want) { + t.Errorf("expected marshalled struct, got %v", cmd.args[3]) + } +} + +func TestVSetAttr_MarshalError(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + bad := func() {} + cmd := c.VSetAttr(context.Background(), "k", "e", bad) + if cmd.Err() == nil { + t.Error("expected error for non-marshallable value") + } +} + +func TestVClearAttributes(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VClearAttributes(context.Background(), "k", "e") + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vsetattr" || cmd.args[3] != "" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSim(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VSim(context.Background(), "k", vec) + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSimWithScores(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VSimWithScores(context.Background(), "k", vec) + cmd := m.lastCmd.(*VectorScoreSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" || cmd.args[len(cmd.args)-1] != "withscores" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSimWithArgs_AllOptions(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VSimArgs{Count: 2, EF: 3, Filter: "f", FilterEF: 4, Truth: true, NoThread: true} + c.VSimWithArgs(context.Background(), "k", vec, args) + cmd := m.lastCmd.(*StringSliceCmd) + found := map[string]bool{} + for _, a := range cmd.args { + if s, ok := a.(string); ok { + found[s] = true + } + } + for _, want := range []string{"count", "ef", "filter", "filter-ef", "truth", "nothread"} { + if !found[want] { + t.Errorf("missing arg: %s", want) + } + } +} + +func TestVSimWithArgsWithScores_AllOptions(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VSimArgs{Count: 2, EF: 3, Filter: "f", FilterEF: 4, Truth: true, NoThread: true} + c.VSimWithArgsWithScores(context.Background(), "k", vec, args) + cmd := m.lastCmd.(*VectorScoreSliceCmd) + found := map[string]bool{} + for _, a := range cmd.args { + if s, ok := a.(string); ok { + found[s] = true + } + } + for _, want := range []string{"count", "ef", "filter", "filter-ef", "truth", "nothread", "withscores"} { + if !found[want] { + t.Errorf("missing arg: %s", want) + } + } +} + +// Additional tests for missing coverage + +func TestVectorValues_EmptySlice(t *testing.T) { + v := &VectorValues{Val: []float64{}} + got := v.Value() + want := []any{"Values", 0} + if !reflect.DeepEqual(got, want) { + t.Errorf("VectorValues.Value() with empty slice = %v, want %v", got, want) + } +} + +func TestVEmb_WithoutRaw(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + c.VEmb(context.Background(), "k", "e", false) + cmd := m.lastCmd.(*SliceCmd) + if cmd.args[0] != "vemb" || cmd.args[1] != "k" || cmd.args[2] != "e" { + t.Errorf("unexpected args: %v", cmd.args) + } + if len(cmd.args) != 3 { + t.Errorf("expected 3 args when raw=false, got %d", len(cmd.args)) + } +} + +func TestVAddWithArgs_Q8Option(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VAddArgs{Q8: true} + c.VAddWithArgs(context.Background(), "k", "e", vec, args) + cmd := m.lastCmd.(*BoolCmd) + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "q8" { + found = true + break + } + } + if !found { + t.Error("missing q8 arg") + } +} + +func TestVAddWithArgs_BinOption(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VAddArgs{Bin: true} + c.VAddWithArgs(context.Background(), "k", "e", vec, args) + cmd := m.lastCmd.(*BoolCmd) + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "bin" { + found = true + break + } + } + if !found { + t.Error("missing bin arg") + } +} + +func TestVAddWithArgs_NilArgs(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VAddWithArgs(context.Background(), "k", "e", vec, nil) + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vadd" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSimWithArgs_NilArgs(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VSimWithArgs(context.Background(), "k", vec, nil) + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } +} + +func TestVSimWithArgsWithScores_NilArgs(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VSimWithArgsWithScores(context.Background(), "k", vec, nil) + cmd := m.lastCmd.(*VectorScoreSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } + // Should still have withscores + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "withscores" { + found = true + break + } + } + if !found { + t.Error("missing withscores arg") + } +} + +func TestVAdd_WithVectorFP32(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorFP32{Val: []byte{1, 2, 3, 4}} + c.VAdd(context.Background(), "k", "e", vec) + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vadd" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } + // Check that FP32 format is used + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "FP32" { + found = true + break + } + } + if !found { + t.Error("missing FP32 format in args") + } +} + +func TestVAdd_WithVectorRef(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorRef{Name: "ref-vector"} + c.VAdd(context.Background(), "k", "e", vec) + cmd := m.lastCmd.(*BoolCmd) + if cmd.args[0] != "vadd" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } + // Check that ele format is used + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "ele" { + found = true + break + } + } + if !found { + t.Error("missing ele format in args") + } +} + +func TestVSim_WithVectorFP32(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorFP32{Val: []byte{1, 2, 3, 4}} + c.VSim(context.Background(), "k", vec) + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } + // Check that FP32 format is used + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "FP32" { + found = true + break + } + } + if !found { + t.Error("missing FP32 format in args") + } +} + +func TestVSim_WithVectorRef(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorRef{Name: "ref-vector"} + c.VSim(context.Background(), "k", vec) + cmd := m.lastCmd.(*StringSliceCmd) + if cmd.args[0] != "vsim" || cmd.args[1] != "k" { + t.Errorf("unexpected args: %v", cmd.args) + } + // Check that ele format is used + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == "ele" { + found = true + break + } + } + if !found { + t.Error("missing ele format in args") + } +} + +func TestVAddWithArgs_ReduceOption(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VAddArgs{Reduce: 128} + c.VAddWithArgs(context.Background(), "k", "e", vec, args) + cmd := m.lastCmd.(*BoolCmd) + // Check that reduce appears early in args (after key) + if cmd.args[0] != "vadd" || cmd.args[1] != "k" || cmd.args[2] != "reduce" { + t.Errorf("unexpected args order: %v", cmd.args) + } +} + +func TestVAddWithArgs_ZeroValues(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + args := &VAddArgs{Reduce: 0, EF: 0, M: 0} // Zero values should not appear in args + c.VAddWithArgs(context.Background(), "k", "e", vec, args) + cmd := m.lastCmd.(*BoolCmd) + // Check that zero values don't appear + for _, a := range cmd.args { + if s, ok := a.(string); ok { + if s == "reduce" || s == "ef" || s == "m" { + t.Errorf("zero value option should not appear in args: %s", s) + } + } + } +} + +func TestVSimArgs_IndividualOptions(t *testing.T) { + tests := []struct { + name string + args *VSimArgs + want string + }{ + {"Count", &VSimArgs{Count: 5}, "count"}, + {"EF", &VSimArgs{EF: 10}, "ef"}, + {"Filter", &VSimArgs{Filter: "test"}, "filter"}, + {"FilterEF", &VSimArgs{FilterEF: 15}, "filter-ef"}, + {"Truth", &VSimArgs{Truth: true}, "truth"}, + {"NoThread", &VSimArgs{NoThread: true}, "nothread"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &mockCmdable{} + c := m.asCmdable() + vec := &VectorValues{Val: []float64{1, 2}} + c.VSimWithArgs(context.Background(), "k", vec, tt.args) + cmd := m.lastCmd.(*StringSliceCmd) + found := false + for _, a := range cmd.args { + if s, ok := a.(string); ok && s == tt.want { + found = true + break + } + } + if !found { + t.Errorf("missing arg: %s", tt.want) + } + }) + } +}