Skip to content

Commit

Permalink
fix: unmarshal to interface is not work
Browse files Browse the repository at this point in the history
  • Loading branch information
bytemain committed Mar 6, 2024
1 parent a9f4f03 commit 882bb9a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 79 deletions.
25 changes: 9 additions & 16 deletions internal/luai/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ import (

// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) reflect.Value {
// Issue #24153 indicates that it is generally not a guaranteed property
func indirect(v reflect.Value) reflect.Value {
// Issue https://github.com/golang/go/issues/24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
Expand All @@ -45,7 +42,7 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value {
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) {
if e.Kind() == reflect.Pointer && !e.IsNil() && (e.Elem().Kind() == reflect.Pointer) {
haveAddr = false
v = e
continue
Expand All @@ -56,10 +53,6 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value {
break
}

if decodingNull && v.CanSet() {
break
}

// Prevent infinite loop if v is an interface pointing to its own address:
// var v interface{}
// v = &v
Expand All @@ -82,7 +75,7 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value {
}

func storeLiteral(value reflect.Value, lvalue lua.LValue) {
value = indirect(value, false)
value = indirect(value)
switch lvalue.Type() {
case lua.LTString:
value.SetString(lvalue.String())
Expand All @@ -104,7 +97,8 @@ func objectInterface(lvalue *lua.LTable) any {
func valueInterface(lvalue lua.LValue) any {
switch lvalue.Type() {
case lua.LTTable:
isArray := lvalue.(*lua.LTable).RawGetInt(0) != lua.LNil
isArray := lvalue.(*lua.LTable).RawGetInt(1) != lua.LNil
logger.Infof("isArray: %v\n", isArray)
if isArray {
return arrayInterface(lvalue.(*lua.LTable))
}
Expand All @@ -117,7 +111,6 @@ func valueInterface(lvalue lua.LValue) any {
return bool(lvalue.(lua.LBool))
}
return nil

}

func arrayInterface(lvalue *lua.LTable) any {
Expand All @@ -134,7 +127,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error {

switch value.Type() {
case lua.LTTable:
reflected = indirect(reflected, false)
reflected = indirect(reflected)
tagMap := make(map[string]int)

switch reflected.Kind() {
Expand All @@ -154,7 +147,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error {
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
return errors.New("luai: unsupported map key type " + keyType.String())
return errors.New("unmarshal: unsupported map key type " + keyType.String())
}

if reflected.IsNil() {
Expand Down Expand Up @@ -202,7 +195,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error {
kv = reflect.New(keyType).Elem()
kv.SetUint(n)
default:
panic("luai: Unexpected key type") // should never occur
panic("unmarshal: Unexpected key type") // should never occur
}
if kv.IsValid() {
reflected.SetMapIndex(kv, subv)
Expand Down
91 changes: 28 additions & 63 deletions internal/luai/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,31 +57,14 @@ func TestRegular(t *testing.T) {

table := _table.(*lua.LTable)

field1 := table.RawGetString("Field1")
if field1.Type() != lua.LTString {
t.Errorf("expected string, got %s", field1.Type())
}

if field1.String() != "test" {
t.Errorf("expected 'test', got '%s'", field1.String())
}

field2 := table.RawGetString("Field2")
if field2.Type() != lua.LTNumber {
t.Errorf("expected number, got %s", field2.Type())
}

if field2.String() != "1" {
t.Errorf("expected '1', got '%s'", field2.String())
}

field3 := table.RawGetString("Field3")
if field3.Type() != lua.LTBool {
t.Errorf("expected bool, got %s", field3.Type())
}
luaVm.SetGlobal("table", table)

if field3.String() != "true" {
t.Errorf("expected 'true', got '%s'", field3.String())
if err := luaVm.DoString(`
assert(table.Field1 == "test")
assert(table.Field2 == 1)
assert(table.Field3 == true)
`); err != nil {
t.Fatal(err)
}

struct2 := testStruct{}
Expand Down Expand Up @@ -118,31 +101,13 @@ func TestTag(t *testing.T) {

table := _table.(*lua.LTable)

field1 := table.RawGetString("field1")
if field1.Type() != lua.LTString {
t.Errorf("expected string, got %s", field1.Type())
}

if field1.String() != "test" {
t.Errorf("expected 'test', got '%s'", field1.String())
}

field2 := table.RawGetString("field2")
if field2.Type() != lua.LTNumber {
t.Errorf("expected number, got %s", field2.Type())
}

if field2.String() != "1" {
t.Errorf("expected '1', got '%s'", field2.String())
}

field3 := table.RawGetString("field3")
if field3.Type() != lua.LTBool {
t.Errorf("expected bool, got %s", field3.Type())
}

if field3.String() != "true" {
t.Errorf("expected 'true', got '%s'", field3.String())
luaVm.SetGlobal("table", table)
if err := luaVm.DoString(`
assert(table.field1 == "test")
assert(table.field2 == 1)
assert(table.field3 == true)
`); err != nil {
t.Fatal(err)
}

struct2 := testStructTag{}
Expand Down Expand Up @@ -211,14 +176,9 @@ func TestMapAndSlice(t *testing.T) {
}

fmt.Printf("m2: %+v\n", m2)
if m2["key1"] != "value1" {
t.Errorf("expected value1, got %v", m2["key1"])
}
if m2["key2"] != 2 {
t.Errorf("expected 2, got %v", m2["key2"])
}
if m2["key3"] != true {
t.Errorf("expected true, got %v", m2["key3"])

if !reflect.DeepEqual(m, m2) {
t.Errorf("expected %+v, got %+v", m, m2)
}

// Test case for slice
Expand All @@ -231,15 +191,20 @@ func TestMapAndSlice(t *testing.T) {

fmt.Printf("s2: %+v\n", s2)

if s2[0] != "value1" {
t.Errorf("expected value1, got %v", s2[0])
if !reflect.DeepEqual(s, s2) {
t.Errorf("expected %+v, got %+v", s, s2)
}
if s2[1] != 2 {
t.Errorf("expected 2, got %v", s2[1])

var s3 any
err = Unmarshal(slice, &s3)
if err != nil {
t.Fatalf("unmarshal slice failed: %v", err)
}
if s2[2] != true {
t.Errorf("expected true, got %v", s2[2])

if !reflect.DeepEqual(s, s3) {
t.Errorf("expected %+v, got %+v", s, s3)
}

}

type complexStruct struct {
Expand Down

0 comments on commit 882bb9a

Please sign in to comment.