Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: rate limit requests to target registry #21

Merged
merged 1 commit into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cmd/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ var (
ignoredUserAgents []string
cacheDuration time.Duration
refreshInterval time.Duration
requestsPerSecond int
)

var serveCmd = &cobra.Command{
Expand All @@ -47,16 +48,21 @@ var serveCmd = &cobra.Command{
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()

if requestsPerSecond < 1 {
requestsPerSecond = 1
}

log := logger.FromContext(ctx)
log.Info("starting server",
slog.Duration("cache-duration", cacheDuration),
slog.String("bind-addr", bindAddr),
slog.Any("ignored-user-agents", ignoredUserAgents),
slog.Any("refresh-interval", refreshInterval),
slog.Int("requests-per-second", requestsPerSecond),
)

client := registry.New(rootCfg)
asyncClient := async.New(client, refreshInterval)
asyncClient := async.New(client, refreshInterval, requestsPerSecond)

filler := filler.New(asyncClient, rootCfg.RegistryHostname, "/")

Expand Down Expand Up @@ -93,5 +99,6 @@ func init() {
serveCmd.PersistentFlags().StringArrayVar(&ignoredUserAgents, "ignored-user-agent", []string{}, "user agents to ignore (reply with empty body and 200 OK). A user agent is ignored if it contains the one of the values passed to this flag")
serveCmd.PersistentFlags().DurationVar(&cacheDuration, "cache-duration", time.Minute*1, "how long to keep a generated page in cache before expiring it, 0 to never expire")
serveCmd.PersistentFlags().DurationVar(&refreshInterval, "refresh-interval", time.Minute*15, "how long to wait before trying to get fresh data from the target registry")
serveCmd.PersistentFlags().IntVar(&requestsPerSecond, "requests-per-second", 1, "limit the number of requests per second that can be done to the target registry")
jordeu marked this conversation as resolved.
Show resolved Hide resolved
rootCmd.AddCommand(serveCmd)
}
38 changes: 24 additions & 14 deletions pkg/registry/async/async_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ type Async struct {
underlying *registryimpl.Registry
// refreshInterval represents the time to wait to synchronize repositories again after a successful synchronization
refreshInterval time.Duration
// requestsPerSecond is the maximum amount of requests that this client
// will do against the underlying registry in the window of time of 1 second.
requestsPerSecond int

// repos is an in memory list of all the repository names in the registry
repos map[string]registry.RepoData
Expand All @@ -57,6 +60,11 @@ type Async struct {

// imageInfo contains the image information indexed by repo name and tag
imageInfo *xsync.MapOf[imageInfoKey, imageInfo]

// limiter is configured by the requestsPerSecond property
// it is used right before sending requests to the registry
// so they wait if needed.
limiter *rateLimiter
}

type imageInfoKey struct {
Expand Down Expand Up @@ -89,9 +97,12 @@ func (c *Async) Start(ctx context.Context) error {
// so that image info is retrieved for each <repo,tag> combination
imageInfoRequestsBuffer := make(chan imageInfoRequest, imageInfoRequestsBufSize)

c.limiter = newRateLimiter(c.requestsPerSecond)

defer func() {
close(repositoryRequestBuffer)
close(imageInfoRequestsBuffer)
c.limiter.Stop()
}()

g.Go(func() error {
Expand Down Expand Up @@ -168,16 +179,13 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, reqChan chan<- imag
reqLog := log.With(slog.Any("req", req))
reqLog.Debug("handleRepositoryRequest")
tags, err := c.underlying.TagList(ctx, req.repo)

if err != nil {
reqLog.Warn("could not list tags for image", logger.ErrAttr(err))
return

}

c.repositoryTags.Store(req.repo, tags)

for _, t := range tags {
c.limiter.Allow()
select {
case reqChan <- imageInfoRequest{
repo: req.repo,
Expand All @@ -188,11 +196,11 @@ func (c *Async) handleRepositoryRequest(ctx context.Context, reqChan chan<- imag
}
}
}

func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest) {
log := logger.FromContext(ctx)
reqLog := log.With(slog.Any("req", req))
reqLog.Debug("handleImageInfoRequest")
c.limiter.Allow()
key := imageInfoKey(req)

// update image info
Expand All @@ -213,13 +221,11 @@ func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest
reqLog.Warn("could not get config file for tag", logger.ErrAttr(err))
return
}

if prev, ok := c.repos[req.repo]; ok {
if prev.LastUpdatedAt.After(cf.Created.Time) {
return
}
}

c.reposMutex.Lock()
defer c.reposMutex.Unlock()
c.repos[req.repo] = registry.RepoData{
Expand All @@ -228,7 +234,6 @@ func (c *Async) handleImageInfoRequest(ctx context.Context, req imageInfoRequest
PullReference: r,
}
}

func (c *Async) RepoList(ctx context.Context) (repos map[string]registry.RepoData, err error) {
return c.repos, nil
}
Expand All @@ -253,13 +258,18 @@ func (c *Async) ImageInfo(ctx context.Context, repo string, tag string) (image v
return info.image, info.reference, nil
}

func New(client *registryimpl.Registry, refreshInterval time.Duration) *Async {
func New(
client *registryimpl.Registry,
refreshInterval time.Duration,
requestsPerSecond int,
) *Async {
return &Async{
underlying: client,
refreshInterval: refreshInterval,
repositoryTags: xsync.NewMapOf[string, []string](),
imageInfo: xsync.NewMapOf[imageInfoKey, imageInfo](),
repos: map[string]registry.RepoData{},
underlying: client,
refreshInterval: refreshInterval,
repositoryTags: xsync.NewMapOf[string, []string](),
imageInfo: xsync.NewMapOf[imageInfoKey, imageInfo](),
repos: map[string]registry.RepoData{},
requestsPerSecond: requestsPerSecond,
}
}

Expand Down
45 changes: 45 additions & 0 deletions pkg/registry/async/ratelimiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright 2024 Seqera
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package async

import (
"time"
)

// rateLimiter is a simple implementation of a rate limiter that
// controls the rate of function calls. It imposes a delay between
// allowed calls based on the specified requests per second (RPS).
// It utilizes time.Ticker for scheduling the allowed times.
type rateLimiter struct {
ticker *time.Ticker
quit chan struct{}
}

func newRateLimiter(rps int) *rateLimiter {
limiter := &rateLimiter{
ticker: time.NewTicker(time.Second / time.Duration(rps)),
quit: make(chan struct{}),
}
return limiter
}

func (rl *rateLimiter) Allow() {
<-rl.ticker.C
}

func (rl *rateLimiter) Stop() {
close(rl.quit)
rl.ticker.Stop()
}
68 changes: 68 additions & 0 deletions pkg/registry/async/ratelimiter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright 2024 Seqera
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package async

import (
"testing"
"time"
)

func TestRateLimiter_Allow(t *testing.T) {
rps := 1
requestCount := 5
limiter := newRateLimiter(rps)
defer limiter.Stop()

start := time.Now()
for i := 0; i < requestCount; i++ {
limiter.Allow()
}
duration := time.Since(start)
if duration < 5*time.Second {
t.Errorf("expected duration to be at least 5 seconds, got %v", duration)
}
}

func TestRateLimiter_ConcurrentRequests(t *testing.T) {
rps := 10
limiter := newRateLimiter(rps)
defer limiter.Stop()

requestCount := 100
done := make(chan struct{}, requestCount)

for i := 0; i < requestCount; i++ {
go func() {
limiter.Allow()
done <- struct{}{}
}()
}

// Wait for all requests to finish
for i := 0; i < requestCount; i++ {
<-done
}

// Since rps is 10, allowing 100 calls should take approximately 10 seconds
// let's wait for that to happen before doing any assertions
duration := time.Duration(requestCount/rps) * time.Second
time.Sleep(duration)

// To check if we are somewhat in the range of expected duration
if duration < 9*time.Second || duration > 11*time.Second {
t.Errorf("expected duration to be around 10 seconds, got %v", duration)
}
}
Loading