From 0b75a492c15812a3b8fcbdb2911910dc2f7d01f7 Mon Sep 17 00:00:00 2001 From: Lorenzo Fontana Date: Wed, 18 Dec 2024 16:10:52 +0100 Subject: [PATCH] feat: rate limit requests to target registry --- cmd/serve.go | 9 +++- pkg/registry/async/async_registry.go | 38 ++++++++------ pkg/registry/async/ratelimiter.go | 45 +++++++++++++++++ pkg/registry/async/ratelimiter_test.go | 68 ++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 15 deletions(-) create mode 100644 pkg/registry/async/ratelimiter.go create mode 100644 pkg/registry/async/ratelimiter_test.go diff --git a/cmd/serve.go b/cmd/serve.go index 3b4b557..ed3c87b 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -37,6 +37,7 @@ var ( ignoredUserAgents []string cacheDuration time.Duration refreshInterval time.Duration + requestsPerSecond int ) var serveCmd = &cobra.Command{ @@ -47,16 +48,21 @@ var serveCmd = &cobra.Command{ ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() + if requestsPerSecond < 1 { + requestsPerSecond = 1 + } + log := logger.FromContext(ctx) log.Info("starting server", slog.Duration("cache-duration", cacheDuration), slog.String("bind-addr", bindAddr), slog.Any("ignored-user-agents", ignoredUserAgents), slog.Any("refresh-interval", refreshInterval), + slog.Int("requests-per-second", requestsPerSecond), ) client := registry.New(rootCfg) - asyncClient := async.New(client, refreshInterval) + asyncClient := async.New(client, refreshInterval, requestsPerSecond) filler := filler.New(asyncClient, rootCfg.RegistryHostname, "/") @@ -93,5 +99,6 @@ func init() { serveCmd.PersistentFlags().StringArrayVar(&ignoredUserAgents, "ignored-user-agent", []string{}, "user agents to ignore (reply with empty body and 200 OK). A user agent is ignored if it contains the one of the values passed to this flag") serveCmd.PersistentFlags().DurationVar(&cacheDuration, "cache-duration", time.Minute*1, "how long to keep a generated page in cache before expiring it, 0 to never expire") serveCmd.PersistentFlags().DurationVar(&refreshInterval, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry") + serveCmd.PersistentFlags().IntVar(&requestsPerSecond, "requests-per-second", 1, "limit the number of requests per second that can be done to the target registry") rootCmd.AddCommand(serveCmd) } diff --git a/pkg/registry/async/async_registry.go b/pkg/registry/async/async_registry.go index 84d7693..6592d06 100644 --- a/pkg/registry/async/async_registry.go +++ b/pkg/registry/async/async_registry.go @@ -47,6 +47,9 @@ type Async struct { underlying *registryimpl.Registry // refreshInterval represents the time to wait to synchronize repositories again after a successful synchronization refreshInterval time.Duration + // requestsPerSecond is the maximum amount of requests that this client + // will do against the underlying registry in the window of time of 1 second. + requestsPerSecond int // repos is an in memory list of all the repository names in the registry repos map[string]registry.RepoData @@ -57,6 +60,11 @@ type Async struct { // imageInfo contains the image information indexed by repo name and tag imageInfo *xsync.MapOf[imageInfoKey, imageInfo] + + // limiter is configured by the requestsPerSecond property + // it is used right before sending requests to the registry + // so they wait if needed. + limiter *rateLimiter } type imageInfoKey struct { @@ -89,9 +97,12 @@ func (c *Async) Start(ctx context.Context) error { // so that image info is retrieved for each combination imageInfoRequestsBuffer := make(chan imageInfoRequest, imageInfoRequestsBufSize) + c.limiter = newRateLimiter(c.requestsPerSecond) + defer func() { close(repositoryRequestBuffer) close(imageInfoRequestsBuffer) + c.limiter.Stop() }() g.Go(func() error { @@ -168,16 +179,13 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, reqChan chan<- imag reqLog := log.With(slog.Any("req", req)) reqLog.Debug("handleRepositoryRequest") tags, err := c.underlying.TagList(ctx, req.repo) - if err != nil { reqLog.Warn("could not list tags for image", logger.ErrAttr(err)) return - } - c.repositoryTags.Store(req.repo, tags) - for _, t := range tags { + c.limiter.Allow() select { case reqChan <- imageInfoRequest{ repo: req.repo, @@ -188,11 +196,11 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, reqChan chan<- imag } } } - func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest) { log := logger.FromContext(ctx) reqLog := log.With(slog.Any("req", req)) reqLog.Debug("handleImageInfoRequest") + c.limiter.Allow() key := imageInfoKey(req) // update image info @@ -213,13 +221,11 @@ func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest reqLog.Warn("could not get config file for tag", logger.ErrAttr(err)) return } - if prev, ok := c.repos[req.repo]; ok { if prev.LastUpdatedAt.After(cf.Created.Time) { return } } - c.reposMutex.Lock() defer c.reposMutex.Unlock() c.repos[req.repo] = registry.RepoData{ @@ -228,7 +234,6 @@ func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest PullReference: r, } } - func (c *Async) RepoList(ctx context.Context) (repos map[string]registry.RepoData, err error) { return c.repos, nil } @@ -253,13 +258,18 @@ func (c *Async) ImageInfo(ctx context.Context, repo string, tag string) (image v return info.image, info.reference, nil } -func New(client *registryimpl.Registry, refreshInterval time.Duration) *Async { +func New( + client *registryimpl.Registry, + refreshInterval time.Duration, + requestsPerSecond int, +) *Async { return &Async{ - underlying: client, - refreshInterval: refreshInterval, - repositoryTags: xsync.NewMapOf[string, []string](), - imageInfo: xsync.NewMapOf[imageInfoKey, imageInfo](), - repos: map[string]registry.RepoData{}, + underlying: client, + refreshInterval: refreshInterval, + repositoryTags: xsync.NewMapOf[string, []string](), + imageInfo: xsync.NewMapOf[imageInfoKey, imageInfo](), + repos: map[string]registry.RepoData{}, + requestsPerSecond: requestsPerSecond, } } diff --git a/pkg/registry/async/ratelimiter.go b/pkg/registry/async/ratelimiter.go new file mode 100644 index 0000000..fe5c55c --- /dev/null +++ b/pkg/registry/async/ratelimiter.go @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2024 Seqera +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package async + +import ( + "time" +) + +// rateLimiter is a simple implementation of a rate limiter that +// controls the rate of function calls. It imposes a delay between +// allowed calls based on the specified requests per second (RPS). +// It utilizes time.Ticker for scheduling the allowed times. +type rateLimiter struct { + ticker *time.Ticker + quit chan struct{} +} + +func newRateLimiter(rps int) *rateLimiter { + limiter := &rateLimiter{ + ticker: time.NewTicker(time.Second / time.Duration(rps)), + quit: make(chan struct{}), + } + return limiter +} + +func (rl *rateLimiter) Allow() { + <-rl.ticker.C +} + +func (rl *rateLimiter) Stop() { + close(rl.quit) + rl.ticker.Stop() +} diff --git a/pkg/registry/async/ratelimiter_test.go b/pkg/registry/async/ratelimiter_test.go new file mode 100644 index 0000000..a4489f8 --- /dev/null +++ b/pkg/registry/async/ratelimiter_test.go @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright 2024 Seqera +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package async + +import ( + "testing" + "time" +) + +func TestRateLimiter_Allow(t *testing.T) { + rps := 1 + requestCount := 5 + limiter := newRateLimiter(rps) + defer limiter.Stop() + + start := time.Now() + for i := 0; i < requestCount; i++ { + limiter.Allow() + } + duration := time.Since(start) + if duration < 5*time.Second { + t.Errorf("expected duration to be at least 5 seconds, got %v", duration) + } +} + +func TestRateLimiter_ConcurrentRequests(t *testing.T) { + rps := 10 + limiter := newRateLimiter(rps) + defer limiter.Stop() + + requestCount := 100 + done := make(chan struct{}, requestCount) + + for i := 0; i < requestCount; i++ { + go func() { + limiter.Allow() + done <- struct{}{} + }() + } + + // Wait for all requests to finish + for i := 0; i < requestCount; i++ { + <-done + } + + // Since rps is 10, allowing 100 calls should take approximately 10 seconds + // let's wait for that to happen before doing any assertions + duration := time.Duration(requestCount/rps) * time.Second + time.Sleep(duration) + + // To check if we are somewhat in the range of expected duration + if duration < 9*time.Second || duration > 11*time.Second { + t.Errorf("expected duration to be around 10 seconds, got %v", duration) + } +}