Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix possible panic for unhandled template write errors #205

Merged
merged 3 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions command/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/avenga/couper/config"
"github.com/avenga/couper/config/env"
"github.com/avenga/couper/config/runtime"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/server"
"github.com/sirupsen/logrus"
)
Expand Down Expand Up @@ -72,6 +73,7 @@ func (r *Run) Execute(args Args, config *config.Couper, logEntry *logrus.Entry)
if err != nil {
return err
}
errors.SetLogger(logEntry)

serverList, listenCmdShutdown := server.NewServerList(r.context, config.Context, logEntry, config.Settings, &timings, srvConf)
for _, srv := range serverList {
Expand Down
4 changes: 2 additions & 2 deletions config/runtime/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func NewServerConfiguration(
)

for _, srvConf := range conf.Servers {
serverOptions, err := server.NewServerOptions(srvConf)
serverOptions, err := server.NewServerOptions(srvConf, log)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -211,7 +211,7 @@ func NewServerConfiguration(
var errTpl *errors.Template

if endpointConf.ErrorFile != "" {
errTpl, err = errors.NewTemplateFromFile(endpointConf.ErrorFile)
errTpl, err = errors.NewTemplateFromFile(endpointConf.ErrorFile, log)
if err != nil {
return nil, err
}
Expand Down
10 changes: 6 additions & 4 deletions config/runtime/server/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package server
import (
"path"

"github.com/sirupsen/logrus"

"github.com/avenga/couper/config"
"github.com/avenga/couper/errors"
"github.com/avenga/couper/utils"
Expand All @@ -23,7 +25,7 @@ type Options struct {
ServerName string
}

func NewServerOptions(conf *config.Server) (*Options, error) {
func NewServerOptions(conf *config.Server, logger *logrus.Entry) (*Options, error) {
options := &Options{
FilesErrTpl: errors.DefaultHTML,
ServerErrTpl: errors.DefaultHTML,
Expand All @@ -36,7 +38,7 @@ func NewServerOptions(conf *config.Server) (*Options, error) {
options.SrvBasePath = path.Join("/", conf.BasePath)

if conf.ErrorFile != "" {
tpl, err := errors.NewTemplateFromFile(conf.ErrorFile)
tpl, err := errors.NewTemplateFromFile(conf.ErrorFile, logger)
if err != nil {
return nil, err
}
Expand All @@ -53,7 +55,7 @@ func NewServerOptions(conf *config.Server) (*Options, error) {
options.APIBasePaths[api] = path.Join(options.SrvBasePath, api.BasePath)

if api.ErrorFile != "" {
tpl, err := errors.NewTemplateFromFile(api.ErrorFile)
tpl, err := errors.NewTemplateFromFile(api.ErrorFile, logger)
if err != nil {
return nil, err
}
Expand All @@ -65,7 +67,7 @@ func NewServerOptions(conf *config.Server) (*Options, error) {

if conf.Files != nil {
if conf.Files.ErrorFile != "" {
tpl, err := errors.NewTemplateFromFile(conf.Files.ErrorFile)
tpl, err := errors.NewTemplateFromFile(conf.Files.ErrorFile, logger)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion config/runtime/server_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestServer_getEndpointsList(t *testing.T) {
},
}

serverOptions, _ := server.NewServerOptions(nil)
serverOptions, _ := server.NewServerOptions(nil, nil)
endpoints, _ := newEndpointMap(srvConf, serverOptions)
if l := len(endpoints); l != 4 {
t.Fatalf("Expected 4 endpointes, given %d", l)
Expand Down
31 changes: 20 additions & 11 deletions errors/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"strings"
"text/template"

"github.com/sirupsen/logrus"

"github.com/avenga/couper/assets"
"github.com/avenga/couper/config/request"
)
Expand All @@ -20,11 +22,11 @@ const HeaderErrorCode = "Couper-Error"

func init() {
var err error
DefaultHTML, err = NewTemplate("text/html", "default.html", assets.Assets.MustOpen("error.html").Bytes())
DefaultHTML, err = NewTemplate("text/html", "default.html", assets.Assets.MustOpen("error.html").Bytes(), nil)
if err != nil {
panic(err)
}
DefaultJSON, err = NewTemplate("application/json", "default.json", assets.Assets.MustOpen("error.json").Bytes())
DefaultJSON, err = NewTemplate("application/json", "default.json", assets.Assets.MustOpen("error.json").Bytes(), nil)
if err != nil {
panic(err)
}
Expand All @@ -35,12 +37,13 @@ type ErrorTemplate interface {
}

type Template struct {
raw []byte
log *logrus.Entry
mime string
raw []byte
tpl *template.Template
}

func NewTemplateFromFile(path string) (*Template, error) {
func NewTemplateFromFile(path string, logger *logrus.Entry) (*Template, error) {
absPath, err := filepath.Abs(path)
if err != nil {
return nil, err
Expand All @@ -56,16 +59,23 @@ func NewTemplateFromFile(path string) (*Template, error) {
}

_, fileName := filepath.Split(path)
return NewTemplate(mime, fileName, tplFile)
return NewTemplate(mime, fileName, tplFile, logger)
}

// SetLogger updates the default templates with the configured "daemon" logger.
func SetLogger(log *logrus.Entry) {
DefaultJSON.log = log
DefaultHTML.log = log
}

func NewTemplate(mime, name string, src []byte) (*Template, error) {
func NewTemplate(mime, name string, src []byte, logger *logrus.Entry) (*Template, error) {
tpl, err := template.New(name).Parse(string(src))
if err != nil {
return nil, err
}

return &Template{
log: logger,
mime: mime,
raw: src,
tpl: tpl,
Expand Down Expand Up @@ -101,22 +111,21 @@ func (t *Template) ServeError(err error) http.Handler {
"path": req.URL.EscapedPath(),
"request_id": escapeValue(t.mime, reqID),
}
err := t.tpl.Execute(rw, data)
tplErr := t.tpl.Execute(rw, data)

// FIXME: If the fallback triggers, maybe we set
// different/double headers on the top of this method
// (recursive call)

// fallback behaviour, execute internal template once
if err != nil && (t != DefaultHTML && t != DefaultJSON) {
if tplErr != nil && (t != DefaultHTML && t != DefaultJSON) {
if !strings.Contains(t.mime, "text/html") {
DefaultJSON.ServeError(errCode).ServeHTTP(rw, req)
return
}
DefaultHTML.ServeError(errCode).ServeHTTP(rw, req)
} else if err != nil {
// FIXME: at least log those errors (maybe netOP, brokenPipe etc)
println(err.Error())
} else if tplErr != nil && t.log != nil {
t.log.WithFields(data).Error(tplErr)
}
})
}
Expand Down
2 changes: 1 addition & 1 deletion handler/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestFile_ServeHTTP(t *testing.T) {
{"not found /w errorFile HEAD", fields{errFile: "testdata/file_err_doc.html"}, httptest.NewRequest(http.MethodHead, "http://domain.test/", nil), http.StatusNotFound},
}

srvOpts, _ := server.NewServerOptions(&config.Server{})
srvOpts, _ := server.NewServerOptions(&config.Server{}, nil)

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion handler/spa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func TestSpa_ServeHTTP(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts, _ := server.NewServerOptions(&config.Server{})
opts, _ := server.NewServerOptions(&config.Server{}, nil)
s, err := handler.NewSpa(path.Join(wd, tt.filePath), opts)
if err != nil {
t.Fatal(err)
Expand Down
3 changes: 1 addition & 2 deletions server/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@ import (
)

func TestMux_FindHandler_PathParamContext(t *testing.T) {

type noContentHandler http.Handler
var noContent noContentHandler = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusNoContent)
})

serverOptions, _ := rs.NewServerOptions(nil)
serverOptions, _ := rs.NewServerOptions(nil, nil)

testOptions := &runtime.MuxOptions{
EndpointRoutes: map[string]http.Handler{
Expand Down