From e25c08a01bc0b424edcf5e010aa4099c0797020e Mon Sep 17 00:00:00 2001 From: Annie Fu <16651409+anniefu@users.noreply.github.com> Date: Wed, 9 Feb 2022 18:18:20 -0800 Subject: [PATCH] fix: return generic error message when function panics (#111) * fix: return generic error message when function panics Previously, the framework returned the panic message and full stack trace in the function response body if there was a panic. Now, the framework returns a generic error message that does not leak function details in the response body, and logs the panic message and full stack trace instead. --- funcframework/framework.go | 28 ++++---- funcframework/framework_test.go | 109 ++++++++++++++++++++++++++------ 2 files changed, 107 insertions(+), 30 deletions(-) diff --git a/funcframework/framework.go b/funcframework/framework.go index ba29ecd..29ecb9e 100644 --- a/funcframework/framework.go +++ b/funcframework/framework.go @@ -35,28 +35,32 @@ const ( functionStatusHeader = "X-Google-Status" crashStatus = "crash" errorStatus = "error" + panicMessageTmpl = "A panic occurred during %s. Please see logs for more details." ) var ( handler http.Handler ) -func recoverPanic(msg string) { +// recoverPanic recovers from a panic in a consistent manner. panicSrc should +// describe what was happening when the panic was encountered, for example +// "user function execution". w is an http.ResponseWriter to write a generic +// response body to that does not expose the details of the panic; w can be +// nil to skip this. +func recoverPanic(w http.ResponseWriter, panicSrc string) { if r := recover(); r != nil { - fmt.Fprintf(os.Stderr, "%s: %v\n\n%s", msg, r, debug.Stack()) - } -} - -func recoverPanicHTTP(w http.ResponseWriter, msg string) { - if r := recover(); r != nil { - writeHTTPErrorResponse(w, http.StatusInternalServerError, crashStatus, fmt.Sprintf("%s: %v\n\n%s", msg, r, debug.Stack())) + genericMsg := fmt.Sprintf(panicMessageTmpl, panicSrc) + fmt.Fprintf(os.Stderr, fmt.Sprintf("%s\npanic message: %v\nstack trace: %s", genericMsg, r, debug.Stack())) + if w != nil { + writeHTTPErrorResponse(w, http.StatusInternalServerError, crashStatus, genericMsg) + } } } // RegisterHTTPFunction registers fn as an HTTP function. // Maintained for backward compatibility. Please use RegisterHTTPFunctionContext instead. func RegisterHTTPFunction(path string, fn interface{}) { - defer recoverPanic("Registration panic") + defer recoverPanic(nil, "function registration") fnHTTP, ok := fn.(func(http.ResponseWriter, *http.Request)) if !ok { @@ -72,8 +76,8 @@ func RegisterHTTPFunction(path string, fn interface{}) { // RegisterEventFunction registers fn as an event function. // Maintained for backward compatibility. Please use RegisterEventFunctionContext instead. func RegisterEventFunction(path string, fn interface{}) { - defer recoverPanic("Registration panic") ctx := context.Background() + defer recoverPanic(nil, "function registration") if err := RegisterEventFunctionContext(ctx, path, fn); err != nil { panic(fmt.Sprintf("unexpected error in RegisterEventFunctionContext: %v", err)) } @@ -149,7 +153,7 @@ func wrapHTTPFunction(path string, fn func(http.ResponseWriter, *http.Request)) defer fmt.Println() defer fmt.Fprintln(os.Stderr) } - defer recoverPanicHTTP(w, "Function panic") + defer recoverPanic(w, "user function execution") fn(w, r) }) return h, nil @@ -167,7 +171,6 @@ func wrapEventFunction(path string, fn interface{}) (http.Handler, error) { defer fmt.Println() defer fmt.Fprintln(os.Stderr) } - defer recoverPanicHTTP(w, "Function panic") if shouldConvertCloudEventToBackgroundRequest(r) { if err := convertCloudEventToBackgroundRequest(r); err != nil { @@ -238,6 +241,7 @@ func runUserFunctionWithContext(ctx context.Context, w http.ResponseWriter, r *h return } + defer recoverPanic(w, "user function execution") userFunErr := reflect.ValueOf(fn).Call([]reflect.Value{ reflect.ValueOf(ctx), argVal.Elem(), diff --git a/funcframework/framework_test.go b/funcframework/framework_test.go index 4cea99f..9da6432 100644 --- a/funcframework/framework_test.go +++ b/funcframework/framework_test.go @@ -23,6 +23,7 @@ import ( "net/http" "net/http/httptest" "os" + "strings" "testing" "github.com/GoogleCloudPlatform/functions-framework-go/functions" @@ -32,29 +33,61 @@ import ( ) func TestHTTPFunction(t *testing.T) { - h, err := wrapHTTPFunction("/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello World!") - }) - if err != nil { - t.Fatalf("registerHTTPFunction(): %v", err) + tests := []struct { + name string + fn func(w http.ResponseWriter, r *http.Request) + wantStatus int // defaults to http.StatusOK + wantResp string + }{ + { + name: "helloworld", + fn: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "Hello World!") + }, + wantResp: "Hello World!", + }, + { + name: "panic in function", + fn: func(w http.ResponseWriter, r *http.Request) { + panic("intentional panic for test") + }, + wantStatus: http.StatusInternalServerError, + wantResp: fmt.Sprintf(panicMessageTmpl, "user function execution"), + }, } - srv := httptest.NewServer(h) - defer srv.Close() + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + h, err := wrapHTTPFunction("/", tc.fn) + if err != nil { + t.Fatalf("registerHTTPFunction(): %v", err) + } - resp, err := http.Get(srv.URL) - if err != nil { - t.Fatalf("http.Get: %v", err) - } + srv := httptest.NewServer(h) + defer srv.Close() - defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - t.Fatalf("ioutil.ReadAll: %v", err) - } + resp, err := http.Get(srv.URL) + if err != nil { + t.Fatalf("http.Get: %v", err) + } + + if tc.wantStatus == 0 { + tc.wantStatus = http.StatusOK + } + if resp.StatusCode != tc.wantStatus { + t.Errorf("TestHTTPFunction status code: got %d, want: %d", resp.StatusCode, tc.wantStatus) + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ioutil.ReadAll: %v", err) + } - if got, want := string(body), "Hello World!"; got != want { - t.Fatalf("TestHTTPFunction: got %v; want %v", got, want) + if got := strings.TrimSpace(string(body)); got != tc.wantResp { + t.Errorf("TestHTTPFunction: got %q; want: %q", got, tc.wantResp) + } + }) } } @@ -75,6 +108,7 @@ func TestEventFunction(t *testing.T) { status int header string ceHeaders map[string]string + wantResp string }{ { name: "valid function", @@ -109,6 +143,16 @@ func TestEventFunction(t *testing.T) { status: http.StatusInternalServerError, header: "error", }, + { + name: "panicking function", + body: []byte(`{"id": 12345,"name": "custom"}`), + fn: func(c context.Context, s customStruct) error { + panic("intential panic for test") + }, + status: http.StatusInternalServerError, + header: "crash", + wantResp: fmt.Sprintf(panicMessageTmpl, "user function execution"), + }, { name: "pubsub event", body: []byte(`{ @@ -257,6 +301,16 @@ func TestEventFunction(t *testing.T) { continue } + if tc.wantResp != "" { + gotBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read got request body: %v", err) + } + if strings.TrimSpace(string(gotBody)) != tc.wantResp { + t.Errorf("TestCloudEventFunction(%s): response body = %q, want %q on error status code %d.", tc.name, string(gotBody), tc.wantResp, tc.status) + } + } + if resp.StatusCode != tc.status { t.Errorf("TestEventFunction(%s): response status = %v, want %v", tc.name, resp.StatusCode, tc.status) continue @@ -406,6 +460,17 @@ func TestCloudEventFunction(t *testing.T) { }, status: http.StatusOK, }, + { + name: "panic returns 500", + body: cloudeventsJSON, + fn: func(ctx context.Context, e cloudevents.Event) error { + panic("intentional panic for test") + }, + status: http.StatusInternalServerError, + ceHeaders: map[string]string{ + "Content-Type": "application/cloudevents+json", + }, + }, } for _, tc := range tests { @@ -432,6 +497,14 @@ func TestCloudEventFunction(t *testing.T) { continue } + gotBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read got request body: %v", err) + } + if resp.StatusCode != http.StatusOK && string(gotBody) != "" { + t.Errorf("TestCloudEventFunction(%s): response body = %q, want %q on error status code %d.", tc.name, gotBody, "", tc.status) + } + if resp.StatusCode != tc.status { gotBody, err := ioutil.ReadAll(resp.Body) if err != nil {