Skip to content

Commit

Permalink
feat: move error wrapping to msgpack only
Browse files Browse the repository at this point in the history
  • Loading branch information
franklinkim committed Mar 3, 2022
1 parent 9c57488 commit ee5ede8
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 57 deletions.
36 changes: 18 additions & 18 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ func (c *bufferedClient) Call(ctx context.Context, url string, endpoint string,
}
}

// Create request
// Create post url
postURL := fmt.Sprintf("%s%s/%s", url, endpoint, method)

// Create request
request, errRequest := newRequest(ctx, postURL, c.handle.contentType, b, c.headers.Clone())
if errRequest != nil {
return NewClientError(errors.Wrap(errRequest, "failed to create request"))
Expand All @@ -104,34 +104,34 @@ func (c *bufferedClient) Call(ctx context.Context, url string, endpoint string,

// Check status
if resp.StatusCode != http.StatusOK {
body := "request failed"
if value, err := ioutil.ReadAll(resp.Body); err == nil {
body = string(value)
var msg string
if value, err := ioutil.ReadAll(resp.Body); err != nil {
msg = "failed to read response body: " + err.Error()
} else {
msg = string(value)
}
return NewClientError(NewHTTPError(body, resp.StatusCode))
return NewClientError(NewHTTPError(msg, resp.StatusCode))
}

wrappedReply := make([]interface{}, len(reply))
for k, v := range reply {
if _, ok := v.(*error); ok {
var e *Error
wrappedReply[k] = e
clientHandle := getHandlerForContentType(resp.Header.Get("Content-Type"))

wrappedReply := reply
if clientHandle.beforeDecodeReply != nil {
if value, err := clientHandle.beforeDecodeReply(reply); err != nil {
return NewClientError(errors.Wrap(err, "failed to call beforeDecodeReply hook"))
} else {
wrappedReply[k] = v
wrappedReply = value
}
}

responseHandle := getHandlerForContentType(resp.Header.Get("Content-Type")).handle
if err := codec.NewDecoder(resp.Body, responseHandle).Decode(wrappedReply); err != nil {
if err := codec.NewDecoder(resp.Body, clientHandle.handle).Decode(wrappedReply); err != nil {
return NewClientError(errors.Wrap(err, "failed to decode response"))
}

// replace error
for k, v := range wrappedReply {
if x, ok := v.(*Error); ok && x != nil {
if y, ok := reply[k].(*error); ok {
*y = x
}
if clientHandle.afterDecodeReply != nil {
if err := clientHandle.afterDecodeReply(&reply, wrappedReply); err != nil {
return NewClientError(errors.Wrap(err, "failed to call afterDecodeReply hook"))
}
}

Expand Down
15 changes: 7 additions & 8 deletions gotsrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"os"
"path"
"path/filepath"
"reflect"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -87,16 +86,16 @@ func Reply(response []interface{}, stats *CallStats, r *http.Request, w http.Res

writer.Header().Set("Content-Type", clientHandle.contentType)

// transform error type to sth that is transportable
for k, v := range response {
if e, ok := v.(error); ok {
if !reflect.ValueOf(e).IsNil() {
response[k] = NewError(e)
}
if clientHandle.beforeEncodeReply != nil {
if err := clientHandle.beforeEncodeReply(&response); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
http.Error(w, "could not encode data to accepted format", http.StatusInternalServerError)
return
}
}

if errEncode := codec.NewEncoder(writer, clientHandle.handle).Encode(response); errEncode != nil {
if err := codec.NewEncoder(writer, clientHandle.handle).Encode(response); err != nil {
_, _ = fmt.Fprintln(os.Stderr, err.Error())
http.Error(w, "could not encode data to accepted format", http.StatusInternalServerError)
return
}
Expand Down
70 changes: 39 additions & 31 deletions transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,55 @@ const (
)

type clientHandle struct {
handle codec.Handle
contentType string
handle codec.Handle
contentType string
beforeEncodeReply func(*[]interface{}) error
beforeDecodeReply func([]interface{}) ([]interface{}, error)
afterDecodeReply func(*[]interface{}, []interface{}) error
}

var msgpackClientHandle = &clientHandle{
handle: &codec.MsgpackHandle{},
contentType: "application/msgpack; charset=utf-8",
handle: &codec.MsgpackHandle{},
// transform error type to sth that is transportable
beforeEncodeReply: func(resp *[]interface{}) error {
for k, v := range *resp {
if e, ok := v.(error); ok {
if !reflect.ValueOf(e).IsNil() {
(*resp)[k] = NewError(e)
}
}
}
return nil
},
beforeDecodeReply: func(reply []interface{}) ([]interface{}, error) {
ret := make([]interface{}, len(reply))
for k, v := range reply {
if _, ok := v.(*error); ok {
var e *Error
ret[k] = e
} else {
ret[k] = v
}
}
return ret, nil
},
afterDecodeReply: func(reply *[]interface{}, wrappedReply []interface{}) error {
for k, v := range wrappedReply {
if x, ok := v.(*Error); ok && x != nil {
if y, ok := (*reply)[k].(*error); ok {
*y = x
}
}
}
return nil
},
}

//type TimeExt struct{}
//
//func (x TimeExt) WriteExt(v interface{}) []byte {
// b := make([]byte, binary.MaxVarintLen64)
// switch t := v.(type) {
// case time.Time:
// binary.PutVarint(b, t.UnixNano())
// return b
// case *time.Time:
// binary.PutVarint(b, t.UnixNano())
// return b
// default:
// panic("Bug")
// }
//}
//func (x TimeExt) ReadExt(dest interface{}, src []byte) {
// tt := dest.(*time.Time)
// r := bytes.NewBuffer(src)
// v, err := binary.ReadVarint(r)
// if err != nil {
// panic("BUG")
// }
// *tt = time.Unix(0, v).UTC()
//}

func init() {
mh := new(codec.MsgpackHandle)
// use map[string]interface{} instead of map[interface{}]interface{}
mh.MapType = reflect.TypeOf(map[string]interface{}(nil))
//if err := mh.SetBytesExt(reflect.TypeOf(time.Time{}), 1, TimeExt{}); err != nil {
// panic("2")
//}
//mh.TimeNotBuiltin = true
msgpackClientHandle.handle = mh
// attempting to set promoted field in literal will cause a compiler error
Expand Down
1 change: 1 addition & 0 deletions typereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func trace(args ...interface{}) {
fmt.Fprintln(os.Stderr, args...)
}
}

func traceData(args ...interface{}) {
if ReaderTrace {
for _, arg := range args {
Expand Down

0 comments on commit ee5ede8

Please sign in to comment.