diff --git a/lib/backend/firestore/firestorebk.go b/lib/backend/firestore/firestorebk.go index a051544f64870..3c6ee5fae0430 100644 --- a/lib/backend/firestore/firestorebk.go +++ b/lib/backend/firestore/firestorebk.go @@ -23,6 +23,7 @@ import ( "cloud.google.com/go/firestore" apiv1 "cloud.google.com/go/firestore/apiv1/admin" + "github.com/gravitational/trace/trail" "google.golang.org/api/option" adminpb "google.golang.org/genproto/googleapis/firestore/admin/v1" "google.golang.org/grpc" @@ -201,6 +202,8 @@ const ( idDocProperty = "id" // timeInBetweenIndexCreationStatusChecks timeInBetweenIndexCreationStatusChecks = time.Second * 10 + // commitLimit is the maximum number of writes per commit + commitLimit = 500 ) // GetName is a part of backend API and it returns Firestore backend type @@ -411,23 +414,12 @@ func (b *Backend) GetRange(ctx context.Context, startKey []byte, endKey []byte, // DeleteRange deletes range of items with keys between startKey and endKey func (b *Backend) DeleteRange(ctx context.Context, startKey, endKey []byte) error { - docSnaps, err := b.getRangeDocs(ctx, startKey, endKey, backend.DefaultRangeLimit) + docs, err := b.getRangeDocs(ctx, startKey, endKey, backend.DefaultRangeLimit) if err != nil { return trace.Wrap(err) } - if len(docSnaps) == 0 { - // Nothing to delete. - return nil - } - batch := b.svc.Batch() - for _, docSnap := range docSnaps { - batch.Delete(docSnap.Ref) - } - _, err = batch.Commit(ctx) - if err != nil { - return ConvertGRPCError(err) - } - return nil + + return trace.Wrap(b.deleteDocuments(docs)) } // Get returns a single item or not found error @@ -706,23 +698,37 @@ func (b *Backend) purgeExpiredDocuments() error { return b.clientContext.Err() case <-t.C: expiryTime := b.clock.Now().UTC().Unix() - numDeleted := 0 - batch := b.svc.Batch() - docs, _ := b.svc.Collection(b.CollectionName).Where(expiresDocProperty, "<=", expiryTime).Documents(b.clientContext).GetAll() - for _, doc := range docs { - batch.Delete(doc.Ref) - numDeleted++ + docs, err := b.svc.Collection(b.CollectionName).Where(expiresDocProperty, "<=", expiryTime).Documents(b.clientContext).GetAll() + if err != nil { + b.Logger.WithError(trail.FromGRPC(err)).Warn("Failed to get expired documents") + continue } - if numDeleted > 0 { - _, err := batch.Commit(b.clientContext) - if err != nil { - return ConvertGRPCError(err) - } + + if err := b.deleteDocuments(docs); err != nil { + return trace.Wrap(err) } } } } +// deleteDocuments removes documents from firestore in batches to stay within the +// firestore write limits +func (b *Backend) deleteDocuments(docs []*firestore.DocumentSnapshot) error { + for i := 0; i < len(docs); i += commitLimit { + batch := b.svc.Batch() + + for j := 0; j < commitLimit && i+j < len(docs); j++ { + batch.Delete(docs[i+j].Ref) + } + + if _, err := batch.Commit(b.clientContext); err != nil { + return ConvertGRPCError(err) + } + } + + return nil +} + // ConvertGRPCError converts GRPC errors func ConvertGRPCError(err error, args ...interface{}) error { if err == nil { diff --git a/lib/backend/firestore/firestorebk_test.go b/lib/backend/firestore/firestorebk_test.go index 27adec9616838..e3e6e2b113b8a 100644 --- a/lib/backend/firestore/firestorebk_test.go +++ b/lib/backend/firestore/firestorebk_test.go @@ -16,20 +16,32 @@ package firestore import ( "context" + "errors" + "fmt" "net" "os" + "strings" "testing" "time" + "cloud.google.com/go/firestore" "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/backend/test" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" - "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "google.golang.org/api/option" adminpb "google.golang.org/genproto/googleapis/firestore/admin/v1" + firestorepb "google.golang.org/genproto/googleapis/firestore/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestMain(m *testing.M) { @@ -76,7 +88,7 @@ func ensureEmulatorRunning(t *testing.T, cfg map[string]interface{}) { if err != nil { t.Skip("Firestore emulator is not running, start it with: gcloud beta emulators firestore start --host-port=localhost:8618") } - con.Close() + require.NoError(t, con.Close()) } func TestFirestoreDB(t *testing.T) { @@ -118,7 +130,7 @@ func newBackend(t *testing.T, cfg map[string]interface{}) *Backend { uut, err := New(context.Background(), cfg, Options{Clock: clock}) require.NoError(t, err) - t.Cleanup(func() { uut.Close() }) + t.Cleanup(func() { require.NoError(t, uut.Close()) }) return uut } @@ -169,3 +181,138 @@ func TestReadLegacyRecord(t *testing.T) { require.Equal(t, item.ID, got.ID) require.Equal(t, item.Expires, got.Expires) } + +type mockFirestoreServer struct { + // Embed for forward compatibility. + // Tests will keep working if more methods are added + // in the future. + firestorepb.FirestoreServer + + reqs []proto.Message + + // If set, Commit returns this error. + commitErr error +} + +func (s *mockFirestoreServer) Commit(ctx context.Context, req *firestorepb.CommitRequest) (*firestorepb.CommitResponse, error) { + md, _ := metadata.FromIncomingContext(ctx) + if xg := md["x-goog-api-client"]; len(xg) == 0 || !strings.Contains(xg[0], "gl-go/") { + return nil, fmt.Errorf("x-goog-api-client = %v, expected gl-go key", xg) + } + + if len(req.Writes) > commitLimit { + return nil, status.Errorf(codes.InvalidArgument, "too many writes in a transaction") + } + + s.reqs = append(s.reqs, req) + if s.commitErr != nil { + return nil, s.commitErr + } + return &firestorepb.CommitResponse{ + WriteResults: []*firestorepb.WriteResult{{ + UpdateTime: timestamppb.Now(), + }}, + }, nil +} + +func TestDeleteDocuments(t *testing.T) { + t.Parallel() + cases := []struct { + name string + assertion require.ErrorAssertionFunc + responseErr error + commitErr error + documents int + }{ + { + name: "failed to commit", + assertion: require.Error, + commitErr: errors.New("failed to commit documents"), + documents: 1, + }, + { + name: "commit less than limit", + assertion: require.NoError, + documents: commitLimit - 123, + }, + { + name: "commit limit", + assertion: require.NoError, + documents: commitLimit, + }, + { + name: "commit more than limit", + assertion: require.NoError, + documents: (commitLimit * 3) + 173, + }, + } + + for _, tt := range cases { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + docs := make([]*firestore.DocumentSnapshot, 0, tt.documents) + for i := 0; i < tt.documents; i++ { + docs = append(docs, &firestore.DocumentSnapshot{ + Ref: &firestore.DocumentRef{ + Path: fmt.Sprintf("projects/test-project/databases/test-db/documents/test/%d", i+1), + }, + CreateTime: time.Now(), + UpdateTime: time.Now(), + }) + } + + mockFirestore := &mockFirestoreServer{ + commitErr: tt.commitErr, + } + srv := grpc.NewServer() + firestorepb.RegisterFirestoreServer(srv, mockFirestore) + + lis, err := net.Listen("tcp", "localhost:0") + require.NoError(t, err) + go func() { require.NoError(t, srv.Serve(lis)) }() + t.Cleanup(srv.Stop) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + client, err := firestore.NewClient(ctx, "test-project", option.WithGRPCConn(conn)) + require.NoError(t, err) + + b := &Backend{ + svc: client, + Entry: utils.NewLoggerForTests().WithFields(logrus.Fields{trace.Component: BackendName}), + clock: clockwork.NewFakeClock(), + clientContext: ctx, + clientCancel: cancel, + backendConfig: backendConfig{ + Config: Config{ + CollectionName: "test-collection", + }, + }, + } + + tt.assertion(t, b.deleteDocuments(docs)) + + if tt.documents == 0 { + return + } + + var committed int + for _, req := range mockFirestore.reqs { + switch r := req.(type) { + case *firestorepb.CommitRequest: + committed += len(r.Writes) + } + } + + require.Equal(t, tt.documents, committed) + + }) + } + +}