Skip to content

Commit

Permalink
framework/controller: add initial support for setting cookies in a co…
Browse files Browse the repository at this point in the history
…ntroller
  • Loading branch information
matthewmueller committed Jul 2, 2022
1 parent 3e2a5a9 commit 1a75ede
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 136 deletions.
2 changes: 2 additions & 0 deletions framework/controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
// Embed templates
"context"
_ "embed"
"fmt"

"github.com/livebud/bud/internal/gotemplate"
"github.com/livebud/bud/package/di"
Expand Down Expand Up @@ -44,6 +45,7 @@ func (g *Generator) GenerateFile(ctx context.Context, fsys overlay.F, file *over
if err != nil {
return err
}
fmt.Println(string(code))
file.Data = code
return nil
}
47 changes: 20 additions & 27 deletions framework/controller/controller.gotext
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@ type {{ $.Pascal }}{{$action.Pascal}}Action struct {
{{- if $action.View }}
View view.Server
{{- end }}
{{- if $action.Context }}
{{- range $field := $action.Context.Fields }}
{{- if $field.Hoisted }}
{{$field.Name}} {{$field.Type}}
{{- end }}
{{- with $provider := $action.Provider }}
{{- range $param := $provider.Hoisted }}
{{$param.Key}} {{$param.FullType}}
{{- end }}
{{- end }}
}
Expand All @@ -54,11 +52,11 @@ func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) Method() string

// ServeHTTP fn
func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) ServeHTTP(w http.ResponseWriter, r *http.Request) {
{{$action.Short}}.handler(r).ServeHTTP(w, r)
{{$action.Short}}.handler(w, r).ServeHTTP(w, r)
}

