Skip to content

Commit

Permalink
feat(encoding/text): support marshaler & stringer, improve performance
Browse files Browse the repository at this point in the history
  • Loading branch information
abemedia committed Jan 14, 2023
1 parent 0de3fc3 commit 83a2875
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 174 deletions.
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ issues:
linters:
- errcheck
- funlen
- goerr113
- gosec
- path: (.+)_test.go
linters: [govet]
Expand Down
134 changes: 134 additions & 0 deletions encoding/text/decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package text

import (
"bytes"
"encoding"
"reflect"
"strconv"
"sync"

"github.com/abemedia/go-don"
"github.com/abemedia/go-don/internal/byteconv"
"github.com/valyala/fasthttp"
)

//nolint:cyclop
func decode(ctx *fasthttp.RequestCtx, v interface{}) error {
b := bytes.TrimSpace(ctx.Request.Body())
if len(b) == 0 {
return nil
}

var err error

switch v := v.(type) {
case *string:
*v = byteconv.Btoa(b)
case *[]byte:
*v = b
case *int:
*v, err = strconv.Atoi(byteconv.Btoa(b))
case *int8:
return decodeInt(b, v, 8)
case *int16:
return decodeInt(b, v, 16)
case *int32:
return decodeInt(b, v, 32)
case *int64:
return decodeInt(b, v, 64)
case *uint:
return decodeUint(b, v, 0)
case *uint8:
return decodeUint(b, v, 8)
case *uint16:
return decodeUint(b, v, 16)
case *uint32:
return decodeUint(b, v, 32)
case *uint64:
return decodeUint(b, v, 64)
case *float32:
return decodeFloat(b, v, 32)
case *float64:
return decodeFloat(b, v, 64)
case *bool:
*v, err = strconv.ParseBool(byteconv.Btoa(b))
default:
return unmarshal(b, v)
}

return err
}

func decodeInt[T int | int8 | int16 | int32 | int64](b []byte, v *T, bits int) error {
d, err := strconv.ParseInt(byteconv.Btoa(b), 10, bits)
*v = T(d)
return err
}

func decodeUint[T uint | uint8 | uint16 | uint32 | uint64](b []byte, v *T, bits int) error {
d, err := strconv.ParseUint(byteconv.Btoa(b), 10, bits)
*v = T(d)
return err
}

func decodeFloat[T float32 | float64](b []byte, v *T, bits int) error {
d, err := strconv.ParseFloat(byteconv.Btoa(b), bits)
*v = T(d)
return err
}

func unmarshal(b []byte, v any) error {
val := reflect.ValueOf(v)
typ := val.Type()
if dec, ok := unmarshalers.Load(typ); ok {
return dec.(func([]byte, reflect.Value) error)(b, val)
}
dec, err := newUnmarshaler(typ)
if err != nil {
return err
}
unmarshalers.Store(typ, dec)
return dec(b, val)
}

func newUnmarshaler(typ reflect.Type) (func([]byte, reflect.Value) error, error) {
if typ.Implements(unmarshalerType) {
isPtr := typ.Kind() == reflect.Pointer
typ = typ.Elem()
return func(b []byte, v reflect.Value) error {
if len(b) == 0 {
return nil
}
if isPtr && v.IsNil() {
v.Set(reflect.New(typ))
}
return v.Interface().(encoding.TextUnmarshaler).UnmarshalText(b) //nolint:forcetypeassert
}, nil
}

if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
dec, err := newUnmarshaler(typ)
if err != nil {
return nil, err
}
return func(b []byte, v reflect.Value) error {
if v.IsNil() {
v.Set(reflect.New(typ))
}
return dec(b, v.Elem())
}, nil
}

return nil, don.ErrUnsupportedMediaType
}

//nolint:gochecknoglobals
var (
unmarshalers sync.Map
unmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
)

func init() {
don.RegisterDecoder("text/plain", decode)
}
57 changes: 57 additions & 0 deletions encoding/text/decode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package text_test

import (
"reflect"
"testing"

"github.com/abemedia/go-don/encoding/text"
"github.com/abemedia/go-don/pkg/httptest"
"github.com/google/go-cmp/cmp"
"github.com/valyala/fasthttp"
)

