Skip to content
This repository has been archived by the owner on Dec 22, 2023. It is now read-only.

Commit

Permalink
Record acl for transient fields
Browse files Browse the repository at this point in the history
refs #496
  • Loading branch information
cheungpat committed Dec 8, 2017
2 parents 107ea4e + 3ca6882 commit 6d2305d
Show file tree
Hide file tree
Showing 19 changed files with 358 additions and 189 deletions.
10 changes: 5 additions & 5 deletions pkg/server/handler/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ func MakeEqualPredicateAssertion(key string, value string) func(predicate *skydb
}
}

func MakeUsernameEmailQueryAssertion(username string, email string) func(query *skydb.Query) {
return func(query *skydb.Query) {
func MakeUsernameEmailQueryAssertion(username string, email string) func(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) {
return func(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) {
So(query.Type, ShouldEqual, "user")

predicate := query.Predicate
Expand Down Expand Up @@ -474,7 +474,7 @@ func TestLoginHandler(t *testing.T) {
conn.CreateAuth(&authinfo)

db.EXPECT().
Query(gomock.Any()).
Query(gomock.Any(), gomock.Any()).
Do(MakeUsernameEmailQueryAssertion("john.doe", "")).
Return(skydb.NewRows(skydb.NewMemoryRows([]skydb.Record{skydb.Record{
ID: skydb.NewRecordID("user", authinfo.ID),
Expand Down Expand Up @@ -536,7 +536,7 @@ func TestLoginHandler(t *testing.T) {
conn.CreateAuth(&authinfo)

db.EXPECT().
Query(gomock.Any()).
Query(gomock.Any(), gomock.Any()).
Do(MakeUsernameEmailQueryAssertion("john.doe", "")).
Return(skydb.NewRows(skydb.NewMemoryRows([]skydb.Record{skydb.Record{
ID: skydb.NewRecordID("user", authinfo.ID),
Expand Down Expand Up @@ -564,7 +564,7 @@ func TestLoginHandler(t *testing.T) {

Convey("login user not found", func() {
db.EXPECT().
Query(gomock.Any()).
Query(gomock.Any(), gomock.Any()).
Do(MakeUsernameEmailQueryAssertion("john.doe", "")).
Return(skydb.NewRows(skydb.NewMemoryRows([]skydb.Record{})), nil).
AnyTimes()
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/handler/authutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (f *UserAuthFetcher) FetchUser(authData skydb.AuthData) (user skydb.Record,
query := f.buildAuthDataQuery(authData)

var results *skydb.Rows
results, err = f.Database.Query(&query)
results, err = f.Database.Query(&query, &skydb.AccessControlOptions{})
if err != nil {
return
}
Expand Down
21 changes: 9 additions & 12 deletions pkg/server/handler/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,9 @@ func (h *RecordQueryHandler) Handle(payload *router.Payload, response *router.Re
return
}

if payload.AuthInfo != nil {
p.Query.ViewAsUser = payload.AuthInfo
}

if payload.HasMasterKey() {
p.Query.BypassAccessControl = true
accessControlOptions := &skydb.AccessControlOptions{
ViewAsUser: payload.AuthInfo,
BypassAccessControl: payload.HasMasterKey(),
}

fieldACL := func() skydb.FieldACL {
Expand All @@ -585,11 +582,11 @@ func (h *RecordQueryHandler) Handle(payload *router.Payload, response *router.Re
return acl
}()

if !p.Query.BypassAccessControl {
if !accessControlOptions.BypassAccessControl {
visitor := &queryAccessVisitor{
FieldACL: fieldACL,
RecordType: p.Query.Type,
AuthInfo: p.Query.ViewAsUser,
AuthInfo: accessControlOptions.ViewAsUser,
ExpressionACLChecker: ExpressionACLChecker{
FieldACL: fieldACL,
RecordType: p.Query.Type,
Expand All @@ -606,7 +603,7 @@ func (h *RecordQueryHandler) Handle(payload *router.Payload, response *router.Re

db := payload.Database

results, err := db.Query(&p.Query)
results, err := db.Query(&p.Query, accessControlOptions)
if err != nil {
response.Err = skyerr.MakeError(err)
return
Expand All @@ -629,13 +626,13 @@ func (h *RecordQueryHandler) Handle(payload *router.Payload, response *router.Re
// so we replace them with some complete assets.
recordutil.MakeAssetsComplete(db, payload.DBConn, records)

eagerRecords := recordutil.DoQueryEager(db, recordutil.EagerIDs(db, records, p.Query))
eagerRecords := recordutil.DoQueryEager(db, recordutil.EagerIDs(db, records, p.Query), accessControlOptions)

recordResultFilter, err := recordutil.NewRecordResultFilter(
payload.DBConn,
h.AssetStore,
payload.AuthInfo,
p.Query.BypassAccessControl,
accessControlOptions.BypassAccessControl,
)
if err != nil {
response.Err = skyerr.MakeError(err)
Expand All @@ -657,7 +654,7 @@ func (h *RecordQueryHandler) Handle(payload *router.Payload, response *router.Re

response.Result = output

resultInfo, err := recordutil.QueryResultInfo(db, &p.Query, results)
resultInfo, err := recordutil.QueryResultInfo(db, &p.Query, accessControlOptions, results)
if err != nil {
response.Err = skyerr.MakeError(err)
return
Expand Down
133 changes: 116 additions & 17 deletions pkg/server/handler/record_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -917,8 +917,9 @@ func TestRecordSaveNoExtendIfRecordMalformed(t *testing.T) {
}

type queryDatabase struct {
lastquery *skydb.Query
databaseID string
lastquery *skydb.Query
lastAccessControlOptions *skydb.AccessControlOptions
databaseID string
skydb.Database
}

Expand All @@ -931,13 +932,15 @@ func (db *queryDatabase) ID() string {
return db.databaseID
}

func (db *queryDatabase) QueryCount(query *skydb.Query) (uint64, error) {
func (db *queryDatabase) QueryCount(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (uint64, error) {
db.lastquery = query
db.lastAccessControlOptions = accessControlOptions
return 0, nil
}

func (db *queryDatabase) Query(query *skydb.Query) (*skydb.Rows, error) {
func (db *queryDatabase) Query(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (*skydb.Rows, error) {
db.lastquery = query
db.lastAccessControlOptions = accessControlOptions
return skydb.EmptyRows, nil
}

Expand All @@ -957,11 +960,11 @@ func (db *queryResultsDatabase) ID() string {
return db.databaseID
}

func (db *queryResultsDatabase) QueryCount(query *skydb.Query) (uint64, error) {
func (db *queryResultsDatabase) QueryCount(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (uint64, error) {
return uint64(len(db.records)), nil
}

func (db *queryResultsDatabase) Query(query *skydb.Query) (*skydb.Rows, error) {
func (db *queryResultsDatabase) Query(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (*skydb.Rows, error) {
return skydb.NewRows(skydb.NewMemoryRows(db.records)), nil
}

Expand Down Expand Up @@ -1060,7 +1063,9 @@ func TestRecordQuery(t *testing.T) {

So(response.Err, ShouldBeNil)
So(db.lastquery, ShouldResemble, &skydb.Query{
Type: "note",
Type: "note",
})
So(db.lastAccessControlOptions, ShouldResemble, &skydb.AccessControlOptions{
ViewAsUser: &authInfo,
})
})
Expand All @@ -1085,7 +1090,9 @@ func TestRecordQuery(t *testing.T) {

So(response.Err, ShouldBeNil)
So(db.lastquery, ShouldResemble, &skydb.Query{
Type: "note",
Type: "note",
})
So(db.lastAccessControlOptions, ShouldResemble, &skydb.AccessControlOptions{
ViewAsUser: &authInfo,
BypassAccessControl: true,
})
Expand Down Expand Up @@ -1689,11 +1696,11 @@ func (db *singleRecordDatabase) Save(record *skydb.Record) error {
return nil
}

func (db *singleRecordDatabase) QueryCount(query *skydb.Query) (uint64, error) {
func (db *singleRecordDatabase) QueryCount(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (uint64, error) {
return uint64(1), nil
}

func (db *singleRecordDatabase) Query(query *skydb.Query) (*skydb.Rows, error) {
func (db *singleRecordDatabase) Query(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (*skydb.Rows, error) {
return skydb.NewRows(skydb.NewMemoryRows([]skydb.Record{db.record})), nil
}

Expand Down Expand Up @@ -2085,6 +2092,7 @@ type referencedRecordDatabase struct {
category skydb.Record
city skydb.Record
user skydb.Record
secret skydb.Record
databaseID string
skydb.Database
}
Expand Down Expand Up @@ -2114,18 +2122,37 @@ func (db *referencedRecordDatabase) Get(id skydb.RecordID, record *skydb.Record)
return nil
}

func (db *referencedRecordDatabase) GetByIDs(ids []skydb.RecordID) (*skydb.Rows, error) {
func (db *referencedRecordDatabase) GetByIDs(ids []skydb.RecordID, accessControlOptions *skydb.AccessControlOptions) (*skydb.Rows, error) {
records := []skydb.Record{}
for _, id := range ids {
var record *skydb.Record
switch id.String() {
case "note/note1":
records = append(records, db.note)
record = &db.note
case "category/important":
records = append(records, db.category)
record = &db.category
case "city/beautiful":
records = append(records, db.city)
record = &db.city
case "user/ownerID":
records = append(records, db.user)
record = &db.user
case "secret/secretID":
record = &db.secret
}

// mock the acl query
// it will only consider direct record acl entry
if record != nil {
if record.ACL == nil || len(record.ACL) == 0 {
records = append(records, *record)
continue
}
for _, aclEntry := range record.ACL {
if aclEntry.Relation == "$direct" &&
aclEntry.UserID == accessControlOptions.ViewAsUser.ID {
records = append(records, db.secret)
continue
}
}
}
}
return skydb.NewRows(skydb.NewMemoryRows(records)), nil
Expand All @@ -2135,11 +2162,11 @@ func (db *referencedRecordDatabase) Save(record *skydb.Record) error {
return nil
}

func (db *referencedRecordDatabase) QueryCount(query *skydb.Query) (uint64, error) {
func (db *referencedRecordDatabase) QueryCount(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (uint64, error) {
return uint64(1), nil
}

func (db *referencedRecordDatabase) Query(query *skydb.Query) (*skydb.Rows, error) {
func (db *referencedRecordDatabase) Query(query *skydb.Query, accessControlOptions *skydb.AccessControlOptions) (*skydb.Rows, error) {
return skydb.NewRows(skydb.NewMemoryRows([]skydb.Record{db.note})), nil
}

Expand Down Expand Up @@ -2187,6 +2214,7 @@ func TestRecordQueryWithEagerLoad(t *testing.T) {
Data: map[string]interface{}{
"category": skydb.NewReference("category", "important"),
"city": skydb.NewReference("city", "beautiful"),
"secret": skydb.NewReference("secret", "secretID"),
},
},
category: skydb.Record{
Expand All @@ -2210,6 +2238,16 @@ func TestRecordQueryWithEagerLoad(t *testing.T) {
"name": "Owner",
},
},
secret: skydb.Record{
ID: skydb.NewRecordID("secret", "secretID"),
OwnerID: "ownerID",
Data: map[string]interface{}{
"content": "Secret of the note",
},
ACL: skydb.RecordACL{
skydb.NewRecordACLEntryDirect("ownerID", skydb.WriteLevel),
},
},
}
conn := skydbtest.NewMapConn()

Expand All @@ -2232,6 +2270,7 @@ func TestRecordQueryWithEagerLoad(t *testing.T) {
"_ownerID": "ownerID",
"category": {"$id":"category/important","$type":"ref"},
"city": {"$id":"city/beautiful","$type":"ref"},
"secret":{"$id":"secret/secretID","$type":"ref"},
"_transient": {
"category": {"_access":null,"_id":"category/important","_type":"record","_ownerID":"ownerID", "title": "This is important."}
}
Expand All @@ -2256,6 +2295,7 @@ func TestRecordQueryWithEagerLoad(t *testing.T) {
"_ownerID": "ownerID",
"category": {"$id":"category/important","$type":"ref"},
"city": {"$id":"city/beautiful","$type":"ref"},
"secret":{"$id":"secret/secretID","$type":"ref"},
"_transient": {
"category": {"_access":null,"_id":"category/important","_type":"record","_ownerID":"ownerID", "title": "This is important."},
"city": {"_access":null,"_id":"city/beautiful","_type":"record","_ownerID":"ownerID", "name": "This is beautiful."}
Expand All @@ -2278,13 +2318,72 @@ func TestRecordQueryWithEagerLoad(t *testing.T) {
"_ownerID": "ownerID",
"category": {"$id":"category/important","$type":"ref"},
"city": {"$id":"city/beautiful","$type":"ref"},
"secret":{"$id":"secret/secretID","$type":"ref"},
"_transient": {
"user": {"_access":null,"_id":"user/ownerID","_type":"record","_ownerID":"ownerID", "name": "Owner"}
}
}]
}`)
})

Convey("query record with eager load on non public record", func() {
resp := handlertest.NewSingleRouteRouter(&RecordQueryHandler{}, func(p *router.Payload) {
p.Database = db
p.DBConn = skydbtest.NewMapConn()
p.AuthInfo = &skydb.AuthInfo{
ID: "user0",
}
}).POST(`{
"record_type": "note",
"include": {
"secret": {"$type": "keypath", "$val": "secret"}
}
}`)

So(resp.Body.Bytes(), ShouldEqualJSON, `{
"result": [{
"_id": "note/note1",
"_type": "record",
"_access": null,
"_ownerID": "ownerID",
"category": {"$id":"category/important","$type":"ref"},
"city": {"$id":"city/beautiful","$type":"ref"},
"secret":{"$id":"secret/secretID","$type":"ref"},
"_transient": {"secret":null}
}]
}`)
})

Convey("query record with eager load on non public record with permission", func() {
resp := handlertest.NewSingleRouteRouter(&RecordQueryHandler{}, func(p *router.Payload) {
p.Database = db
p.DBConn = skydbtest.NewMapConn()
p.AuthInfo = &skydb.AuthInfo{
ID: "ownerID",
}
}).POST(`{
"record_type": "note",
"include": {
"secret": {"$type": "keypath", "$val": "secret"}
}
}`)

So(resp.Body.Bytes(), ShouldEqualJSON, `{
"result": [{
"_id": "note/note1",
"_type": "record",
"_access": null,
"_ownerID": "ownerID",
"category": {"$id":"category/important","$type":"ref"},
"city": {"$id":"city/beautiful","$type":"ref"},
"secret":{"$id":"secret/secretID","$type":"ref"},
"_transient": {
"secret": {"_access":[{"level":"write","relation":"$direct","user_id":"ownerID"}],"_id":"secret/secretID","_type":"record","_ownerID":"ownerID", "content": "Secret of the note"}
}
}]
}`)
})

})

Convey("Given a referenced record with null reference in DB", t, func() {
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/plugin/transportstate_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 6d2305d

Please sign in to comment.