// Handler function
func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) handler(httpRequest *http.Request) http.Handler {
func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) handler(httpResponse http.ResponseWriter, httpRequest *http.Request) http.Handler {
{{- if $action.Params }}
// Define the input struct
var in {{ $action.Input}}
Expand All @@ -72,35 +70,30 @@ func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) handler(httpReq
}
}
{{- end }}
{{- if $action.Context }}
{{ $action.Context.Results.List }} := {{ $action.Context.Function }}(
{{- range $field := $action.Context.Fields }}
{{- if $field.Hoisted }}
{{ $action.Short }}.{{ $field.Name }},
{{- else }}
{{ $field.Variable }},
{{- end }}
{{- with $provider := $action.Provider }}
controller, err := {{ $provider.Name }}(
{{- range $param := $provider.Hoisted }}
{{ $action.Short }}.{{ $param.Key }},
{{- end }}
{{- if $provider.Variable "context.Context" }}httpRequest.Context(),{{ end }}
{{- if $provider.Variable "net/http.*Request" }}httpRequest,{{ end }}
{{- if $provider.Variable "net/http.ResponseWriter" }}httpResponse,{{ end }}
)
{{- if $action.Context.Results.Error }}
if {{ $action.Context.Results.Error }} != nil {
{{- end }}
if err != nil {
return &response.Format{
{{- if ne $action.Method "GET" }}
HTML: response.Status(http.StatusSeeOther).RedirectBack(httpRequest.URL.Path),
{{- end }}
JSON: response.Status(500).Set("Content-Type", "application/json").JSON(map[string]string{"error": {{ $action.Context.Results.Error }}.Error()}),
JSON: response.Status(500).Set("Content-Type", "application/json").JSON(map[string]string{"error": err.Error()}),
}
}
{{- end }}
fn := {{$action.Context.Results.Result}}.{{$action.Name}}
{{- else }}
fn := {{$.Name}}.{{$action.Name}}
{{- end }}
handler := controller.{{$action.Name}}
{{- if $action.HandlerFunc }}
return http.HandlerFunc(fn)
return http.HandlerFunc(handler)
{{- else }}
// Call the controller
{{ $action.Results.Set }}fn(
{{ $action.Results.Set }}handler(
{{- range $param := $action.Params }}
{{ $param.Variable }},
{{- end }}
Expand Down Expand Up @@ -152,7 +145,7 @@ func ({{$action.Short}} *{{ $.Pascal }}{{$action.Pascal}}Action) handler(httpReq

{{- template "controller" $.Controller }}

{{- range $context := $.Contexts }}
{{- range $provider := $.Providers }}

{{$context.Code}}
{{ $provider.Function }}
{{- end }}
57 changes: 57 additions & 0 deletions framework/controller/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2265,3 +2265,60 @@ func TestRedirectBack(t *testing.T) {
Location: /10
`))
}

func TestSession(t *testing.T) {
is := is.New(t)
ctx := context.Background()
dir := t.TempDir()
td := testdir.New(dir)
td.Files["logz/logz.go"] = `
package logz
func New() *Logger { return &Logger{} }
type Logger struct {}
func (l *Logger) Info(msg string) {}
`
td.Files["session/session.go"] = `
package session
import "net/http"
import "app.com/logz"
func New(log *logz.Logger, w http.ResponseWriter, r *http.Request) *Session {
return &Session{log, w, r}
}
type Session struct {
log *logz.Logger
w http.ResponseWriter
r *http.Request
}
func (s *Session) Set(key, value string) {
s.log.Info("setting session")
http.SetCookie(s.w, &http.Cookie{Name: key, Value: value })
}
`
td.Files["controller/controller.go"] = `
package controller
import "app.com/session"
type Controller struct {
Session *session.Session
}
func (c *Controller) Create() error {
c.Session.Set("sessionid", "some-key")
return nil
}
`
is.NoErr(td.Write(ctx))
cli := testcli.New(dir)
app, err := cli.Start(ctx, "run")
is.NoErr(err)
defer app.Close()
// Post request
req, err := app.PostRequest("/", nil)
is.NoErr(err)
req.Header.Set("Referer", "/new")
res, err := app.Do(req)
is.NoErr(err)
is.NoErr(res.DiffHeaders(`
HTTP/1.1 302 Found
Location: /
Set-Cookie: sessionid=some-key
`))
}
124 changes: 59 additions & 65 deletions framework/controller/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,35 +30,35 @@ func Load(fsys fs.FS, injector *di.Injector, module *gomod.Module, parser *parse
return nil, fs.ErrNotExist
}
loader := &loader{
fsys: fsys,
contexts: newContextSet(),
imports: imports.New(),
injector: injector,
module: module,
parser: parser,
exist: exist,
fsys: fsys,
providers: newProviderSet(),
imports: imports.New(),
injector: injector,
module: module,
parser: parser,
exist: exist,
}
return loader.Load()
}

// loader struct
type loader struct {
bail.Struct
fsys fs.FS
injector *di.Injector
imports *imports.Set
contexts *contextSet
module *gomod.Module
parser *parser.Parser
exist map[string]bool
fsys fs.FS
injector *di.Injector
imports *imports.Set
providers *providerSet
module *gomod.Module
parser *parser.Parser
exist map[string]bool
}

// load fn
func (l *loader) Load() (state *State, err error) {
defer l.Recover2(&err, "controller: unable to load")
state = new(State)
state.Controller = l.loadController("controller")
state.Contexts = l.contexts.List()
state.Providers = l.providers.List()
state.Imports = l.imports.List()
return state, nil
}
Expand Down Expand Up @@ -179,7 +179,7 @@ func (l *loader) loadAction(controller *Controller, method *parser.Function) *Ac
}
action.RespondJSON = len(action.Results) > 0
action.RespondHTML = l.loadRespondHTML(action.Results)
action.Context = l.loadContext(controller, method)
action.Provider = l.loadProvider(controller, method)
action.Redirect = l.loadActionRedirect(action)
return action
}
Expand Down Expand Up @@ -492,7 +492,7 @@ func (l *loader) loadRespondHTML(results ActionResults) bool {
return false
}

func (l *loader) loadContext(controller *Controller, method *parser.Function) *Context {
func (l *loader) loadProvider(controller *Controller, method *parser.Function) *di.Provider {
recv := method.Receiver()
if recv == nil {
return nil
Expand All @@ -519,9 +519,9 @@ func (l *loader) loadContext(controller *Controller, method *parser.Function) *C
&di.Error{},
},
Params: []di.Dependency{
di.ToType("net/http", "ResponseWriter"),
di.ToType("net/http", "*Request"),
di.ToType("context", "Context"),
di.ToType("net/http", "*Request"),
di.ToType("net/http", "ResponseWriter"),
},
Aliases: di.Aliases{
di.ToType("github.com/livebud/bud/runtime/view", "Renderer"): di.ToType("github.com/livebud/bud/runtime/view", "*Server"),
Expand All @@ -534,32 +534,26 @@ func (l *loader) loadContext(controller *Controller, method *parser.Function) *C
for _, imp := range provider.Imports {
l.imports.AddNamed(imp.Name, imp.Path)
}
// Create the context
context := new(Context)
context.Function = fnName
context.Code = provider.Function()
context.Fields = l.loadContextInputs(provider)
context.Results = l.loadContextResults(provider)
// Add the context to the context set
l.contexts.Add(context)
return context
// Add the context to the provider set
l.providers.Add(provider)
return provider
}

func (l *loader) loadContextInputs(provider *di.Provider) (fields []*ContextField) {
for _, param := range provider.Externals {
fields = append(fields, l.loadContextField(param))
}
return fields
}
// func (l *loader) loadContextInputs(provider *di.Provider) (fields []*ContextField) {
// for _, param := range provider.Externals {
// fields = append(fields, l.loadContextField(param))
// }
// return fields
// }

func (l *loader) loadContextField(param *di.External) *ContextField {
field := new(ContextField)
field.Name = param.Key
field.Variable = param.Variable.Name
field.Hoisted = param.Hoisted
field.Type = param.FullType
return field
}
// func (l *loader) loadContextField(param *di.External) *ContextField {
// field := new(ContextField)
// field.Name = param.Key
// field.Variable = param.Variable.Name
// field.Hoisted = param.Hoisted
// field.Type = param.FullType
// return field
// }

// func (l *loader) loadContextInputName(dataType string) (typeName string) {
// parts := strings.Split(dataType, ".")
Expand All @@ -571,39 +565,39 @@ func (l *loader) loadContextField(param *di.External) *ContextField {
// return strings.TrimLeft(typeName, "[]*")
// }

func (l *loader) loadContextResults(provider *di.Provider) (outputs []*ContextResult) {
for _, result := range provider.Results {
outputs = append(outputs, l.loadContextResult(result))
}
return outputs
}
// func (l *loader) loadContextResults(provider *di.Provider) (outputs []*ContextResult) {
// for _, result := range provider.Results {
// outputs = append(outputs, l.loadContextResult(result))
// }
// return outputs
// }

func (l *loader) loadContextResult(result *di.Variable) *ContextResult {
output := new(ContextResult)
output.Variable = gotext.Camel(result.Name)
return output
}
// func (l *loader) loadContextResult(result *di.Variable) *ContextResult {
// output := new(ContextResult)
// output.Variable = gotext.Camel(result.Name)
// return output
// }

func newContextSet() *contextSet {
return &contextSet{map[string]*Context{}}
func newProviderSet() *providerSet {
return &providerSet{map[string]*di.Provider{}}
}

type contextSet struct {
contextMap map[string]*Context
type providerSet struct {
providerMap map[string]*di.Provider
}

func (c *contextSet) Add(context *Context) {
c.contextMap[context.Function] = context
func (c *providerSet) Add(provider *di.Provider) {
c.providerMap[provider.Name] = provider
}

func (c *contextSet) List() (contexts []*Context) {
for _, context := range c.contextMap {
contexts = append(contexts, context)
func (c *providerSet) List() (providers []*di.Provider) {
for _, provider := range c.providerMap {
providers = append(providers, provider)
}
sort.Slice(contexts, func(i, j int) bool {
return contexts[i].Function < contexts[j].Function
sort.Slice(providers, func(i, j int) bool {
return providers[i].Name < providers[j].Name
})
return contexts
return providers
}

func tagValue(snake string) (out string) {
Expand Down
Loading

0 comments on commit 1a75ede

Please sign in to comment.