func TestDecode(t *testing.T) {
tests := []struct {
in string
want any
}{
{"test\n", "test"},
{"test\n", []byte("test")},
{"5\n", int(5)},
{"5\n", int8(5)},
{"5\n", int16(5)},
{"5\n", int32(5)},
{"5\n", int64(5)},
{"5\n", uint(5)},
{"5\n", uint8(5)},
{"5\n", uint16(5)},
{"5\n", uint32(5)},
{"5\n", uint64(5)},
{"5.1\n", float32(5.1)},
{"5.1\n", float64(5.1)},
{"true\n", true},
{"test\n", unmarshaler{"test"}},
{"test\n", &unmarshaler{"test"}},
}

for _, test := range tests {
ctx := httptest.NewRequest(fasthttp.MethodGet, "/", test.in, nil)
v := reflect.New(reflect.TypeOf(test.want)).Interface()
if err := text.Decode(ctx, v); err != nil {
t.Error(err)
} else {
if diff := cmp.Diff(test.want, reflect.ValueOf(v).Elem().Interface()); diff != "" {
t.Errorf("%T: %s", test.in, diff)
}
}
}
}

type unmarshaler struct {
S string
}

func (m *unmarshaler) UnmarshalText(text []byte) error {
m.S = string(text)
return nil
}
72 changes: 72 additions & 0 deletions encoding/text/encode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package text

import (
"encoding"
"fmt"
"strconv"

"github.com/abemedia/go-don"
"github.com/abemedia/go-don/internal/byteconv"
"github.com/valyala/fasthttp"
)

//nolint:cyclop
func encode(ctx *fasthttp.RequestCtx, v interface{}) error {
var (
b []byte
err error
)

if v != nil {
switch v := v.(type) {
case string:
b = byteconv.Atob(v)
case []byte:
b = v
case int:
b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10)
case int8:
b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10)
case int16:
b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10)
case int32:
b = strconv.AppendInt(ctx.Response.Body(), int64(v), 10)
case int64:
b = strconv.AppendInt(ctx.Response.Body(), v, 10)
case uint:
b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10)
case uint8:
b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10)
case uint16:
b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10)
case uint32:
b = strconv.AppendUint(ctx.Response.Body(), uint64(v), 10)
case uint64:
b = strconv.AppendUint(ctx.Response.Body(), v, 10)
case float32:
b = strconv.AppendFloat(ctx.Response.Body(), float64(v), 'f', -1, 32)
case float64:
b = strconv.AppendFloat(ctx.Response.Body(), v, 'f', -1, 64)
case bool:
b = strconv.AppendBool(ctx.Response.Body(), v)
case error:
b = byteconv.Atob(v.Error())
case encoding.TextMarshaler:
b, err = v.MarshalText()
case fmt.Stringer:
b = append(ctx.Response.Body(), v.String()...)
default:
return don.ErrNotAcceptable
}
}

if len(b) > 0 {
ctx.Response.SetBodyRaw(append(b, '\n'))
}

return err
}

func init() {
don.RegisterEncoder("text/plain", encode)
}
62 changes: 62 additions & 0 deletions encoding/text/encode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package text_test

import (
"errors"
"testing"

"github.com/abemedia/go-don/encoding/text"
"github.com/abemedia/go-don/pkg/httptest"
"github.com/google/go-cmp/cmp"
"github.com/valyala/fasthttp"
)

func TestEncode(t *testing.T) {
tests := []struct {
in any
want string
}{
{"test", "test\n"},
{[]byte("test"), "test\n"},
{int(5), "5\n"},
{int8(5), "5\n"},
{int16(5), "5\n"},
{int32(5), "5\n"},
{int64(5), "5\n"},
{uint(5), "5\n"},
{uint8(5), "5\n"},
{uint16(5), "5\n"},
{uint32(5), "5\n"},
{uint64(5), "5\n"},
{float32(5.1), "5.1\n"},
{float64(5.1), "5.1\n"},
{true, "true\n"},
{errors.New("test"), "test\n"},
{marshaler{}, "test\n"},
{&marshaler{}, "test\n"},
{stringer{}, "test\n"},
{&stringer{}, "test\n"},
}

for _, test := range tests {
ctx := httptest.NewRequest(fasthttp.MethodGet, "/", "", nil)
if err := text.Encode(ctx, test.in); err != nil {
t.Error(err)
continue
}
if diff := cmp.Diff(test.want, string(ctx.Response.Body())); diff != "" {
t.Errorf("%T: %s", test.in, diff)
}
}
}

type marshaler struct{}

func (m marshaler) MarshalText() ([]byte, error) {
return []byte("test"), nil
}

type stringer struct{}

func (m stringer) String() string {
return "test"
}
6 changes: 6 additions & 0 deletions encoding/text/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package text

var (
Encode = encode
Decode = decode
)
Loading

0 comments on commit 83a2875

Please sign in to comment.