Skip to content

Commit

Permalink
backport of commit d70c17f (hashicorp#20018)
Browse files Browse the repository at this point in the history
Co-authored-by: miagilepner <mia.epner@hashicorp.com>
Co-authored-by: Yoko Hyakuna <yoko@hashicorp.com>
  • Loading branch information
3 people authored Apr 11, 2023
1 parent 9bfc9f3 commit e87f04f
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 0 deletions.
95 changes: 95 additions & 0 deletions vault/activity_log_util_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"errors"
"fmt"
"io"
"sort"
"time"

"github.com/axiomhq/hyperloglog"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/vault/activity"
"google.golang.org/protobuf/proto"
)

type HLLGetter func(ctx context.Context, startTime time.Time) (*hyperloglog.Sketch, error)
Expand Down Expand Up @@ -281,3 +283,96 @@ func (a *ActivityLog) mountAccessorToMountPath(mountAccessor string) string {
}
return displayPath
}

type singleTypeSegmentReader struct {
basePath string
startTime time.Time
paths []string
currentPathIndex int
a *ActivityLog
}
type segmentReader struct {
tokens *singleTypeSegmentReader
entities *singleTypeSegmentReader
}

// SegmentReader is an interface that provides methods to read tokens and entities in order
type SegmentReader interface {
ReadToken(ctx context.Context) (*activity.TokenCount, error)
ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error)
}

func (a *ActivityLog) NewSegmentFileReader(ctx context.Context, startTime time.Time) (SegmentReader, error) {
entities, err := a.newSingleTypeSegmentReader(ctx, startTime, activityEntityBasePath)
if err != nil {
return nil, err
}
tokens, err := a.newSingleTypeSegmentReader(ctx, startTime, activityTokenBasePath)
if err != nil {
return nil, err
}
return &segmentReader{entities: entities, tokens: tokens}, nil
}

func (a *ActivityLog) newSingleTypeSegmentReader(ctx context.Context, startTime time.Time, prefix string) (*singleTypeSegmentReader, error) {
basePath := prefix + fmt.Sprint(startTime.Unix()) + "/"
pathList, err := a.view.List(ctx, basePath)
if err != nil {
return nil, err
}
return &singleTypeSegmentReader{
basePath: basePath,
startTime: startTime,
paths: pathList,
currentPathIndex: 0,
a: a,
}, nil
}

func (s *singleTypeSegmentReader) nextValue(ctx context.Context, out proto.Message) error {
var raw *logical.StorageEntry
var path string
for raw == nil {
if s.currentPathIndex >= len(s.paths) {
return io.EOF
}
path = s.paths[s.currentPathIndex]
// increment the index to continue iterating for the next read call, even if an error occurs during this call
s.currentPathIndex++
var err error
raw, err = s.a.view.Get(ctx, s.basePath+path)
if err != nil {
return err
}
if raw == nil {
s.a.logger.Warn("expected log segment file has been deleted", "startTime", s.startTime, "segmentPath", path)
}
}
err := proto.Unmarshal(raw.Value, out)
if err != nil {
return fmt.Errorf("unable to parse segment file %v%v: %w", s.basePath, path, err)
}
return nil
}

// ReadToken reads a token from the segment
// If there is none available, then the error will be io.EOF
func (e *segmentReader) ReadToken(ctx context.Context) (*activity.TokenCount, error) {
out := &activity.TokenCount{}
err := e.tokens.nextValue(ctx, out)
if err != nil {
return nil, err
}
return out, nil
}

// ReadEntity reads an entity from the segment
// If there is none available, then the error will be io.EOF
func (e *segmentReader) ReadEntity(ctx context.Context) (*activity.EntityActivityLog, error) {
out := &activity.EntityActivityLog{}
err := e.entities.nextValue(ctx, out)
if err != nil {
return nil, err
}
return out, nil
}
183 changes: 183 additions & 0 deletions vault/activity_log_util_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ package vault

import (
"context"
"errors"
"fmt"
"io"
"testing"
"time"

"github.com/axiomhq/hyperloglog"
"github.com/hashicorp/vault/helper/timeutil"
"github.com/hashicorp/vault/vault/activity"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
)

func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T) {
Expand Down Expand Up @@ -150,3 +155,181 @@ func Test_ActivityLog_ComputeCurrentMonthForBillingPeriodInternal(t *testing.T)
t.Fatalf("wrong number of new non entity clients. Expected 0, got %d", monthRecord.NewClients.Counts.NonEntityClients)
}
}

// writeEntitySegment writes a single segment file with the given time and index for an entity
func writeEntitySegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.EntityActivityLog) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, ts, index), protoItem)
}

// writeTokenSegment writes a single segment file with the given time and index for a token
func writeTokenSegment(t *testing.T, core *Core, ts time.Time, index int, item *activity.TokenCount) {
t.Helper()
protoItem, err := proto.Marshal(item)
require.NoError(t, err)
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, ts, index), protoItem)
}

// makeSegmentPath formats the path for a segment at a particular time and index
func makeSegmentPath(t *testing.T, typ string, ts time.Time, index int) string {
t.Helper()
return fmt.Sprintf("%s%s%d/%d", ActivityPrefix, typ, ts.Unix(), index)
}

