diff --git a/vault/activity_log.go b/vault/activity_log.go index e9bc03ea1b8b..211af1c9090b 100644 --- a/vault/activity_log.go +++ b/vault/activity_log.go @@ -1504,6 +1504,12 @@ func (a *ActivityLog) DefaultStartTime(endTime time.Time) time.Time { func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.Time, limitNamespaces int) (map[string]interface{}, error) { var computePartial bool + + // Change the start time to the beginning of the month, and the end time to be the end + // of the month. + startTime = timeutil.StartOfMonth(startTime) + endTime = timeutil.EndOfMonth(endTime) + // If the endTime of the query is the current month, request data from the queryStore // with the endTime equal to the end of the last month, and add in the current month // data. @@ -1513,16 +1519,26 @@ func (a *ActivityLog) handleQuery(ctx context.Context, startTime, endTime time.T computePartial = true } - // From the precomputed queries stored in the queryStore (computed at the end of each month) - // get the query associated with the start and end time specified - pq, err := a.queryStore.Get(ctx, startTime, precomputedQueryEndTime) - if err != nil { - return nil, err - } - if pq == nil { - return nil, nil + pq := &activity.PrecomputedQuery{} + if startTime.After(precomputedQueryEndTime) && timeutil.IsCurrentMonth(startTime, time.Now().UTC()) { + // We're only calculating the partial month client count. Skip the precomputation + // get call. + pq = &activity.PrecomputedQuery{ + StartTime: startTime, + EndTime: endTime, + Namespaces: make([]*activity.NamespaceRecord, 0), + Months: make([]*activity.MonthRecord, 0), + } + } else { + storedQuery, err := a.queryStore.Get(ctx, startTime, precomputedQueryEndTime) + if err != nil { + return nil, err + } + if storedQuery == nil { + return nil, nil + } + pq = storedQuery } - // Calculate the namespace response breakdowns and totals for entities and tokens from the initial // namespace data. totalEntities, totalTokens, byNamespaceResponse, err := a.calculateByNamespaceResponseForQuery(ctx, pq.Namespaces) @@ -1634,7 +1650,7 @@ func modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Tim if err != nil { return months } - for start.Before(firstMonth) { + for start.Before(firstMonth) && !timeutil.IsCurrentMonth(start, firstMonth) { monthPlaceholder := &ResponseMonth{Timestamp: start.UTC().Format(time.RFC3339)} modifiedResponseMonths = append(modifiedResponseMonths, monthPlaceholder) start = timeutil.StartOfMonth(start.AddDate(0, 1, 0)) @@ -1645,7 +1661,7 @@ func modifyResponseMonths(months []*ResponseMonth, start time.Time, end time.Tim return modifiedResponseMonths } lastMonth := timeutil.EndOfMonth(lastMonthStart) - for lastMonth.Before(end) { + for lastMonth.Before(end) && !timeutil.IsCurrentMonth(end, lastMonth) { lastMonth = timeutil.StartOfMonth(lastMonth).AddDate(0, 1, 0) monthPlaceholder := &ResponseMonth{Timestamp: lastMonth.UTC().Format(time.RFC3339)} modifiedResponseMonths = append(modifiedResponseMonths, monthPlaceholder) diff --git a/vault/activity_log_test.go b/vault/activity_log_test.go index ca434fbb205c..dd25bacf0f89 100644 --- a/vault/activity_log_test.go +++ b/vault/activity_log_test.go @@ -3826,3 +3826,129 @@ func TestActivityLog_partialMonthClientCount(t *testing.T) { t.Errorf("bad client count. expected %d, got %d", len(clients), clientCount) } } + +func TestActivityLog_partialMonthClientCountUsingHandleQuery(t *testing.T) { + timeutil.SkipAtEndOfMonth(t) + + ctx := namespace.RootContext(nil) + now := time.Now().UTC() + a, clients, _ := setupActivityRecordsInStorage(t, timeutil.StartOfMonth(now), true, true) + + // clients[0] belongs to previous month + clients = clients[1:] + + clientCounts := make(map[string]uint64) + for _, client := range clients { + clientCounts[client.NamespaceID] += 1 + } + + a.SetEnable(true) + var wg sync.WaitGroup + err := a.refreshFromStoredLog(ctx, &wg, now) + if err != nil { + t.Fatalf("error loading clients: %v", err) + } + wg.Wait() + + results, err := a.handleQuery(ctx, time.Now().UTC(), time.Now().UTC(), 0) + if err != nil { + t.Fatal(err) + } + if results == nil { + t.Fatal("no results to test") + } + if err != nil { + t.Fatal(err) + } + if results == nil { + t.Fatal("no results to test") + } + + byNamespace, ok := results["by_namespace"] + if !ok { + t.Fatalf("malformed results. got %v", results) + } + + clientCountResponse := make([]*ResponseNamespace, 0) + err = mapstructure.Decode(byNamespace, &clientCountResponse) + if err != nil { + t.Fatal(err) + } + + for _, clientCount := range clientCountResponse { + if int(clientCounts[clientCount.NamespaceID]) != clientCount.Counts.DistinctEntities { + t.Errorf("bad entity count for namespace %s . expected %d, got %d", clientCount.NamespaceID, int(clientCounts[clientCount.NamespaceID]), clientCount.Counts.DistinctEntities) + } + totalCount := int(clientCounts[clientCount.NamespaceID]) + if totalCount != clientCount.Counts.Clients { + t.Errorf("bad client count for namespace %s . expected %d, got %d", clientCount.NamespaceID, totalCount, clientCount.Counts.Clients) + } + } + + totals, ok := results["total"] + if !ok { + t.Fatalf("malformed results. got %v", results) + } + totalCounts := ResponseCounts{} + err = mapstructure.Decode(totals, &totalCounts) + distinctEntities := totalCounts.DistinctEntities + if distinctEntities != len(clients) { + t.Errorf("bad entity count. expected %d, got %d", len(clients), distinctEntities) + } + + clientCount := totalCounts.Clients + if clientCount != len(clients) { + t.Errorf("bad client count. expected %d, got %d", len(clients), clientCount) + } + // Ensure that the month response is the same as the totals, because all clients + // are new clients and there will be no approximation in the single month partial + // case + monthsRaw, ok := results["months"] + if !ok { + t.Fatalf("malformed results. got %v", results) + } + monthsResponse := make([]ResponseMonth, 0) + err = mapstructure.Decode(monthsRaw, &monthsResponse) + if len(monthsResponse) != 1 { + t.Fatalf("wrong number of months returned. got %v", monthsResponse) + } + if monthsResponse[0].Counts.Clients != totalCounts.Clients { + t.Fatalf("wrong client count. got %v, expected %v", monthsResponse[0].Counts.Clients, totalCounts.Clients) + } + if monthsResponse[0].Counts.EntityClients != totalCounts.EntityClients { + t.Fatalf("wrong entity client count. got %v, expected %v", monthsResponse[0].Counts.EntityClients, totalCounts.EntityClients) + } + if monthsResponse[0].Counts.NonEntityClients != totalCounts.NonEntityClients { + t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityClients, totalCounts.NonEntityClients) + } + if monthsResponse[0].Counts.NonEntityTokens != totalCounts.NonEntityTokens { + t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityTokens, totalCounts.NonEntityTokens) + } + if monthsResponse[0].Counts.Clients != monthsResponse[0].NewClients.Counts.Clients { + t.Fatalf("wrong client count. got %v, expected %v", monthsResponse[0].Counts.Clients, monthsResponse[0].NewClients.Counts.Clients) + } + if monthsResponse[0].Counts.DistinctEntities != monthsResponse[0].NewClients.Counts.DistinctEntities { + t.Fatalf("wrong distinct entities count. got %v, expected %v", monthsResponse[0].Counts.DistinctEntities, monthsResponse[0].NewClients.Counts.DistinctEntities) + } + if monthsResponse[0].Counts.EntityClients != monthsResponse[0].NewClients.Counts.EntityClients { + t.Fatalf("wrong entity client count. got %v, expected %v", monthsResponse[0].Counts.EntityClients, monthsResponse[0].NewClients.Counts.EntityClients) + } + if monthsResponse[0].Counts.NonEntityClients != monthsResponse[0].NewClients.Counts.NonEntityClients { + t.Fatalf("wrong non-entity client count. got %v, expected %v", monthsResponse[0].Counts.NonEntityClients, monthsResponse[0].NewClients.Counts.NonEntityClients) + } + if monthsResponse[0].Counts.NonEntityTokens != monthsResponse[0].NewClients.Counts.NonEntityTokens { + t.Fatalf("wrong non-entity token count. got %v, expected %v", monthsResponse[0].Counts.NonEntityTokens, monthsResponse[0].NewClients.Counts.NonEntityTokens) + } + + namespaceResponseMonth := monthsResponse[0].Namespaces + + for _, clientCount := range namespaceResponseMonth { + if int(clientCounts[clientCount.NamespaceID]) != clientCount.Counts.EntityClients { + t.Errorf("bad entity count for namespace %s . expected %d, got %d", clientCount.NamespaceID, int(clientCounts[clientCount.NamespaceID]), clientCount.Counts.DistinctEntities) + } + totalCount := int(clientCounts[clientCount.NamespaceID]) + if totalCount != clientCount.Counts.Clients { + t.Errorf("bad client count for namespace %s . expected %d, got %d", clientCount.NamespaceID, totalCount, clientCount.Counts.Clients) + } + } +} diff --git a/vault/activity_log_util_common.go b/vault/activity_log_util_common.go index 65b75fd8abd5..db45a8ee54b9 100644 --- a/vault/activity_log_util_common.go +++ b/vault/activity_log_util_common.go @@ -69,6 +69,15 @@ func (a *ActivityLog) StoreHyperlogLog(ctx context.Context, startTime time.Time, } func (a *ActivityLog) computeCurrentMonthForBillingPeriodInternal(ctx context.Context, byMonth map[int64]*processMonth, hllGetFunc HLLGetter, startTime time.Time, endTime time.Time) (*activity.MonthRecord, error) { + if timeutil.IsCurrentMonth(startTime, time.Now().UTC()) { + monthlyComputation := a.transformMonthBreakdowns(byMonth) + if len(monthlyComputation) > 1 { + a.logger.Warn("monthly in-memory activitylog computation returned multiple months of data", "months returned", len(byMonth)) + } + if len(monthlyComputation) >= 0 { + return monthlyComputation[0], nil + } + } // Fetch all hyperloglogs for months from startMonth to endMonth. If a month doesn't have an associated // hll, warn and continue. @@ -144,7 +153,6 @@ func (a *ActivityLog) computeCurrentMonthForBillingPeriodInternal(ctx context.Co // the current month's entities minus the size of the initial billing period hll. currentMonthNewEntities := billingPeriodHLLWithCurrentMonthEntityClients.Estimate() - billingPeriodHLL.Estimate() currentMonthNewNonEntities := billingPeriodHLLWithCurrentMonthNonEntityClients.Estimate() - billingPeriodHLL.Estimate() - return &activity.MonthRecord{ Timestamp: timeutil.StartOfMonth(endTime).UTC().Unix(), NewClients: &activity.NewClientRecord{Counts: &activity.CountsRecord{EntityClients: int(currentMonthNewEntities), NonEntityClients: int(currentMonthNewNonEntities)}},