From ed180079ad63fd2cccd0446b0b08f11f29efb22f Mon Sep 17 00:00:00 2001 From: Jacob Heun Date: Mon, 23 Jan 2023 20:21:14 +0100 Subject: [PATCH] chore: refactor filter logic for deal list --- db/deals.go | 51 +++++++++++++++++++++++++++++++----------------- db/deals_test.go | 50 +++++++++++++++++++++++++++++++++++------------ gql/resolver.go | 19 ++++++------------ 3 files changed, 76 insertions(+), 44 deletions(-) diff --git a/db/deals.go b/db/deals.go index f8c3e5493..cd0f2d8c8 100644 --- a/db/deals.go +++ b/db/deals.go @@ -39,6 +39,13 @@ type dealAccessor struct { def map[string]fielddef.FieldDefinition } +type FilterOptions struct { + Checkpoint *string + IsOffline *bool + TransferType *string + VerifiedDeal *bool +} + func (d *DealsDB) newDealDef(deal *types.ProviderDealState) *dealAccessor { return newDealAccessor(d.db, deal) } @@ -229,7 +236,7 @@ func (d *DealsDB) BySignedProposalCID(ctx context.Context, proposalCid cid.Cid) return d.scanRow(row) } -func (d *DealsDB) Count(ctx context.Context, query string, filter map[string]interface{}) (int, error) { +func (d *DealsDB) Count(ctx context.Context, query string, filter *FilterOptions) (int, error) { whereArgs := []interface{}{} where := "SELECT count(*) FROM Deals" if query != "" { @@ -238,8 +245,8 @@ func (d *DealsDB) Count(ctx context.Context, query string, filter map[string]int whereArgs = append(whereArgs, searchArgs...) } - if len(filter) > 0 { - filterWhere, filterArgs := withSearchFilter(filter) + if filter != nil { + filterWhere, filterArgs := withSearchFilter(*filter) if query != "" { where += " AND " @@ -264,7 +271,7 @@ func (d *DealsDB) ListCompleted(ctx context.Context) ([]*types.ProviderDealState return d.list(ctx, 0, 0, "Checkpoint = ?", dealcheckpoints.Complete.String()) } -func (d *DealsDB) List(ctx context.Context, query string, filter map[string]interface{}, cursor *graphql.ID, offset int, limit int) ([]*types.ProviderDealState, error) { +func (d *DealsDB) List(ctx context.Context, query string, filter *FilterOptions, cursor *graphql.ID, offset int, limit int) ([]*types.ProviderDealState, error) { where := "" whereArgs := []interface{}{} @@ -284,12 +291,12 @@ func (d *DealsDB) List(ctx context.Context, query string, filter map[string]inte whereArgs = append(whereArgs, searchArgs...) } - if len(filter) > 0 { + if filter != nil { if where != "" { where += " AND " } - filterWhere, filterArgs := withSearchFilter(filter) + filterWhere, filterArgs := withSearchFilter(*filter) where += filterWhere whereArgs = append(whereArgs, filterArgs...) } @@ -297,20 +304,28 @@ func (d *DealsDB) List(ctx context.Context, query string, filter map[string]inte return d.list(ctx, offset, limit, where, whereArgs...) } -var filterFields = []string{"Checkpoint", "IsOffline", "TransferType", "VerifiedDeal"} - -func withSearchFilter(filter map[string]interface{}) (string, []interface{}) { +func withSearchFilter(filter FilterOptions) (string, []interface{}) { whereArgs := []interface{}{} - statements := []string{} - for _, filterField := range filterFields { - // If the filterField is in the filter and it's not empty, append - value, ok := filter[filterField] - if ok && value != nil { - st := filterField + " = ?" - statements = append(statements, st) - whereArgs = append(whereArgs, value) - } + + if filter.Checkpoint != nil { + statements = append(statements, "Checkpoint = ?") + whereArgs = append(whereArgs, *filter.Checkpoint) + } + + if filter.IsOffline != nil { + statements = append(statements, "IsOffline = ?") + whereArgs = append(whereArgs, *filter.IsOffline) + } + + if filter.TransferType != nil { + statements = append(statements, "TransferType = ?") + whereArgs = append(whereArgs, *filter.TransferType) + } + + if filter.VerifiedDeal != nil { + statements = append(statements, "VerifiedDeal = ?") + whereArgs = append(whereArgs, *filter.VerifiedDeal) } if len(statements) == 0 { diff --git a/db/deals_test.go b/db/deals_test.go index 3a2569ad1..57594ac0b 100644 --- a/db/deals_test.go +++ b/db/deals_test.go @@ -13,6 +13,29 @@ import ( "github.com/stretchr/testify/require" ) +func ToFilterOptions(filters map[string]interface{}) *FilterOptions { + filter := &FilterOptions{} + + cp, ok := filters["Checkpoint"].(string) + if ok { + filter.Checkpoint = &cp + } + io, ok := filters["IsOffline"].(bool) + if ok { + filter.IsOffline = &io + } + tt, ok := filters["TransferType"].(string) + if ok { + filter.TransferType = &tt + } + vd, ok := filters["VerifiedDeal"].(bool) + if ok { + filter.VerifiedDeal = &vd + } + + return filter +} + func TestDealsDB(t *testing.T) { req := require.New(t) ctx := context.Background() @@ -123,7 +146,7 @@ func TestDealsDBSearch(t *testing.T) { tcs := []struct { name string value string - filter map[string]interface{} + filter *FilterOptions count int }{{ name: "search error", @@ -183,24 +206,24 @@ func TestDealsDBSearch(t *testing.T) { }, { name: "filter out isOffline", value: "", - filter: map[string]interface{}{ + filter: ToFilterOptions(map[string]interface{}{ "IsOffline": false, - }, + }), count: 0, }, { name: "filter isOffline", value: "", - filter: map[string]interface{}{ + filter: ToFilterOptions(map[string]interface{}{ "IsOffline": true, - }, + }), count: 5, }, { name: "filter isOffline and IndexedAndAnnounced (in sealing)", value: "", - filter: map[string]interface{}{ + filter: ToFilterOptions(map[string]interface{}{ "IsOffline": true, - "Checkpoint": dealcheckpoints.IndexedAndAnnounced, - }, + "Checkpoint": dealcheckpoints.IndexedAndAnnounced.String(), + }), count: 0, }} for _, tc := range tcs { @@ -224,13 +247,13 @@ func TestDealsDBSearch(t *testing.T) { func TestWithSearchFilter(t *testing.T) { req := require.New(t) - filter := map[string]interface{}{ + + fo := ToFilterOptions(map[string]interface{}{ "Checkpoint": "Accepted", "IsOffline": true, "NotAValidFilter": 123, - } - - where, whereArgs := withSearchFilter(filter) + }) + where, whereArgs := withSearchFilter(*fo) expectedArgs := []interface{}{ "Accepted", true, @@ -238,10 +261,11 @@ func TestWithSearchFilter(t *testing.T) { req.Equal("(Checkpoint = ? AND IsOffline = ?)", where) req.Equal(expectedArgs, whereArgs) - where, whereArgs = withSearchFilter(map[string]interface{}{ + fo = ToFilterOptions(map[string]interface{}{ "IsOffline": nil, "NotAValidFilter": nil, }) + where, whereArgs = withSearchFilter(*fo) req.Equal("", where) req.Equal(0, len(whereArgs)) diff --git a/gql/resolver.go b/gql/resolver.go index cea761aed..f09c9540f 100644 --- a/gql/resolver.go +++ b/gql/resolver.go @@ -143,18 +143,11 @@ func (r *resolver) Deals(ctx context.Context, args dealsArgs) (*dealListResolver query = *args.Query.Value } - filter := map[string]interface{}{} - if args.Filter.Checkpoint.String != "" { - filter["Checkpoint"] = args.Filter.Checkpoint.String - } - if args.Filter.IsOffline.Set && args.Filter.IsOffline.Value != nil { - filter["IsOffline"] = args.Filter.IsOffline.Value - } - if args.Filter.TransferType.Set && args.Filter.TransferType.Value != nil { - filter["TransferType"] = args.Filter.TransferType.Value - } - if args.Filter.VerifiedDeal.Set && args.Filter.VerifiedDeal.Value != nil { - filter["VerifiedDeal"] = args.Filter.VerifiedDeal.Value + filter := &db.FilterOptions{ + Checkpoint: &args.Filter.Checkpoint.String, + IsOffline: args.Filter.IsOffline.Value, + TransferType: args.Filter.TransferType.Value, + VerifiedDeal: args.Filter.VerifiedDeal.Value, } deals, count, more, err := r.dealList(ctx, query, filter, args.Cursor, offset, limit) @@ -329,7 +322,7 @@ func (r *resolver) dealsByPublishCID(ctx context.Context, publishCid cid.Cid) ([ return deals, nil } -func (r *resolver) dealList(ctx context.Context, query string, filter map[string]interface{}, cursor *graphql.ID, offset int, limit int) ([]types.ProviderDealState, int, bool, error) { +func (r *resolver) dealList(ctx context.Context, query string, filter *db.FilterOptions, cursor *graphql.ID, offset int, limit int) ([]types.ProviderDealState, int, bool, error) { // Fetch one extra deal so that we can check if there are more deals // beyond the limit deals, err := r.dealsDB.List(ctx, query, filter, cursor, offset, limit+1)