diff --git a/store/ristretto/ristretto.go b/store/ristretto/ristretto.go index 146ae8db..a24d73a9 100644 --- a/store/ristretto/ristretto.go +++ b/store/ristretto/ristretto.go @@ -20,6 +20,7 @@ const ( // RistrettoClientInterface represents a dgraph-io/ristretto client type RistrettoClientInterface interface { Get(key any) (any, bool) + GetTTL(key any) (time.Duration, bool) SetWithTTL(key, value any, cost int64, ttl time.Duration) bool Del(key any) Clear() @@ -55,7 +56,11 @@ func (s *RistrettoStore) Get(_ context.Context, key any) (any, error) { // GetWithTTL returns data stored from a given key and its corresponding TTL func (s *RistrettoStore) GetWithTTL(ctx context.Context, key any) (any, time.Duration, error) { value, err := s.Get(ctx, key) - return value, 0, err + if err != nil { + return value, 0, err + } + ttl, _ := s.client.GetTTL(key) + return value, ttl, nil } // Set defines data in Ristretto memory cache for given key identifier diff --git a/store/ristretto/ristretto_mock.go b/store/ristretto/ristretto_mock.go index 1c96d9b8..7ab5e6d4 100644 --- a/store/ristretto/ristretto_mock.go +++ b/store/ristretto/ristretto_mock.go @@ -1,5 +1,10 @@ // Code generated by MockGen. DO NOT EDIT. // Source: store/ristretto/ristretto.go +// +// Generated by this command: +// +// mockgen -source=store/ristretto/ristretto.go -destination=store/ristretto/ristretto_mock.go -package=ristretto +// // Package ristretto is a generated GoMock package. package ristretto @@ -15,6 +20,7 @@ import ( type MockRistrettoClientInterface struct { ctrl *gomock.Controller recorder *MockRistrettoClientInterfaceMockRecorder + isgomock struct{} } // MockRistrettoClientInterfaceMockRecorder is the mock recorder for MockRistrettoClientInterface. @@ -53,7 +59,7 @@ func (m *MockRistrettoClientInterface) Del(key any) { } // Del indicates an expected call of Del. -func (mr *MockRistrettoClientInterfaceMockRecorder) Del(key interface{}) *gomock.Call { +func (mr *MockRistrettoClientInterfaceMockRecorder) Del(key any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockRistrettoClientInterface)(nil).Del), key) } @@ -68,11 +74,26 @@ func (m *MockRistrettoClientInterface) Get(key any) (any, bool) { } // Get indicates an expected call of Get. -func (mr *MockRistrettoClientInterfaceMockRecorder) Get(key interface{}) *gomock.Call { +func (mr *MockRistrettoClientInterfaceMockRecorder) Get(key any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockRistrettoClientInterface)(nil).Get), key) } +// GetTTL mocks base method. +func (m *MockRistrettoClientInterface) GetTTL(key any) (time.Duration, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetTTL", key) + ret0, _ := ret[0].(time.Duration) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetTTL indicates an expected call of GetTTL. +func (mr *MockRistrettoClientInterfaceMockRecorder) GetTTL(key any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTTL", reflect.TypeOf((*MockRistrettoClientInterface)(nil).GetTTL), key) +} + // SetWithTTL mocks base method. func (m *MockRistrettoClientInterface) SetWithTTL(key, value any, cost int64, ttl time.Duration) bool { m.ctrl.T.Helper() @@ -82,7 +103,7 @@ func (m *MockRistrettoClientInterface) SetWithTTL(key, value any, cost int64, tt } // SetWithTTL indicates an expected call of SetWithTTL. -func (mr *MockRistrettoClientInterfaceMockRecorder) SetWithTTL(key, value, cost, ttl interface{}) *gomock.Call { +func (mr *MockRistrettoClientInterfaceMockRecorder) SetWithTTL(key, value, cost, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWithTTL", reflect.TypeOf((*MockRistrettoClientInterface)(nil).SetWithTTL), key, value, cost, ttl) } diff --git a/store/ristretto/ristretto_test.go b/store/ristretto/ristretto_test.go index 5b9eaa00..8c7bc530 100644 --- a/store/ristretto/ristretto_test.go +++ b/store/ristretto/ristretto_test.go @@ -80,6 +80,7 @@ func TestRistrettoGetWithTTL(t *testing.T) { client := NewMockRistrettoClientInterface(ctrl) client.EXPECT().Get(cacheKey).Return(cacheValue, true) + client.EXPECT().GetTTL(cacheKey).Return(time.Minute, true) store := NewRistretto(client) @@ -89,7 +90,7 @@ func TestRistrettoGetWithTTL(t *testing.T) { // Then assert.Nil(t, err) assert.Equal(t, cacheValue, value) - assert.Equal(t, 0*time.Second, ttl) + assert.Equal(t, time.Minute, ttl) } func TestRistrettoGetWithTTLWhenError(t *testing.T) {