Skip to content

Commit

Permalink
fix: ensure correct response headers and status code used on template…
Browse files Browse the repository at this point in the history
… render error (#316)

Co-authored-by: Adrian Hesketh <adrianhesketh@hushmail.com>
  • Loading branch information
oddwhocanfly and a-h authored Dec 17, 2023
1 parent 060219c commit 99fd97f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 23 deletions.
19 changes: 14 additions & 5 deletions runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,27 @@ const componentHandlerErrorMessage = "templ: failed to render template"

// ServeHTTP implements the http.Handler interface.
func (ch ComponentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if ch.Status != 0 {
w.WriteHeader(ch.Status)
}
w.Header().Add("Content-Type", ch.ContentType)
err := ch.Component.Render(r.Context(), w)
// Since the component may error, write to a buffer first.
// This prevents partial responses from being written to the client.
buf := GetBuffer()
defer ReleaseBuffer(buf)
err := ch.Component.Render(r.Context(), buf)
if err != nil {
if ch.ErrorHandler != nil {
w.Header().Set("Content-Type", ch.ContentType)
ch.ErrorHandler(r, err).ServeHTTP(w, r)
return
}
http.Error(w, componentHandlerErrorMessage, http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", ch.ContentType)
if ch.Status != 0 {
w.WriteHeader(ch.Status)
}
// Ignore write error like http.Error() does, because there is
// no way to recover at this point.
_, _ = w.Write(buf.Bytes())
}

// Handler creates a http.Handler that renders the template.
Expand Down
54 changes: 36 additions & 18 deletions runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,32 +382,46 @@ func TestHandler(t *testing.T) {
return nil
})
errorComponent := templ.ComponentFunc(func(ctx context.Context, w io.Writer) error {
if _, err := io.WriteString(w, "Hello"); err != nil {
t.Fatalf("failed to write string: %v", err)
}
return errors.New("handler error")
})

tests := []struct {
name string
input *templ.ComponentHandler
expectedStatus int
expectedBody string
name string
input *templ.ComponentHandler
expectedStatus int
expectedMIMEType string
expectedBody string
}{
{
name: "handlers return OK by default",
input: templ.Handler(hello),
expectedStatus: http.StatusOK,
expectedBody: "Hello",
name: "handlers return OK by default",
input: templ.Handler(hello),
expectedStatus: http.StatusOK,
expectedMIMEType: "text/html",
expectedBody: "Hello",
},
{
name: "handlers can be configured to return an alternative status code",
input: templ.Handler(hello, templ.WithStatus(http.StatusNotFound)),
expectedStatus: http.StatusNotFound,
expectedBody: "Hello",
name: "handlers can be configured to return an alternative status code",
input: templ.Handler(hello, templ.WithStatus(http.StatusNotFound)),
expectedStatus: http.StatusNotFound,
expectedMIMEType: "text/html",
expectedBody: "Hello",
},
{
name: "handlers that fail return a 500 error",
input: templ.Handler(errorComponent),
expectedStatus: http.StatusInternalServerError,
expectedBody: "templ: failed to render template\n",
name: "handlers can be configured to return an alternative status code and content type",
input: templ.Handler(hello, templ.WithStatus(http.StatusOK), templ.WithContentType("text/csv")),
expectedStatus: http.StatusOK,
expectedMIMEType: "text/csv",
expectedBody: "Hello",
},
{
name: "handlers that fail return a 500 error",
input: templ.Handler(errorComponent),
expectedStatus: http.StatusInternalServerError,
expectedMIMEType: "text/plain; charset=utf-8",
expectedBody: "templ: failed to render template\n",
},
{
name: "error handling can be customised",
Expand All @@ -421,8 +435,9 @@ func TestHandler(t *testing.T) {
}
})
})),
expectedStatus: http.StatusBadRequest,
expectedBody: "custom body",
expectedStatus: http.StatusBadRequest,
expectedMIMEType: "text/html",
expectedBody: "custom body",
},
}
for _, tt := range tests {
Expand All @@ -434,6 +449,9 @@ func TestHandler(t *testing.T) {
if got := w.Result().StatusCode; tt.expectedStatus != got {
t.Errorf("expected status %d, got %d", tt.expectedStatus, got)
}
if mimeType := w.Result().Header.Get("Content-Type"); tt.expectedMIMEType != mimeType {
t.Errorf("expected content-type %s, got %s", tt.expectedMIMEType, mimeType)
}
body, err := io.ReadAll(w.Result().Body)
if err != nil {
t.Errorf("failed to read body: %v", err)
Expand Down

0 comments on commit 99fd97f

Please sign in to comment.