Skip to content

Commit

Permalink
chore: refactor filter logic for deal list
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobheun committed Jan 23, 2023
1 parent 674caa9 commit ed18007
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 44 deletions.
51 changes: 33 additions & 18 deletions db/deals.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ type dealAccessor struct {
def map[string]fielddef.FieldDefinition
}

type FilterOptions struct {
Checkpoint *string
IsOffline *bool
TransferType *string
VerifiedDeal *bool
}

func (d *DealsDB) newDealDef(deal *types.ProviderDealState) *dealAccessor {
return newDealAccessor(d.db, deal)
}
Expand Down Expand Up @@ -229,7 +236,7 @@ func (d *DealsDB) BySignedProposalCID(ctx context.Context, proposalCid cid.Cid)
return d.scanRow(row)
}

func (d *DealsDB) Count(ctx context.Context, query string, filter map[string]interface{}) (int, error) {
func (d *DealsDB) Count(ctx context.Context, query string, filter *FilterOptions) (int, error) {
whereArgs := []interface{}{}
where := "SELECT count(*) FROM Deals"
if query != "" {
Expand All @@ -238,8 +245,8 @@ func (d *DealsDB) Count(ctx context.Context, query string, filter map[string]int
whereArgs = append(whereArgs, searchArgs...)
}

if len(filter) > 0 {
filterWhere, filterArgs := withSearchFilter(filter)
if filter != nil {
filterWhere, filterArgs := withSearchFilter(*filter)

if query != "" {
where += " AND "
Expand All @@ -264,7 +271,7 @@ func (d *DealsDB) ListCompleted(ctx context.Context) ([]*types.ProviderDealState
return d.list(ctx, 0, 0, "Checkpoint = ?", dealcheckpoints.Complete.String())
}

func (d *DealsDB) List(ctx context.Context, query string, filter map[string]interface{}, cursor *graphql.ID, offset int, limit int) ([]*types.ProviderDealState, error) {
func (d *DealsDB) List(ctx context.Context, query string, filter *FilterOptions, cursor *graphql.ID, offset int, limit int) ([]*types.ProviderDealState, error) {
where := ""
whereArgs := []interface{}{}

Expand All @@ -284,33 +291,41 @@ func (d *DealsDB) List(ctx context.Context, query string, filter map[string]inte
whereArgs = append(whereArgs, searchArgs...)
}

if len(filter) > 0 {
if filter != nil {
if where != "" {
where += " AND "
}

filterWhere, filterArgs := withSearchFilter(filter)
filterWhere, filterArgs := withSearchFilter(*filter)
where += filterWhere
whereArgs = append(whereArgs, filterArgs...)
}

return d.list(ctx, offset, limit, where, whereArgs...)
}

var filterFields = []string{"Checkpoint", "IsOffline", "TransferType", "VerifiedDeal"}

func withSearchFilter(filter map[string]interface{}) (string, []interface{}) {
func withSearchFilter(filter FilterOptions) (string, []interface{}) {
whereArgs := []interface{}{}

statements := []string{}
for _, filterField := range filterFields {
// If the filterField is in the filter and it's not empty, append
value, ok := filter[filterField]
if ok && value != nil {
st := filterField + " = ?"
statements = append(statements, st)
whereArgs = append(whereArgs, value)
}

if filter.Checkpoint != nil {
statements = append(statements, "Checkpoint = ?")
whereArgs = append(whereArgs, *filter.Checkpoint)
}

if filter.IsOffline != nil {
statements = append(statements, "IsOffline = ?")
whereArgs = append(whereArgs, *filter.IsOffline)
}

if filter.TransferType != nil {
statements = append(statements, "TransferType = ?")
whereArgs = append(whereArgs, *filter.TransferType)
}

if filter.VerifiedDeal != nil {
statements = append(statements, "VerifiedDeal = ?")
whereArgs = append(whereArgs, *filter.VerifiedDeal)
}

if len(statements) == 0 {
Expand Down
50 changes: 37 additions & 13 deletions db/deals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,29 @@ import (
"github.com/stretchr/testify/require"
)

func ToFilterOptions(filters map[string]interface{}) *FilterOptions {
filter := &FilterOptions{}

cp, ok := filters["Checkpoint"].(string)
if ok {
filter.Checkpoint = &cp
}
io, ok := filters["IsOffline"].(bool)
if ok {
filter.IsOffline = &io
}
tt, ok := filters["TransferType"].(string)
if ok {
filter.TransferType = &tt
}
vd, ok := filters["VerifiedDeal"].(bool)
if ok {
filter.VerifiedDeal = &vd
}

return filter
}

func TestDealsDB(t *testing.T) {
req := require.New(t)
ctx := context.Background()
Expand Down Expand Up @@ -123,7 +146,7 @@ func TestDealsDBSearch(t *testing.T) {
tcs := []struct {
name string
value string
filter map[string]interface{}
filter *FilterOptions
count int
}{{
name: "search error",
Expand Down Expand Up @@ -183,24 +206,24 @@ func TestDealsDBSearch(t *testing.T) {
}, {
name: "filter out isOffline",
value: "",
filter: map[string]interface{}{
filter: ToFilterOptions(map[string]interface{}{
"IsOffline": false,
},
}),
count: 0,
}, {
name: "filter isOffline",
value: "",
filter: map[string]interface{}{
filter: ToFilterOptions(map[string]interface{}{
"IsOffline": true,
},
}),
count: 5,
}, {
name: "filter isOffline and IndexedAndAnnounced (in sealing)",
value: "",
filter: map[string]interface{}{
filter: ToFilterOptions(map[string]interface{}{
"IsOffline": true,
"Checkpoint": dealcheckpoints.IndexedAndAnnounced,
},
"Checkpoint": dealcheckpoints.IndexedAndAnnounced.String(),
}),
count: 0,
}}
for _, tc := range tcs {
Expand All @@ -224,24 +247,25 @@ func TestDealsDBSearch(t *testing.T) {

func TestWithSearchFilter(t *testing.T) {
req := require.New(t)
filter := map[string]interface{}{

fo := ToFilterOptions(map[string]interface{}{
"Checkpoint": "Accepted",
"IsOffline": true,
"NotAValidFilter": 123,
}

where, whereArgs := withSearchFilter(filter)
})
where, whereArgs := withSearchFilter(*fo)
expectedArgs := []interface{}{
"Accepted",
true,
}
req.Equal("(Checkpoint = ? AND IsOffline = ?)", where)
req.Equal(expectedArgs, whereArgs)

where, whereArgs = withSearchFilter(map[string]interface{}{
fo = ToFilterOptions(map[string]interface{}{
"IsOffline": nil,
"NotAValidFilter": nil,
})
where, whereArgs = withSearchFilter(*fo)

req.Equal("", where)
req.Equal(0, len(whereArgs))
Expand Down
19 changes: 6 additions & 13 deletions gql/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,11 @@ func (r *resolver) Deals(ctx context.Context, args dealsArgs) (*dealListResolver
query = *args.Query.Value
}

filter := map[string]interface{}{}
if args.Filter.Checkpoint.String != "" {
filter["Checkpoint"] = args.Filter.Checkpoint.String
}
if args.Filter.IsOffline.Set && args.Filter.IsOffline.Value != nil {
filter["IsOffline"] = args.Filter.IsOffline.Value
}
if args.Filter.TransferType.Set && args.Filter.TransferType.Value != nil {
filter["TransferType"] = args.Filter.TransferType.Value
}
if args.Filter.VerifiedDeal.Set && args.Filter.VerifiedDeal.Value != nil {
filter["VerifiedDeal"] = args.Filter.VerifiedDeal.Value
filter := &db.FilterOptions{
Checkpoint: &args.Filter.Checkpoint.String,
IsOffline: args.Filter.IsOffline.Value,
TransferType: args.Filter.TransferType.Value,
VerifiedDeal: args.Filter.VerifiedDeal.Value,
}

deals, count, more, err := r.dealList(ctx, query, filter, args.Cursor, offset, limit)
Expand Down Expand Up @@ -329,7 +322,7 @@ func (r *resolver) dealsByPublishCID(ctx context.Context, publishCid cid.Cid) ([
return deals, nil
}

func (r *resolver) dealList(ctx context.Context, query string, filter map[string]interface{}, cursor *graphql.ID, offset int, limit int) ([]types.ProviderDealState, int, bool, error) {
func (r *resolver) dealList(ctx context.Context, query string, filter *db.FilterOptions, cursor *graphql.ID, offset int, limit int) ([]types.ProviderDealState, int, bool, error) {
// Fetch one extra deal so that we can check if there are more deals
// beyond the limit
deals, err := r.dealsDB.List(ctx, query, filter, cursor, offset, limit+1)
Expand Down

0 comments on commit ed18007

Please sign in to comment.