diff --git a/go/api/base_client.go b/go/api/base_client.go index 59d25c53d5..632625ae67 100644 --- a/go/api/base_client.go +++ b/go/api/base_client.go @@ -7289,7 +7289,12 @@ func (client *baseClient) ZInter(keys options.KeyArray) ([]string, error) { // // Parameters: // +// keysOrWeightedKeys - The keys or weighted keys of the sorted sets, see - [options.KeysOrWeightedKeys]. +// - Use `options.NewKeyArray()` for keys only. +// - Use `options.NewWeightedKeys()` for weighted keys with score multipliers. // options - The options for the ZInter command, see - [options.ZInterOptions]. +// Optional `aggregate` option specifies the aggregation strategy to apply when combining the scores of +// elements. // // Return value: // @@ -7301,11 +7306,16 @@ func (client *baseClient) ZInter(keys options.KeyArray) ([]string, error) { // fmt.Println(res) // map[member1:1.0 member2:2.0 member3:3.0] // // [valkey.io]: https://valkey.io/commands/zinter/ -func (client *baseClient) ZInterWithScores(zInterOptions *options.ZInterOptions) (map[string]float64, error) { - args, err := zInterOptions.ToArgs() +func (client *baseClient) ZInterWithScores( + keysOrWeightedKeys options.KeysOrWeightedKeys, + zInterOptions *options.ZInterOptions, +) (map[string]float64, error) { + args := keysOrWeightedKeys.ToArgs() + optionsArgs, err := zInterOptions.ToArgs() if err != nil { return nil, err } + args = append(args, optionsArgs...) args = append(args, options.WithScores) result, err := client.executeCommand(C.ZInter, args) if err != nil { @@ -7427,3 +7437,84 @@ func (client *baseClient) ZDiffStore(destination string, keys []string) (int64, } return handleIntResponse(result) } + +// Computes the intersection of sorted sets given by the specified `keysOrWeightedKeys` +// and stores the result in `destination`. If `destination` already exists, it is overwritten. +// Otherwise, a new sorted set will be created. +// +// Note: +// +// When in cluster mode, all keys must map to the same hash slot. +// +// See [valkey.io] for details. +// +// Parameters: +// +// destination - The destination key for the result. +// keysOrWeightedKeys - The keys or weighted keys of the sorted sets, see - [options.KeysOrWeightedKeys]. +// - Use `options.NewKeyArray()` for keys only. +// - Use `options.NewWeightedKeys()` for weighted keys with score multipliers. +// +// Return value: +// +// The number of elements in the resulting sorted set stored at destination. +// +// Example: +// +// res, err := client.ZInterStore("destination", options.NewKeyArray("key1", "key2", "key3")) +// fmt.Println(res) // 3 +// +// [valkey.io]: https://valkey.io/commands/zinterstore/ +func (client *baseClient) ZInterStore(destination string, keysOrWeightedKeys options.KeysOrWeightedKeys) (int64, error) { + return client.ZInterStoreWithOptions(destination, keysOrWeightedKeys, nil) +} + +// Computes the intersection of sorted sets given by the specified `keysOrWeightedKeys` +// and stores the result in `destination`. If `destination` already exists, it is overwritten. +// Otherwise, a new sorted set will be created. +// +// Note: +// +// When in cluster mode, all keys must map to the same hash slot. +// +// See [valkey.io] for details. +// +// Parameters: +// +// destination - The destination key for the result. +// keysOrWeightedKeys - The keys or weighted keys of the sorted sets, see - [options.KeysOrWeightedKeys]. +// - Use `options.NewKeyArray()` for keys only. +// - Use `options.NewWeightedKeys()` for weighted keys with score multipliers. +// options - The options for the ZInterStore command, see - [options.ZInterOptions]. +// Optional `aggregate` option specifies the aggregation strategy to apply when combining the scores of +// elements. +// +// Return value: +// +// The number of elements in the resulting sorted set stored at destination. +// +// Example: +// +// res, err := client.ZInterStore("destination", options.NewZInterOptionsBuilder(options.NewKeyArray("key1", "key2", "key3"))) +// fmt.Println(res) // 3 +// +// [valkey.io]: https://valkey.io/commands/zinterstore/ +func (client *baseClient) ZInterStoreWithOptions( + destination string, + keysOrWeightedKeys options.KeysOrWeightedKeys, + zInterOptions *options.ZInterOptions, +) (int64, error) { + args := append([]string{destination}, keysOrWeightedKeys.ToArgs()...) + if zInterOptions != nil { + optionsArgs, err := zInterOptions.ToArgs() + if err != nil { + return defaultIntResponse, err + } + args = append(args, optionsArgs...) + } + result, err := client.executeCommand(C.ZInterStore, args) + if err != nil { + return defaultIntResponse, err + } + return handleIntResponse(result) +} diff --git a/go/api/options/zinter_options.go b/go/api/options/zinter_options.go index c36bf4ef07..0b0728886d 100644 --- a/go/api/options/zinter_options.go +++ b/go/api/options/zinter_options.go @@ -4,12 +4,11 @@ package options // This struct represents the optional arguments for the ZINTER command. type ZInterOptions struct { - keysOrWeightedKeys KeysOrWeightedKeys - aggregate Aggregate + aggregate Aggregate } -func NewZInterOptionsBuilder(keysOrWeightedKeys KeysOrWeightedKeys) *ZInterOptions { - return &ZInterOptions{keysOrWeightedKeys: keysOrWeightedKeys} +func NewZInterOptionsBuilder() *ZInterOptions { + return &ZInterOptions{} } // SetAggregate sets the aggregate method for the ZInter command. @@ -21,10 +20,6 @@ func (options *ZInterOptions) SetAggregate(aggregate Aggregate) *ZInterOptions { func (options *ZInterOptions) ToArgs() ([]string, error) { args := []string{} - if options.keysOrWeightedKeys != nil { - args = append(args, options.keysOrWeightedKeys.ToArgs()...) - } - if options.aggregate != "" { args = append(args, options.aggregate.ToArgs()...) } diff --git a/go/api/sorted_set_commands.go b/go/api/sorted_set_commands.go index c05ab4a8da..3031d3640c 100644 --- a/go/api/sorted_set_commands.go +++ b/go/api/sorted_set_commands.go @@ -83,9 +83,17 @@ type SortedSetCommands interface { ZMScore(key string, members []string) ([]Result[float64], error) + ZDiffStore(destination string, keys []string) (int64, error) + ZInter(keys options.KeyArray) ([]string, error) - ZInterWithScores(options *options.ZInterOptions) (map[string]float64, error) + ZInterWithScores(keysOrWeightedKeys options.KeysOrWeightedKeys, options *options.ZInterOptions) (map[string]float64, error) - ZDiffStore(destination string, keys []string) (int64, error) + ZInterStore(destination string, keysOrWeightedKeys options.KeysOrWeightedKeys) (int64, error) + + ZInterStoreWithOptions( + destination string, + keysOrWeightedKeys options.KeysOrWeightedKeys, + options *options.ZInterOptions, + ) (int64, error) } diff --git a/go/integTest/shared_commands_test.go b/go/integTest/shared_commands_test.go index c035402572..bda0018aba 100644 --- a/go/integTest/shared_commands_test.go +++ b/go/integTest/shared_commands_test.go @@ -7888,56 +7888,61 @@ func (suite *GlideTestSuite) TestZInter() { assert.Equal(suite.T(), []string{"two"}, zinterResult) // intersection with scores - zinterWithScoresResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateSum), + zinterWithScoresResult, err := client.ZInterWithScores(options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), ) assert.NoError(suite.T(), err) assert.Equal(suite.T(), map[string]float64{"two": 5.5}, zinterWithScoresResult) // intersect results with max aggregate zinterWithMaxAggregateResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateMax), + options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateMax), ) assert.NoError(suite.T(), err) assert.Equal(suite.T(), map[string]float64{"two": 3.5}, zinterWithMaxAggregateResult) // intersect results with min aggregate zinterWithMinAggregateResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateMin), + options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateMin), ) assert.NoError(suite.T(), err) assert.Equal(suite.T(), map[string]float64{"two": 2.0}, zinterWithMinAggregateResult) // intersect results with sum aggregate zinterWithSumAggregateResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key2}}).SetAggregate(options.AggregateSum), + options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), ) assert.NoError(suite.T(), err) assert.Equal(suite.T(), map[string]float64{"two": 5.5}, zinterWithSumAggregateResult) // Scores are multiplied by a 2.0 weight for key1 and key2 during aggregation zinterWithWeightedKeysResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder( - options.WeightedKeys{ - KeyWeightPairs: []options.KeyWeightPair{ - {Key: key1, Weight: 2.0}, - {Key: key2, Weight: 2.0}, - }, + options.WeightedKeys{ + KeyWeightPairs: []options.KeyWeightPair{ + {Key: key1, Weight: 2.0}, + {Key: key2, Weight: 2.0}, }, - ).SetAggregate(options.AggregateSum), + }, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), ) assert.NoError(suite.T(), err) assert.Equal(suite.T(), map[string]float64{"two": 11.0}, zinterWithWeightedKeysResult) // non-existent key - empty intersection zinterWithNonExistentKeyResult, err := client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key3}}).SetAggregate(options.AggregateSum), + options.KeyArray{Keys: []string{key1, key3}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), ) assert.NoError(suite.T(), err) assert.Empty(suite.T(), zinterWithNonExistentKeyResult) // empty key list - request error - _, err = client.ZInterWithScores(options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{}})) + _, err = client.ZInterWithScores(options.KeyArray{Keys: []string{}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), + ) assert.NotNil(suite.T(), err) assert.IsType(suite.T(), &errors.RequestError{}, err) @@ -7950,8 +7955,124 @@ func (suite *GlideTestSuite) TestZInter() { assert.IsType(suite.T(), &errors.RequestError{}, err) _, err = client.ZInterWithScores( - options.NewZInterOptionsBuilder(options.KeyArray{Keys: []string{key1, key3}}).SetAggregate(options.AggregateSum), + options.KeyArray{Keys: []string{key1, key3}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), + ) + assert.NotNil(suite.T(), err) + assert.IsType(suite.T(), &errors.RequestError{}, err) + }) +} + +func (suite *GlideTestSuite) TestZInterStore() { + suite.runWithDefaultClients(func(client api.BaseClient) { + key1 := "{key}-" + uuid.New().String() + key2 := "{key}-" + uuid.New().String() + key3 := "{key}-" + uuid.New().String() + key4 := "{key}-" + uuid.New().String() + query := options.NewRangeByIndexQuery(0, -1) + memberScoreMap1 := map[string]float64{ + "one": 1.0, + "two": 2.0, + } + memberScoreMap2 := map[string]float64{ + "one": 1.5, + "two": 2.5, + "three": 3.5, + } + + // Add members to sorted sets + res, err := client.ZAdd(key1, memberScoreMap1) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + res, err = client.ZAdd(key2, memberScoreMap2) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(3), res) + + // Store the intersection of key1 and key2 in key3 + res, err = client.ZInterStore(key3, options.KeyArray{Keys: []string{key1, key2}}) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result + zrangeResult, err := client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": 2.5, "two": 4.5}, zrangeResult) + + // Store the intersection of key1 and key2 in key4 with max aggregate + res, err = client.ZInterStoreWithOptions(key3, options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateMax), ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result with max aggregate + zrangeResult, err = client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": 1.5, "two": 2.5}, zrangeResult) + + // Store the intersection of key1 and key2 in key5 with min aggregate + res, err = client.ZInterStoreWithOptions(key3, options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateMin), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result with min aggregate + zrangeResult, err = client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": 1.0, "two": 2.0}, zrangeResult) + + // Store the intersection of key1 and key2 in key6 with sum aggregate + res, err = client.ZInterStoreWithOptions(key3, options.KeyArray{Keys: []string{key1, key2}}, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateSum), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result with sum aggregate (same as default aggregate) + zrangeResult, err = client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": 2.5, "two": 4.5}, zrangeResult) + + // Store the intersection of key1 and key2 in key3 with 2.0 weights + res, err = client.ZInterStore(key3, options.WeightedKeys{ + KeyWeightPairs: []options.KeyWeightPair{ + {Key: key1, Weight: 2.0}, + {Key: key2, Weight: 2.0}, + }, + }) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result with weighted keys + zrangeResult, err = client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": 5.0, "two": 9.0}, zrangeResult) + + // Store the intersection of key1 with 1.0 weight and key2 with -2.0 weight in key3 with 2.0 weights + // and min aggregate + res, err = client.ZInterStoreWithOptions(key3, options.WeightedKeys{ + KeyWeightPairs: []options.KeyWeightPair{ + {Key: key1, Weight: 1.0}, + {Key: key2, Weight: -2.0}, + }, + }, + options.NewZInterOptionsBuilder().SetAggregate(options.AggregateMin), + ) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), int64(2), res) + + // checking stored intersection result with weighted keys + zrangeResult, err = client.ZRangeWithScores(key3, query) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), map[string]float64{"one": -3.0, "two": -5.0}, zrangeResult) + + // key exists but not a set + _, err = client.Set(key4, "value") + assert.NoError(suite.T(), err) + + _, err = client.ZInterStore(key3, options.KeyArray{Keys: []string{key1, key4}}) assert.NotNil(suite.T(), err) assert.IsType(suite.T(), &errors.RequestError{}, err) })