Skip to content

Commit

Permalink
oss port of vault-7225-bugfix (#16745)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hridoy Roy authored Aug 16, 2022
1 parent ff504ad commit 687b3d1
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 12 deletions.
38 changes: 27 additions & 11 deletions vault/activity_log.go
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,12 @@ func (a *ActivityLog) DefaultStartTime(endTime time.Time) time.Time {

func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.Time, limitNamespaces int) (map[string]interface{}, error) {
var computePartial bool

// Change the start time to the beginning of the month, and the end time to be the end
// of the month.
startTime = timeutil.StartOfMonth(startTime)
endTime = timeutil.EndOfMonth(endTime)

// If the endTime of the query is the current month, request data from the queryStore
// with the endTime equal to the end of the last month, and add in the current month
// data.
Expand All @@ -1513,16 +1519,26 @@ func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.T
computePartial = true
}

// From the precomputed queries stored in the queryStore (computed at the end of each month)
// get the query associated with the start and end time specified
pq, err := a.queryStore.Get(ctx, startTime, precomputedQueryEndTime)
if err != nil {
return nil, err
}
if pq == nil {
return nil, nil
pq := &activity.PrecomputedQuery{}
if startTime.After(precomputedQueryEndTime) && timeutil.IsCurrentMonth(startTime, time.Now().UTC()) {
// We're only calculating the partial month client count. Skip the precomputation
// get call.
pq = &activity.PrecomputedQuery{
StartTime: startTime,
EndTime: endTime,
Namespaces: make([]*activity.NamespaceRecord, 0),
Months: make([]*activity.MonthRecord, 0),
}
} else {
storedQuery, err := a.queryStore.Get(ctx, startTime, precomputedQueryEndTime)
if err != nil {
return nil, err
}
if storedQuery == nil {
return nil, nil
}
pq = storedQuery
}

// Calculate the namespace response breakdowns and totals for entities and tokens from the initial
// namespace data.
totalEntities, totalTokens, byNamespaceResponse, err := a.calculateByNamespaceResponseForQuery(ctx, pq.Namespaces)
Expand Down Expand Up @@ -1634,7 +1650,7 @@ func modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Tim
if err != nil {
return months
}
for start.Before(firstMonth) {
for start.Before(firstMonth) && !timeutil.IsCurrentMonth(start, firstMonth) {
monthPlaceholder := &ResponseMonth{Timestamp: start.UTC().Format(time.RFC3339)}
modifiedResponseMonths = append(modifiedResponseMonths, monthPlaceholder)
start = timeutil.StartOfMonth(start.AddDate(0, 1, 0))
Expand All @@ -1645,7 +1661,7 @@ func modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Tim
return modifiedResponseMonths
}
lastMonth := timeutil.EndOfMonth(lastMonthStart)
for lastMonth.Before(end) {
for lastMonth.Before(end) && !timeutil.IsCurrentMonth(end, lastMonth) {
lastMonth = timeutil.StartOfMonth(lastMonth).AddDate(0, 1, 0)
monthPlaceholder := &ResponseMonth{Timestamp: lastMonth.UTC().Format(time.RFC3339)}
modifiedResponseMonths = append(modifiedResponseMonths, monthPlaceholder)
Expand Down
126 changes: 126 additions & 0 deletions vault/activity_log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3826,3 +3826,129 @@ func TestActivityLog_partialMonthClientCount(t *testing.T) {
t.Errorf("bad client count. expected %d, got %d", len(clients), clientCount)
}
}

func TestActivityLog_partialMonthClientCountUsingHandleQuery(t *testing.T) {
timeutil.SkipAtEndOfMonth(t)

ctx := namespace.RootContext(nil)
now := time.Now().UTC()
a, clients, _ := setupActivityRecordsInStorage(t, timeutil.StartOfMonth(now), true, true)

// clients[0] belongs to previous month
clients = clients[1:]

clientCounts := make(map[string]uint64)
for _, client := range clients {
clientCounts[client.NamespaceID] += 1
}

a.SetEnable(true)
var wg sync.WaitGroup
err := a.refreshFromStoredLog(ctx, &wg, now)
if err != nil {
t.Fatalf("error loading clients: %v", err)
}
wg.Wait()

results, err := a.handleQuery(ctx, time.Now().UTC(), time.Now().UTC(), 0)
if err != nil {
t.Fatal(err)
}
if results == nil {
t.Fatal("no results to test")
}
if err != nil {
t.Fatal(err)
}
if results == nil {
t.Fatal("no results to test")
}

byNamespace, ok := results["by_namespace"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}

clientCountResponse := make([]*ResponseNamespace, 0)
err = mapstructure.Decode(byNamespace, &clientCountResponse)
if err != nil {
t.Fatal(err)
}

for _, clientCount := range clientCountResponse {
if int(clientCounts[clientCount.NamespaceID]) != clientCount.Counts.DistinctEntities {
t.Errorf("bad entity count for namespace %s . expected %d, got %d", clientCount.NamespaceID, int(clientCounts[clientCount.NamespaceID]), clientCount.Counts.DistinctEntities)
}
totalCount := int(clientCounts[clientCount.NamespaceID])
if totalCount != clientCount.Counts.Clients {
t.Errorf("bad client count for namespace %s . expected %d, got %d", clientCount.NamespaceID, totalCount, clientCount.Counts.Clients)
}
}

totals, ok := results["total"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}
totalCounts := ResponseCounts{}
err = mapstructure.Decode(totals, &totalCounts)
distinctEntities := totalCounts.DistinctEntities
if distinctEntities != len(clients) {
t.Errorf("bad entity count. expected %d, got %d", len(clients), distinctEntities)
}

clientCount := totalCounts.Clients
if clientCount != len(clients) {
t.Errorf("bad client count. expected %d, got %d", len(clients), clientCount)
}
// Ensure that the month response is the same as the totals, because all clients
// are new clients and there will be no approximation in the single month partial
// case
monthsRaw, ok := results["months"]
if !ok {
t.Fatalf("malformed results. got %v", results)
}
monthsResponse := make([]ResponseMonth, 0)
err = mapstructure.Decode(monthsRaw, &monthsResponse)
if len(monthsResponse) != 1 {
t.Fatalf("wrong number of months returned. got %v", monthsResponse)
}
if monthsResponse[0].Counts.Clients != totalCounts.Clients {
t.Fatalf("wrong client count. got %v, expected %v", monthsResponse[0].Counts.Clients, totalCounts.Clients)
}
if monthsResponse[0].Counts.EntityClients != totalCounts.EntityClients {
t.Fatalf("wrong entity client count. got %v, expected %v", monthsResponse[0].Counts.EntityClients, totalCounts.EntityClients)
}
if monthsResponse[0].Counts.NonEntityClients != totalCounts.NonEntityClients {
t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityClients, totalCounts.NonEntityClients)
}
if monthsResponse[0].Counts.NonEntityTokens != totalCounts.NonEntityTokens {
t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityTokens, totalCounts.NonEntityTokens)
}
if monthsResponse[0].Counts.Clients != monthsResponse[0].NewClients.Counts.Clients {
t.Fatalf("wrong client count. got %v, expected %v", monthsResponse[0].Counts.Clients, monthsResponse[0].NewClients.Counts.Clients)
}
if monthsResponse[0].Counts.DistinctEntities != monthsResponse[0].NewClients.Counts.DistinctEntities {
t.Fatalf("wrong distinct entities count. got %v, expected %v", monthsResponse[0].Counts.DistinctEntities, monthsResponse[0].NewClients.Counts.DistinctEntities)
}
if monthsResponse[0].Counts.EntityClients != monthsResponse[0].NewClients.Counts.EntityClients {
t.Fatalf("wrong entity client count. got %v, expected %v", monthsResponse[0].Counts.EntityClients, monthsResponse[0].NewClients.Counts.EntityClients)
}
if monthsResponse[0].Counts.NonEntityClients != monthsResponse[0].NewClients.Counts.NonEntityClients {
t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityClients, monthsResponse[0].NewClients.Counts.NonEntityClients)
}
if monthsResponse[0].Counts.NonEntityTokens != monthsResponse[0].NewClients.Counts.NonEntityTokens {
t.Fatalf("wrong non-entity token count. got %v, expected %v", monthsResponse[0].Counts.NonEntityTokens, monthsResponse[0].NewClients.Counts.NonEntityTokens)
}

namespaceResponseMonth := monthsResponse[0].Namespaces

for _, clientCount := range namespaceResponseMonth {
if int(clientCounts[clientCount.NamespaceID]) != clientCount.Counts.EntityClients {
t.Errorf("bad entity count for namespace %s . expected %d, got %d", clientCount.NamespaceID, int(clientCounts[clientCount.NamespaceID]), clientCount.Counts.DistinctEntities)
}
totalCount := int(clientCounts[clientCount.NamespaceID])
if totalCount != clientCount.Counts.Clients {
t.Errorf("bad client count for namespace %s . expected %d, got %d", clientCount.NamespaceID, totalCount, clientCount.Counts.Clients)
}
}
}
10 changes: 9 additions & 1 deletion vault/activity_log_util_common.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ func (a *ActivityLog) StoreHyperlogLog(ctx context.Context, startTime time.Time,
}

func (a *ActivityLog) computeCurrentMonthForBillingPeriodInternal(ctx context.Context, byMonth map[int64]*processMonth, hllGetFunc HLLGetter, startTime time.Time, endTime time.Time) (*activity.MonthRecord, error) {
if timeutil.IsCurrentMonth(startTime, time.Now().UTC()) {
monthlyComputation := a.transformMonthBreakdowns(byMonth)
if len(monthlyComputation) > 1 {
a.logger.Warn("monthly in-memory activitylog computation returned multiple months of data", "months returned", len(byMonth))
}
if len(monthlyComputation) >= 0 {
return monthlyComputation[0], nil
}
}
// Fetch all hyperloglogs for months from startMonth to endMonth. If a month doesn't have an associated
// hll, warn and continue.

Expand Down Expand Up @@ -144,7 +153,6 @@ func (a *ActivityLog) computeCurrentMonthForBillingPeriodInternal(ctx context.Co
// the current month's entities minus the size of the initial billing period hll.
currentMonthNewEntities := billingPeriodHLLWithCurrentMonthEntityClients.Estimate() - billingPeriodHLL.Estimate()
currentMonthNewNonEntities := billingPeriodHLLWithCurrentMonthNonEntityClients.Estimate() - billingPeriodHLL.Estimate()

return &activity.MonthRecord{
Timestamp: timeutil.StartOfMonth(endTime).UTC().Unix(),
NewClients: &activity.NewClientRecord{Counts: &activity.CountsRecord{EntityClients: int(currentMonthNewEntities), NonEntityClients: int(currentMonthNewNonEntities)}},
Expand Down

0 comments on commit 687b3d1

Please sign in to comment.