From ee5ede87ee0e198440fc327fbd36a70559eeb84c Mon Sep 17 00:00:00 2001 From: Kevin Franklin Kim Date: Thu, 3 Mar 2022 07:47:19 +0100 Subject: [PATCH] feat: move error wrapping to msgpack only --- client.go | 36 +++++++++++++------------- gotsrpc.go | 15 ++++++----- transport.go | 70 ++++++++++++++++++++++++++++----------------------- typereader.go | 1 + 4 files changed, 65 insertions(+), 57 deletions(-) diff --git a/client.go b/client.go index a11664d..3c4ff45 100644 --- a/client.go +++ b/client.go @@ -87,10 +87,10 @@ func (c *bufferedClient) Call(ctx context.Context, url string, endpoint string, } } - // Create request // Create post url postURL := fmt.Sprintf("%s%s/%s", url, endpoint, method) + // Create request request, errRequest := newRequest(ctx, postURL, c.handle.contentType, b, c.headers.Clone()) if errRequest != nil { return NewClientError(errors.Wrap(errRequest, "failed to create request")) @@ -104,34 +104,34 @@ func (c *bufferedClient) Call(ctx context.Context, url string, endpoint string, // Check status if resp.StatusCode != http.StatusOK { - body := "request failed" - if value, err := ioutil.ReadAll(resp.Body); err == nil { - body = string(value) + var msg string + if value, err := ioutil.ReadAll(resp.Body); err != nil { + msg = "failed to read response body: " + err.Error() + } else { + msg = string(value) } - return NewClientError(NewHTTPError(body, resp.StatusCode)) + return NewClientError(NewHTTPError(msg, resp.StatusCode)) } - wrappedReply := make([]interface{}, len(reply)) - for k, v := range reply { - if _, ok := v.(*error); ok { - var e *Error - wrappedReply[k] = e + clientHandle := getHandlerForContentType(resp.Header.Get("Content-Type")) + + wrappedReply := reply + if clientHandle.beforeDecodeReply != nil { + if value, err := clientHandle.beforeDecodeReply(reply); err != nil { + return NewClientError(errors.Wrap(err, "failed to call beforeDecodeReply hook")) } else { - wrappedReply[k] = v + wrappedReply = value } } - responseHandle := getHandlerForContentType(resp.Header.Get("Content-Type")).handle - if err := codec.NewDecoder(resp.Body, responseHandle).Decode(wrappedReply); err != nil { + if err := codec.NewDecoder(resp.Body, clientHandle.handle).Decode(wrappedReply); err != nil { return NewClientError(errors.Wrap(err, "failed to decode response")) } // replace error - for k, v := range wrappedReply { - if x, ok := v.(*Error); ok && x != nil { - if y, ok := reply[k].(*error); ok { - *y = x - } + if clientHandle.afterDecodeReply != nil { + if err := clientHandle.afterDecodeReply(&reply, wrappedReply); err != nil { + return NewClientError(errors.Wrap(err, "failed to call afterDecodeReply hook")) } } diff --git a/gotsrpc.go b/gotsrpc.go index 730cbed..6dc2098 100644 --- a/gotsrpc.go +++ b/gotsrpc.go @@ -11,7 +11,6 @@ import ( "os" "path" "path/filepath" - "reflect" "sort" "strings" "time" @@ -87,16 +86,16 @@ func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.Res writer.Header().Set("Content-Type", clientHandle.contentType) - // transform error type to sth that is transportable - for k, v := range response { - if e, ok := v.(error); ok { - if !reflect.ValueOf(e).IsNil() { - response[k] = NewError(e) - } + if clientHandle.beforeEncodeReply != nil { + if err := clientHandle.beforeEncodeReply(&response); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err.Error()) + http.Error(w, "could not encode data to accepted format", http.StatusInternalServerError) + return } } - if errEncode := codec.NewEncoder(writer, clientHandle.handle).Encode(response); errEncode != nil { + if err := codec.NewEncoder(writer, clientHandle.handle).Encode(response); err != nil { + _, _ = fmt.Fprintln(os.Stderr, err.Error()) http.Error(w, "could not encode data to accepted format", http.StatusInternalServerError) return } diff --git a/transport.go b/transport.go index 01bb03e..f80ca20 100644 --- a/transport.go +++ b/transport.go @@ -15,47 +15,55 @@ const ( ) type clientHandle struct { - handle codec.Handle - contentType string + handle codec.Handle + contentType string + beforeEncodeReply func(*[]interface{}) error + beforeDecodeReply func([]interface{}) ([]interface{}, error) + afterDecodeReply func(*[]interface{}, []interface{}) error } var msgpackClientHandle = &clientHandle{ - handle: &codec.MsgpackHandle{}, contentType: "application/msgpack; charset=utf-8", + handle: &codec.MsgpackHandle{}, + // transform error type to sth that is transportable + beforeEncodeReply: func(resp *[]interface{}) error { + for k, v := range *resp { + if e, ok := v.(error); ok { + if !reflect.ValueOf(e).IsNil() { + (*resp)[k] = NewError(e) + } + } + } + return nil + }, + beforeDecodeReply: func(reply []interface{}) ([]interface{}, error) { + ret := make([]interface{}, len(reply)) + for k, v := range reply { + if _, ok := v.(*error); ok { + var e *Error + ret[k] = e + } else { + ret[k] = v + } + } + return ret, nil + }, + afterDecodeReply: func(reply *[]interface{}, wrappedReply []interface{}) error { + for k, v := range wrappedReply { + if x, ok := v.(*Error); ok && x != nil { + if y, ok := (*reply)[k].(*error); ok { + *y = x + } + } + } + return nil + }, } -//type TimeExt struct{} -// -//func (x TimeExt) WriteExt(v interface{}) []byte { -// b := make([]byte, binary.MaxVarintLen64) -// switch t := v.(type) { -// case time.Time: -// binary.PutVarint(b, t.UnixNano()) -// return b -// case *time.Time: -// binary.PutVarint(b, t.UnixNano()) -// return b -// default: -// panic("Bug") -// } -//} -//func (x TimeExt) ReadExt(dest interface{}, src []byte) { -// tt := dest.(*time.Time) -// r := bytes.NewBuffer(src) -// v, err := binary.ReadVarint(r) -// if err != nil { -// panic("BUG") -// } -// *tt = time.Unix(0, v).UTC() -//} - func init() { mh := new(codec.MsgpackHandle) // use map[string]interface{} instead of map[interface{}]interface{} mh.MapType = reflect.TypeOf(map[string]interface{}(nil)) - //if err := mh.SetBytesExt(reflect.TypeOf(time.Time{}), 1, TimeExt{}); err != nil { - // panic("2") - //} //mh.TimeNotBuiltin = true msgpackClientHandle.handle = mh // attempting to set promoted field in literal will cause a compiler error diff --git a/typereader.go b/typereader.go index 3a981c1..9a30b1a 100644 --- a/typereader.go +++ b/typereader.go @@ -42,6 +42,7 @@ func trace(args ...interface{}) { fmt.Fprintln(os.Stderr, args...) } } + func traceData(args ...interface{}) { if ReaderTrace { for _, arg := range args {