// TestSegmentFileReader_BadData verifies that the reader returns errors when the data is unable to be parsed
// However, the next time that Read*() is called, the reader should still progress and be able to then return any
// valid data without errors
func TestSegmentFileReader_BadData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()

// write bad data that won't be able to be unmarshaled at index 0
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, 0), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, 0), []byte("fake data"))

// write entity at index 1
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 1, entity)

// write token at index 1
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 1, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)

// first the bad entity is read, which returns an error
_, err = reader.ReadEntity(context.Background())
require.Error(t, err)
// then, the reader can read the good entity at index 1
gotEntity, err := reader.ReadEntity(context.Background())
require.True(t, proto.Equal(gotEntity, entity))
require.Nil(t, err)

// the bad token causes an error
_, err = reader.ReadToken(context.Background())
require.Error(t, err)
// but the good token is able to be read
gotToken, err := reader.ReadToken(context.Background())
require.True(t, proto.Equal(gotToken, token))
require.Nil(t, err)
}

// TestSegmentFileReader_MissingData verifies that the segment file reader will skip over missing segment paths without
// errorring until it is able to find a valid segment path
func TestSegmentFileReader_MissingData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
// write entities and tokens at indexes 0, 1, 2
for i := 0; i < 3; i++ {
WriteToStorage(t, core, makeSegmentPath(t, activityTokenBasePath, now, i), []byte("fake data"))
WriteToStorage(t, core, makeSegmentPath(t, activityEntityBasePath, now, i), []byte("fake data"))

}
// write entity at index 3
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: "id",
},
}}
writeEntitySegment(t, core, now, 3, entity)
// write token at index 3
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
"ns": 1,
}}
writeTokenSegment(t, core, now, 3, token)
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)

// delete the indexes 0, 1, 2
for i := 0; i < 3; i++ {
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityTokenBasePath, now, i)))
require.NoError(t, core.barrier.Delete(context.Background(), makeSegmentPath(t, activityEntityBasePath, now, i)))
}

// we expect the reader to only return the data at index 3, and then be done
gotEntity, err := reader.ReadEntity(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotEntity, entity))
_, err = reader.ReadEntity(context.Background())
require.Equal(t, err, io.EOF)

gotToken, err := reader.ReadToken(context.Background())
require.NoError(t, err)
require.True(t, proto.Equal(gotToken, token))
_, err = reader.ReadToken(context.Background())
require.Equal(t, err, io.EOF)
}

// TestSegmentFileReader_NoData verifies that the reader return io.EOF when there is no data
func TestSegmentFileReader_NoData(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)
entity, err := reader.ReadEntity(context.Background())
require.Nil(t, entity)
require.Equal(t, err, io.EOF)
token, err := reader.ReadToken(context.Background())
require.Nil(t, token)
require.Equal(t, err, io.EOF)
}

// TestSegmentFileReader verifies that the reader iterates through all segments paths in ascending order and returns
// io.EOF when it's done
func TestSegmentFileReader(t *testing.T) {
core, _, _ := TestCoreUnsealed(t)
now := time.Now()
entities := make([]*activity.EntityActivityLog, 0, 3)
tokens := make([]*activity.TokenCount, 0, 3)

// write 3 entity segment pieces and 3 token segment pieces
for i := 0; i < 3; i++ {
entity := &activity.EntityActivityLog{Clients: []*activity.EntityRecord{
{
ClientID: fmt.Sprintf("id-%d", i),
},
}}
token := &activity.TokenCount{CountByNamespaceID: map[string]uint64{
fmt.Sprintf("ns-%d", i): uint64(i),
}}
writeEntitySegment(t, core, now, i, entity)
writeTokenSegment(t, core, now, i, token)
entities = append(entities, entity)
tokens = append(tokens, token)
}

reader, err := core.activityLog.NewSegmentFileReader(context.Background(), now)
require.NoError(t, err)

gotEntities := make([]*activity.EntityActivityLog, 0, 3)
gotTokens := make([]*activity.TokenCount, 0, 3)

// read the entities from the reader
for entity, err := reader.ReadEntity(context.Background()); !errors.Is(err, io.EOF); entity, err = reader.ReadEntity(context.Background()) {
require.NoError(t, err)
gotEntities = append(gotEntities, entity)
}

// read the tokens from the reader
for token, err := reader.ReadToken(context.Background()); !errors.Is(err, io.EOF); token, err = reader.ReadToken(context.Background()) {
require.NoError(t, err)
gotTokens = append(gotTokens, token)
}
require.Len(t, gotEntities, 3)
require.Len(t, gotTokens, 3)

// verify that the entities and tokens we got from the reader are correct
// we can't use require.Equals() here because there are protobuf differences in unexported fields
for i := 0; i < 3; i++ {
require.True(t, proto.Equal(gotEntities[i], entities[i]))
require.True(t, proto.Equal(gotTokens[i], tokens[i]))
}
}

0 comments on commit e87f04f

Please sign in to comment.