From acd26a473cf2633f7bc6529c65cbfe7dc3a6c3a3 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Mon, 12 Apr 2021 13:34:25 +0300 Subject: [PATCH 1/4] [bug] Use DB tag to look up "next" field in auth pagination Tested by fixing the test (to look at >1 word), and also by paging with amount 2 and examining the `next_offset` field. Fixes #1748. --- pkg/auth/errors.go | 1 + pkg/auth/page.go | 1 - pkg/auth/service.go | 20 +++++++++++++++++++- pkg/auth/service_test.go | 8 ++++---- 4 files changed, 24 insertions(+), 6 deletions(-) delete mode 100644 pkg/auth/page.go diff --git a/pkg/auth/errors.go b/pkg/auth/errors.go index 30a046aef5d..8d1c027768b 100644 --- a/pkg/auth/errors.go +++ b/pkg/auth/errors.go @@ -11,4 +11,5 @@ var ( ErrAlreadyExists = db.ErrAlreadyExists ErrInvalidArn = errors.New("invalid ARN") ErrInsufficientPermissions = errors.New("insufficient permissions") + ErrNoField = errors.New("no field tagged in struct") ) diff --git a/pkg/auth/page.go b/pkg/auth/page.go deleted file mode 100644 index 8832b06d188..00000000000 --- a/pkg/auth/page.go +++ /dev/null @@ -1 +0,0 @@ -package auth diff --git a/pkg/auth/service.go b/pkg/auth/service.go index 4dbcdab6fe9..bf4f68c01b9 100644 --- a/pkg/auth/service.go +++ b/pkg/auth/service.go @@ -81,8 +81,25 @@ type Service interface { var psql = sq.StatementBuilder.PlaceholderFormat(sq.Dollar) +// fieldNameByTag returns the name of the field of t that is tagged tag on key, or an empty string. +func fieldByTag(t reflect.Type, key, tag string) string { + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if l, ok := field.Tag.Lookup(key); ok { + if l == tag { + return field.Name + } + } + } + return "" +} + func ListPaged(ctx context.Context, db db.Querier, retType reflect.Type, params *model.PaginationParams, tokenColumnName string, queryBuilder sq.SelectBuilder) (*reflect.Value, *model.Paginator, error) { ptrType := reflect.PtrTo(retType) + tokenField := fieldByTag(retType, "db", tokenColumnName) + if tokenField == "" { + return nil, nil, fmt.Errorf("[I] no field %s: %w", tokenColumnName, ErrNoField) + } slice := reflect.MakeSlice(reflect.SliceOf(ptrType), 0, 0) queryBuilder = queryBuilder.OrderBy(tokenColumnName) if params != nil { @@ -112,7 +129,8 @@ func ListPaged(ctx context.Context, db db.Querier, retType reflect.Type, params // we have more pages slice = slice.Slice(0, params.Amount) p.Amount = params.Amount - p.NextPageToken = slice.Index(slice.Len() - 1).Elem().FieldByName(tokenColumnName).String() + lastElem := slice.Index(slice.Len() - 1).Elem() + p.NextPageToken = lastElem.FieldByName(tokenField).String() return &slice, p, nil } p.Amount = slice.Len() diff --git a/pkg/auth/service_test.go b/pkg/auth/service_test.go index e92cf909f7a..6040f216798 100644 --- a/pkg/auth/service_test.go +++ b/pkg/auth/service_test.go @@ -106,8 +106,8 @@ func TestDBAuthService_ListPaged(t *testing.T) { ctx := context.Background() const chars = "abcdefghijklmnopqrstuvwxyz" adb, _ := testutil.GetDB(t, databaseURI) - type row struct{ A string } - if _, err := adb.Exec(ctx, `CREATE TABLE test_pages (a text PRIMARY KEY)`); err != nil { + type row struct{ TheKey string `db:"the_key"` } + if _, err := adb.Exec(ctx, `CREATE TABLE test_pages (the_key text PRIMARY KEY)`); err != nil { t.Fatalf("CREATE TABLE test_pages: %s", err) } insert := psql.Insert("test_pages") @@ -131,7 +131,7 @@ func TestDBAuthService_ListPaged(t *testing.T) { got := "" for { values, paginator, err := auth.ListPaged(ctx, - adb, reflect.TypeOf(row{}), pagination, "A", psql.Select("a").From("test_pages")) + adb, reflect.TypeOf(row{}), pagination, "the_key", psql.Select("the_key").From("test_pages")) if err != nil { t.Errorf("ListPaged: %s", err) break @@ -141,7 +141,7 @@ func TestDBAuthService_ListPaged(t *testing.T) { } letters := values.Interface().([]*row) for _, c := range letters { - got = got + c.A + got = got + c.TheKey } if paginator.NextPageToken == "" { if size > 0 && len(letters) > size { From a5d3461d34df1f7b35813b4379b36be9cb5fbbd5 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Mon, 12 Apr 2021 14:02:53 +0300 Subject: [PATCH 2/4] [bug] Return HasMore field from auth listings Previously was always `false`, so paging was never flagged. Tested by listing: ``` go run ./cmd/lakectl -c ~/.lakectl.nessie.yaml auth users list --amount 2 --after five +---------+-------------------------------+ | USER ID | CREATION DATE | +---------+-------------------------------+ | four | 2021-04-12 13:29:20 +0300 IDT | | nessie | 2021-04-11 17:07:59 +0300 IDT | +---------+-------------------------------+ for more results run with --amount 2 --after "nessie" ``` The line "for more results..." never used to appear! --- pkg/api/controller.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pkg/api/controller.go b/pkg/api/controller.go index aa4c15fba57..85d83b1f7d1 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -211,6 +211,7 @@ func (c *Controller) ListGroups(w http.ResponseWriter, r *http.Request, params L response := GroupList{ Results: make([]Group, 0, len(groups)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -324,6 +325,7 @@ func (c *Controller) ListGroupMembers(w http.ResponseWriter, r *http.Request, gr response := UserList{ Results: make([]User, 0, len(users)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -397,6 +399,7 @@ func (c *Controller) ListGroupPolicies(w http.ResponseWriter, r *http.Request, g response := PolicyList{ Results: make([]Policy, 0, len(policies)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -485,6 +488,7 @@ func (c *Controller) ListPolicies(w http.ResponseWriter, r *http.Request, params response := PolicyList{ Results: make([]Policy, 0, len(policies)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -632,6 +636,7 @@ func (c *Controller) ListUsers(w http.ResponseWriter, r *http.Request, params Li response := UserList{ Results: make([]User, 0, len(users)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -742,6 +747,7 @@ func (c *Controller) ListUserCredentials(w http.ResponseWriter, r *http.Request, response := CredentialsList{ Results: make([]Credentials, 0, len(credentials)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -850,6 +856,7 @@ func (c *Controller) ListUserGroups(w http.ResponseWriter, r *http.Request, user response := GroupList{ Results: make([]Group, 0, len(groups)), Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -892,6 +899,7 @@ func (c *Controller) ListUserPolicies(w http.ResponseWriter, r *http.Request, us response := PolicyList{ Pagination: Pagination{ + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, From 452f8d4d8a16d4f7dfa6dc0ce3bc3e6e6b7b1961 Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Mon, 12 Apr 2021 14:03:53 +0300 Subject: [PATCH 3/4] [bug] Cap maximal number amount in auth listings --- pkg/auth/service.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pkg/auth/service.go b/pkg/auth/service.go index bf4f68c01b9..59eaa5b0a6d 100644 --- a/pkg/auth/service.go +++ b/pkg/auth/service.go @@ -94,6 +94,8 @@ func fieldByTag(t reflect.Type, key, tag string) string { return "" } +const maxPage = 1000 + func ListPaged(ctx context.Context, db db.Querier, retType reflect.Type, params *model.PaginationParams, tokenColumnName string, queryBuilder sq.SelectBuilder) (*reflect.Value, *model.Paginator, error) { ptrType := reflect.PtrTo(retType) tokenField := fieldByTag(retType, "db", tokenColumnName) @@ -102,12 +104,19 @@ func ListPaged(ctx context.Context, db db.Querier, retType reflect.Type, params } slice := reflect.MakeSlice(reflect.SliceOf(ptrType), 0, 0) queryBuilder = queryBuilder.OrderBy(tokenColumnName) + amount := 0 if params != nil { queryBuilder = queryBuilder.Where(sq.Gt{tokenColumnName: params.After}) if params.Amount >= 0 { - queryBuilder = queryBuilder.Limit(uint64(params.Amount) + 1) + amount = params.Amount + 1 } } + if amount > maxPage { + amount = maxPage + } + if amount > 0 { + queryBuilder = queryBuilder.Limit(uint64(amount)) + } query, args, err := queryBuilder.ToSql() if err != nil { return nil, nil, fmt.Errorf("convert to SQL: %w", err) From ddd6fdfc14cdc387656b4fc46d88d80c97c402fc Mon Sep 17 00:00:00 2001 From: "Ariel Shaqed (Scolnicov)" Date: Mon, 12 Apr 2021 14:39:37 +0300 Subject: [PATCH 4/4] Use unadorned column names auth paginator cannot handle table-decorated column names, it looks for the a field with an exact tag name. Avoid unnecessary adornment of the table from which a column comes. It is (by definition!) not needed to describe the column, which is how the unadorned `db:...` tag works. --- pkg/api/controller.go | 16 ++++++++-------- pkg/auth/service.go | 6 +++--- pkg/auth/service_test.go | 4 +++- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 85d83b1f7d1..acffb96bfd8 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -211,7 +211,7 @@ func (c *Controller) ListGroups(w http.ResponseWriter, r *http.Request, params L response := GroupList{ Results: make([]Group, 0, len(groups)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -325,7 +325,7 @@ func (c *Controller) ListGroupMembers(w http.ResponseWriter, r *http.Request, gr response := UserList{ Results: make([]User, 0, len(users)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -399,7 +399,7 @@ func (c *Controller) ListGroupPolicies(w http.ResponseWriter, r *http.Request, g response := PolicyList{ Results: make([]Policy, 0, len(policies)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -488,7 +488,7 @@ func (c *Controller) ListPolicies(w http.ResponseWriter, r *http.Request, params response := PolicyList{ Results: make([]Policy, 0, len(policies)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -636,7 +636,7 @@ func (c *Controller) ListUsers(w http.ResponseWriter, r *http.Request, params Li response := UserList{ Results: make([]User, 0, len(users)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -747,7 +747,7 @@ func (c *Controller) ListUserCredentials(w http.ResponseWriter, r *http.Request, response := CredentialsList{ Results: make([]Credentials, 0, len(credentials)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -856,7 +856,7 @@ func (c *Controller) ListUserGroups(w http.ResponseWriter, r *http.Request, user response := GroupList{ Results: make([]Group, 0, len(groups)), Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, @@ -899,7 +899,7 @@ func (c *Controller) ListUserPolicies(w http.ResponseWriter, r *http.Request, us response := PolicyList{ Pagination: Pagination{ - HasMore: paginator.NextPageToken != "", + HasMore: paginator.NextPageToken != "", NextOffset: paginator.NextPageToken, Results: paginator.Amount, }, diff --git a/pkg/auth/service.go b/pkg/auth/service.go index 59eaa5b0a6d..f47c5df7520 100644 --- a/pkg/auth/service.go +++ b/pkg/auth/service.go @@ -300,7 +300,7 @@ func (s *DBAuthService) ListUsers(ctx context.Context, params *model.PaginationP func (s *DBAuthService) ListUserCredentials(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Credential, *model.Paginator, error) { var credential model.Credential - slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(credential), params, "auth_credentials.access_key_id", psql.Select("auth_credentials.*"). + slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(credential), params, "access_key_id", psql.Select("auth_credentials.*"). From("auth_credentials"). Join("auth_users ON (auth_credentials.user_id = auth_users.id)"). Where(sq.Eq{"auth_users.display_name": username})) @@ -353,7 +353,7 @@ func (s *DBAuthService) DetachPolicyFromUser(ctx context.Context, policyDisplayN func (s *DBAuthService) ListUserPolicies(ctx context.Context, username string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { var policy model.Policy - slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(policy), params, "auth_policies.display_name", psql.Select("auth_policies.*"). + slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(policy), params, "display_name", psql.Select("auth_policies.*"). From("auth_policies"). Join("auth_user_policies ON (auth_policies.id = auth_user_policies.policy_id)"). Join("auth_users ON (auth_user_policies.user_id = auth_users.id)"). @@ -412,7 +412,7 @@ func (s *DBAuthService) ListEffectivePolicies(ctx context.Context, username stri func (s *DBAuthService) ListGroupPolicies(ctx context.Context, groupDisplayName string, params *model.PaginationParams) ([]*model.Policy, *model.Paginator, error) { var policy model.Policy - slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(policy), params, "auth_policies.display_name", + slice, paginator, err := ListPaged(ctx, s.db, reflect.TypeOf(policy), params, "display_name", psql.Select("auth_policies.*"). From("auth_policies"). Join("auth_group_policies ON (auth_policies.id = auth_group_policies.policy_id)"). diff --git a/pkg/auth/service_test.go b/pkg/auth/service_test.go index 6040f216798..b286663129f 100644 --- a/pkg/auth/service_test.go +++ b/pkg/auth/service_test.go @@ -106,7 +106,9 @@ func TestDBAuthService_ListPaged(t *testing.T) { ctx := context.Background() const chars = "abcdefghijklmnopqrstuvwxyz" adb, _ := testutil.GetDB(t, databaseURI) - type row struct{ TheKey string `db:"the_key"` } + type row struct { + TheKey string `db:"the_key"` + } if _, err := adb.Exec(ctx, `CREATE TABLE test_pages (the_key text PRIMARY KEY)`); err != nil { t.Fatalf("CREATE TABLE test_pages: %s", err) }