From c817481f68782414c031de37d9a50b5f20cb77d7 Mon Sep 17 00:00:00 2001 From: yk-eukarya <81808708+yk-eukarya@users.noreply.github.com> Date: Fri, 22 Nov 2024 10:22:11 +0100 Subject: [PATCH] chore(server): item import performance enhancement (#1319) --- server/internal/infrastructure/memory/item.go | 36 +++++--- .../infrastructure/memory/memorygit/memory.go | 12 +++ .../internal/infrastructure/memory/thread.go | 14 ++++ server/internal/infrastructure/mongo/item.go | 15 ++++ .../infrastructure/mongo/mongodoc/thread.go | 14 ++++ .../mongo/mongogit/collection.go | 21 +++++ .../internal/infrastructure/mongo/thread.go | 14 ++++ server/internal/usecase/interactor/item.go | 84 ++++++++++++------- server/internal/usecase/repo/Item.go | 1 + server/internal/usecase/repo/thread.go | 1 + server/pkg/item/id.go | 1 + server/pkg/item/list.go | 21 +++-- server/pkg/thread/list.go | 9 ++ 13 files changed, 196 insertions(+), 47 deletions(-) diff --git a/server/internal/infrastructure/memory/item.go b/server/internal/infrastructure/memory/item.go index ece73d183b..2639e7658f 100644 --- a/server/internal/infrastructure/memory/item.go +++ b/server/internal/infrastructure/memory/item.go @@ -24,7 +24,13 @@ type Item struct { err error } -func (r *Item) FindByAssets(ctx context.Context, list id.AssetIDList, ref *version.Ref) (item.VersionedList, error) { +func NewItem() repo.Item { + return &Item{ + data: memorygit.NewVersionedSyncMap[item.ID, *item.Item](), + } +} + +func (r *Item) FindByAssets(_ context.Context, list id.AssetIDList, ref *version.Ref) (item.VersionedList, error) { if r.err != nil { return nil, r.err } @@ -41,12 +47,6 @@ func (r *Item) FindByAssets(ctx context.Context, list id.AssetIDList, ref *versi return res, nil } -func NewItem() repo.Item { - return &Item{ - data: memorygit.NewVersionedSyncMap[item.ID, *item.Item](), - } -} - func (r *Item) Filtered(filter repo.ProjectFilter) repo.Item { return &Item{ data: r.data, @@ -109,7 +109,7 @@ func (r *Item) FindByIDs(_ context.Context, list id.ItemIDList, ref *version.Ref return r.data.LoadAll(list, lo.ToPtr(ref.OrLatest().OrVersion())), nil } -func (r *Item) FindVersionByID(ctx context.Context, itemID id.ItemID, ver version.VersionOrRef) (item.Versioned, error) { +func (r *Item) FindVersionByID(_ context.Context, itemID id.ItemID, ver version.VersionOrRef) (item.Versioned, error) { if r.err != nil { return nil, r.err } @@ -133,7 +133,7 @@ func (r *Item) FindAllVersionsByID(_ context.Context, id id.ItemID) (item.Versio }), nil } -func (r *Item) FindAllVersionsByIDs(ctx context.Context, ids id.ItemIDList) (item.VersionedList, error) { +func (r *Item) FindAllVersionsByIDs(_ context.Context, ids id.ItemIDList) (item.VersionedList, error) { if r.err != nil { return nil, r.err } @@ -145,7 +145,7 @@ func (r *Item) FindAllVersionsByIDs(ctx context.Context, ids id.ItemIDList) (ite }), nil } -func (r *Item) LastModifiedByModel(ctx context.Context, modelID id.ModelID) (time.Time, error) { +func (r *Item) LastModifiedByModel(_ context.Context, modelID id.ModelID) (time.Time, error) { if r.err != nil { return time.Time{}, r.err } @@ -175,7 +175,21 @@ func (r *Item) Save(_ context.Context, t *item.Item) error { return nil } -func (r *Item) UpdateRef(ctx context.Context, item id.ItemID, ref version.Ref, vr *version.VersionOrRef) error { +func (r *Item) SaveAll(_ context.Context, il item.List) error { + if r.err != nil { + return r.err + } + + for _, t := range il { + if !r.f.CanWrite(t.Project()) { + return repo.ErrOperationDenied + } + } + r.data.SaveAll(il.IDs(), il, nil) + return nil +} + +func (r *Item) UpdateRef(_ context.Context, item id.ItemID, ref version.Ref, vr *version.VersionOrRef) error { if r.err != nil { return r.err } diff --git a/server/internal/infrastructure/memory/memorygit/memory.go b/server/internal/infrastructure/memory/memorygit/memory.go index 94140e9e94..7fd172d12c 100644 --- a/server/internal/infrastructure/memory/memorygit/memory.go +++ b/server/internal/infrastructure/memory/memorygit/memory.go @@ -75,6 +75,18 @@ func (m *VersionedSyncMap[K, V]) SaveOne(key K, value V, parent *version.Version } } +func (m *VersionedSyncMap[K, V]) SaveAll(key []K, value []V, parent []*version.VersionOrRef) { + if len(key) != len(value) || (parent != nil && len(key) != len(parent)) { + return + } + if len(key) == 0 { + return + } + for i := 0; i < len(key); i++ { + m.SaveOne(key[i], value[i], parent[i]) + } +} + func (m *VersionedSyncMap[K, V]) UpdateRef(key K, ref version.Ref, vr *version.VersionOrRef) { m.Range(func(k K, v *version.Values[V]) bool { if k == key { diff --git a/server/internal/infrastructure/memory/thread.go b/server/internal/infrastructure/memory/thread.go index 7926788614..ed70b72d44 100644 --- a/server/internal/infrastructure/memory/thread.go +++ b/server/internal/infrastructure/memory/thread.go @@ -35,6 +35,20 @@ func (r *Thread) Save(_ context.Context, th *thread.Thread) error { return nil } +func (r *Thread) SaveAll(_ context.Context, th thread.List) error { + if r.err != nil { + return r.err + } + + for _, t := range th { + if !r.f.CanWrite(t.Workspace()) { + return repo.ErrOperationDenied + } + } + r.data.StoreAll(th.ToMap()) + return nil +} + func (r *Thread) Filtered(f repo.WorkspaceFilter) repo.Thread { return &Thread{ data: r.data, diff --git a/server/internal/infrastructure/mongo/item.go b/server/internal/infrastructure/mongo/item.go index be799d3dc7..4a952113c8 100644 --- a/server/internal/infrastructure/mongo/item.go +++ b/server/internal/infrastructure/mongo/item.go @@ -13,6 +13,7 @@ import ( "github.com/reearth/reearthx/mongox" "github.com/reearth/reearthx/rerror" "github.com/reearth/reearthx/usecasex" + "github.com/samber/lo" "go.mongodb.org/mongo-driver/bson" ) @@ -228,6 +229,20 @@ func (r *Item) Save(ctx context.Context, item *item.Item) error { return r.client.SaveOne(ctx, id, doc, nil) } +func (r *Item) SaveAll(ctx context.Context, items item.List) error { + if len(items) == 0 { + return nil + } + + for _, itm := range items { + if !r.f.CanWrite(itm.Project()) { + return repo.ErrOperationDenied + } + } + docs, ids := mongodoc.NewItems(items) + return r.client.SaveAll(ctx, ids, lo.ToAnySlice(docs), nil) +} + func (r *Item) UpdateRef(ctx context.Context, item id.ItemID, ref version.Ref, vr *version.VersionOrRef) error { return r.client.UpdateRef(ctx, item.String(), ref, vr) } diff --git a/server/internal/infrastructure/mongo/mongodoc/thread.go b/server/internal/infrastructure/mongo/mongodoc/thread.go index 72c2388b91..afcdb8cf6c 100644 --- a/server/internal/infrastructure/mongo/mongodoc/thread.go +++ b/server/internal/infrastructure/mongo/mongodoc/thread.go @@ -40,6 +40,20 @@ func NewThread(a *thread.Thread) (*ThreadDocument, string) { return thd, id } +func NewThreads(a thread.List) ([]ThreadDocument, []string) { + res := make([]ThreadDocument, 0, len(a)) + ids := make([]string, 0, len(a)) + for _, th := range a { + if th == nil { + continue + } + thDoc, thId := NewThread(th) + res = append(res, *thDoc) + ids = append(ids, thId) + } + return res, ids +} + func (d *ThreadDocument) Model() (*thread.Thread, error) { thid, err := id.ThreadIDFrom(d.ID) if err != nil { diff --git a/server/internal/infrastructure/mongo/mongogit/collection.go b/server/internal/infrastructure/mongo/mongogit/collection.go index 39e9e9f1fc..9edb602ce0 100644 --- a/server/internal/infrastructure/mongo/mongogit/collection.go +++ b/server/internal/infrastructure/mongo/mongogit/collection.go @@ -105,6 +105,27 @@ func (c *Collection) SaveOne(ctx context.Context, id string, d any, parent *vers return nil } +func (c *Collection) SaveAll(ctx context.Context, ids []string, docs []any, parents []*version.VersionOrRef) error { + // TODO: optimize to use bulk write + if len(ids) != len(docs) || (parents != nil && len(ids) != len(parents)) { + return rerror.ErrInvalidParams + } + if len(ids) == 0 { + return nil + } + for i := 0; i < len(ids); i++ { + var parent *version.VersionOrRef = nil + if parents != nil { + parent = parents[i] + } + err := c.SaveOne(ctx, ids[i], docs[i], parent) + if err != nil { + return err + } + } + return nil +} + func (c *Collection) UpdateRef(ctx context.Context, id string, ref version.Ref, dest *version.VersionOrRef) error { if _, err := c.client.Client().UpdateMany(ctx, bson.M{ "id": id, diff --git a/server/internal/infrastructure/mongo/thread.go b/server/internal/infrastructure/mongo/thread.go index 4b765aa2ec..ba9ad55229 100644 --- a/server/internal/infrastructure/mongo/thread.go +++ b/server/internal/infrastructure/mongo/thread.go @@ -9,6 +9,7 @@ import ( "github.com/reearth/reearth-cms/server/pkg/thread" "github.com/reearth/reearthx/mongox" "github.com/reearth/reearthx/rerror" + "github.com/samber/lo" "go.mongodb.org/mongo-driver/bson" ) @@ -39,6 +40,19 @@ func (r *ThreadRepo) Save(ctx context.Context, thread *thread.Thread) error { return r.client.SaveOne(ctx, id, doc) } +func (r *ThreadRepo) SaveAll(ctx context.Context, threads thread.List) error { + if len(threads) == 0 { + return nil + } + for _, t := range threads { + if !r.f.CanWrite(t.Workspace()) { + return repo.ErrOperationDenied + } + } + docs, ids := mongodoc.NewThreads(threads) + return r.client.SaveAll(ctx, ids, lo.ToAnySlice(docs)) +} + func (r *ThreadRepo) Filtered(f repo.WorkspaceFilter) repo.Thread { return &ThreadRepo{ client: r.client, diff --git a/server/internal/usecase/interactor/item.go b/server/internal/usecase/interactor/item.go index 204970ecd4..a1f7f1a992 100644 --- a/server/internal/usecase/interactor/item.go +++ b/server/internal/usecase/interactor/item.go @@ -366,16 +366,41 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope } res := NewImportRes() + prj, err := i.repos.Project.FindByID(ctx, s.Project()) + if err != nil { + return interfaces.ImportItemsResponse{}, err + } + m, err := i.repos.Model.FindByID(ctx, param.ModelID) if err != nil { return interfaces.ImportItemsResponse{}, err } + itemsIds := lo.FilterMap(param.Items, func(i interfaces.ImportItemParam, _ int) (item.ID, bool) { + if i.ItemId != nil { + return *i.ItemId, true + } + return item.ID{}, false + }) + oldItems, err := i.repos.Item.FindByIDs(ctx, itemsIds, nil) + if err != nil { + return interfaces.ImportItemsResponse{}, err + } + isMetadata := false if m.Metadata() != nil && s.ID() == *m.Metadata() { isMetadata = true } + threadsToSave := thread.List{} + itemsToSave := item.List{} + + type itemChanges struct { + oldFields item.Fields + action interfaces.ImportStrategyType + } + itemsEvent := map[item.ID]itemChanges{} + // update schema if needed if param.MutateSchema && len(param.Fields) > 0 { for _, fieldParam := range param.Fields { @@ -410,10 +435,7 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope var oldItem *item.Item if itemParam.ItemId != nil { - itm, err := i.repos.Item.FindByID(ctx, *itemParam.ItemId, nil) - if err != nil && !errors.Is(err, rerror.ErrNotFound) { - return interfaces.ImportItemsResponse{}, err - } + itm := oldItems.Item(*itemParam.ItemId) oldItem = itm.Value() } @@ -454,9 +476,7 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope if err != nil { return interfaces.ImportItemsResponse{}, err } - if err := i.repos.Thread.Save(ctx, th); err != nil { - return interfaces.ImportItemsResponse{}, err - } + threadsToSave = append(threadsToSave, th) ib := item.New(). NewID(). @@ -508,9 +528,7 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope return interfaces.ImportItemsResponse{}, interfaces.ErrMetadataMismatch } mi.Value().SetOriginalItem(it.ID()) - if err := i.repos.Item.Save(ctx, mi.Value()); err != nil { - return interfaces.ImportItemsResponse{}, err - } + itemsToSave = append(itemsToSave, mi.Value()) } modelSchemaFields, otherFields := filterFieldParamsBySchema(itemParam.Fields, s) @@ -527,7 +545,7 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope oldFields := it.Fields() it.UpdateFields(fields) - groupFields, groupSchemas, err := i.handleGroupFields(ctx, otherFields, s, m.ID(), it.Fields()) + groupFields, _, err := i.handleGroupFields(ctx, otherFields, s, m.ID(), it.Fields()) if err != nil { return interfaces.ImportItemsResponse{}, err } @@ -538,33 +556,42 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope return interfaces.ImportItemsResponse{}, err } - if err := i.repos.Item.Save(ctx, it); err != nil { - return interfaces.ImportItemsResponse{}, err - } + itemsToSave = append(itemsToSave, it) if isMetadata { continue } - - vi, err := i.repos.Item.FindByID(ctx, it.ID(), nil) - if err != nil { - return interfaces.ImportItemsResponse{}, err + itemsEvent[it.ID()] = itemChanges{ + oldFields: oldFields, + action: action, } + } - refItems, err := i.getReferencedItems(ctx, it.Fields()) - if err != nil { - return interfaces.ImportItemsResponse{}, err - } + if err := i.repos.Thread.SaveAll(ctx, threadsToSave); err != nil { + return interfaces.ImportItemsResponse{}, err + } - prj, err := i.repos.Project.FindByID(ctx, s.Project()) + if err := i.repos.Item.SaveAll(ctx, itemsToSave); err != nil { + return interfaces.ImportItemsResponse{}, err + } + + // TODO: create ItemsImported event + items, err := i.repos.Item.FindByIDs(ctx, lo.Keys(itemsEvent), nil) + if err != nil { + return interfaces.ImportItemsResponse{}, err + } + + for k, changes := range itemsEvent { + vi := items.Item(k) + it := vi.Value() + + refItems, err := i.getReferencedItems(ctx, it.Fields()) if err != nil { return interfaces.ImportItemsResponse{}, err } - // TODO: check if event creation is transactional - // A: in future create ItemsImported event var eType event.Type - if action == interfaces.ImportStrategyTypeInsert { + if changes.action == interfaces.ImportStrategyTypeInsert { eType = event.ItemCreate res.ItemInserted() } else { @@ -580,9 +607,9 @@ func (i Item) Import(ctx context.Context, param interfaces.ImportItemsParam, ope Item: vi.Value(), Model: m, Schema: s, - GroupSchemas: groupSchemas, + GroupSchemas: param.SP.GroupSchemas(), ReferencedItems: refItems, - Changes: item.CompareFields(it.Fields(), oldFields), + Changes: item.CompareFields(it.Fields(), changes.oldFields), }, Operator: operator.Operator(), }); err != nil { @@ -1021,6 +1048,7 @@ func (i Item) getItemCorrespondingItems(ctx context.Context, fr schema.FieldRefe } func (i Item) handleGroupFields(ctx context.Context, params []interfaces.ItemFieldParam, s *schema.Schema, mId id.ModelID, itemFields item.Fields) (item.Fields, schema.List, error) { + // TODO: use schema package to enhance performance var res item.Fields var groupSchemas schema.List for _, field := range itemFields.FieldsByType(value.TypeGroup) { diff --git a/server/internal/usecase/repo/Item.go b/server/internal/usecase/repo/Item.go index 5f73544691..f99726c4b3 100644 --- a/server/internal/usecase/repo/Item.go +++ b/server/internal/usecase/repo/Item.go @@ -32,6 +32,7 @@ type Item interface { FindByModelAndValue(context.Context, id.ModelID, []FieldAndValue, *version.Ref) (item.VersionedList, error) IsArchived(context.Context, id.ItemID) (bool, error) Save(context.Context, *item.Item) error + SaveAll(context.Context, item.List) error UpdateRef(context.Context, id.ItemID, version.Ref, *version.VersionOrRef) error Remove(context.Context, id.ItemID) error Archive(context.Context, id.ItemID, id.ProjectID, bool) error diff --git a/server/internal/usecase/repo/thread.go b/server/internal/usecase/repo/thread.go index 1b5b5bfc9e..41be400364 100644 --- a/server/internal/usecase/repo/thread.go +++ b/server/internal/usecase/repo/thread.go @@ -15,6 +15,7 @@ var ( type Thread interface { Save(context.Context, *thread.Thread) error + SaveAll(context.Context, thread.List) error Filtered(filter WorkspaceFilter) Thread FindByID(ctx context.Context, id id.ThreadID) (*thread.Thread, error) FindByIDs(context.Context, id.ThreadIDList) ([]*thread.Thread, error) diff --git a/server/pkg/item/id.go b/server/pkg/item/id.go index 747159b984..0155f4abc0 100644 --- a/server/pkg/item/id.go +++ b/server/pkg/item/id.go @@ -6,6 +6,7 @@ import ( ) type ID = id.ItemID +type IDList = id.ItemIDList type ProjectID = id.ProjectID type SchemaID = id.SchemaID type FieldID = id.FieldID diff --git a/server/pkg/item/list.go b/server/pkg/item/list.go index 4ade23db1a..a1e2c81143 100644 --- a/server/pkg/item/list.go +++ b/server/pkg/item/list.go @@ -1,34 +1,39 @@ package item import ( - "github.com/reearth/reearth-cms/server/pkg/id" "github.com/reearth/reearth-cms/server/pkg/version" "github.com/samber/lo" ) type List []*Item -func (l List) ItemsByField(fid id.FieldID, value any) List { +func (l List) ItemsByField(fid FieldID, value any) List { return lo.Filter(l, func(i *Item, _ int) bool { return i.HasField(fid, value) }) } -func (l List) FilterFields(lids id.FieldIDList) List { +func (l List) FilterFields(lids FieldIDList) List { return lo.Map(l, func(i *Item, _ int) *Item { return i.FilterFields(lids) }) } -func (l List) Item(iID id.ItemID) (*Item, bool) { +func (l List) Item(iID ID) (*Item, bool) { return lo.Find(l, func(i *Item) bool { return i.ID() == iID }) } +func (l List) IDs() IDList { + return lo.Map(l, func(i *Item, _ int) ID { + return i.ID() + }) +} + type VersionedList []Versioned -func (l VersionedList) FilterFields(fields id.FieldIDList) VersionedList { +func (l VersionedList) FilterFields(fields FieldIDList) VersionedList { return lo.Map(l, func(a Versioned, _ int) Versioned { return version.ValueFrom(a, a.Value().FilterFields(fields)) }) @@ -41,7 +46,7 @@ func (l VersionedList) Unwrap() List { return version.UnwrapValues(l) } -func (l VersionedList) Item(iid id.ItemID) Versioned { +func (l VersionedList) Item(iid ID) Versioned { if l == nil { return nil } @@ -53,8 +58,8 @@ func (l VersionedList) Item(iid id.ItemID) Versioned { return nil } -func (l VersionedList) ToMap() map[id.ItemID]*version.Value[*Item] { - m := make(map[id.ItemID]*version.Value[*Item], len(l)) +func (l VersionedList) ToMap() map[ID]*version.Value[*Item] { + m := make(map[ID]*version.Value[*Item], len(l)) for _, i := range l { m[i.Value().ID()] = i } diff --git a/server/pkg/thread/list.go b/server/pkg/thread/list.go index 7c8d6fe8a8..05faf33a2d 100644 --- a/server/pkg/thread/list.go +++ b/server/pkg/thread/list.go @@ -2,6 +2,7 @@ package thread import ( "github.com/reearth/reearthx/util" + "github.com/samber/lo" "golang.org/x/exp/slices" ) @@ -18,3 +19,11 @@ func (l List) SortByID() List { func (l List) Clone() List { return util.Map(l, func(th *Thread) *Thread { return th.Clone() }) } + +func (l List) IDs() []ID { + return util.Map(l, func(th *Thread) ID { return th.ID() }) +} + +func (l List) ToMap() map[ID]*Thread { + return lo.SliceToMap(l, func(th *Thread) (ID, *Thread) { return th.ID(), th }) +}