Skip to content
44 changes: 44 additions & 0 deletions state/redis/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const (
infoReplicationDelimiter = "\r\n"
maxRetries = "maxRetries"
maxRetryBackoff = "maxRetryBackoff"
ttlInSeconds = "ttlInSeconds"
defaultBase = 10
defaultBitSize = 0
defaultDB = 0
Expand Down Expand Up @@ -230,6 +231,10 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
if err != nil {
return err
}
ttl, err := r.parseTTL(req)
if err != nil {
return fmt.Errorf("failed to parse ttl from metadata: %s", err)
}

bt, _ := utils.Marshal(req.Value, r.json.Marshal)

Expand All @@ -242,6 +247,20 @@ func (r *StateStore) setValue(req *state.SetRequest) error {
return fmt.Errorf("failed to set key %s: %s", req.Key, err)
}

if ttl != nil && *ttl > 0 {
_, err = r.client.Do(r.ctx, "EXPIRE", req.Key, *ttl).Result()
if err != nil {
return fmt.Errorf("failed to set key %s ttl: %s", req.Key, err)
}
}

if ttl != nil && *ttl <= 0 {
_, err = r.client.Do(r.ctx, "PERSIST", req.Key).Result()
if err != nil {
return fmt.Errorf("failed to persist key %s: %s", req.Key, err)
}
}

if req.Options.Consistency == state.Strong && r.replicas > 0 {
_, err = r.client.Do(r.ctx, "WAIT", r.replicas, 1000).Result()
if err != nil {
Expand All @@ -261,14 +280,25 @@ func (r *StateStore) Set(req *state.SetRequest) error {
func (r *StateStore) Multi(request *state.TransactionalStateRequest) error {
pipe := r.client.TxPipeline()
for _, o := range request.Operations {
//nolint:golint,nestif
if o.Operation == state.Upsert {
req := o.Request.(state.SetRequest)
ver, err := r.parseETag(&req)
if err != nil {
return err
}
ttl, err := r.parseTTL(&req)
if err != nil {
return fmt.Errorf("failed to parse ttl from metadata: %s", err)
}
bt, _ := utils.Marshal(req.Value, r.json.Marshal)
pipe.Do(r.ctx, "EVAL", setQuery, 1, req.Key, ver, bt)
if ttl != nil && *ttl > 0 {
pipe.Do(r.ctx, "EXPIRE", req.Key, *ttl)
}
if ttl != nil && *ttl <= 0 {
pipe.Do(r.ctx, "PERSIST", req.Key)
}
} else if o.Operation == state.Delete {
req := o.Request.(state.DeleteRequest)
if req.ETag == nil {
Expand Down Expand Up @@ -318,6 +348,20 @@ func (r *StateStore) parseETag(req *state.SetRequest) (int, error) {
return ver, nil
}

func (r *StateStore) parseTTL(req *state.SetRequest) (*int, error) {
if val, ok := req.Metadata[ttlInSeconds]; ok && val != "" {
parsedVal, err := strconv.ParseInt(val, defaultBase, defaultBitSize)
if err != nil {
return nil, err
}
ttl := int(parsedVal)

return &ttl, nil
}

return nil, nil
}

func (r *StateStore) Close() error {
r.cancel()

Expand Down
163 changes: 157 additions & 6 deletions state/redis/redis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ package redis

import (
"context"
"strconv"
"testing"
"time"

"github.com/agrea/ptr"
miniredis "github.com/alicebob/miniredis/v2"
Expand Down Expand Up @@ -90,6 +92,51 @@ func TestParseEtag(t *testing.T) {
})
}

func TestParseTTL(t *testing.T) {
store := NewRedisStateStore(logger.NewLogger("test"))
t.Run("TTL Not an integer", func(t *testing.T) {
ttlInSeconds := "not an integer"
ttl, err := store.parseTTL(&state.SetRequest{
Metadata: map[string]string{
"ttlInSeconds": ttlInSeconds,
},
})
assert.Error(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL specified with wrong key", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := store.parseTTL(&state.SetRequest{
Metadata: map[string]string{
"expirationTime": strconv.Itoa(ttlInSeconds),
},
})
assert.NoError(t, err)
assert.Nil(t, ttl)
})
t.Run("TTL is a number", func(t *testing.T) {
ttlInSeconds := 12345
ttl, err := store.parseTTL(&state.SetRequest{
Metadata: map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
},
})
assert.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})

t.Run("TTL never expires", func(t *testing.T) {
ttlInSeconds := -1
ttl, err := store.parseTTL(&state.SetRequest{
Metadata: map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
},
})
assert.NoError(t, err)
assert.Equal(t, *ttl, ttlInSeconds)
})
}

func TestParseConnectedSlavs(t *testing.T) {
store := NewRedisStateStore(logger.NewLogger("test"))

Expand Down Expand Up @@ -126,13 +173,35 @@ func TestTransactionalUpsert(t *testing.T) {
ss.ctx, ss.cancel = context.WithCancel(context.Background())

err := ss.Multi(&state.TransactionalStateRequest{
Operations: []state.TransactionalStateOperation{{
Operation: state.Upsert,
Request: state.SetRequest{
Key: "weapon",
Value: "deathstar",
Operations: []state.TransactionalStateOperation{
{
Operation: state.Upsert,
Request: state.SetRequest{
Key: "weapon",
Value: "deathstar",
},
},
}},
{
Operation: state.Upsert,
Request: state.SetRequest{
Key: "weapon2",
Value: "deathstar2",
Metadata: map[string]string{
"ttlInSeconds": "123",
},
},
},
{
Operation: state.Upsert,
Request: state.SetRequest{
Key: "weapon3",
Value: "deathstar3",
Metadata: map[string]string{
"ttlInSeconds": "-1",
},
},
},
},
})
assert.Equal(t, nil, err)

Expand All @@ -144,6 +213,18 @@ func TestTransactionalUpsert(t *testing.T) {
assert.Equal(t, nil, err)
assert.Equal(t, ptr.String("1"), version)
assert.Equal(t, `"deathstar"`, data)

res, err = c.Do(context.Background(), "TTL", "weapon").Result()
assert.Equal(t, nil, err)
assert.Equal(t, int64(-1), res)

res, err = c.Do(context.Background(), "TTL", "weapon2").Result()
assert.Equal(t, nil, err)
assert.Equal(t, int64(123), res)

res, err = c.Do(context.Background(), "TTL", "weapon3").Result()
assert.Equal(t, nil, err)
assert.Equal(t, int64(-1), res)
}

func TestTransactionalDelete(t *testing.T) {
Expand Down Expand Up @@ -201,6 +282,76 @@ func TestPing(t *testing.T) {
assert.Error(t, err)
}

func TestSetRequestWithTTL(t *testing.T) {
s, c := setupMiniredis()
defer s.Close()

ss := &StateStore{
client: c,
json: jsoniter.ConfigFastest,
logger: logger.NewLogger("test"),
}
ss.ctx, ss.cancel = context.WithCancel(context.Background())

t.Run("TTL specified", func(t *testing.T) {
ttlInSeconds := 100
ss.Set(&state.SetRequest{
Key: "weapon100",
Value: "deathstar100",
Metadata: map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
},
})

ttl, _ := ss.client.TTL(ss.ctx, "weapon100").Result()

assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl)
})

t.Run("TTL not specified", func(t *testing.T) {
ss.Set(&state.SetRequest{
Key: "weapon200",
Value: "deathstar200",
})

ttl, _ := ss.client.TTL(ss.ctx, "weapon200").Result()

assert.Equal(t, time.Duration(-1), ttl)
})

t.Run("TTL Changed for Existing Key", func(t *testing.T) {
ss.Set(&state.SetRequest{
Key: "weapon300",
Value: "deathstar300",
})
ttl, _ := ss.client.TTL(ss.ctx, "weapon300").Result()
assert.Equal(t, time.Duration(-1), ttl)

// make the key no longer persistent
ttlInSeconds := 123
ss.Set(&state.SetRequest{
Key: "weapon300",
Value: "deathstar300",
Metadata: map[string]string{
"ttlInSeconds": strconv.Itoa(ttlInSeconds),
},
})
ttl, _ = ss.client.TTL(ss.ctx, "weapon300").Result()
assert.Equal(t, time.Duration(ttlInSeconds)*time.Second, ttl)

// make the key persistent again
ss.Set(&state.SetRequest{
Key: "weapon300",
Value: "deathstar301",
Metadata: map[string]string{
"ttlInSeconds": strconv.Itoa(-1),
},
})
ttl, _ = ss.client.TTL(ss.ctx, "weapon300").Result()
assert.Equal(t, time.Duration(-1), ttl)
})
}

func TestTransactionalDeleteNoEtag(t *testing.T) {
s, c := setupMiniredis()
defer s.Close()
Expand Down