diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index ee174e9b9b5b9..6588db91409de 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -1310,7 +1310,7 @@ func (s *APIServer) searchSessionEvents(auth ClientI, w http.ResponseWriter, r * } } // only pull back start and end events to build list of completed sessions - eventsList, _, err := auth.SearchSessionEvents(from, to, limit, types.EventOrderDescending, "", nil) + eventsList, _, err := auth.SearchSessionEvents(from, to, limit, types.EventOrderDescending, "", nil, "") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index bb5bab0135234..f8efcb6e4af06 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -2955,7 +2955,7 @@ func (a *ServerWithRoles) findSessionEndEvent(namespace string, sid session.ID) &types.WhereExpr{Equals: types.WhereExpr2{ L: &types.WhereExpr{Field: events.SessionEventID}, R: &types.WhereExpr{Literal: sid.String()}, - }}, + }}, sid.String(), ) if err != nil { return nil, trace.Wrap(err) @@ -4149,7 +4149,7 @@ func (a *ServerWithRoles) SearchEvents(fromUTC, toUTC time.Time, namespace strin } // SearchSessionEvents allows searching session audit events with pagination support. -func (a *ServerWithRoles) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) (events []apievents.AuditEvent, lastKey string, err error) { +func (a *ServerWithRoles) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) (events []apievents.AuditEvent, lastKey string, err error) { if cond != nil { return nil, "", trace.BadParameter("cond is an internal parameter, should not be set by client") } @@ -4160,7 +4160,7 @@ func (a *ServerWithRoles) SearchSessionEvents(fromUTC, toUTC time.Time, limit in } // TODO(codingllama): Refactor cond out of SearchSessionEvents and simplify signature. - events, lastKey, err = a.alog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond) + events, lastKey, err = a.alog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) if err != nil { return nil, "", trace.Wrap(err) } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 98400993c4d65..c8635971e0492 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -1145,7 +1145,7 @@ func (c *Client) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventT } // SearchSessionEvents returns session related events to find completed sessions. -func (c *Client) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (c *Client) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { events, lastKey, err := c.APIClient.SearchSessionEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey) if err != nil { return nil, "", trace.Wrap(err) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index d25c9a257ebd8..d1786b6d082c4 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -3208,7 +3208,7 @@ func (g *GRPCServer) GetSessionEvents(ctx context.Context, req *proto.GetSession return nil, trace.Wrap(err) } - rawEvents, lastkey, err := auth.ServerWithRoles.SearchSessionEvents(req.StartDate, req.EndDate, int(req.Limit), types.EventOrder(req.Order), req.StartKey, nil) + rawEvents, lastkey, err := auth.ServerWithRoles.SearchSessionEvents(req.StartDate, req.EndDate, int(req.Limit), types.EventOrder(req.Order), req.StartKey, nil, "") if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/events/api.go b/lib/events/api.go index af48a37428adf..2ee156087ab68 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -733,7 +733,7 @@ type IAuditLog interface { // a query to be resumed. // // This function may never return more than 1 MiB of event data. - SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) + SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) // StreamSessionEvents streams all events from a given session recording. An error is returned on the first // channel if one is encountered. Otherwise the event channel is closed when the stream ends. diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 32d66b0787698..d673c375bdf20 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -31,6 +31,11 @@ import ( "sync" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" @@ -38,10 +43,6 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" ) const ( @@ -870,12 +871,12 @@ func (l *AuditLog) SearchEvents(fromUTC, toUTC time.Time, namespace string, even return l.localLog.SearchEvents(fromUTC, toUTC, namespace, eventType, limit, order, startKey) } -func (l *AuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (l *AuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { l.log.Debugf("SearchSessionEvents(%v, %v, %v)", fromUTC, toUTC, limit) if l.ExternalLog != nil { - return l.ExternalLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond) + return l.ExternalLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) } - return l.localLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond) + return l.localLog.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) } // StreamSessionEvents streams all events from a given session recording. An error is returned on the first diff --git a/lib/events/discard.go b/lib/events/discard.go index 3d2ff5081ace6..3f9d25ce02be8 100644 --- a/lib/events/discard.go +++ b/lib/events/discard.go @@ -47,7 +47,7 @@ func (d *DiscardAuditLog) GetSessionEvents(namespace string, sid session.ID, aft func (d *DiscardAuditLog) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventType []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { return make([]apievents.AuditEvent, 0), "", nil } -func (d *DiscardAuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (d *DiscardAuditLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { return make([]apievents.AuditEvent, 0), "", nil } func (d *DiscardAuditLog) EmitAuditEvent(ctx context.Context, event apievents.AuditEvent) error { diff --git a/lib/events/dynamoevents/dynamoevents.go b/lib/events/dynamoevents/dynamoevents.go index fd6241b34b1c4..c7ff88a8a2102 100644 --- a/lib/events/dynamoevents/dynamoevents.go +++ b/lib/events/dynamoevents/dynamoevents.go @@ -549,11 +549,11 @@ type checkpointKey struct { // // This function may never return more than 1 MiB of event data. func (l *Log) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { - return l.searchEventsWithFilter(fromUTC, toUTC, namespace, limit, order, startKey, searchEventsFilter{eventTypes: eventTypes}) + return l.searchEventsWithFilter(fromUTC, toUTC, namespace, limit, order, startKey, searchEventsFilter{eventTypes: eventTypes}, "") } -func (l *Log) searchEventsWithFilter(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter) ([]apievents.AuditEvent, string, error) { - rawEvents, lastKey, err := l.searchEventsRaw(fromUTC, toUTC, namespace, limit, order, startKey, filter) +func (l *Log) searchEventsWithFilter(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string) ([]apievents.AuditEvent, string, error) { + rawEvents, lastKey, err := l.searchEventsRaw(fromUTC, toUTC, namespace, limit, order, startKey, filter, sessionID) if err != nil { return nil, "", trace.Wrap(err) } @@ -629,25 +629,19 @@ func reverseStrings(slice []string) []string { // searchEventsRaw is a low level function for searching for events. This is kept // separate from the SearchEvents function in order to allow tests to grab more metadata. -func (l *Log) searchEventsRaw(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter) ([]event, string, error) { - var checkpoint checkpointKey - - // If a checkpoint key is provided, unmarshal it so we can work with it's parts. - if startKey != "" { - if err := json.Unmarshal([]byte(startKey), &checkpoint); err != nil { - return nil, "", trace.Wrap(err) - } +func (l *Log) searchEventsRaw(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string) ([]event, string, error) { + checkpoint, err := getCheckpointFromStartKey(startKey) + if err != nil { + return nil, "", trace.Wrap(err) } - var values []event totalSize := 0 dates := daysBetween(fromUTC, toUTC) if order == types.EventOrderDescending { dates = reverseStrings(dates) } - query := "CreatedAtDate = :date AND CreatedAt BETWEEN :start and :end" - g := l.WithFields(log.Fields{"From": fromUTC, "To": toUTC, "Namespace": namespace, "Filter": filter, "Limit": limit, "StartKey": startKey, "Order": order}) + indexName := aws.String(indexTimeSearchV2) var left int64 if limit != 0 { left = int64(limit) @@ -655,19 +649,6 @@ func (l *Log) searchEventsRaw(fromUTC, toUTC time.Time, namespace string, limit left = math.MaxInt64 } - var filterConds []string - if len(filter.eventTypes) > 0 { - typeList := eventFilterList(len(filter.eventTypes)) - filterConds = append(filterConds, fmt.Sprintf("EventType IN %s", typeList)) - } - if filter.condExpr != "" { - filterConds = append(filterConds, filter.condExpr) - } - var filterExpr *string - if len(filterConds) > 0 { - filterExpr = aws.String(strings.Join(filterConds, " AND ")) - } - // Resume scanning at the correct date. We need to do this because we send individual queries per date // and you can't resume a query with the wrong iterator checkpoint. // @@ -679,7 +660,6 @@ func (l *Log) searchEventsRaw(fromUTC, toUTC time.Time, namespace string, limit } } - hasLeft := false foundStart := checkpoint.EventKey == "" var forward bool @@ -692,117 +672,49 @@ func (l *Log) searchEventsRaw(fromUTC, toUTC time.Time, namespace string, limit return nil, "", trace.BadParameter("invalid event order: %v", order) } - var attributeNames map[string]*string - if len(filter.condParams.attrNames) > 0 { - attributeNames = aws.StringMap(filter.condParams.attrNames) + logger := l.WithFields(log.Fields{ + "From": fromUTC, + "To": toUTC, + "Namespace": namespace, + "Filter": filter, + "Limit": limit, + "StartKey": startKey, + "Order": order, + }) + + ef := eventsFetcher{ + log: logger, + totalSize: totalSize, + checkpoint: &checkpoint, + foundStart: foundStart, + dates: dates, + left: left, + fromUTC: fromUTC, + toUTC: toUTC, + tableName: l.Tablename, + api: l.svc, + forward: forward, + indexName: indexName, + filter: filter, } - // This is the main query loop, here we send individual queries for each date and - // we stop if we hit `limit` or process all dates, whichever comes first. -dateLoop: - for i, date := range dates { - checkpoint.Date = date + filterExpr := getExprFilter(filter) - attributes := map[string]interface{}{ - ":date": date, - ":start": fromUTC.Unix(), - ":end": toUTC.Unix(), - } - - for i, eventType := range filter.eventTypes { - attributes[fmt.Sprintf(":eventType%d", i)] = eventType - } - for k, v := range filter.condParams.attrValues { - attributes[k] = v - } - - attributeValues, err := dynamodbattribute.MarshalMap(attributes) + var values []event + if fromUTC.IsZero() && sessionID != "" { + values, err = ef.QueryBySessionIDIndex(sessionID, filterExpr) if err != nil { return nil, "", trace.Wrap(err) } - - for { - input := dynamodb.QueryInput{ - KeyConditionExpression: aws.String(query), - TableName: aws.String(l.Tablename), - ExpressionAttributeNames: attributeNames, - ExpressionAttributeValues: attributeValues, - IndexName: aws.String(indexTimeSearchV2), - ExclusiveStartKey: checkpoint.Iterator, - Limit: aws.Int64(left), - FilterExpression: filterExpr, - ScanIndexForward: aws.Bool(forward), - } - - start := time.Now() - out, err := l.svc.Query(&input) - if err != nil { - return nil, "", trace.Wrap(err) - } - g.WithFields(log.Fields{"duration": time.Since(start), "items": len(out.Items), "forward": forward, "iterator": checkpoint.Iterator}).Debugf("Query completed.") - oldIterator := checkpoint.Iterator - checkpoint.Iterator = out.LastEvaluatedKey - - for _, item := range out.Items { - var e event - if err := dynamodbattribute.UnmarshalMap(item, &e); err != nil { - return nil, "", trace.WrapWithMessage(err, "failed to unmarshal event") - } - data, err := json.Marshal(e.FieldsMap) - if err != nil { - return nil, "", trace.Wrap(err) - } - - if !foundStart { - key, err := getSubPageCheckpoint(&e) - if err != nil { - return nil, "", trace.Wrap(err) - } - - if key != checkpoint.EventKey { - continue - } - - foundStart = true - } - - // Because this may break on non page boundaries an additional - // checkpoint is needed for sub-page breaks. - if totalSize+len(data) >= events.MaxEventBytesInResponse { - hasLeft = i+1 != len(dates) || len(checkpoint.Iterator) != 0 - - key, err := getSubPageCheckpoint(&e) - if err != nil { - return nil, "", trace.Wrap(err) - } - checkpoint.EventKey = key - - // We need to reset the iterator so we get the previous page again. - checkpoint.Iterator = oldIterator - break dateLoop - } - - totalSize += len(data) - values = append(values, e) - left-- - - if left == 0 { - hasLeft = i+1 != len(dates) || len(checkpoint.Iterator) != 0 - checkpoint.EventKey = "" - break dateLoop - } - } - - if len(checkpoint.Iterator) == 0 { - continue dateLoop - } + } else { + values, err = ef.QueryByDateIndex(filterExpr) + if err != nil { + return nil, "", trace.Wrap(err) } } var lastKey []byte - var err error - - if hasLeft { + if ef.hasLeft { lastKey, err = json.Marshal(&checkpoint) if err != nil { return nil, "", trace.Wrap(err) @@ -812,6 +724,34 @@ dateLoop: return values, string(lastKey), nil } +func getCheckpointFromStartKey(startKey string) (checkpointKey, error) { + var checkpoint checkpointKey + if startKey == "" { + return checkpoint, nil + } + // If a checkpoint key is provided, unmarshal it so we can work with it's parts. + if err := json.Unmarshal([]byte(startKey), &checkpoint); err != nil { + return checkpoint, trace.Wrap(err) + } + return checkpoint, nil +} + +func getExprFilter(filter searchEventsFilter) *string { + var filterConds []string + if len(filter.eventTypes) > 0 { + typeList := eventFilterList(len(filter.eventTypes)) + filterConds = append(filterConds, fmt.Sprintf("EventType IN %s", typeList)) + } + if filter.condExpr != "" { + filterConds = append(filterConds, filter.condExpr) + } + var filterExpr *string + if len(filterConds) > 0 { + filterExpr = aws.String(strings.Join(filterConds, " AND ")) + } + return filterExpr +} + func getSubPageCheckpoint(e *event) (string, error) { data, err := utils.FastMarshal(e) if err != nil { @@ -824,7 +764,7 @@ func getSubPageCheckpoint(e *event) (string, error) { // SearchSessionEvents returns session related events only. This is used to // find completed session. -func (l *Log) SearchSessionEvents(fromUTC time.Time, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { filter := searchEventsFilter{eventTypes: []string{events.SessionEndEvent, events.WindowsDesktopSessionEndEvent}} if cond != nil { params := condFilterParams{attrValues: make(map[string]interface{}), attrNames: make(map[string]string)} @@ -835,7 +775,7 @@ func (l *Log) SearchSessionEvents(fromUTC time.Time, toUTC time.Time, limit int, filter.condExpr = expr filter.condParams = params } - return l.searchEventsWithFilter(fromUTC, toUTC, apidefaults.Namespace, limit, order, startKey, filter) + return l.searchEventsWithFilter(fromUTC, toUTC, apidefaults.Namespace, limit, order, startKey, filter, sessionID) } type searchEventsFilter struct { @@ -1150,3 +1090,210 @@ func (l *Log) StreamSessionEvents(ctx context.Context, sessionID session.ID, sta e <- trace.NotImplemented("not implemented") return c, e } + +type query interface { + Query(input *dynamodb.QueryInput) (*dynamodb.QueryOutput, error) +} + +type eventsFetcher struct { + log *log.Entry + api query + + totalSize int + hasLeft bool + checkpoint *checkpointKey + foundStart bool + dates []string + left int64 + + fromUTC time.Time + toUTC time.Time + tableName string + forward bool + indexName *string + filter searchEventsFilter +} + +func (l *eventsFetcher) processQueryOutput(output *dynamodb.QueryOutput, hasLeftFun func() bool) ([]event, bool, error) { + var out []event + oldIterator := l.checkpoint.Iterator + l.checkpoint.Iterator = output.LastEvaluatedKey + + for _, item := range output.Items { + var e event + if err := dynamodbattribute.UnmarshalMap(item, &e); err != nil { + return nil, false, trace.WrapWithMessage(err, "failed to unmarshal event") + } + data, err := json.Marshal(e.FieldsMap) + if err != nil { + return nil, false, trace.Wrap(err) + } + if !l.foundStart { + key, err := getSubPageCheckpoint(&e) + if err != nil { + return nil, false, trace.Wrap(err) + } + + if key != l.checkpoint.EventKey { + continue + } + l.foundStart = true + } + // Because this may break on non page boundaries an additional + // checkpoint is needed for sub-page breaks. + if l.totalSize+len(data) >= events.MaxEventBytesInResponse { + hf := false + if hasLeftFun != nil { + hf = hasLeftFun() + } + l.hasLeft = hf || len(l.checkpoint.Iterator) != 0 + + key, err := getSubPageCheckpoint(&e) + if err != nil { + return nil, false, trace.Wrap(err) + } + l.checkpoint.EventKey = key + + // We need to reset the iterator so we get the previous page again. + l.checkpoint.Iterator = oldIterator + return out, true, nil + } + l.totalSize += len(data) + out = append(out, e) + l.left-- + + if l.left == 0 { + hf := false + if hasLeftFun != nil { + hf = hasLeftFun() + } + l.hasLeft = hf || len(l.checkpoint.Iterator) != 0 + l.checkpoint.EventKey = "" + return out, true, nil + } + } + return out, false, nil +} + +func (l *eventsFetcher) QueryByDateIndex(filterExpr *string) (values []event, err error) { + query := "CreatedAtDate = :date AND CreatedAt BETWEEN :start and :end" + var attributeNames map[string]*string + if len(l.filter.condParams.attrNames) > 0 { + attributeNames = aws.StringMap(l.filter.condParams.attrNames) + } + +dateLoop: + for i, date := range l.dates { + l.checkpoint.Date = date + + attributes := map[string]interface{}{ + ":date": date, + ":start": l.fromUTC.Unix(), + ":end": l.toUTC.Unix(), + } + for i, eventType := range l.filter.eventTypes { + attributes[fmt.Sprintf(":eventType%d", i)] = eventType + } + for k, v := range l.filter.condParams.attrValues { + attributes[k] = v + } + attributeValues, err := dynamodbattribute.MarshalMap(attributes) + if err != nil { + return nil, trace.Wrap(err) + } + for { + input := dynamodb.QueryInput{ + KeyConditionExpression: aws.String(query), + TableName: aws.String(l.tableName), + ExpressionAttributeNames: attributeNames, + ExpressionAttributeValues: attributeValues, + IndexName: aws.String(indexTimeSearchV2), + ExclusiveStartKey: l.checkpoint.Iterator, + Limit: aws.Int64(l.left), + FilterExpression: filterExpr, + ScanIndexForward: aws.Bool(l.forward), + } + start := time.Now() + out, err := l.api.Query(&input) + if err != nil { + return nil, trace.Wrap(err) + } + l.log.WithFields(log.Fields{ + "duration": time.Since(start), + "items": len(out.Items), + "forward": l.forward, + "iterator": l.checkpoint.Iterator, + }).Debugf("Query completed.") + + hasLeft := func() bool { + return i+1 != len(l.dates) + } + result, limitReached, err := l.processQueryOutput(out, hasLeft) + if err != nil { + return nil, trace.Wrap(err) + } + values = append(values, result...) + if limitReached { + return values, nil + } + if len(l.checkpoint.Iterator) == 0 { + continue dateLoop + } + } + } + return values, nil +} + +func (l *eventsFetcher) QueryBySessionIDIndex(sessionID string, filterExpr *string) (values []event, err error) { + query := "SessionID = :id" + var attributeNames map[string]*string + if len(l.filter.condParams.attrNames) > 0 { + attributeNames = aws.StringMap(l.filter.condParams.attrNames) + } + + attributes := map[string]interface{}{ + ":id": sessionID, + } + for i, eventType := range l.filter.eventTypes { + attributes[fmt.Sprintf(":eventType%d", i)] = eventType + } + for k, v := range l.filter.condParams.attrValues { + attributes[k] = v + } + attributeValues, err := dynamodbattribute.MarshalMap(attributes) + if err != nil { + return nil, trace.Wrap(err) + } + input := dynamodb.QueryInput{ + KeyConditionExpression: aws.String(query), + TableName: aws.String(l.tableName), + ExpressionAttributeNames: attributeNames, + ExpressionAttributeValues: attributeValues, + IndexName: nil, // Use primary SessionID index. + ExclusiveStartKey: l.checkpoint.Iterator, + Limit: aws.Int64(l.left), + FilterExpression: filterExpr, + ScanIndexForward: aws.Bool(l.forward), + } + start := time.Now() + out, err := l.api.Query(&input) + if err != nil { + return nil, trace.Wrap(err) + } + l.log.WithFields(log.Fields{ + "duration": time.Since(start), + "items": len(out.Items), + "forward": l.forward, + "iterator": l.checkpoint.Iterator, + }).Debugf("Query completed.") + + result, limitReached, err := l.processQueryOutput(out, nil) + if err != nil { + return nil, trace.Wrap(err) + } + values = append(values, result...) + if limitReached { + return values, nil + } + return values, nil +} diff --git a/lib/events/dynamoevents/dynamoevents_test.go b/lib/events/dynamoevents/dynamoevents_test.go index c61db1f3ca4f5..2d803df0e843b 100644 --- a/lib/events/dynamoevents/dynamoevents_test.go +++ b/lib/events/dynamoevents/dynamoevents_test.go @@ -280,7 +280,11 @@ func (s *DynamoeventsLargeTableSuite) TestEmitAuditEventForLargeEvents(c *check. Path: strings.Repeat("A", maxItemSize), } err = s.Log.EmitAuditEvent(ctx, appReqEvent) - c.Check(trace.Unwrap(err), check.FitsTypeOf, errAWSValidation) + c.Assert(err, check.NotNil) +} + +func (s *DynamoeventsSuite) TestSearchSessionEvensBySessionID(c *check.C) { + s.SearchSessionEvensBySessionID(c) } func TestConfig_SetFromURL(t *testing.T) { diff --git a/lib/events/filelog.go b/lib/events/filelog.go index 78b4766c7b2df..f6ebaef77c9e6 100644 --- a/lib/events/filelog.go +++ b/lib/events/filelog.go @@ -362,7 +362,7 @@ func getCheckpointFromEvent(event apievents.AuditEvent) (string, error) { return event.GetID(), nil } -func (l *FileLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (l *FileLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { l.Debugf("SearchSessionEvents(%v, %v, order=%v, limit=%v, cond=%q)", fromUTC, toUTC, order, limit, cond) filter := searchEventsFilter{eventTypes: []string{SessionEndEvent, WindowsDesktopSessionEndEvent}} if cond != nil { diff --git a/lib/events/filelog_test.go b/lib/events/filelog_test.go index 3e3fd5aeb9a3b..5e3915bd82022 100644 --- a/lib/events/filelog_test.go +++ b/lib/events/filelog_test.go @@ -111,12 +111,7 @@ func TestSearchSessionEvents(t *testing.T) { })) clock.Advance(1 * time.Minute) - result, _, err := log.SearchSessionEvents(start, clock.Now(), - 10, // limit - types.EventOrderAscending, - "", // startKey - nil, // cond - ) + result, _, err := log.SearchSessionEvents(start, clock.Now(), 10, types.EventOrderAscending, "", nil, "") require.NoError(t, err) require.Len(t, result, 1) require.Equal(t, result[0].GetType(), SessionEndEvent) @@ -132,12 +127,7 @@ func TestSearchSessionEvents(t *testing.T) { })) clock.Advance(1 * time.Minute) - result, _, err = log.SearchSessionEvents(start, clock.Now(), - 10, // limit - types.EventOrderAscending, - "", // startKey - nil, // cond - ) + result, _, err = log.SearchSessionEvents(start, clock.Now(), 10, types.EventOrderAscending, "", nil, "") require.NoError(t, err) require.Len(t, result, 1) require.Equal(t, result[0].GetType(), SessionEndEvent) @@ -153,12 +143,7 @@ func TestSearchSessionEvents(t *testing.T) { })) clock.Advance(1 * time.Minute) - result, _, err = log.SearchSessionEvents(start, clock.Now(), - 10, // limit - types.EventOrderAscending, - "", // startKey - nil, // cond - ) + result, _, err = log.SearchSessionEvents(start, clock.Now(), 10, types.EventOrderAscending, "", nil, "") require.NoError(t, err) require.Len(t, result, 2) require.Equal(t, result[0].GetType(), SessionEndEvent) diff --git a/lib/events/firestoreevents/firestoreevents.go b/lib/events/firestoreevents/firestoreevents.go index bc88f9430e344..e9e0133ef4161 100644 --- a/lib/events/firestoreevents/firestoreevents.go +++ b/lib/events/firestoreevents/firestoreevents.go @@ -23,30 +23,26 @@ import ( "strconv" "time" - "google.golang.org/genproto/googleapis/firestore/admin/v1" - - "github.com/gravitational/teleport/api/types" - apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/backend" + "cloud.google.com/go/firestore" + apiv1 "cloud.google.com/go/firestore/apiv1/admin" + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "google.golang.org/genproto/googleapis/firestore/admin/v1" "github.com/gravitational/teleport" apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" apiutils "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/lib/backend" firestorebk "github.com/gravitational/teleport/lib/backend/firestore" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" - - "cloud.google.com/go/firestore" - - apiv1 "cloud.google.com/go/firestore/apiv1/admin" - - "github.com/google/uuid" - "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" ) var ( @@ -397,17 +393,17 @@ func (l *Log) GetSessionEvents(namespace string, sid session.ID, after int, inlc // // This function may never return more than 1 MiB of event data. func (l *Log) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { - return l.searchEventsWithFilter(fromUTC, toUTC, namespace, limit, order, startKey, searchEventsFilter{eventTypes: eventTypes}) + return l.searchEventsWithFilter(fromUTC, toUTC, namespace, limit, order, startKey, searchEventsFilter{eventTypes: eventTypes}, "") } -func (l *Log) searchEventsWithFilter(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter) ([]apievents.AuditEvent, string, error) { +func (l *Log) searchEventsWithFilter(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string) ([]apievents.AuditEvent, string, error) { var eventsArr []apievents.AuditEvent var estimatedSize int checkpoint := startKey left := limit for { - gotEvents, withSize, withCheckpoint, err := l.searchEventsOnce(fromUTC, toUTC, namespace, left, order, checkpoint, filter, events.MaxEventBytesInResponse-estimatedSize) + gotEvents, withSize, withCheckpoint, err := l.searchEventsOnce(fromUTC, toUTC, namespace, left, order, checkpoint, filter, events.MaxEventBytesInResponse-estimatedSize, sessionID) if nil != err { return nil, "", trace.Wrap(err) } @@ -425,7 +421,7 @@ func (l *Log) searchEventsWithFilter(fromUTC, toUTC time.Time, namespace string, return eventsArr, checkpoint, nil } -func (l *Log) searchEventsOnce(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter, spaceRemaining int) ([]apievents.AuditEvent, int, string, error) { +func (l *Log) searchEventsOnce(fromUTC, toUTC time.Time, namespace string, limit int, order types.EventOrder, startKey string, filter searchEventsFilter, spaceRemaining int, sessionID string) ([]apievents.AuditEvent, int, string, error) { g := l.WithFields(log.Fields{"From": fromUTC, "To": toUTC, "Namespace": namespace, "Filter": filter, "Limit": limit, "StartKey": startKey}) var lastKey int64 @@ -469,6 +465,9 @@ func (l *Log) searchEventsOnce(fromUTC, toUTC time.Time, namespace string, limit if len(filter.eventTypes) > 0 { query = query.Where(eventTypeDocProperty, "in", filter.eventTypes) } + if sessionID != "" { + query = query.Where(sessionIDDocProperty, "==", sessionID) + } start := time.Now() docSnaps, err := query.Documents(l.svcContext).GetAll() @@ -548,7 +547,7 @@ func (l *Log) searchEventsOnce(fromUTC, toUTC time.Time, namespace string, limit // SearchSessionEvents returns session related events only. This is used to // find completed sessions. -func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) ([]apievents.AuditEvent, string, error) { +func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { filter := searchEventsFilter{eventTypes: []string{events.SessionEndEvent, events.WindowsDesktopSessionEndEvent}} if cond != nil { condFn, err := utils.ToFieldsCondition(cond) @@ -557,7 +556,7 @@ func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order typ } filter.condition = condFn } - return l.searchEventsWithFilter(fromUTC, toUTC, apidefaults.Namespace, limit, order, startKey, filter) + return l.searchEventsWithFilter(fromUTC, toUTC, apidefaults.Namespace, limit, order, startKey, filter, sessionID) } type searchEventsFilter struct { diff --git a/lib/events/firestoreevents/firestoreevents_test.go b/lib/events/firestoreevents/firestoreevents_test.go index 93992435bc0ab..44557adb42a05 100644 --- a/lib/events/firestoreevents/firestoreevents_test.go +++ b/lib/events/firestoreevents/firestoreevents_test.go @@ -24,6 +24,7 @@ import ( "github.com/jonboulle/clockwork" "gopkg.in/check.v1" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events/test" "github.com/gravitational/teleport/lib/utils" ) @@ -59,6 +60,8 @@ func (s *FirestoreeventsSuite) SetUpSuite(c *check.C) { config.Clock = fakeClock config.UIDGenerator = utils.NewFakeUID() + config.RetryPeriod = defaults.HighResPollingPeriod + config.PurgeExpiredDocumentsPollInterval = time.Second log, err := New(config) @@ -107,3 +110,7 @@ func (s *FirestoreeventsSuite) TestSessionEventsCRUD(c *check.C) { func (s *FirestoreeventsSuite) TestPagination(c *check.C) { s.EventPagination(c) } + +func (s *FirestoreeventsSuite) TestSearchSessionEvensBySessionID(c *check.C) { + s.SearchSessionEvensBySessionID(c) +} diff --git a/lib/events/multilog.go b/lib/events/multilog.go index 40f9178761a42..2acc76ac2c3b4 100644 --- a/lib/events/multilog.go +++ b/lib/events/multilog.go @@ -116,9 +116,9 @@ func (m *MultiLog) SearchEvents(fromUTC, toUTC time.Time, namespace string, even // // Event types to filter can be specified and pagination is handled by an iterator key that allows // a query to be resumed. -func (m *MultiLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) (events []apievents.AuditEvent, lastKey string, err error) { +func (m *MultiLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) (events []apievents.AuditEvent, lastKey string, err error) { for _, log := range m.loggers { - events, lastKey, err = log.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond) + events, lastKey, err = log.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) if !trace.IsNotImplemented(err) { return events, lastKey, err } diff --git a/lib/events/test/suite.go b/lib/events/test/suite.go index 52ca9bb958dcb..9fe5a16afffb4 100644 --- a/lib/events/test/suite.go +++ b/lib/events/test/suite.go @@ -20,11 +20,17 @@ package test import ( "bytes" "context" + "fmt" "io" "os" "testing" "time" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + "gopkg.in/check.v1" + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" @@ -32,10 +38,6 @@ import ( "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" - - "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/require" - "gopkg.in/check.v1" ) // UploadDownload tests uploads and downloads @@ -234,7 +236,7 @@ func (s *EventsSuite) SessionEventsCRUD(c *check.C) { c.Assert(historyEvents[0].GetString(events.EventType), check.Equals, events.SessionStartEvent) c.Assert(historyEvents[1].GetString(events.EventType), check.Equals, events.SessionEndEvent) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", nil) + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", nil, "") c.Assert(err, check.IsNil) c.Assert(history, check.HasLen, 1) @@ -245,15 +247,56 @@ func (s *EventsSuite) SessionEventsCRUD(c *check.C) { }} } - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("alice")) + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("alice"), "") c.Assert(err, check.IsNil) c.Assert(history, check.HasLen, 1) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("cecile")) + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("cecile"), "") c.Assert(err, check.IsNil) c.Assert(history, check.HasLen, 0) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(time.Hour-time.Second), 100, types.EventOrderAscending, "", nil) + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(time.Hour-time.Second), 100, types.EventOrderAscending, "", nil, "") c.Assert(err, check.IsNil) c.Assert(history, check.HasLen, 0) } + +func (s *EventsSuite) SearchSessionEvensBySessionID(c *check.C) { + now := time.Now().UTC() + firstID := uuid.New().String() + secondID := uuid.New().String() + thirdID := uuid.New().String() + for i, id := range []string{firstID, secondID, thirdID} { + event := &apievents.WindowsDesktopSessionEnd{ + Metadata: apievents.Metadata{ + ID: fmt.Sprintf("eventID%d", i), + Type: events.WindowsDesktopSessionEndEvent, + Code: events.DesktopSessionEndCode, + Time: now.Add(time.Duration(i) * time.Second), + }, + SessionMetadata: apievents.SessionMetadata{ + SessionID: id, + }, + } + err := s.Log.EmitAuditEvent(context.Background(), event) + c.Assert(err, check.IsNil) + } + from := time.Time{} + to := now.Add(10 * time.Second) + + done := make(chan struct{}) + go func() { + defer close(done) + events, _, err := s.Log.SearchSessionEvents(from, to, 1000, types.EventOrderDescending, "", nil, secondID) + c.Assert(err, check.IsNil) + c.Assert(events, check.HasLen, 1) + e, ok := events[0].(*apievents.WindowsDesktopSessionEnd) + c.Assert(ok, check.Equals, true) + c.Assert(e.GetSessionID(), check.Equals, secondID) + }() + + select { + case <-time.After(time.Second * 10): + c.Fatalf("Search event query timeout") + case <-done: + } +} diff --git a/lib/events/writer.go b/lib/events/writer.go index d910a07a4e06b..bb32d210d4f63 100644 --- a/lib/events/writer.go +++ b/lib/events/writer.go @@ -90,7 +90,7 @@ func (w *WriterLog) SearchEvents(fromUTC, toUTC time.Time, namespace string, eve // // Event types to filter can be specified and pagination is handled by an iterator key that allows // a query to be resumed. -func (w *WriterLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr) (events []apievents.AuditEvent, lastKey string, err error) { +func (w *WriterLog) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) (events []apievents.AuditEvent, lastKey string, err error) { return nil, "", trace.NotImplemented("not implemented") } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index f639e328f6021..2545b3e241d01 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -2227,7 +2227,7 @@ func (h *Handler) clusterSearchEvents(w http.ResponseWriter, r *http.Request, p // If no order is provided it defaults to descending. func (h *Handler) clusterSearchSessionEvents(w http.ResponseWriter, r *http.Request, p httprouter.Params, ctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { searchSessionEvents := func(clt auth.ClientI, from, to time.Time, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { - return clt.SearchSessionEvents(from, to, limit, order, startKey, nil) + return clt.SearchSessionEvents(from, to, limit, order, startKey, nil, "") } return clusterEventsList(ctx, site, r.URL.Query(), searchSessionEvents) }