diff --git a/internal/luai/decode.go b/internal/luai/decode.go index 7d9cf657..d9d4b763 100644 --- a/internal/luai/decode.go +++ b/internal/luai/decode.go @@ -15,11 +15,8 @@ import ( // indirect walks down v allocating pointers as needed, // until it gets to a non-pointer. -// If it encounters an Unmarshaler, indirect stops and returns that. -// If decodingNull is true, indirect stops at the first settable pointer so it -// can be set to nil. -func indirect(v reflect.Value, decodingNull bool) reflect.Value { - // Issue #24153 indicates that it is generally not a guaranteed property +func indirect(v reflect.Value) reflect.Value { + // Issue https://github.com/golang/go/issues/24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from // unexported embedded struct fields. @@ -45,7 +42,7 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value { // usefully addressable. if v.Kind() == reflect.Interface && !v.IsNil() { e := v.Elem() - if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + if e.Kind() == reflect.Pointer && !e.IsNil() && (e.Elem().Kind() == reflect.Pointer) { haveAddr = false v = e continue @@ -56,10 +53,6 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value { break } - if decodingNull && v.CanSet() { - break - } - // Prevent infinite loop if v is an interface pointing to its own address: // var v interface{} // v = &v @@ -82,7 +75,7 @@ func indirect(v reflect.Value, decodingNull bool) reflect.Value { } func storeLiteral(value reflect.Value, lvalue lua.LValue) { - value = indirect(value, false) + value = indirect(value) switch lvalue.Type() { case lua.LTString: value.SetString(lvalue.String()) @@ -104,7 +97,8 @@ func objectInterface(lvalue *lua.LTable) any { func valueInterface(lvalue lua.LValue) any { switch lvalue.Type() { case lua.LTTable: - isArray := lvalue.(*lua.LTable).RawGetInt(0) != lua.LNil + isArray := lvalue.(*lua.LTable).RawGetInt(1) != lua.LNil + logger.Infof("isArray: %v\n", isArray) if isArray { return arrayInterface(lvalue.(*lua.LTable)) } @@ -117,7 +111,6 @@ func valueInterface(lvalue lua.LValue) any { return bool(lvalue.(lua.LBool)) } return nil - } func arrayInterface(lvalue *lua.LTable) any { @@ -134,7 +127,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { switch value.Type() { case lua.LTTable: - reflected = indirect(reflected, false) + reflected = indirect(reflected) tagMap := make(map[string]int) switch reflected.Kind() { @@ -154,7 +147,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: default: - return errors.New("luai: unsupported map key type " + keyType.String()) + return errors.New("unmarshal: unsupported map key type " + keyType.String()) } if reflected.IsNil() { @@ -202,7 +195,7 @@ func unmarshalWorker(value lua.LValue, reflected reflect.Value) error { kv = reflect.New(keyType).Elem() kv.SetUint(n) default: - panic("luai: Unexpected key type") // should never occur + panic("unmarshal: Unexpected key type") // should never occur } if kv.IsValid() { reflected.SetMapIndex(kv, subv) diff --git a/internal/luai/encoding_test.go b/internal/luai/encoding_test.go index 10ec3bc4..520c2b9b 100644 --- a/internal/luai/encoding_test.go +++ b/internal/luai/encoding_test.go @@ -57,31 +57,14 @@ func TestRegular(t *testing.T) { table := _table.(*lua.LTable) - field1 := table.RawGetString("Field1") - if field1.Type() != lua.LTString { - t.Errorf("expected string, got %s", field1.Type()) - } - - if field1.String() != "test" { - t.Errorf("expected 'test', got '%s'", field1.String()) - } - - field2 := table.RawGetString("Field2") - if field2.Type() != lua.LTNumber { - t.Errorf("expected number, got %s", field2.Type()) - } - - if field2.String() != "1" { - t.Errorf("expected '1', got '%s'", field2.String()) - } - - field3 := table.RawGetString("Field3") - if field3.Type() != lua.LTBool { - t.Errorf("expected bool, got %s", field3.Type()) - } + luaVm.SetGlobal("table", table) - if field3.String() != "true" { - t.Errorf("expected 'true', got '%s'", field3.String()) + if err := luaVm.DoString(` + assert(table.Field1 == "test") + assert(table.Field2 == 1) + assert(table.Field3 == true) + `); err != nil { + t.Fatal(err) } struct2 := testStruct{} @@ -118,31 +101,13 @@ func TestTag(t *testing.T) { table := _table.(*lua.LTable) - field1 := table.RawGetString("field1") - if field1.Type() != lua.LTString { - t.Errorf("expected string, got %s", field1.Type()) - } - - if field1.String() != "test" { - t.Errorf("expected 'test', got '%s'", field1.String()) - } - - field2 := table.RawGetString("field2") - if field2.Type() != lua.LTNumber { - t.Errorf("expected number, got %s", field2.Type()) - } - - if field2.String() != "1" { - t.Errorf("expected '1', got '%s'", field2.String()) - } - - field3 := table.RawGetString("field3") - if field3.Type() != lua.LTBool { - t.Errorf("expected bool, got %s", field3.Type()) - } - - if field3.String() != "true" { - t.Errorf("expected 'true', got '%s'", field3.String()) + luaVm.SetGlobal("table", table) + if err := luaVm.DoString(` + assert(table.field1 == "test") + assert(table.field2 == 1) + assert(table.field3 == true) + `); err != nil { + t.Fatal(err) } struct2 := testStructTag{} @@ -211,14 +176,9 @@ func TestMapAndSlice(t *testing.T) { } fmt.Printf("m2: %+v\n", m2) - if m2["key1"] != "value1" { - t.Errorf("expected value1, got %v", m2["key1"]) - } - if m2["key2"] != 2 { - t.Errorf("expected 2, got %v", m2["key2"]) - } - if m2["key3"] != true { - t.Errorf("expected true, got %v", m2["key3"]) + + if !reflect.DeepEqual(m, m2) { + t.Errorf("expected %+v, got %+v", m, m2) } // Test case for slice @@ -231,15 +191,20 @@ func TestMapAndSlice(t *testing.T) { fmt.Printf("s2: %+v\n", s2) - if s2[0] != "value1" { - t.Errorf("expected value1, got %v", s2[0]) + if !reflect.DeepEqual(s, s2) { + t.Errorf("expected %+v, got %+v", s, s2) } - if s2[1] != 2 { - t.Errorf("expected 2, got %v", s2[1]) + + var s3 any + err = Unmarshal(slice, &s3) + if err != nil { + t.Fatalf("unmarshal slice failed: %v", err) } - if s2[2] != true { - t.Errorf("expected true, got %v", s2[2]) + + if !reflect.DeepEqual(s, s3) { + t.Errorf("expected %+v, got %+v", s, s3) } + } type complexStruct struct {