Skip to content

Commit

Permalink
add listPrefix in awsS3WriteCommitPrefix (#31776)
Browse files Browse the repository at this point in the history
* add listPrefix in awsS3WriteCommitPrefix

* linting and changelog

* linting

* try fixing integration test

* try fixing integration test
  • Loading branch information
Andrea Spacca authored and chrisberkhout committed Jun 1, 2023
1 parent 53ef09a commit 68b9331
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 65 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.next.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ https://github.com/elastic/beats/compare/v8.2.0\...main[Check the HEAD diff]
- sophos.xg: Update module to handle new log fields. {issue}31038[31038] {pull}31388[31388]
- Fix MISP documentation for `var.filters` config option. {pull}31434[31434]
- Fix type mapping of client.as.number in okta module. {pull}31676[31676]
- Fix last write pagination commit checkpoint on `aws-s3` input for s3 direct polling when using the same bucket and different list prefixes. {pull}31776[31776]

*Heartbeat*

Expand Down
95 changes: 65 additions & 30 deletions x-pack/filebeat/input/awss3/input_benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,37 @@ package awss3

import (
"context"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"time"

"github.com/elastic/beats/v7/libbeat/beat"
"github.com/elastic/beats/v7/libbeat/statestore"
"github.com/elastic/beats/v7/libbeat/statestore/storetest"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/dustin/go-humanize"
"github.com/olekukonko/tablewriter"
"github.com/pkg/errors"

"github.com/elastic/beats/v7/libbeat/beat"
pubtest "github.com/elastic/beats/v7/libbeat/publisher/testing"
"github.com/elastic/beats/v7/libbeat/statestore"
"github.com/elastic/beats/v7/libbeat/statestore/storetest"
awscommon "github.com/elastic/beats/v7/x-pack/libbeat/common/aws"
conf "github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/monitoring"
)

const (
cloudtrailTestFile = "testdata/aws-cloudtrail.json.gz"
totalListingObjects = 10000
cloudtrailTestFile = "testdata/aws-cloudtrail.json.gz"
totalListingObjects = 10000
totalListingObjectsForInputS3 = totalListingObjects / 5
)

type constantSQS struct {
Expand All @@ -54,11 +57,11 @@ func (c *constantSQS) ReceiveMessage(ctx context.Context, maxMessages int) ([]sq
return c.msgs, nil
}

func (_ *constantSQS) DeleteMessage(ctx context.Context, msg *sqs.Message) error {
func (*constantSQS) DeleteMessage(ctx context.Context, msg *sqs.Message) error {
return nil
}

func (_ *constantSQS) ChangeMessageVisibility(ctx context.Context, msg *sqs.Message, timeout time.Duration) error {
func (*constantSQS) ChangeMessageVisibility(ctx context.Context, msg *sqs.Message, timeout time.Duration) error {
return nil
}

Expand Down Expand Up @@ -93,16 +96,16 @@ func (c *s3PagerConstant) Err() error {
return nil
}

func newS3PagerConstant() *s3PagerConstant {
func newS3PagerConstant(listPrefix string) *s3PagerConstant {
lastModified := time.Now()
ret := &s3PagerConstant{
currentIndex: 0,
}

for i := 0; i < totalListingObjects; i++ {
for i := 0; i < totalListingObjectsForInputS3; i++ {
ret.objects = append(ret.objects, s3.Object{
Key: aws.String(fmt.Sprintf("key-%d.json.gz", i)),
ETag: aws.String(fmt.Sprintf("etag-%d", i)),
Key: aws.String(fmt.Sprintf("%s-%d.json.gz", listPrefix, i)),
ETag: aws.String(fmt.Sprintf("etag-%s-%d", listPrefix, i)),
LastModified: aws.Time(lastModified),
})
}
Expand Down Expand Up @@ -213,7 +216,7 @@ func benchmarkInputSQS(t *testing.T, maxMessagesInflight int) testing.BenchmarkR
}

func TestBenchmarkInputSQS(t *testing.T) {
logp.TestingSetup(logp.WithLevel(logp.InfoLevel))
_ = logp.TestingSetup(logp.WithLevel(logp.InfoLevel))

results := []testing.BenchmarkResult{
benchmarkInputSQS(t, 1),
Expand All @@ -236,7 +239,7 @@ func TestBenchmarkInputSQS(t *testing.T) {
"Time (sec)",
"CPUs",
}
var data [][]string
data := make([][]string, 0)
for _, r := range results {
data = append(data, []string{
fmt.Sprintf("%v", r.Extra["max_messages_inflight"]),
Expand All @@ -258,8 +261,7 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
log := logp.NewLogger(inputName)
metricRegistry := monitoring.NewRegistry()
metrics := newInputMetrics(metricRegistry, "test_id")
s3API := newConstantS3(t)
s3API.pagerConstant = newS3PagerConstant()

client := pubtest.NewChanClientWithCallback(100, func(event beat.Event) {
event.Private.(*awscommon.EventACKTracker).ACK()
})
Expand All @@ -273,14 +275,8 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
t.Fatalf("Failed to access store: %v", err)
}

err = store.Set(awsS3WriteCommitPrefix+"bucket", &commitWriteState{time.Time{}})
if err != nil {
t.Fatalf("Failed to reset store: %v", err)
}

s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, conf.FileSelectors)
s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", "key-", "region", "provider", numberOfWorkers, time.Second)

b.ResetTimer()
start := time.Now()
ctx, cancel := context.WithCancel(context.Background())
b.Cleanup(cancel)

Expand All @@ -291,13 +287,42 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
cancel()
}()

b.ResetTimer()
start := time.Now()
if err := s3Poller.Poll(ctx); err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
errChan := make(chan error)
wg := new(sync.WaitGroup)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(i int, wg *sync.WaitGroup) {
defer wg.Done()
listPrefix := fmt.Sprintf("list_prefix_%d", i)
s3API := newConstantS3(t)
s3API.pagerConstant = newS3PagerConstant(listPrefix)
err = store.Set(awsS3WriteCommitPrefix+"bucket"+listPrefix, &commitWriteState{time.Time{}})
if err != nil {
errChan <- err
return
}

s3EventHandlerFactory := newS3ObjectProcessorFactory(log.Named("s3"), metrics, s3API, client, conf.FileSelectors)
s3Poller := newS3Poller(logp.NewLogger(inputName), metrics, s3API, s3EventHandlerFactory, newStates(inputCtx), store, "bucket", listPrefix, "region", "provider", numberOfWorkers, time.Second)

if err := s3Poller.Poll(ctx); err != nil {
if !errors.Is(err, context.DeadlineExceeded) {
errChan <- err
}
}
}(i, wg)
}

wg.Wait()
select {
case err := <-errChan:
if err != nil {
t.Fatal(err)
}
default:

}

b.StopTimer()
elapsed := time.Since(start)

Expand All @@ -322,7 +347,7 @@ func benchmarkInputS3(t *testing.T, numberOfWorkers int) testing.BenchmarkResult
}

func TestBenchmarkInputS3(t *testing.T) {
logp.TestingSetup(logp.WithLevel(logp.InfoLevel))
_ = logp.TestingSetup(logp.WithLevel(logp.InfoLevel))

results := []testing.BenchmarkResult{
benchmarkInputS3(t, 1),
Expand All @@ -340,22 +365,32 @@ func TestBenchmarkInputS3(t *testing.T) {

headers := []string{
"Number of workers",
"Objects listed total",
"Objects listed per sec",
"Objects processed total",
"Objects processed per sec",
"Objects acked total",
"Objects acked per sec",
"Events total",
"Events per sec",
"S3 Bytes total",
"S3 Bytes per sec",
"Time (sec)",
"CPUs",
}
var data [][]string
data := make([][]string, 0)
for _, r := range results {
data = append(data, []string{
fmt.Sprintf("%v", r.Extra["number_of_workers"]),
fmt.Sprintf("%v", r.Extra["objects_listed"]),
fmt.Sprintf("%v", r.Extra["objects_listed_per_sec"]),
fmt.Sprintf("%v", r.Extra["objects_processed"]),
fmt.Sprintf("%v", r.Extra["objects_processed_per_sec"]),
fmt.Sprintf("%v", r.Extra["objects_acked"]),
fmt.Sprintf("%v", r.Extra["objects_acked_per_sec"]),
fmt.Sprintf("%v", r.Extra["events"]),
fmt.Sprintf("%v", r.Extra["events_per_sec"]),
fmt.Sprintf("%v", humanize.Bytes(uint64(r.Extra["s3_bytes"]))),
fmt.Sprintf("%v", humanize.Bytes(uint64(r.Extra["s3_bytes_per_sec"]))),
fmt.Sprintf("%v", r.Extra["sec"]),
fmt.Sprintf("%v", runtime.GOMAXPROCS(0)),
Expand Down
49 changes: 32 additions & 17 deletions x-pack/filebeat/input/awss3/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ package awss3

import (
"context"
"errors"
"fmt"
"net/url"
"sync"
"time"

"github.com/gofrs/uuid"
"github.com/pkg/errors"
"go.uber.org/multierr"

"github.com/elastic/beats/v7/libbeat/statestore"
Expand Down Expand Up @@ -126,9 +127,11 @@ func (p *s3Poller) ProcessObject(s3ObjectPayloadChan <-chan *s3ObjectPayload) er

if err != nil {
event := s3ObjectPayload.s3ObjectEvent
errs = append(errs, errors.Wrapf(err,
"failed processing S3 event for object key %q in bucket %q",
event.S3.Object.Key, event.S3.Bucket.Name))
errs = append(errs,
fmt.Errorf(
fmt.Sprintf("failed processing S3 event for object key %q in bucket %q: %%w",
event.S3.Object.Key, event.S3.Bucket.Name),
err))

p.handlePurgingLock(info, false)
continue
Expand Down Expand Up @@ -178,7 +181,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
continue
}

state := newState(bucketName, filename, *object.ETag, *object.LastModified)
state := newState(bucketName, filename, *object.ETag, p.listPrefix, *object.LastModified)
if p.states.MustSkip(state, p.store) {
p.log.Debugw("skipping state.", "state", state)
continue
Expand All @@ -197,6 +200,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-

s3Processor := p.s3ObjectHandler.Create(ctx, p.log, acker, event)
if s3Processor == nil {
p.log.Debugw("empty s3 processor.", "state", state)
continue
}

Expand All @@ -216,6 +220,7 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
}

if totProcessableObjects == 0 {
p.log.Debugw("0 processable objects on bucket pagination.", "bucket", p.bucket, "listPrefix", p.listPrefix, "listingID", listingID)
// nothing to be ACKed, unlock here
p.states.DeleteListing(listingID.String())
lock.Unlock()
Expand All @@ -236,12 +241,11 @@ func (p *s3Poller) GetS3Objects(ctx context.Context, s3ObjectPayloadChan chan<-
if err := paginator.Err(); err != nil {
p.log.Warnw("Error when paginating listing.", "error", err)
}

return
}

func (p *s3Poller) Purge() {
listingIDs := p.states.GetListingIDs()
p.log.Debugw("purging listing.", "listingIDs", listingIDs)
for _, listingID := range listingIDs {
// we lock here in order to process the purge only after
// full listing page is ACKed by all the workers
Expand All @@ -250,39 +254,45 @@ func (p *s3Poller) Purge() {
// purge calls can overlap, GetListingIDs can return
// an outdated snapshot with listing already purged
p.states.DeleteListing(listingID)
p.log.Debugw("deleting already purged listing from states.", "listingID", listingID)
continue
}

lock.(*sync.Mutex).Lock()

keys := map[string]struct{}{}
latestStoredTimeByBucket := make(map[string]time.Time, 0)
latestStoredTimeByBucketAndListPrefix := make(map[string]time.Time, 0)

for _, state := range p.states.GetStatesByListingID(listingID) {
// it is not stored, keep
if !state.Stored {
p.log.Debugw("state not stored, skip purge", "state", state)
continue
}

var latestStoredTime time.Time
keys[state.ID] = struct{}{}
latestStoredTime, ok := latestStoredTimeByBucket[state.Bucket]
latestStoredTime, ok := latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix]
if !ok {
var commitWriteState commitWriteState
err := p.store.Get(awsS3WriteCommitPrefix+state.Bucket, &commitWriteState)
err := p.store.Get(awsS3WriteCommitPrefix+state.Bucket+state.ListPrefix, &commitWriteState)
if err == nil {
// we have no entry in the map and we have no entry in the store
// set zero time
latestStoredTime = time.Time{}
p.log.Debugw("last stored time is zero time", "bucket", state.Bucket, "listPrefix", state.ListPrefix)
} else {
latestStoredTime = commitWriteState.Time
p.log.Debugw("last stored time is commitWriteState", "commitWriteState", commitWriteState, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
}
} else {
p.log.Debugw("last stored time from memory", "latestStoredTime", latestStoredTime, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
}

if state.LastModified.After(latestStoredTime) {
latestStoredTimeByBucket[state.Bucket] = state.LastModified
p.log.Debugw("last stored time updated", "state.LastModified", state.LastModified, "bucket", state.Bucket, "listPrefix", state.ListPrefix)
latestStoredTimeByBucketAndListPrefix[state.Bucket+state.ListPrefix] = state.LastModified
}

}

for key := range keys {
Expand All @@ -293,8 +303,8 @@ func (p *s3Poller) Purge() {
p.log.Errorw("Failed to write states to the registry", "error", err)
}

for bucket, latestStoredTime := range latestStoredTimeByBucket {
if err := p.store.Set(awsS3WriteCommitPrefix+bucket, commitWriteState{latestStoredTime}); err != nil {
for bucketAndListPrefix, latestStoredTime := range latestStoredTimeByBucketAndListPrefix {
if err := p.store.Set(awsS3WriteCommitPrefix+bucketAndListPrefix, commitWriteState{latestStoredTime}); err != nil {
p.log.Errorw("Failed to write commit time to the registry", "error", err)
}
}
Expand All @@ -304,8 +314,6 @@ func (p *s3Poller) Purge() {
p.workersListingMap.Delete(listingID)
p.states.DeleteListing(listingID)
}

return
}

func (p *s3Poller) Poll(ctx context.Context) error {
Expand Down Expand Up @@ -349,8 +357,15 @@ func (p *s3Poller) Poll(ctx context.Context) error {
}()
}

timed.Wait(ctx, p.bucketPollInterval)
err = timed.Wait(ctx, p.bucketPollInterval)
if err != nil {
if errors.Is(err, context.Canceled) {
// A canceled context is a normal shutdown.
return nil
}

return err
}
}

// Wait for all workers to finish.
Expand Down
Loading

0 comments on commit 68b9331

Please sign in to comment.