diff --git a/core/services/llo/observation/cache.go b/core/services/llo/observation/cache.go index 80dc4147484..e44487a66ef 100644 --- a/core/services/llo/observation/cache.go +++ b/core/services/llo/observation/cache.go @@ -94,7 +94,7 @@ func (c *Cache) Add(id llotypes.StreamID, value llo.StreamValue) { c.values[id] = item{value: value, createdAt: time.Now()} } -func (c *Cache) Get(id llotypes.StreamID) (llo.StreamValue, bool) { +func (c *Cache) Get(id llotypes.StreamID) llo.StreamValue { c.mu.RLock() defer c.mu.RUnlock() @@ -102,16 +102,16 @@ func (c *Cache) Get(id llotypes.StreamID) (llo.StreamValue, bool) { item, ok := c.values[id] if !ok { promCacheMissCount.WithLabelValues(label, "notFound").Inc() - return nil, false + return nil } if time.Since(item.createdAt) >= c.maxAge { promCacheMissCount.WithLabelValues(label, "maxAge").Inc() - return nil, false + return nil } promCacheHitCount.WithLabelValues(label).Inc() - return item.value, true + return item.value } func (c *Cache) cleanup() { diff --git a/core/services/llo/observation/cache_test.go b/core/services/llo/observation/cache_test.go index 4f8242cf8e2..df9504a8bde 100644 --- a/core/services/llo/observation/cache_test.go +++ b/core/services/llo/observation/cache_test.go @@ -87,7 +87,6 @@ func TestCache_Add_Get(t *testing.T) { value llo.StreamValue maxAge time.Duration wantValue llo.StreamValue - wantFound bool beforeGet func(cache *Cache) }{ { @@ -96,15 +95,12 @@ func TestCache_Add_Get(t *testing.T) { value: &mockStreamValue{value: []byte{42}}, maxAge: time.Second, wantValue: &mockStreamValue{value: []byte{42}}, - wantFound: true, }, { name: "get non-existent value", streamID: 1, - value: &mockStreamValue{value: []byte{42}}, maxAge: time.Second, wantValue: nil, - wantFound: false, }, { name: "get expired by age", @@ -112,7 +108,6 @@ func TestCache_Add_Get(t *testing.T) { value: &mockStreamValue{value: []byte{42}}, maxAge: time.Nanosecond * 100, wantValue: nil, - wantFound: false, beforeGet: func(_ *Cache) { time.Sleep(time.Millisecond) }, @@ -123,19 +118,13 @@ func TestCache_Add_Get(t *testing.T) { t.Run(tt.name, func(t *testing.T) { cache := NewCache(tt.maxAge, 0) - if tt.wantFound { - cache.Add(tt.streamID, tt.value) - } + cache.Add(tt.streamID, tt.value) if tt.beforeGet != nil { tt.beforeGet(cache) } - gotValue, gotFound := cache.Get(tt.streamID) - assert.Equal(t, tt.wantFound, gotFound) - if tt.wantFound { - assert.Equal(t, tt.wantValue, gotValue) - } + assert.Equal(t, tt.wantValue, cache.Get(tt.streamID)) }) } } @@ -148,8 +137,7 @@ func TestCache_Cleanup(t *testing.T) { cache.Add(streamID, value) time.Sleep(time.Millisecond * 2) - gotValue, gotFound := cache.Get(streamID) - assert.False(t, gotFound) + gotValue := cache.Get(streamID) assert.Nil(t, gotValue) } @@ -177,9 +165,7 @@ func TestCache_ConcurrentAccess(t *testing.T) { for i := uint32(0); i < numGoroutines; i++ { for j := uint32(0); j < numOperations; j++ { streamID := i*numOperations + j - value, found := cache.Get(streamID) - assert.True(t, found) - assert.Equal(t, &mockStreamValue{value: []byte{byte(i)}}, value) + assert.Equal(t, &mockStreamValue{value: []byte{byte(i)}}, cache.Get(streamID)) } } } diff --git a/core/services/llo/observation/data_source.go b/core/services/llo/observation/data_source.go index 432a5d8a340..42dbce5a0b0 100644 --- a/core/services/llo/observation/data_source.go +++ b/core/services/llo/observation/data_source.go @@ -91,7 +91,7 @@ type dataSource struct { cache *Cache observationLoopStarted atomic.Bool observationLoopCloseCh services.StopChan - waitForLoopToExitCh chan struct{} // will be closed when we exit the observation loop + observationLoopDoneCh chan struct{} // will be closed when we exit the observation loop configDigestToStreamMu sync.Mutex configDigestToStream map[types.ConfigDigest]observableStreamValues @@ -109,7 +109,7 @@ func newDataSource(lggr logger.Logger, registry Registry, t Telemeter, shouldCac cache: NewCache(500*time.Millisecond, time.Minute), configDigestToStream: make(map[types.ConfigDigest]observableStreamValues), observationLoopCloseCh: make(chan struct{}), - waitForLoopToExitCh: make(chan struct{}), + observationLoopDoneCh: make(chan struct{}), } } @@ -130,57 +130,45 @@ func (d *dataSource) Observe(ctx context.Context, streamValues llo.StreamValues, // Fetch the cached observations for all streams. for streamID := range streamValues { - val := d.fromCache(streamID) - if val != nil { - streamValues[streamID] = val - } + streamValues[streamID] = d.cache.Get(streamID) } return nil } -func (d *dataSource) setObservableStreams(ctx context.Context, streamValues llo.StreamValues, opts llo.DSOpts) { - values := make(llo.StreamValues, len(streamValues)) - for streamID := range streamValues { - values[streamID] = nil - } - - deadline, ok := ctx.Deadline() - if !ok { - deadline = time.Now().Add(100 * time.Millisecond) - } - - streamVals := make(llo.StreamValues) - for streamID := range values { - streamVals[streamID] = values[streamID] - } - - d.configDigestToStreamMu.Lock() - d.configDigestToStream[opts.ConfigDigest()] = observableStreamValues{ - opts: opts, - streamValues: streamVals, - observationInterval: time.Until(deadline), - } - d.configDigestToStreamMu.Unlock() -} - // startObservationLoop continuously makes observations for the streams in d.configDigestToStream and stores those in // the cache. It does not check for cached versions, it always calculates fresh values. // // NOTE: This method needs to be run in a goroutine. func (d *dataSource) startObservationLoop(loopStartedCh chan struct{}) { - var elapsed time.Duration + if !d.observationLoopStarted.CompareAndSwap(false, true) { + close(loopStartedCh) + return + } + loopStarting := true + var elapsed time.Duration stopChanCtx, stopChanCancel := d.observationLoopCloseCh.NewCtx() defer stopChanCancel() + for { if stopChanCtx.Err() != nil { - close(d.waitForLoopToExitCh) + close(d.observationLoopDoneCh) return } - loopStart := time.Now() + startTS := time.Now() opts, streamValues, observationInterval := d.getObservableStreams() + if len(streamValues) == 0 || opts == nil { + // There is nothing to observe, exit and let the next Observe() call reinitialize the loop. + d.lggr.Debugw("invalid observation loop parameters", "opts", opts, "streamValues", streamValues) + + // still at the loop initialization, notify the caller and return + if loopStarting { + close(loopStartedCh) + } + return + } ctx, cancel := context.WithTimeout(stopChanCtx, observationInterval) lggr := logger.With(d.lggr, "observationTimestamp", opts.ObservationTimestamp(), "configDigest", opts.ConfigDigest(), "seqNr", opts.OutCtx().SeqNr) @@ -241,16 +229,16 @@ func (d *dataSource) startObservationLoop(loopStartedCh chan struct{}) { } // cache the observed value - d.toCache(streamID, val) + d.cache.Add(streamID, val) }(streamID) } wg.Wait() - elapsed = time.Since(loopStart) + elapsed = time.Since(startTS) - // Notify the caller that we've completed our first round of observations. - if !d.observationLoopStarted.Load() { - d.observationLoopStarted.Store(true) + // notify the caller that we've completed our first round of observations. + if loopStarting { + loopStarting = false close(loopStartedCh) } @@ -305,24 +293,11 @@ func (d *dataSource) startObservationLoop(loopStartedCh chan struct{}) { func (d *dataSource) Close() error { close(d.observationLoopCloseCh) d.observationLoopStarted.Store(false) - <-d.waitForLoopToExitCh - - return nil -} + <-d.observationLoopDoneCh -func (d *dataSource) fromCache(streamID llotypes.StreamID) llo.StreamValue { - if streamValue, found := d.cache.Get(streamID); found && streamValue != nil { - return streamValue - } return nil } -func (d *dataSource) toCache(streamID llotypes.StreamID, val llo.StreamValue) { - if val != nil { - d.cache.Add(streamID, val) - } -} - type observableStreamValues struct { opts llo.DSOpts streamValues llo.StreamValues @@ -343,6 +318,32 @@ func (o *observableStreamValues) IsActive() (bool, error) { return false, nil } +// setObservableStreams sets the observable streams for the given config digest. +func (d *dataSource) setObservableStreams(ctx context.Context, streamValues llo.StreamValues, opts llo.DSOpts) { + values := make(llo.StreamValues, len(streamValues)) + for streamID := range streamValues { + values[streamID] = nil + } + + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(100 * time.Millisecond) + } + + streamVals := make(llo.StreamValues) + for streamID := range values { + streamVals[streamID] = values[streamID] + } + + d.configDigestToStreamMu.Lock() + d.configDigestToStream[opts.ConfigDigest()] = observableStreamValues{ + opts: opts, + streamValues: streamVals, + observationInterval: time.Until(deadline), + } + d.configDigestToStreamMu.Unlock() +} + // getObservableStreams returns the active plugin data source options, the streams to observe and the observation interval // the observation interval is the maximum time we can spend observing streams. We ensure that we don't exceed this time and // we wait for the remaining time in the observation loop. @@ -354,19 +355,17 @@ func (d *dataSource) getObservableStreams() (llo.DSOpts, llo.StreamValues, time. } d.configDigestToStreamMu.Unlock() - // deduplicate streams and get the active ocr instance options for _, vals := range streamsToObserve { active, err := vals.IsActive() - if !active { - continue - } - if err != nil { d.lggr.Errorw("getObservableStreams: failed to check if OCR instance is active", "error", err) continue } - return vals.opts, vals.streamValues, vals.observationInterval + if active { + return vals.opts, vals.streamValues, vals.observationInterval + } + } d.lggr.Errorw("getObservableStreams: no active OCR instance found")