diff --git a/graphql/e2e/auth/auth_test.go b/graphql/e2e/auth/auth_test.go index 914613159e5..9f415696676 100644 --- a/graphql/e2e/auth/auth_test.go +++ b/graphql/e2e/auth/auth_test.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "net/http" "os" "testing" @@ -112,20 +113,59 @@ type uidResult struct { } } -func getJWT(t *testing.T, user, role string) string { - metaInfo.AuthVars = map[string]interface{}{ - "USER": user, - "ROLE": role, +func getJWT(t *testing.T, user, role string) http.Header { + metaInfo.AuthVars = map[string]interface{}{} + if user != "" { + metaInfo.AuthVars["USER"] = user + } + + if role != "" { + metaInfo.AuthVars["ROLE"] = role } jwtToken, err := metaInfo.GetSignedToken("./sample_private_key.pem") require.NoError(t, err) - return jwtToken + + h := make(http.Header) + h.Add(metaInfo.Header, jwtToken) + return h } func TestOrRBACFilter(t *testing.T) { - t.Skip() - testCases := []TestCase{} + testCases := []TestCase{{ + user: "user1", + role: "ADMIN", + result: `{ + "queryProject": [ + { + "name": "Project1" + }, + { + "name": "Project2" + } + ] + }`, + }, { + user: "user1", + role: "USER", + result: `{ + "queryProject": [ + { + "name": "Project1" + } + ] + }`, + }, { + user: "user4", + role: "USER", + result: `{ + "queryProject": [ + { + "name": "Project2" + } + ] + }`, + }} query := ` query { @@ -138,10 +178,9 @@ func TestOrRBACFilter(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -166,11 +205,10 @@ func getColID(t *testing.T, tcase TestCase) string { } getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, Variables: map[string]interface{}{"name": tcase.name}, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -220,11 +258,10 @@ func TestRootGetFilter(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, Variables: map[string]interface{}{"id": tcase.name}, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -266,11 +303,10 @@ func TestDeepFilter(t *testing.T) { for _, tcase := range tcases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, Variables: map[string]interface{}{"name": tcase.name}, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -298,16 +334,46 @@ func TestRootFilter(t *testing.T) { queryColumn(order: {asc: name}) { name } + }` + + for _, tcase := range testCases { + t.Run(tcase.role+tcase.user, func(t *testing.T) { + getUserParams := &common.GraphQLParams{ + Headers: getJWT(t, tcase.user, tcase.role), + Query: query, + } + + gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) + require.Nil(t, gqlResponse.Errors) + + require.JSONEq(t, string(gqlResponse.Data), tcase.result) + }) + } +} + +func TestDeepRBACValue(t *testing.T) { + testCases := []TestCase{ + {user: "user1", role: "USER", result: `{"queryUser": [{"username": "user1", "issues":[]}]}`}, + {user: "user1", role: "ADMIN", result: `{"queryUser":[{"username":"user1","issues":[{"msg":"Issue1"}]}]}`}, } + + query := ` +{ + queryUser (filter:{username:{eq:"user1"}}) { + username + issues { + msg + } + } +} ` for _, tcase := range testCases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -318,8 +384,11 @@ func TestRootFilter(t *testing.T) { } func TestRBACFilter(t *testing.T) { - t.Skip() - testCases := []TestCase{} + testCases := []TestCase{ + {role: "USER", result: `{"queryLog": []}`}, + {result: `{"queryLog": []}`}, + {role: "ADMIN", result: `{"queryLog": [{"logs": "Log1"},{"logs": "Log2"}]}`}} + query := ` query { queryLog (order: {asc: logs}) { @@ -331,10 +400,9 @@ func TestRBACFilter(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -345,8 +413,19 @@ func TestRBACFilter(t *testing.T) { } func TestAndRBACFilter(t *testing.T) { - t.Skip() - testCases := []TestCase{} + testCases := []TestCase{{ + user: "user1", + role: "USER", + result: `{"queryIssue": []}`, + }, { + user: "user2", + role: "USER", + result: `{"queryIssue": []}`, + }, { + user: "user2", + role: "ADMIN", + result: `{"queryIssue": [{"msg": "Issue2"}]}`, + }} query := ` query { queryIssue (order: {asc: msg}) { @@ -358,10 +437,9 @@ func TestAndRBACFilter(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -369,7 +447,6 @@ func TestAndRBACFilter(t *testing.T) { require.JSONEq(t, string(gqlResponse.Data), tcase.result) }) } - } func TestNestedFilter(t *testing.T) { @@ -456,10 +533,9 @@ func TestNestedFilter(t *testing.T) { for _, tcase := range testCases { t.Run(tcase.role+tcase.user, func(t *testing.T) { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -504,13 +580,12 @@ func TestDeleteAuthRule(t *testing.T) { for _, tcase := range testCases { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, Variables: map[string]interface{}{ "filter": tcase.filter, }, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) @@ -540,7 +615,7 @@ func AddDeleteDeepAuthTestData(t *testing.T) { require.NoError(t, err) userQuery := `{ - query(func: type(User)) @filter(eq(User.username, "user1") or eq(User.username, "user3") or + query(func: type(User)) @filter(eq(User.username, "user1") or eq(User.username, "user3") or eq(User.username, "user5") ) { uid } }` @@ -615,13 +690,12 @@ func TestDeleteDeepAuthRule(t *testing.T) { for _, tcase := range testCases { getUserParams := &common.GraphQLParams{ - Headers: map[string][]string{}, + Headers: getJWT(t, tcase.user, tcase.role), Query: query, Variables: map[string]interface{}{ "filter": tcase.filter, }, } - getUserParams.Headers.Add(metaInfo.Header, getJWT(t, tcase.user, tcase.role)) gqlResponse := getUserParams.ExecuteAsPost(t, graphqlURL) require.Nil(t, gqlResponse.Errors) diff --git a/graphql/e2e/auth/schema.graphql b/graphql/e2e/auth/schema.graphql index 318d250d1bb..5e4466e986c 100644 --- a/graphql/e2e/auth/schema.graphql +++ b/graphql/e2e/auth/schema.graphql @@ -22,6 +22,7 @@ type User @auth( disabled: Boolean tickets: [Ticket] @hasInverse(field: assignedTo) secrets: [UserSecret] + issues: [Issue] } type UserSecret @auth( @@ -131,8 +132,19 @@ type Log @auth( logs: String } +type ComplexLog @auth( + query: { and : [ + { rule: "{$ROLE: { eq: \"ADMIN\" }}" }, + { not : { rule: "{$ROLE: { eq: \"USER\" }}" }} + ]} +) { + id: ID! + logs: String +} + type Project @auth( - query: { rule: """query($USER: String!) { + query: { or: [ + { rule: """query($USER: String!) { queryProject { roles(filter: { permission: { eq: VIEW } }) { assignedTo(filter: { username: { eq: $USER } }) { @@ -141,6 +153,8 @@ type Project @auth( } } }""" }, + { rule: "{$ROLE: { eq: \"ADMIN\" }}" } + ]} ) { projID: ID! name: String! diff --git a/graphql/e2e/auth/test_data.json b/graphql/e2e/auth/test_data.json index d0c0298e080..d58a5448cb9 100644 --- a/graphql/e2e/auth/test_data.json +++ b/graphql/e2e/auth/test_data.json @@ -5,7 +5,8 @@ "User.age": 10, "User.username": "user1", "User.isPublic": true, - "User.disabled": false + "User.disabled": false, + "User.issues": [{"uid": "_:issue1"}] }, { "uid": "_:user2", @@ -13,7 +14,8 @@ "User.age": 11, "User.username": "user2", "User.isPublic": true, - "User.disabled": true + "User.disabled": true, + "User.issues": [{"uid": "_:issue2"}] }, { "uid": "_:user3", diff --git a/graphql/resolve/auth_query_test.yaml b/graphql/resolve/auth_query_test.yaml index cd475887c18..e3672dda4c9 100644 --- a/graphql/resolve/auth_query_test.yaml +++ b/graphql/resolve/auth_query_test.yaml @@ -52,6 +52,182 @@ UserSecret2 as var(func: uid(UserSecret1)) @filter(eq(UserSecret.ownedBy, "user1")) @cascade } +- name: "Deep RBAC rules true" + gqlquery: | + query { + queryUser { + issues { + id + } + } + } + role: "ADMIN" + dgquery: |- + query { + queryUser(func: type(User)) { + issues : User.issues @filter(uid(Issue1)) { + id : uid + } + dgraph.uid : uid + } + Issue1 as var(func: type(Issue)) @cascade { + owner : Issue.owner @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + } + +- name: "Deep RBAC rules false" + gqlquery: | + query { + queryUser { + username + issues { + id + } + } + } + role: "USER" + dgquery: |- + query { + queryUser(func: type(User)) { + username : User.username + dgraph.uid : uid + } + } + + +- name: "Auth with top level AND rbac true" + gqlquery: | + query { + queryIssue { + msg + } + } + role: "ADMIN" + dgquery: |- + query { + queryIssue(func: uid(Issue1)) @filter(uid(Issue2)) { + msg : Issue.msg + dgraph.uid : uid + } + Issue1 as var(func: type(Issue)) + Issue2 as var(func: uid(Issue1)) @cascade { + owner : Issue.owner @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + } + +- name: "Auth with complex rbac rules, true" + gqlquery: | + query { + queryComplexLog { + logs + } + } + role: "ADMIN" + dgquery: |- + query { + queryComplexLog(func: type(ComplexLog)) { + logs : ComplexLog.logs + dgraph.uid : uid + } + } + +- name: "Auth with complex rbac rules, false" + gqlquery: | + query { + queryComplexLog { + logs + } + } + role: "USER" + dgquery: |- + query { + queryComplexLog() + } + +- name: "Auth with top level rbac true" + gqlquery: | + query { + queryLog { + logs + } + } + role: "ADMIN" + dgquery: |- + query { + queryLog(func: type(Log)) { + logs : Log.logs + dgraph.uid : uid + } + } + +- name: "Auth with top level rbac false" + gqlquery: | + query { + queryLog { + logs + } + } + role: "USER" + dgquery: |- + query { + queryLog() + } + +- name: "Auth with top level AND rbac false" + gqlquery: | + query { + queryIssue { + msg + } + } + role: "USER" + dgquery: |- + query { + queryIssue() + } + + +- name: "Auth with top level OR rbac true" + gqlquery: | + query { + queryProject { + name + } + } + role: "ADMIN" + dgquery: |- + query { + queryProject(func: type(Project)) { + name : Project.name + dgraph.uid : uid + } + } + +- name: "Auth with top level OR rbac false" + gqlquery: | + query { + queryProject { + name + } + } + role: "USER" + dgquery: |- + query { + queryProject(func: uid(Project1)) @filter(uid(Project2)) { + name : Project.name + dgraph.uid : uid + } + Project1 as var(func: type(Project)) + Project2 as var(func: uid(Project1)) @cascade { + roles : Project.roles @filter(eq(Role.permission, "VIEW")) { + assignedTo : Role.assignedTo @filter(eq(User.username, "user1")) + dgraph.uid : uid + } + dgraph.uid : uid + } + } - name: "Auth with top level filter : query, filter and order" gqlquery: | diff --git a/graphql/resolve/auth_test.go b/graphql/resolve/auth_test.go index cd4e61774f4..895c3940e41 100644 --- a/graphql/resolve/auth_test.go +++ b/graphql/resolve/auth_test.go @@ -154,6 +154,7 @@ func queryRewriting(t *testing.T, sch string, authMeta *testutil.AuthMeta) { authMeta.AuthVars = map[string]interface{}{ "USER": "user1", + "ROLE": tcase.Role, } ctx, err := authMeta.AddClaimsToContext(context.Background()) require.NoError(t, err) diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 48079b4b65a..fca755d77a4 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -192,12 +192,19 @@ func addUID(dgQuery *gql.GraphQuery) { } func rewriteAsQueryByIds(field schema.Field, uids []uint64, authRw *authRewriter) *gql.GraphQuery { + rbac := authRw.evaluateRBAC(field) dgQuery := &gql.GraphQuery{ Attr: field.ResponseName(), - Func: &gql.Function{ - Name: "uid", - UID: uids, - }, + } + + if rbac == schema.Negative { + dgQuery.Attr = dgQuery.Attr + "()" + return dgQuery + } + + dgQuery.Func = &gql.Function{ + Name: "uid", + UID: uids, } if ids := idFilter(field, field.Type().IDField()); ids != nil { @@ -208,7 +215,10 @@ func rewriteAsQueryByIds(field schema.Field, uids []uint64, authRw *authRewriter selectionAuth := addSelectionSetFrom(dgQuery, field, authRw) addUID(dgQuery) - dgQuery = authRw.addAuthQueries(field.Type(), dgQuery) + if rbac == schema.Uncertain { + dgQuery = authRw.addAuthQueries(field.Type(), dgQuery) + } + if len(selectionAuth) > 0 { dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} } @@ -232,6 +242,10 @@ func rewriteAsGet( auth *authRewriter) *gql.GraphQuery { var dgQuery *gql.GraphQuery + rbac := auth.evaluateRBAC(field) + if rbac == schema.Negative { + return &gql.GraphQuery{} + } if xid == nil { dgQuery = rewriteAsQueryByIds(field, []uint64{uid}, auth) @@ -280,7 +294,10 @@ func rewriteAsGet( addUID(dgQuery) addTypeFilter(dgQuery, field.Type()) - dgQuery = auth.addAuthQueries(field.Type(), dgQuery) + if rbac == schema.Uncertain { + dgQuery = auth.addAuthQueries(field.Type(), dgQuery) + } + if len(selectionAuth) > 0 { dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} } @@ -289,10 +306,16 @@ func rewriteAsGet( } func rewriteAsQuery(field schema.Field, authRw *authRewriter) *gql.GraphQuery { + rbac := authRw.evaluateRBAC(field) dgQuery := &gql.GraphQuery{ Attr: field.ResponseName(), } + if rbac == schema.Negative { + dgQuery.Attr = dgQuery.Attr + "()" + return dgQuery + } + if authRw != nil && authRw.isWritingAuth && authRw.varName != "" { // When rewriting auth rules, they always start like // Todo2 as var(func: uid(Todo1)) @cascade { @@ -317,7 +340,9 @@ func rewriteAsQuery(field schema.Field, authRw *authRewriter) *gql.GraphQuery { selectionAuth := addSelectionSetFrom(dgQuery, field, authRw) addUID(dgQuery) - dgQuery = authRw.addAuthQueries(field.Type(), dgQuery) + if rbac == schema.Uncertain { + dgQuery = authRw.addAuthQueries(field.Type(), dgQuery) + } if len(selectionAuth) > 0 { dgQuery = &gql.GraphQuery{Children: append([]*gql.GraphQuery{dgQuery}, selectionAuth...)} @@ -408,6 +433,16 @@ func (authRw *authRewriter) rewriteAuthQueries(typ schema.Type) ([]*gql.GraphQue }).rewriteRuleNode(typ, authRw.selector(typ)) } +func (authRw *authRewriter) evaluateRBAC(f schema.Field) schema.RuleResult { + if authRw == nil || authRw.isWritingAuth { + return schema.Uncertain + } + + typ := f.Type() + rn := authRw.selector(typ) + return rn.EvaluateRBACRules(authRw.authVariables) +} + func (authRw *authRewriter) rewriteRuleNode( typ schema.Type, rn *schema.RuleNode) ([]*gql.GraphQuery, *gql.FilterTree) { @@ -425,7 +460,9 @@ func (authRw *authRewriter) rewriteRuleNode( for _, orRn := range rns { q, f := authRw.rewriteRuleNode(typ, orRn) qrys = append(qrys, q...) - filts = append(filts, f) + if f != nil { + filts = append(filts, f) + } } return qrys, filts } @@ -433,18 +470,33 @@ func (authRw *authRewriter) rewriteRuleNode( switch { case len(rn.And) > 0: qrys, filts := nodeList(typ, rn.And) + if len(filts) == 0 { + return qrys, nil + } + if len(filts) == 1 { + return qrys, filts[0] + } return qrys, &gql.FilterTree{ Op: "and", Child: filts, } case len(rn.Or) > 0: qrys, filts := nodeList(typ, rn.Or) + if len(filts) == 0 { + return qrys, nil + } + if len(filts) == 1 { + return qrys, filts[0] + } return qrys, &gql.FilterTree{ Op: "or", Child: filts, } case rn.Not != nil: qrys, filter := authRw.rewriteRuleNode(typ, rn.Not) + if filter == nil { + return qrys, nil + } return qrys, &gql.FilterTree{ Op: "not", Child: []*gql.FilterTree{filter}, @@ -562,15 +614,23 @@ func addSelectionSetFrom( addFilter(child, f.Type(), filter) addOrder(child, f) addPagination(child, f) + rbac := auth.evaluateRBAC(f) selectionAuth := addSelectionSetFrom(child, f, auth) addedFields[f.Name()] = true - q.Children = append(q.Children, child) + + if rbac == schema.Positive || rbac == schema.Uncertain { + q.Children = append(q.Children, child) + } + + if rbac != schema.Uncertain { + continue + } fieldAuth, authFilter := auth.rewriteAuthQueries(f.Type()) authQueries = append(authQueries, selectionAuth...) authQueries = append(authQueries, fieldAuth...) - if len(fieldAuth) > 0 { + if authFilter != nil { if child.Filter == nil { child.Filter = authFilter } else { diff --git a/graphql/schema/auth.go b/graphql/schema/auth.go index e63847ed09f..cef660c1117 100644 --- a/graphql/schema/auth.go +++ b/graphql/schema/auth.go @@ -51,6 +51,74 @@ type AuthContainer struct { Delete *RuleNode } +type RuleResult int + +const ( + Uncertain RuleResult = iota + Positive + Negative +) + +func (rq *RBACQuery) EvaluateRBACRule(av map[string]interface{}) RuleResult { + if rq.Operator == "eq" { + if av[rq.Variable] == rq.Operand { + return Positive + } + } + return Negative +} + +func (node *RuleNode) EvaluateRBACRules(av map[string]interface{}) RuleResult { + if node == nil { + return Uncertain + } + + hasUncertain := false + for _, rule := range node.Or { + val := rule.EvaluateRBACRules(av) + if val == Positive { + return Positive + } else if val == Uncertain { + hasUncertain = true + } + } + + if len(node.Or) > 0 && !hasUncertain { + return Negative + } + + for _, rule := range node.And { + val := rule.EvaluateRBACRules(av) + if val == Negative { + return Negative + } else if val == Uncertain { + hasUncertain = true + } + } + + if len(node.And) > 0 && !hasUncertain { + return Positive + } + + if node.Not != nil && node.Not.RBACRule != nil { + switch node.Not.EvaluateRBACRules(av) { + case Uncertain: + return Uncertain + case Positive: + return Negative + case Negative: + return Positive + } + + } + + if node.RBACRule != nil { + return node.RBACRule.EvaluateRBACRule(av) + } + + return Uncertain +} + type TypeAuth struct { Rules *AuthContainer Fields map[string]*AuthContainer