Skip to content
This repository has been archived by the owner on Aug 16, 2022. It is now read-only.

Commit

Permalink
refactor: remove filter args from repos to prevent implementation err…
Browse files Browse the repository at this point in the history
…ors in the use case layer (#122)
  • Loading branch information
rot1024 authored Mar 15, 2022
1 parent 83a66a4 commit 82cf28c
Show file tree
Hide file tree
Showing 90 changed files with 2,199 additions and 1,639 deletions.
5 changes: 2 additions & 3 deletions internal/adapter/gql/resolver_mutation_tag.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ func (r *mutationResolver) CreateTagGroup(ctx context.Context, input gqlmodel.Cr

func (r *mutationResolver) UpdateTag(ctx context.Context, input gqlmodel.UpdateTagInput) (*gqlmodel.UpdateTagPayload, error) {
tag, err := usecases(ctx).Tag.UpdateTag(ctx, interfaces.UpdateTagParam{
Label: input.Label,
SceneID: id.SceneID(input.SceneID),
TagID: id.TagID(input.TagID),
Label: input.Label,
TagID: id.TagID(input.TagID),
}, getOperator(ctx))
if err != nil {
return nil, err
Expand Down
35 changes: 18 additions & 17 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo {
// basic middleware
logger := GetEchoLogger()
e.Logger = logger
e.Use(logger.Hook(), middleware.Recover(), otelecho.Middleware("reearth-backend"))
e.Use(
logger.Hook(),
middleware.Recover(),
otelecho.Middleware("reearth-backend"),
)
origins := allowedOrigins(cfg)
if len(origins) > 0 {
e.Use(
Expand All @@ -41,6 +45,12 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo {
)
}

e.Use(
jwtEchoMiddleware(cfg),
parseJwtMiddleware(),
authMiddleware(cfg),
)

// enable pprof
if e.Debug {
pprofGroup := e.Group("/debug/pprof")
Expand All @@ -65,14 +75,13 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo {
publishedIndexHTML = string(html)
}
}
usecases := interactor.NewContainer(cfg.Repos, cfg.Gateways, interactor.ContainerConfig{

e.Use(UsecaseMiddleware(cfg.Repos, cfg.Gateways, interactor.ContainerConfig{
SignupSecret: cfg.Config.SignupSecret,
PublishedIndexHTML: publishedIndexHTML,
PublishedIndexURL: cfg.Config.Published.IndexURL,
AuthSrvUIDomain: cfg.Config.AuthSrv.UIDomain,
})

e.Use(UsecaseMiddleware(&usecases))
}))

// auth srv
auth := e.Group("")
Expand All @@ -88,15 +97,13 @@ func initEcho(ctx context.Context, cfg *ServerConfig) *echo.Echo {
api.GET("/published/:name", PublishedMetadata())
api.GET("/published_data/:name", PublishedData())

privateApi := api.Group("")
authRequired(privateApi, cfg)
privateApi := api.Group("", AuthRequiredMiddleware())
graphqlAPI(e, privateApi, cfg)
privateAPI(e, privateApi, cfg.Repos)

published := e.Group("/p")
publishedAuth := PublishedAuthMiddleware()
published.GET("/:name/data.json", PublishedData(), publishedAuth)
published.GET("/:name/", PublishedIndex(), publishedAuth)
published := e.Group("/p", PublishedAuthMiddleware())
published.GET("/:name/data.json", PublishedData())
published.GET("/:name/", PublishedIndex())

serveFiles(e, cfg.Gateways.File)
web(e, cfg.Config.Web, cfg.Config.Auth0)
Expand All @@ -121,12 +128,6 @@ func errorHandler(next func(error, echo.Context)) func(error, echo.Context) {
}
}

func authRequired(g *echo.Group, cfg *ServerConfig) {
g.Use(jwtEchoMiddleware(cfg))
g.Use(parseJwtMiddleware())
g.Use(authMiddleware(cfg))
}

func allowedOrigins(cfg *ServerConfig) []string {
if cfg == nil {
return nil
Expand Down
29 changes: 23 additions & 6 deletions internal/app/auth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,19 @@ func authMiddleware(cfg *ServerConfig) echo.MiddlewareFunc {
}
}

op, err := generateOperator(ctx, cfg, u)
if err != nil {
return err
if sub != "" {
ctx = adapter.AttachSub(ctx, sub)
}

ctx = adapter.AttachSub(ctx, sub)
ctx = adapter.AttachOperator(ctx, op)
ctx = adapter.AttachUser(ctx, u)
if u != nil {
op, err := generateOperator(ctx, cfg, u)
if err != nil {
return err
}

ctx = adapter.AttachUser(ctx, u)
ctx = adapter.AttachOperator(ctx, op)
}

c.SetRequest(req.WithContext(ctx))
return next(c)
Expand Down Expand Up @@ -121,3 +126,15 @@ func addAuth0SubToUser(ctx context.Context, u *user.User, a user.Auth, cfg *Serv
}
return nil
}

func AuthRequiredMiddleware() echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
if adapter.Operator(ctx) == nil {
return echo.ErrUnauthorized
}
return next(c)
}
}
}
10 changes: 5 additions & 5 deletions internal/app/private.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func privateAPI(
if op == nil {
return &echo.HTTPError{Code: http.StatusUnauthorized, Message: ErrOpDenied}
}
scenes := op.AllReadableScenes()
repos := repos.Filtered(repo.TeamFilterFromOperator(op), repo.SceneFilterFromOperator(op))

param := c.Param("param")
params := strings.Split(param, ".")
Expand All @@ -68,7 +68,7 @@ func privateAPI(
return &echo.HTTPError{Code: http.StatusBadRequest, Message: ErrBadID}
}

layer, err := repos.Layer.FindByID(ctx, lid, scenes)
layer, err := repos.Layer.FindByID(ctx, lid)
if err != nil {
if errors.Is(rerror.ErrNotFound, err) {
return &echo.HTTPError{Code: http.StatusNotFound, Message: err}
Expand All @@ -88,11 +88,11 @@ func privateAPI(

ex := &encoding.Exporter{
Merger: &merging.Merger{
LayerLoader: repo.LayerLoaderFrom(repos.Layer, scenes),
PropertyLoader: repo.PropertyLoaderFrom(repos.Property, scenes),
LayerLoader: repo.LayerLoaderFrom(repos.Layer),
PropertyLoader: repo.PropertyLoaderFrom(repos.Property),
},
Sealer: &merging.Sealer{
DatasetGraphLoader: repo.DatasetGraphLoaderFrom(repos.Dataset, scenes),
DatasetGraphLoader: repo.DatasetGraphLoaderFrom(repos.Dataset),
},
Encoder: e,
}
Expand Down
21 changes: 12 additions & 9 deletions internal/app/public_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"testing"

"github.com/labstack/echo/v4"
"github.com/reearth/reearth-backend/internal/adapter"
"github.com/reearth/reearth-backend/internal/usecase/interfaces"
"github.com/reearth/reearth-backend/pkg/rerror"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -70,9 +71,7 @@ func TestPublishedAuthMiddleware(t *testing.T) {
c := e.NewContext(req, res)
c.SetParamNames("name")
c.SetParamValues(tc.PublishedName)
m := UsecaseMiddleware(&interfaces.Container{
Published: &mockPublished{},
})
m := mockPublishedUsecaseMiddleware(false)

err := m(PublishedAuthMiddleware()(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
Expand Down Expand Up @@ -121,9 +120,7 @@ func TestPublishedData(t *testing.T) {
c := e.NewContext(req, res)
c.SetParamNames("name")
c.SetParamValues(tc.PublishedName)
m := UsecaseMiddleware(&interfaces.Container{
Published: &mockPublished{},
})
m := mockPublishedUsecaseMiddleware(false)

err := m(PublishedData())(c)

Expand Down Expand Up @@ -178,9 +175,7 @@ func TestPublishedIndex(t *testing.T) {
c := e.NewContext(req, res)
c.SetParamNames("name")
c.SetParamValues(tc.PublishedName)
m := UsecaseMiddleware(&interfaces.Container{
Published: &mockPublished{EmptyIndex: tc.EmptyIndex},
})
m := mockPublishedUsecaseMiddleware(tc.EmptyIndex)

err := m(PublishedIndex())(c)

Expand All @@ -196,6 +191,14 @@ func TestPublishedIndex(t *testing.T) {
}
}

func mockPublishedUsecaseMiddleware(emptyIndex bool) echo.MiddlewareFunc {
return ContextMiddleware(func(ctx context.Context) context.Context {
return adapter.AttachUsecases(ctx, &interfaces.Container{
Published: &mockPublished{EmptyIndex: emptyIndex},
})
})
}

type mockPublished struct {
interfaces.Published
EmptyIndex bool
Expand Down
34 changes: 27 additions & 7 deletions internal/app/usecase.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
package app

import (
"context"

"github.com/labstack/echo/v4"
"github.com/reearth/reearth-backend/internal/adapter"
"github.com/reearth/reearth-backend/internal/usecase/interfaces"
"github.com/reearth/reearth-backend/internal/usecase/gateway"
"github.com/reearth/reearth-backend/internal/usecase/interactor"
"github.com/reearth/reearth-backend/internal/usecase/repo"
)

func UsecaseMiddleware(uc *interfaces.Container) echo.MiddlewareFunc {
func UsecaseMiddleware(r *repo.Container, g *gateway.Container, config interactor.ContainerConfig) echo.MiddlewareFunc {
return ContextMiddleware(func(ctx context.Context) context.Context {
var r2 *repo.Container
if op := adapter.Operator(ctx); op != nil && r != nil {
// apply filters to repos
r3 := r.Filtered(
repo.TeamFilterFromOperator(op),
repo.SceneFilterFromOperator(op),
)
r2 = &r3
} else {
r2 = r
}

uc := interactor.NewContainer(r2, g, config)
ctx = adapter.AttachUsecases(ctx, &uc)
return ctx
})
}

func ContextMiddleware(fn func(ctx context.Context) context.Context) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
req := c.Request()
ctx := req.Context()

ctx = adapter.AttachUsecases(ctx, uc)

c.SetRequest(req.WithContext(ctx))
c.SetRequest(req.WithContext(fn(req.Context())))
return next(c)
}
}
Expand Down
19 changes: 15 additions & 4 deletions internal/infrastructure/adapter/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,20 @@ func NewPlugin(readers []repo.Plugin, writer repo.Plugin) repo.Plugin {
}
}

func (r *pluginRepo) FindByID(ctx context.Context, id id.PluginID, sids []id.SceneID) (*plugin.Plugin, error) {
func (r *pluginRepo) Filtered(f repo.SceneFilter) repo.Plugin {
readers := make([]repo.Plugin, 0, len(r.readers))
for _, r := range r.readers {
readers = append(readers, r.Filtered(f))
}
return &pluginRepo{
readers: readers,
writer: r.writer.Filtered(f),
}
}

func (r *pluginRepo) FindByID(ctx context.Context, id id.PluginID) (*plugin.Plugin, error) {
for _, re := range r.readers {
if res, err := re.FindByID(ctx, id, sids); err != nil {
if res, err := re.FindByID(ctx, id); err != nil {
if errors.Is(err, rerror.ErrNotFound) {
continue
} else {
Expand All @@ -39,10 +50,10 @@ func (r *pluginRepo) FindByID(ctx context.Context, id id.PluginID, sids []id.Sce
return nil, rerror.ErrNotFound
}

func (r *pluginRepo) FindByIDs(ctx context.Context, ids []id.PluginID, sids []id.SceneID) ([]*plugin.Plugin, error) {
func (r *pluginRepo) FindByIDs(ctx context.Context, ids []id.PluginID) ([]*plugin.Plugin, error) {
results := make([]*plugin.Plugin, 0, len(ids))
for _, id := range ids {
res, err := r.FindByID(ctx, id, sids)
res, err := r.FindByID(ctx, id)
if err != nil && err != rerror.ErrNotFound {
return nil, err
}
Expand Down
11 changes: 11 additions & 0 deletions internal/infrastructure/adapter/property_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ func NewPropertySchema(readers []repo.PropertySchema, writer repo.PropertySchema
}
}

func (r *propertySchema) Filtered(f repo.SceneFilter) repo.PropertySchema {
readers := make([]repo.PropertySchema, 0, len(r.readers))
for _, r := range r.readers {
readers = append(readers, r.Filtered(f))
}
return &propertySchema{
readers: readers,
writer: r.writer.Filtered(f),
}
}

func (r *propertySchema) FindByID(ctx context.Context, id id.PropertySchemaID) (*property.Schema, error) {
for _, re := range r.readers {
if res, err := re.FindByID(ctx, id); err != nil {
Expand Down
17 changes: 12 additions & 5 deletions internal/infrastructure/fs/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (

type pluginRepo struct {
fs afero.Fs
f repo.SceneFilter
}

func NewPlugin(fs afero.Fs) repo.Plugin {
Expand All @@ -24,24 +25,30 @@ func NewPlugin(fs afero.Fs) repo.Plugin {
}
}

func (r *pluginRepo) FindByID(ctx context.Context, pid id.PluginID, sids []id.SceneID) (*plugin.Plugin, error) {
func (r *pluginRepo) Filtered(f repo.SceneFilter) repo.Plugin {
return &pluginRepo{
fs: r.fs,
f: f.Clone(),
}
}

func (r *pluginRepo) FindByID(ctx context.Context, pid id.PluginID) (*plugin.Plugin, error) {
m, err := readPluginManifest(r.fs, pid)
if err != nil {
return nil, err
}

sid := m.Plugin.ID().Scene()
if sid != nil && !sid.Contains(sids) {
if s := m.Plugin.ID().Scene(); s != nil && !r.f.CanRead(*s) {
return nil, nil
}

return m.Plugin, nil
}

func (r *pluginRepo) FindByIDs(ctx context.Context, ids []id.PluginID, sids []id.SceneID) ([]*plugin.Plugin, error) {
func (r *pluginRepo) FindByIDs(ctx context.Context, ids []id.PluginID) ([]*plugin.Plugin, error) {
results := make([]*plugin.Plugin, 0, len(ids))
for _, id := range ids {
res, err := r.FindByID(ctx, id, sids)
res, err := r.FindByID(ctx, id)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/infrastructure/fs/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
func TestPlugin(t *testing.T) {
ctx := context.Background()
fs := NewPlugin(mockPluginFS())
p, err := fs.FindByID(ctx, plugin.MustID("testplugin~1.0.0"), nil)
p, err := fs.FindByID(ctx, plugin.MustID("testplugin~1.0.0"))
assert.NoError(t, err)
assert.Equal(t, plugin.New().ID(plugin.MustID("testplugin~1.0.0")).Name(i18n.String{
"en": "testplugin",
Expand Down
Loading

0 comments on commit 82cf28c

Please sign in to comment.