diff --git a/.golangci.yml b/.golangci.yml index 7d80377..1f03863 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -54,6 +54,7 @@ issues: linters: - errcheck - funlen + - goerr113 - gosec - path: (.+)_test.go linters: [govet] diff --git a/encoding/text/decode.go b/encoding/text/decode.go new file mode 100644 index 0000000..e7fee11 --- /dev/null +++ b/encoding/text/decode.go @@ -0,0 +1,134 @@ +package text + +import ( + "bytes" + "encoding" + "reflect" + "strconv" + "sync" + + "github.com/abemedia/go-don" + "github.com/abemedia/go-don/internal/byteconv" + "github.com/valyala/fasthttp" +) + +//nolint:cyclop +func decode(ctx *fasthttp.RequestCtx, v interface{}) error { + b := bytes.TrimSpace(ctx.Request.Body()) + if len(b) == 0 { + return nil + } + + var err error + + switch v := v.(type) { + case *string: + *v = byteconv.Btoa(b) + case *[]byte: + *v = b + case *int: + *v, err = strconv.Atoi(byteconv.Btoa(b)) + case *int8: + return decodeInt(b, v, 8) + case *int16: + return decodeInt(b, v, 16) + case *int32: + return decodeInt(b, v, 32) + case *int64: + return decodeInt(b, v, 64) + case *uint: + return decodeUint(b, v, 0) + case *uint8: + return decodeUint(b, v, 8) + case *uint16: + return decodeUint(b, v, 16) + case *uint32: + return decodeUint(b, v, 32) + case *uint64: + return decodeUint(b, v, 64) + case *float32: + return decodeFloat(b, v, 32) + case *float64: + return decodeFloat(b, v, 64) + case *bool: + *v, err = strconv.ParseBool(byteconv.Btoa(b)) + default: + return unmarshal(b, v) + } + + return err +} + +func decodeInt[T int | int8 | int16 | int32 | int64](b []byte, v *T, bits int) error { + d, err := strconv.ParseInt(byteconv.Btoa(b), 10, bits) + *v = T(d) + return err +} + +func decodeUint[T uint | uint8 | uint16 | uint32 | uint64](b []byte, v *T, bits int) error { + d, err := strconv.ParseUint(byteconv.Btoa(b), 10, bits) + *v = T(d) + return err +} + +func decodeFloat[T float32 | float64](b []byte, v *T, bits int) error { + d, err := strconv.ParseFloat(byteconv.Btoa(b), bits) + *v = T(d) + return err +} + +func unmarshal(b []byte, v any) error { + val := reflect.ValueOf(v) + typ := val.Type() + if dec, ok := unmarshalers.Load(typ); ok { + return dec.(func([]byte, reflect.Value) error)(b, val) + } + dec, err := newUnmarshaler(typ) + if err != nil { + return err + } + unmarshalers.Store(typ, dec) + return dec(b, val) +} + +func newUnmarshaler(typ reflect.Type) (func([]byte, reflect.Value) error, error) { + if typ.Implements(unmarshalerType) { + isPtr := typ.Kind() == reflect.Pointer + typ = typ.Elem() + return func(b []byte, v reflect.Value) error { + if len(b) == 0 { + return nil + } + if isPtr && v.IsNil() { + v.Set(reflect.New(typ)) + } + return v.Interface().(encoding.TextUnmarshaler).UnmarshalText(b) //nolint:forcetypeassert + }, nil + } + + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + dec, err := newUnmarshaler(typ) + if err != nil { + return nil, err + } + return func(b []byte, v reflect.Value) error { + if v.IsNil() { + v.Set(reflect.New(typ)) + } + return dec(b, v.Elem()) + }, nil + } + + return nil, don.ErrUnsupportedMediaType +} + +//nolint:gochecknoglobals +var ( + unmarshalers sync.Map + unmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +func init() { + don.RegisterDecoder("text/plain", decode) +} diff --git a/encoding/text/decode_test.go b/encoding/text/decode_test.go new file mode 100644 index 0000000..783a6ff --- /dev/null +++ b/encoding/text/decode_test.go @@ -0,0 +1,57 @@ +package text_test + +import ( + "reflect" + "testing" + + "github.com/abemedia/go-don/encoding/text" + "github.com/abemedia/go-don/pkg/httptest" + "github.com/google/go-cmp/cmp" + "github.com/valyala/fasthttp" +) + +func TestDecode(t *testing.T) { + tests := []struct { + in string + want any + }{ + {"test\n", "test"}, + {"test\n", []byte("test")}, + {"5\n", int(5)}, + {"5\n", int8(5)}, + {"5\n", int16(5)}, + {"5\n", int32(5)}, + {"5\n", int64(5)}, + {"5\n", uint(5)}, + {"5\n", uint8(5)}, + {"5\n", uint16(5)}, + {"5\n", uint32(5)}, + {"5\n", uint64(5)}, + {"5.1\n", float32(5.1)}, + {"5.1\n", float64(5.1)}, + {"true\n", true}, + {"test\n", unmarshaler{"test"}}, + {"test\n", &unmarshaler{"test"}}, + } + + for _, test := range tests { + ctx := httptest.NewRequest(fasthttp.MethodGet, "/", test.in, nil) + v := reflect.New(reflect.TypeOf(test.want)).Interface() + if err := text.Decode(ctx, v); err != nil { + t.Error(err) + } else { + if diff := cmp.Diff(test.want, reflect.ValueOf(v).Elem().Interface()); diff != "" { + t.Errorf("%T: %s", test.in, diff) + } + } + } +} + +type unmarshaler struct { + S string +} + +func (m *unmarshaler) UnmarshalText(text []byte) error { + m.S = string(text) + return nil +} diff --git a/encoding/text/encode.go b/encoding/text/encode.go new file mode 100644 index 0000000..05f9de2 --- /dev/null +++ b/encoding/text/encode.go @@ -0,0 +1,72 @@ +package text + +import ( + "encoding" + "fmt" + "strconv" + + "github.com/abemedia/go-don" + "github.com/abemedia/go-don/internal/byteconv" + "github.com/valyala/fasthttp" +) + +//nolint:cyclop +func encode(ctx *fasthttp.RequestCtx, v interface{}) error { + var ( + b []byte + err error + ) + + if v != nil { + switch v := v.(type) { + case string: + b = byteconv.Atob(v) + case []byte: + b = v + case int: + b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10) + case int8: + b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10) + case int16: + b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10) + case int32: + b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10) + case int64: + b = strconv.AppendInt(ctx.Response.Body(), v, 10) + case uint: + b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10) + case uint8: + b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10) + case uint16: + b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10) + case uint32: + b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10) + case uint64: + b = strconv.AppendUint(ctx.Response.Body(), v, 10) + case float32: + b = strconv.AppendFloat(ctx.Response.Body(), float64(v), 'f', -1, 32) + case float64: + b = strconv.AppendFloat(ctx.Response.Body(), v, 'f', -1, 64) + case bool: + b = strconv.AppendBool(ctx.Response.Body(), v) + case error: + b = byteconv.Atob(v.Error()) + case encoding.TextMarshaler: + b, err = v.MarshalText() + case fmt.Stringer: + b = append(ctx.Response.Body(), v.String()...) + default: + return don.ErrNotAcceptable + } + } + + if len(b) > 0 { + ctx.Response.SetBodyRaw(append(b, '\n')) + } + + return err +} + +func init() { + don.RegisterEncoder("text/plain", encode) +} diff --git a/encoding/text/encode_test.go b/encoding/text/encode_test.go new file mode 100644 index 0000000..a1fe522 --- /dev/null +++ b/encoding/text/encode_test.go @@ -0,0 +1,62 @@ +package text_test + +import ( + "errors" + "testing" + + "github.com/abemedia/go-don/encoding/text" + "github.com/abemedia/go-don/pkg/httptest" + "github.com/google/go-cmp/cmp" + "github.com/valyala/fasthttp" +) + +func TestEncode(t *testing.T) { + tests := []struct { + in any + want string + }{ + {"test", "test\n"}, + {[]byte("test"), "test\n"}, + {int(5), "5\n"}, + {int8(5), "5\n"}, + {int16(5), "5\n"}, + {int32(5), "5\n"}, + {int64(5), "5\n"}, + {uint(5), "5\n"}, + {uint8(5), "5\n"}, + {uint16(5), "5\n"}, + {uint32(5), "5\n"}, + {uint64(5), "5\n"}, + {float32(5.1), "5.1\n"}, + {float64(5.1), "5.1\n"}, + {true, "true\n"}, + {errors.New("test"), "test\n"}, + {marshaler{}, "test\n"}, + {&marshaler{}, "test\n"}, + {stringer{}, "test\n"}, + {&stringer{}, "test\n"}, + } + + for _, test := range tests { + ctx := httptest.NewRequest(fasthttp.MethodGet, "/", "", nil) + if err := text.Encode(ctx, test.in); err != nil { + t.Error(err) + continue + } + if diff := cmp.Diff(test.want, string(ctx.Response.Body())); diff != "" { + t.Errorf("%T: %s", test.in, diff) + } + } +} + +type marshaler struct{} + +func (m marshaler) MarshalText() ([]byte, error) { + return []byte("test"), nil +} + +type stringer struct{} + +func (m stringer) String() string { + return "test" +} diff --git a/encoding/text/export_test.go b/encoding/text/export_test.go new file mode 100644 index 0000000..1aa5334 --- /dev/null +++ b/encoding/text/export_test.go @@ -0,0 +1,6 @@ +package text + +var ( + Encode = encode + Decode = decode +) diff --git a/encoding/text/text.go b/encoding/text/text.go deleted file mode 100644 index 12f276d..0000000 --- a/encoding/text/text.go +++ /dev/null @@ -1,174 +0,0 @@ -package json - -import ( - "fmt" - "strconv" - "unsafe" - - "github.com/abemedia/go-don" - "github.com/valyala/fasthttp" -) - -func b2s(b []byte) string { - return *(*string)(unsafe.Pointer(&b)) -} - -//nolint:cyclop -func decodeText(ctx *fasthttp.RequestCtx, v interface{}) error { - b := ctx.Request.Body() - if len(b) == 0 { - return nil - } - - switch t := v.(type) { - case *string: - *t = b2s(b) - - case *[]byte: - *t = b - - case *int: - d, err := strconv.Atoi(b2s(b)) - if err != nil { - return err - } - - *t = d - - case *int8: - d, err := strconv.ParseInt(b2s(b), 10, 8) - if err != nil { - return err - } - - *t = int8(d) - - case *int16: - d, err := strconv.ParseInt(b2s(b), 10, 16) - if err != nil { - return err - } - - *t = int16(d) - - case *int32: - d, err := strconv.ParseInt(b2s(b), 10, 32) - if err != nil { - return err - } - - *t = int32(d) - - case *int64: - d, err := strconv.ParseInt(b2s(b), 10, 64) - if err != nil { - return err - } - - *t = d - - case *uint: - d, err := strconv.ParseUint(b2s(b), 10, 0) - if err != nil { - return err - } - - *t = uint(d) - - case *uint8: - d, err := strconv.ParseUint(b2s(b), 10, 8) - if err != nil { - return err - } - - *t = uint8(d) - - case *uint16: - d, err := strconv.ParseUint(b2s(b), 10, 16) - if err != nil { - return err - } - - *t = uint16(d) - - case *uint32: - d, err := strconv.ParseUint(b2s(b), 10, 32) - if err != nil { - return err - } - - *t = uint32(d) - - case *uint64: - d, err := strconv.ParseUint(b2s(b), 10, 64) - if err != nil { - return err - } - - *t = d - - case *float32: - d, err := strconv.ParseFloat(b2s(b), 32) - if err != nil { - return err - } - - *t = float32(d) - - case *float64: - d, err := strconv.ParseFloat(b2s(b), 64) - if err != nil { - return err - } - - *t = d - - case *bool: - d, err := strconv.ParseBool(b2s(b)) - if err != nil { - return err - } - - *t = d - - default: - return don.ErrUnsupportedMediaType - } - - return nil -} - -func encodeText(ctx *fasthttp.RequestCtx, v interface{}) error { - if v != nil { - switch v.(type) { - case *string, string, - *[]byte, []byte, - *int, int, - *int8, int8, - *int16, int16, - *int32, int32, - *int64, int64, - *uint, uint, - *uint8, uint8, - *uint16, uint16, - *uint32, uint32, - *uint64, uint64, - *float32, float32, - *float64, float64, - *bool, bool, - error: - - default: - return don.ErrNotAcceptable - } - } - - _, err := fmt.Fprintln(ctx, v) - - return err -} - -func init() { - don.RegisterDecoder("text/plain", decodeText) - don.RegisterEncoder("text/plain", encodeText) -}