Skip to content

Commit

Permalink
feat: add dynamic cors support (#6174) (#6270)
Browse files Browse the repository at this point in the history
  • Loading branch information
poonai authored Aug 25, 2020
1 parent 6acaef2 commit 3c19386
Show file tree
Hide file tree
Showing 28 changed files with 610 additions and 57 deletions.
58 changes: 56 additions & 2 deletions dgraph/cmd/alpha/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ import (
"syscall"
"time"

"github.com/dgraph-io/dgraph/graphql/web"

badgerpb "github.com/dgraph-io/badger/v2/pb"
"github.com/dgraph-io/badger/v2/y"
"github.com/dgraph-io/dgo/v200/protos/api"
"github.com/dgraph-io/dgraph/edgraph"
"github.com/dgraph-io/dgraph/ee/enc"
"github.com/dgraph-io/dgraph/graphql/admin"
"github.com/dgraph-io/dgraph/graphql/web"
"github.com/dgraph-io/dgraph/posting"
"github.com/dgraph-io/dgraph/protos/pb"
"github.com/dgraph-io/dgraph/schema"
"github.com/dgraph-io/dgraph/tok"
"github.com/dgraph-io/dgraph/worker"
Expand Down Expand Up @@ -710,7 +711,21 @@ func run() {
// and health check passes
edgraph.ResetAcl()
edgraph.RefreshAcls(aclCloser)
edgraph.ResetCors()
// Update the accepted cors origins.
for {
origins, err := edgraph.GetCorsOrigins(context.TODO())
if err != nil {
glog.Errorf("Error while retriving cors origins: %s", err.Error())
continue
}
x.UpdateCorsOrigins(origins)
break
}
}()
// Listen for any new cors origin update.
corsCloser := y.NewCloser(1)
go listenForCorsUpdate(corsCloser)

// Graphql subscribes to alpha to get schema updates. We need to close that before we
// close alpha. This closer is for closing and waiting that subscription.
Expand All @@ -719,10 +734,49 @@ func run() {
setupServer(adminCloser)
glog.Infoln("GRPC and HTTP stopped.")
aclCloser.SignalAndWait()
corsCloser.SignalAndWait()
worker.BlockingStop()
adminCloser.SignalAndWait()
glog.Info("Disposing server state.")
worker.State.Dispose()
x.RemoveCidFile()
glog.Infoln("Server shutdown. Bye!")
}

// listenForCorsUpdate listen for any cors change and update the accepeted cors.
func listenForCorsUpdate(closer *y.Closer) {
prefix := x.DataKey("dgraph.cors", 0)
// Remove uid from the key, to get the correct prefix
prefix = prefix[:len(prefix)-8]
worker.SubscribeForUpdates([][]byte{prefix}, func(kvs *badgerpb.KVList) {
// Last update contains the latest value. So, taking the last update.
lastIdx := len(kvs.GetKv()) - 1
kv := kvs.GetKv()[lastIdx]
glog.Infof("Updating cors from subscription.")
// Unmarshal the incoming posting list.
pl := &pb.PostingList{}
err := pl.Unmarshal(kv.GetValue())
if err != nil {
glog.Errorf("Unable to unmarshal the posting list for cors update %s", err)
return
}
// Skip if there is no posting. Our all upsert call contains atleast one
// posting.
if len(pl.Postings) == 0 {
return
}
origins := make([]string, 0)
for _, posting := range pl.Postings {
val := strings.TrimSpace(string(posting.Value))
if val == "_STAR_ALL" {
// If the posting list contains __STAR_ALL then it's a delete call.
// we usually do it before updating as part of upsert. So, let's
// ignore this update.
continue
}
origins = append(origins, val)
}
glog.Infof("Updating cors origins: %+v", origins)
x.UpdateCorsOrigins(origins)
}, 1, closer)
}
6 changes: 3 additions & 3 deletions dgraph/cmd/alpha/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ func TestDeletePredicate(t *testing.T) {
testutil.CompareJSON(t, `{"data":{"schema":[`+
`{"predicate":"age","type":"default"},`+
`{"predicate":"name","type":"string","index":true, "tokenizer":["term"]},`+
x.AclPredicates+","+x.GraphqlPredicates+","+
x.AclPredicates+","+x.GraphqlPredicates+","+x.CorsPredicate+","+
`{"predicate":"dgraph.type","type":"string","index":true, "tokenizer":["exact"],
"list":true}],`+x.InitialTypes+`}}`, output)

Expand Down Expand Up @@ -1077,7 +1077,7 @@ func TestListTypeSchemaChange(t *testing.T) {
res, err = runGraphqlQuery(q)
require.NoError(t, err)
testutil.CompareJSON(t, `{"data":{"schema":[`+
x.AclPredicates+","+x.GraphqlPredicates+","+
x.AclPredicates+","+x.GraphqlPredicates+","+x.CorsPredicate+","+
`{"predicate":"occupations","type":"string"},`+
`{"predicate":"dgraph.type", "type":"string", "index":true, "tokenizer": ["exact"],
"list":true}],`+x.InitialTypes+`}}`, res)
Expand Down Expand Up @@ -1325,7 +1325,7 @@ func TestDropAll(t *testing.T) {
require.NoError(t, err)
testutil.CompareJSON(t,
`{"data":{"schema":[`+
x.AclPredicates+","+x.GraphqlPredicates+","+
x.AclPredicates+","+x.GraphqlPredicates+","+x.CorsPredicate+","+
`{"predicate":"dgraph.type", "type":"string", "index":true, "tokenizer":["exact"],
"list":true}],`+x.InitialTypes+`}}`, output)

Expand Down
1 change: 1 addition & 0 deletions dgraph/cmd/bulk/systest/test-bulk-schema.sh
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ EOF
dgraph debug -p out/1/p 2>|/dev/null | grep '{s}' | cut -d' ' -f4 >> all_dbs.out
diff <(LC_ALL=C sort all_dbs.out | uniq -c) - <<EOF
1 dgraph.acl.rule
1 dgraph.cors
1 dgraph.graphql.schema
1 dgraph.graphql.xid
1 dgraph.password
Expand Down
2 changes: 1 addition & 1 deletion dgraph/cmd/live/load-uids/load_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ func TestLiveLoadExportedSchema(t *testing.T) {
require.Nilf(t, resp.Errors, resp.Errors.Error())

// wait a bit to be sure export is complete
time.Sleep(time.Second)
time.Sleep(8 * time.Second)

// copy the export files from docker
exportId, groupId := copyExportToLocalFs(t)
Expand Down
125 changes: 125 additions & 0 deletions edgraph/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ const (
// NoAuthorize is used to indicate that authorization needs to be skipped.
// Used when ACL needs to query information for performing the authorization check.
NoAuthorize
// CorsMutationAllowed is used to indicate that the given request is authorized to do
// cors mutation.
CorsMutationAllowed
)

var (
Expand Down Expand Up @@ -326,6 +329,7 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er
_, err = UpdateGQLSchema(ctx, "", "")
// recreate the admin account after a drop all operation
ResetAcl()
ResetCors()
return empty, err
}

Expand All @@ -350,6 +354,7 @@ func (s *Server) Alter(ctx context.Context, op *api.Operation) (*api.Payload, er
_, err = UpdateGQLSchema(ctx, graphQLSchema, "")
// recreate the admin account after a drop data operation
ResetAcl()
ResetCors()
return empty, err
}

Expand Down Expand Up @@ -973,6 +978,12 @@ func (s *Server) doQuery(ctx context.Context, req *api.Request, doAuth AuthMode)
return
}
}

if doAuth != CorsMutationAllowed {
if rerr = validateCorsInMutation(ctx, qc); rerr != nil {
return
}
}
// We use defer here because for queries, startTs will be
// assigned in the processQuery function called below.
defer annotateStartTs(qc.span, qc.req.StartTs)
Expand Down Expand Up @@ -1208,6 +1219,29 @@ func authorizeRequest(ctx context.Context, qc *queryContext) error {
return nil
}

// validateCorsInMutation check whether mutation contains cors predication. If it's contain cors
// predicate, we'll throw an error.
func validateCorsInMutation(ctx context.Context, qc *queryContext) error {
validateNquad := func(nquads []*api.NQuad) error {
for _, nquad := range nquads {
if nquad.Predicate != "dgraph.cors" {
continue
}
return errors.New("Mutations are not allowed for the predicate dgraph.cors")
}
return nil
}
for _, gmu := range qc.gmuList {
if err := validateNquad(gmu.Set); err != nil {
return err
}
if err := validateNquad(gmu.Del); err != nil {
return err
}
}
return nil
}

// CommitOrAbort commits or aborts a transaction.
func (s *Server) CommitOrAbort(ctx context.Context, tc *api.TxnContext) (*api.TxnContext, error) {
ctx, span := otrace.StartSpan(ctx, "Server.CommitOrAbort")
Expand Down Expand Up @@ -1502,3 +1536,94 @@ func isDropAll(op *api.Operation) bool {
}
return false
}

// ResetCors make the dgraph to accept all the origins if no origins were given
// by the users.
func ResetCors() {
req := &api.Request{
Query: `query{
cors as var(func: has(dgraph.cors))
}`,
Mutations: []*api.Mutation{
{
Set: []*api.NQuad{
{
Subject: "_:a",
Predicate: "dgraph.cors",
ObjectValue: &api.Value{Val: &api.Value_StrVal{StrVal: "*"}},
},
},
Cond: `@if(eq(len(cors), 0))`,
},
},
CommitNow: true,
}

for {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
if _, err := (&Server{}).doQuery(ctx, req, CorsMutationAllowed); err != nil {
glog.Infof("Unable to upsert cors. Error: %v", err)
time.Sleep(100 * time.Millisecond)
}
break
}
}

func generateNquadsForCors(origins []string) []byte {
out := &bytes.Buffer{}
for _, origin := range origins {
out.Write([]byte(fmt.Sprintf("uid(cors) <dgraph.cors> \"%s\" . \n", origin)))
}
return out.Bytes()
}

// AddCorsOrigins Adds the cors origins to the Dgraph.
func AddCorsOrigins(ctx context.Context, origins []string) error {
req := &api.Request{
Query: `query{
cors as var(func: has(dgraph.cors))
}`,
Mutations: []*api.Mutation{
{
SetNquads: generateNquadsForCors(origins),
Cond: `@if(eq(len(cors), 1))`,
DelNquads: []byte(`uid(cors) <dgraph.cors> * .`),
},
},
CommitNow: true,
}
_, err := (&Server{}).doQuery(ctx, req, CorsMutationAllowed)
return err
}

// GetCorsOrigins retrive all the cors origin from the database.
func GetCorsOrigins(ctx context.Context) ([]string, error) {
req := &api.Request{
Query: `query{
me(func: has(dgraph.cors)){
dgraph.cors
}
}`,
ReadOnly: true,
}
res, err := (&Server{}).doQuery(ctx, req, NoAuthorize)
if err != nil {
return nil, err
}

type corsResponse struct {
Me []struct {
DgraphCors []string `json:"dgraph.cors"`
} `json:"me"`
}
corsRes := &corsResponse{}
if err = json.Unmarshal(res.Json, corsRes); err != nil {
return nil, err
}
if len(corsRes.Me) > 1 {
glog.Errorf("Something went wrong in cors predicate, expected 1 predicate but got %d",
len(corsRes.Me))
}
return corsRes.Me[0].DgraphCors, nil
}
12 changes: 11 additions & 1 deletion ee/acl/acl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1936,7 +1936,17 @@ func TestSchemaQueryWithACL(t *testing.T) {
"predicate": "dgraph.acl.rule",
"type": "uid",
"list": true
},
},
{
"predicate": "dgraph.cors",
"type": "string",
"list": true,
"index": true,
"tokenizer": [
"exact"
],
"upsert": true
},
{
"predicate": "dgraph.graphql.schema",
"type": "string"
Expand Down
Loading

0 comments on commit 3c19386

Please sign in to comment.