From 45e3ac8a24c91b3bde356c95d07c024666010459 Mon Sep 17 00:00:00 2001 From: Hackerwins Date: Tue, 3 May 2022 19:57:30 +0900 Subject: [PATCH] Clean up codes - Extract Paging - Move project logic to projects - Remove unused codes --- .golangci.yml | 11 ++-- api/types/paging.go | 25 ++++++++ api/types/project.go | 19 ++++++ api/types/project_test.go | 51 ++++++++++++++++ server/admin/server.go | 8 ++- server/backend/db/client_info.go | 3 + server/backend/db/db.go | 8 +-- server/backend/db/doc_info.go | 28 +++++++-- server/backend/db/memory/db.go | 20 +++--- server/backend/db/memory/db_test.go | 24 ++++++-- server/backend/db/mongo/client.go | 24 ++++---- server/backend/db/project_info.go | 19 ------ server/backend/db/project_info_test.go | 25 -------- server/documents/documents.go | 6 +- server/projects/context.go | 36 +++++++++++ server/projects/projects.go | 31 ++++++---- server/rpc/auth/auth.go | 29 +++++---- server/rpc/auth/webhook.go | 13 ++-- .../interceptors/{metadata.go => context.go} | 61 ++++++++++++------- server/rpc/interceptors/default.go | 4 +- server/rpc/{auth => metadata}/context.go | 23 +++---- server/rpc/server.go | 11 +--- server/server.go | 12 ---- 23 files changed, 302 insertions(+), 189 deletions(-) create mode 100644 api/types/paging.go create mode 100644 api/types/project_test.go create mode 100644 server/projects/context.go rename server/rpc/interceptors/{metadata.go => context.go} (53%) rename server/rpc/{auth => metadata}/context.go (61%) diff --git a/.golangci.yml b/.golangci.yml index 6287b95d3..404327330 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,3 @@ -linters-settings: - goimports: - local-prefixes: github.com/yorkie-team/yorkie - linters: enable: - bodyclose @@ -16,5 +12,12 @@ linters: - misspell - nakedret +linters-settings: + goimports: + local-prefixes: github.com/yorkie-team/yorkie + gosec: + excludes: + - G107 # Potential HTTP request made with variable url + issues: exclude-use-default: false diff --git a/api/types/paging.go b/api/types/paging.go new file mode 100644 index 000000000..9dd2711a0 --- /dev/null +++ b/api/types/paging.go @@ -0,0 +1,25 @@ +/* + * Copyright 2022 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package types + +// Paging is the paging information for the document. +type Paging struct { + PreviousID ID + PageSize int + IsForward bool +} diff --git a/api/types/project.go b/api/types/project.go index 76d7ddcc4..3ead43b0a 100644 --- a/api/types/project.go +++ b/api/types/project.go @@ -43,3 +43,22 @@ type Project struct { // CreatedAt is the time when the project was created. CreatedAt time.Time `json:"created_at"` } + +// RequireAuth returns whether the given method requires authorization. +func (p *Project) RequireAuth(method Method) bool { + if len(p.AuthWebhookURL) == 0 { + return false + } + + if len(p.AuthWebhookMethods) == 0 { + return true + } + + for _, m := range p.AuthWebhookMethods { + if Method(m) == method { + return true + } + } + + return false +} diff --git a/api/types/project_test.go b/api/types/project_test.go new file mode 100644 index 000000000..86841ea8e --- /dev/null +++ b/api/types/project_test.go @@ -0,0 +1,51 @@ +/* + * Copyright 2022 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package types_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/yorkie-team/yorkie/api/types" +) + +func TestProjectInfo(t *testing.T) { + t.Run("require auth test", func(t *testing.T) { + // 1. Specify which methods to allow + info := &types.Project{ + AuthWebhookURL: "ValidWebhookURL", + AuthWebhookMethods: []string{string(types.ActivateClient)}, + } + assert.True(t, info.RequireAuth(types.ActivateClient)) + assert.False(t, info.RequireAuth(types.DetachDocument)) + + // 2. Allow all + info2 := &types.Project{ + AuthWebhookURL: "ValidWebhookURL", + AuthWebhookMethods: []string{}, + } + assert.True(t, info2.RequireAuth(types.ActivateClient)) + assert.True(t, info2.RequireAuth(types.DetachDocument)) + + // 3. Empty webhook URL + info3 := &types.Project{ + AuthWebhookURL: "", + } + assert.False(t, info3.RequireAuth(types.ActivateClient)) + }) +} diff --git a/server/admin/server.go b/server/admin/server.go index cf886b279..18008033f 100644 --- a/server/admin/server.go +++ b/server/admin/server.go @@ -183,9 +183,11 @@ func (s *Server) ListDocuments( docs, err := documents.ListDocumentSummaries( ctx, s.backend, - types.ID(req.PreviousId), - int(req.PageSize), - req.IsForward, + types.Paging{ + PreviousID: types.ID(req.PreviousId), + PageSize: int(req.PageSize), + IsForward: req.IsForward, + }, ) if err != nil { return nil, err diff --git a/server/backend/db/client_info.go b/server/backend/db/client_info.go index f09203623..c8be99183 100644 --- a/server/backend/db/client_info.go +++ b/server/backend/db/client_info.go @@ -56,6 +56,9 @@ type ClientInfo struct { // ID is the unique ID of the client. ID types.ID `bson:"_id"` + // ProjectID is the ID of the project the client belongs to. + ProjectID types.ID `bson:"project_id"` + // Key is the key of the client. It is used to identify the client by users. Key string `bson:"key"` diff --git a/server/backend/db/db.go b/server/backend/db/db.go index 17e082152..6d0198203 100644 --- a/server/backend/db/db.go +++ b/server/backend/db/db.go @@ -139,11 +139,9 @@ type DB interface { serverSeq uint64, ) error - // FindDocInfosByPreviousIDAndPageSize returns the documentInfos of the given previousID and pageSize. - FindDocInfosByPreviousIDAndPageSize( + // FindDocInfosByPaging returns the documentInfos of the given paging. + FindDocInfosByPaging( ctx context.Context, - previousID types.ID, - pageSize int, - isForward bool, + paging types.Paging, ) ([]*DocInfo, error) } diff --git a/server/backend/db/doc_info.go b/server/backend/db/doc_info.go index d55e11a7d..c63b147ad 100644 --- a/server/backend/db/doc_info.go +++ b/server/backend/db/doc_info.go @@ -25,13 +25,29 @@ import ( // DocInfo is a structure representing information of the document. type DocInfo struct { - ID types.ID `bson:"_id"` - Key key.Key `bson:"key"` - ServerSeq uint64 `bson:"server_seq"` - Owner types.ID `bson:"owner"` - CreatedAt time.Time `bson:"created_at"` + // ID is the unique ID of the document. + ID types.ID `bson:"_id"` + + // ProjectID is the ID of the project that the document belongs to. + ProjectID types.ID `bson:"project_id"` + + // Key is the key of the document. + Key key.Key `bson:"key"` + + // ServerSeq is the sequence number of the last change of the document on the server. + ServerSeq uint64 `bson:"server_seq"` + + // Owner is the owner(ID of the client) of the document. + Owner types.ID `bson:"owner"` + + // CreatedAt is the time when the document is created. + CreatedAt time.Time `bson:"created_at"` + + // AccessedAt is the time when the document is accessed. AccessedAt time.Time `bson:"accessed_at"` - UpdatedAt time.Time `bson:"updated_at"` + + // UpdatedAt is the time when the document is updated. + UpdatedAt time.Time `bson:"updated_at"` } // IncreaseServerSeq increases server sequence of the document. diff --git a/server/backend/db/memory/db.go b/server/backend/db/memory/db.go index f4f44c148..8bba39e7c 100644 --- a/server/backend/db/memory/db.go +++ b/server/backend/db/memory/db.go @@ -655,26 +655,24 @@ func (d *DB) UpdateSyncedSeq( return nil } -// FindDocInfosByPreviousIDAndPageSize returns the docInfos of the given previousID and pageSize. -func (d *DB) FindDocInfosByPreviousIDAndPageSize( +// FindDocInfosByPaging returns the documentInfos of the given paging. +func (d *DB) FindDocInfosByPaging( ctx context.Context, - previousID types.ID, - pageSize int, - isForward bool, + paging types.Paging, ) ([]*db.DocInfo, error) { txn := d.db.Txn(false) defer txn.Abort() var iterator memdb.ResultIterator var err error - if isForward { + if paging.IsForward { iterator, err = txn.LowerBound( tblDocuments, "id", - previousID.String(), + paging.PreviousID.String(), ) } else { - if previousID == "" { + if paging.PreviousID == "" { iterator, err = txn.GetReverse( tblDocuments, "id", @@ -683,7 +681,7 @@ func (d *DB) FindDocInfosByPreviousIDAndPageSize( iterator, err = txn.ReverseLowerBound( tblDocuments, "id", - previousID.String(), + paging.PreviousID.String(), ) } } @@ -695,11 +693,11 @@ func (d *DB) FindDocInfosByPreviousIDAndPageSize( var docInfos []*db.DocInfo for raw := iterator.Next(); raw != nil; raw = iterator.Next() { info := raw.(*db.DocInfo) - if len(docInfos) >= pageSize { + if len(docInfos) >= paging.PageSize { break } - if info.ID != previousID { + if info.ID != paging.PreviousID { docInfos = append(docInfos, info) } } diff --git a/server/backend/db/memory/db_test.go b/server/backend/db/memory/db_test.go index 3c47b6337..e05522e80 100644 --- a/server/backend/db/memory/db_test.go +++ b/server/backend/db/memory/db_test.go @@ -215,27 +215,41 @@ func TestDB(t *testing.T) { } // initial page, previousID is empty - infos, err := localDB.FindDocInfosByPreviousIDAndPageSize(ctx, "", pageSize, false) + infos, err := localDB.FindDocInfosByPaging(ctx, types.Paging{PageSize: pageSize}) assert.NoError(t, err) assertKeys([]key.Key{"8", "7", "6", "5", "4"}, infos) // backward - infos, err = localDB.FindDocInfosByPreviousIDAndPageSize(ctx, infos[len(infos)-1].ID, pageSize, false) + infos, err = localDB.FindDocInfosByPaging(ctx, types.Paging{ + PreviousID: infos[len(infos)-1].ID, + PageSize: pageSize, + }) assert.NoError(t, err) assertKeys([]key.Key{"3", "2", "1", "0"}, infos) // backward again - emptyInfos, err := localDB.FindDocInfosByPreviousIDAndPageSize(ctx, infos[len(infos)-1].ID, pageSize, false) + emptyInfos, err := localDB.FindDocInfosByPaging(ctx, types.Paging{ + PreviousID: infos[len(infos)-1].ID, + PageSize: pageSize, + }) assert.NoError(t, err) assertKeys(nil, emptyInfos) // forward - infos, err = localDB.FindDocInfosByPreviousIDAndPageSize(ctx, infos[0].ID, pageSize, true) + infos, err = localDB.FindDocInfosByPaging(ctx, types.Paging{ + PreviousID: infos[0].ID, + PageSize: pageSize, + IsForward: true, + }) assert.NoError(t, err) assertKeys([]key.Key{"4", "5", "6", "7", "8"}, infos) // forward again - emptyInfos, err = localDB.FindDocInfosByPreviousIDAndPageSize(ctx, infos[len(infos)-1].ID, pageSize, true) + emptyInfos, err = localDB.FindDocInfosByPaging(ctx, types.Paging{ + PreviousID: infos[len(infos)-1].ID, + PageSize: pageSize, + IsForward: true, + }) assert.NoError(t, err) assertKeys(nil, emptyInfos) }) diff --git a/server/backend/db/mongo/client.go b/server/backend/db/mongo/client.go index 40207c490..f2b1bbf00 100644 --- a/server/backend/db/mongo/client.go +++ b/server/backend/db/mongo/client.go @@ -669,31 +669,29 @@ func (c *Client) UpdateAndFindMinSyncedTicket( ), nil } -// FindDocInfosByPreviousIDAndPageSize returns the docInfos of the given previousID and pageSize. -func (c *Client) FindDocInfosByPreviousIDAndPageSize( +// FindDocInfosByPaging returns the docInfos of the given paging. +func (c *Client) FindDocInfosByPaging( ctx context.Context, - previousID types.ID, - pageSize int, - isForward bool, + paging types.Paging, ) ([]*db.DocInfo, error) { filter := bson.M{} - if previousID != "" { - encodedPreviousID, err := encodeID(previousID) + if paging.PreviousID != "" { + encodedPreviousID, err := encodeID(paging.PreviousID) if err != nil { return nil, err } - key := "$lt" - if isForward { - key = "$gt" + k := "$lt" + if paging.IsForward { + k = "$gt" } filter = bson.M{ - "_id": bson.M{key: encodedPreviousID}, + "_id": bson.M{k: encodedPreviousID}, } } - opts := options.Find().SetLimit(int64(pageSize)) - if !isForward { + opts := options.Find().SetLimit(int64(paging.PageSize)) + if !paging.IsForward { opts = opts.SetSort(map[string]int{"_id": -1}) } diff --git a/server/backend/db/project_info.go b/server/backend/db/project_info.go index 33f6e8b89..17482a3a3 100644 --- a/server/backend/db/project_info.go +++ b/server/backend/db/project_info.go @@ -105,25 +105,6 @@ func (i *ProjectInfo) Validate() error { return nil } -// RequireAuth returns whether the given method requires authorization. -func (i *ProjectInfo) RequireAuth(method types.Method) bool { - if len(i.AuthWebhookURL) == 0 { - return false - } - - if len(i.AuthWebhookMethods) == 0 { - return true - } - - for _, m := range i.AuthWebhookMethods { - if types.Method(m) == method { - return true - } - } - - return false -} - // ToProject converts the ProjectInfo to the Project. func (i *ProjectInfo) ToProject() *types.Project { return &types.Project{ diff --git a/server/backend/db/project_info_test.go b/server/backend/db/project_info_test.go index c227c7a9b..02ccd31c1 100644 --- a/server/backend/db/project_info_test.go +++ b/server/backend/db/project_info_test.go @@ -21,35 +21,10 @@ import ( "github.com/stretchr/testify/assert" - "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/server/backend/db" ) func TestProjectInfo(t *testing.T) { - t.Run("require auth test", func(t *testing.T) { - // 1. Specify which methods to allow - info := &db.ProjectInfo{ - AuthWebhookURL: "ValidWebhookURL", - AuthWebhookMethods: []string{string(types.ActivateClient)}, - } - assert.True(t, info.RequireAuth(types.ActivateClient)) - assert.False(t, info.RequireAuth(types.DetachDocument)) - - // 2. Allow all - info2 := &db.ProjectInfo{ - AuthWebhookURL: "ValidWebhookURL", - AuthWebhookMethods: []string{}, - } - assert.True(t, info2.RequireAuth(types.ActivateClient)) - assert.True(t, info2.RequireAuth(types.DetachDocument)) - - // 3. Empty webhook URL - info3 := &db.ProjectInfo{ - AuthWebhookURL: "", - } - assert.False(t, info3.RequireAuth(types.ActivateClient)) - }) - t.Run("validation test", func(t *testing.T) { conf := &db.ProjectInfo{ AuthWebhookMethods: []string{"ActivateClient"}, diff --git a/server/documents/documents.go b/server/documents/documents.go index 6de45fc8c..825862425 100644 --- a/server/documents/documents.go +++ b/server/documents/documents.go @@ -28,11 +28,9 @@ import ( func ListDocumentSummaries( ctx context.Context, be *backend.Backend, - previousID types.ID, - pageSize int, - isForward bool, + paging types.Paging, ) ([]*types.DocumentSummary, error) { - docInfo, err := be.DB.FindDocInfosByPreviousIDAndPageSize(ctx, previousID, pageSize, isForward) + docInfo, err := be.DB.FindDocInfosByPaging(ctx, paging) if err != nil { return nil, err } diff --git a/server/projects/context.go b/server/projects/context.go new file mode 100644 index 000000000..0d267d9c5 --- /dev/null +++ b/server/projects/context.go @@ -0,0 +1,36 @@ +/* + * Copyright 2022 The Yorkie Authors. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package projects + +import ( + "context" + + "github.com/yorkie-team/yorkie/api/types" +) + +// projectKey is the key for the context.Context. +type projectKey struct{} + +// From returns the project from the context. +func From(ctx context.Context) *types.Project { + return ctx.Value(projectKey{}).(*types.Project) +} + +// With creates a new context with the given Project. +func With(ctx context.Context, project *types.Project) context.Context { + return context.WithValue(ctx, projectKey{}, project) +} diff --git a/server/projects/projects.go b/server/projects/projects.go index 6cd85e074..7c00181c6 100644 --- a/server/projects/projects.go +++ b/server/projects/projects.go @@ -24,19 +24,6 @@ import ( "github.com/yorkie-team/yorkie/server/backend/db" ) -// FindProjectByPublicKey finds the project by public key. -func FindProjectByPublicKey( - ctx context.Context, - be *backend.Backend, - publicKey string, -) (*types.Project, error) { - info, err := be.DB.FindProjectInfoByPublicKey(ctx, publicKey) - if err != nil { - return nil, err - } - return info.ToProject(), nil -} - // CreateProject creates a project. func CreateProject( ctx context.Context, @@ -69,6 +56,24 @@ func ListProjects( return projects, nil } +// GetProjectFromAPIKey returns a project from an API key. +func GetProjectFromAPIKey(ctx context.Context, be *backend.Backend, apiKey string) (*types.Project, error) { + if apiKey == "" { + info, err := be.DB.EnsureDefaultProjectInfo(ctx) + if err != nil { + return nil, err + } + return info.ToProject(), nil + } + + info, err := be.DB.FindProjectInfoByPublicKey(ctx, apiKey) + if err != nil { + return nil, err + } + + return info.ToProject(), nil +} + // UpdateProject updates a project. func UpdateProject( ctx context.Context, diff --git a/server/rpc/auth/auth.go b/server/rpc/auth/auth.go index 51d78b50b..fe17bcaf3 100644 --- a/server/rpc/auth/auth.go +++ b/server/rpc/auth/auth.go @@ -22,7 +22,8 @@ import ( "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/pkg/document/change" "github.com/yorkie-team/yorkie/server/backend" - "github.com/yorkie-team/yorkie/server/backend/db" + "github.com/yorkie-team/yorkie/server/projects" + "github.com/yorkie-team/yorkie/server/rpc/metadata" ) // AccessAttributes returns an array of AccessAttribute from the given pack. @@ -42,20 +43,18 @@ func AccessAttributes(pack *change.Pack) []types.AccessAttribute { // VerifyAccess verifies the given access. func VerifyAccess(ctx context.Context, be *backend.Backend, accessInfo *types.AccessInfo) error { - md := MetadataFromCtx(ctx) - - // TODO(hackerwins): Improve the performance of this function. - // Consider using a cache to store the projectInfo. - var projectInfo *db.ProjectInfo - var err error - if md.APIKey == "" { - projectInfo, err = be.DB.EnsureDefaultProjectInfo(ctx) - } else { - projectInfo, err = be.DB.FindProjectInfoByPublicKey(ctx, md.APIKey) - } - if err != nil { - return err + md := metadata.From(ctx) + project := projects.From(ctx) + + if !project.RequireAuth(accessInfo.Method) { + return nil } - return verifyAccess(ctx, be, projectInfo, accessInfo, md) + return verifyAccess( + ctx, + be, + project.AuthWebhookURL, + md.Authorization, + accessInfo, + ) } diff --git a/server/rpc/auth/webhook.go b/server/rpc/auth/webhook.go index d8151f976..2b1aaf742 100644 --- a/server/rpc/auth/webhook.go +++ b/server/rpc/auth/webhook.go @@ -29,7 +29,6 @@ import ( "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/server/backend" - "github.com/yorkie-team/yorkie/server/backend/db" "github.com/yorkie-team/yorkie/server/logging" ) @@ -48,16 +47,12 @@ var ( func verifyAccess( ctx context.Context, be *backend.Backend, - projectInfo *db.ProjectInfo, + authWebhookURL string, + token string, accessInfo *types.AccessInfo, - md Metadata, ) error { - if !projectInfo.RequireAuth(accessInfo.Method) { - return nil - } - reqBody, err := json.Marshal(types.AuthWebhookRequest{ - Token: md.Authorization, + Token: token, Method: accessInfo.Method, Attributes: accessInfo.Attributes, }) @@ -77,7 +72,7 @@ func verifyAccess( var authResp *types.AuthWebhookResponse if err := withExponentialBackoff(ctx, be.Config, func() (int, error) { resp, err := http.Post( - projectInfo.AuthWebhookURL, + authWebhookURL, "application/json", bytes.NewBuffer(reqBody), ) diff --git a/server/rpc/interceptors/metadata.go b/server/rpc/interceptors/context.go similarity index 53% rename from server/rpc/interceptors/metadata.go rename to server/rpc/interceptors/context.go index b72888162..1a8438ac4 100644 --- a/server/rpc/interceptors/metadata.go +++ b/server/rpc/interceptors/context.go @@ -23,27 +23,28 @@ import ( grpcmiddleware "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/metadata" + grpcmetadata "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "github.com/yorkie-team/yorkie/server/backend" - "github.com/yorkie-team/yorkie/server/rpc/auth" + "github.com/yorkie-team/yorkie/server/projects" + "github.com/yorkie-team/yorkie/server/rpc/metadata" ) -// MetadataInterceptor is an interceptor for extracting metadata from gRPC context. -type MetadataInterceptor struct { +// ContextInterceptor is an interceptor for building additional context. +type ContextInterceptor struct { backend *backend.Backend } -// NewMetadataInterceptor creates a new instance of MetadataInterceptor. -func NewMetadataInterceptor(be *backend.Backend) *MetadataInterceptor { - return &MetadataInterceptor{ +// NewContextInterceptor creates a new instance of ContextInterceptor. +func NewContextInterceptor(be *backend.Backend) *ContextInterceptor { + return &ContextInterceptor{ backend: be, } } -// Unary creates a unary server interceptor for authorization. -func (i *MetadataInterceptor) Unary() grpc.UnaryServerInterceptor { +// Unary creates a unary server interceptor for building additional context. +func (i *ContextInterceptor) Unary() grpc.UnaryServerInterceptor { return func( ctx context.Context, req interface{}, @@ -51,19 +52,20 @@ func (i *MetadataInterceptor) Unary() grpc.UnaryServerInterceptor { handler grpc.UnaryHandler, ) (resp interface{}, err error) { if isRPCService(info.FullMethod) { - md, err := i.extractMetadata(ctx) + ctx, err := i.buildContext(ctx) if err != nil { return nil, err } - return handler(auth.CtxWithMetadata(ctx, md), req) + + return handler(ctx, req) } return handler(ctx, req) } } -// Stream creates a stream server interceptor for authorization. -func (i *MetadataInterceptor) Stream() grpc.StreamServerInterceptor { +// Stream creates a stream server interceptor for building additional context. +func (i *ContextInterceptor) Stream() grpc.StreamServerInterceptor { return func( srv interface{}, ss grpc.ServerStream, @@ -71,12 +73,15 @@ func (i *MetadataInterceptor) Stream() grpc.StreamServerInterceptor { handler grpc.StreamHandler, ) error { if isRPCService(info.FullMethod) { - md, err := i.extractMetadata(ss.Context()) + ctx := ss.Context() + + ctx, err := i.buildContext(ctx) if err != nil { return err } + wrapped := grpcmiddleware.WrapServerStream(ss) - wrapped.WrappedContext = auth.CtxWithMetadata(ss.Context(), md) + wrapped.WrappedContext = ctx return handler(srv, wrapped) } @@ -88,18 +93,20 @@ func isRPCService(method string) bool { return strings.HasPrefix(method, "/api.Yorkie/") } -func (i *MetadataInterceptor) extractMetadata(ctx context.Context) (auth.Metadata, error) { - md := auth.Metadata{} - data, ok := metadata.FromIncomingContext(ctx) +// buildContext builds a context data for RPC. It includes the metadata of the +// request and the project information. +func (i *ContextInterceptor) buildContext(ctx context.Context) (context.Context, error) { + // 01. building metadata + md := metadata.Metadata{} + data, ok := grpcmetadata.FromIncomingContext(ctx) if !ok { - return md, status.Errorf(codes.Unauthenticated, "metadata is not provided") + return nil, status.Errorf(codes.Unauthenticated, "metadata is not provided") } apiKey := data["x-api-key"] if len(apiKey) == 0 && !i.backend.Config.UseDefaultProject { - return md, status.Errorf(codes.Unauthenticated, "api key is not provided") + return nil, status.Errorf(codes.Unauthenticated, "api key is not provided") } - if len(apiKey) > 0 { md.APIKey = apiKey[0] } @@ -108,6 +115,16 @@ func (i *MetadataInterceptor) extractMetadata(ctx context.Context) (auth.Metadat if len(authorization) > 0 { md.Authorization = authorization[0] } + ctx = metadata.With(ctx, md) + + // 02. building project + // TODO(hackerwins): Improve the performance of this function. + // Consider using a cache to store the info. + project, err := projects.GetProjectFromAPIKey(ctx, i.backend, md.APIKey) + if err != nil { + return nil, toStatusError(err) + } + ctx = projects.With(ctx, project) - return md, nil + return ctx, nil } diff --git a/server/rpc/interceptors/default.go b/server/rpc/interceptors/default.go index 3821b4096..4d0dd36da 100644 --- a/server/rpc/interceptors/default.go +++ b/server/rpc/interceptors/default.go @@ -52,7 +52,7 @@ func (i *DefaultInterceptor) Unary() grpc.UnaryServerInterceptor { if gotime.Since(start) > 100*gotime.Millisecond { reqLogger.Infof("RPC : %q %s", info.FullMethod, gotime.Since(start)) } - return resp, err + return resp, nil } } @@ -74,6 +74,6 @@ func (i *DefaultInterceptor) Stream() grpc.StreamServerInterceptor { } reqLogger.Infof("RPC : stream %q %s", info.FullMethod, gotime.Since(start)) - return err + return nil } } diff --git a/server/rpc/auth/context.go b/server/rpc/metadata/context.go similarity index 61% rename from server/rpc/auth/context.go rename to server/rpc/metadata/context.go index 7c07e25b5..56c75c408 100644 --- a/server/rpc/auth/context.go +++ b/server/rpc/metadata/context.go @@ -1,5 +1,5 @@ /* - * Copyright 2021 The Yorkie Authors. All rights reserved. + * Copyright 2022 The Yorkie Authors. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,17 +14,14 @@ * limitations under the License. */ -package auth +package metadata import ( "context" ) -// key is the key for the context.Context. -type key int - -// metadataKey Key = 0 -const metadataKey key = 0 +// metadataKey is the key for the context.Context. +type metadataKey struct{} // Metadata represents the metadata of the request. type Metadata struct { @@ -35,12 +32,12 @@ type Metadata struct { Authorization string } -// MetadataFromCtx returns the metadata from the given context. -func MetadataFromCtx(ctx context.Context) Metadata { - return ctx.Value(metadataKey).(Metadata) +// From returns the metadata from the given context. +func From(ctx context.Context) Metadata { + return ctx.Value(metadataKey{}).(Metadata) } -// CtxWithMetadata creates a new context with the given Metadata. -func CtxWithMetadata(ctx context.Context, md Metadata) context.Context { - return context.WithValue(ctx, metadataKey, md) +// With creates a new context with the given Metadata. +func With(ctx context.Context, md Metadata) context.Context { + return context.WithValue(ctx, metadataKey{}, md) } diff --git a/server/rpc/server.go b/server/rpc/server.go index eba1c4281..8dba83ce5 100644 --- a/server/rpc/server.go +++ b/server/rpc/server.go @@ -44,20 +44,20 @@ type Server struct { // NewServer creates a new instance of Server. func NewServer(conf *Config, be *backend.Backend) (*Server, error) { loggingInterceptor := interceptors.NewLoggingInterceptor() - metadataInterceptor := interceptors.NewMetadataInterceptor(be) + contextInterceptor := interceptors.NewContextInterceptor(be) defaultInterceptor := interceptors.NewDefaultInterceptor() opts := []grpc.ServerOption{ grpc.UnaryInterceptor(grpcmiddleware.ChainUnaryServer( loggingInterceptor.Unary(), be.Metrics.ServerMetrics().UnaryServerInterceptor(), - metadataInterceptor.Unary(), + contextInterceptor.Unary(), defaultInterceptor.Unary(), )), grpc.StreamInterceptor(grpcmiddleware.ChainStreamServer( loggingInterceptor.Stream(), be.Metrics.ServerMetrics().StreamServerInterceptor(), - metadataInterceptor.Stream(), + contextInterceptor.Stream(), defaultInterceptor.Stream(), )), } @@ -105,11 +105,6 @@ func (s *Server) Shutdown(graceful bool) { } } -// GRPCServer returns the gRPC server. -func (s *Server) GRPCServer() *grpc.Server { - return s.grpcServer -} - func (s *Server) listenAndServeGRPC() error { lis, err := net.Listen("tcp", fmt.Sprintf(":%d", s.conf.Port)) if err != nil { diff --git a/server/server.go b/server/server.go index fdfb7faae..21934695a 100644 --- a/server/server.go +++ b/server/server.go @@ -17,10 +17,8 @@ package server import ( - "context" gosync "sync" - "github.com/yorkie-team/yorkie/api/types" "github.com/yorkie-team/yorkie/server/admin" "github.com/yorkie-team/yorkie/server/backend" "github.com/yorkie-team/yorkie/server/backend/sync" @@ -149,16 +147,6 @@ func (r *Yorkie) AdminAddr() string { return r.conf.AdminAddr() } -// DefaultProject returns the default project. -func (r *Yorkie) DefaultProject() (*types.Project, error) { - info, err := r.backend.DB.EnsureDefaultProjectInfo(context.Background()) - if err != nil { - return nil, err - } - - return info.ToProject(), nil -} - // Members returns the members of this cluster. func (r *Yorkie) Members() map[string]*sync.ServerInfo { return r.backend.Members()