diff --git a/protocol/transcode/core.go b/protocol/transcode/core.go index 11ddb6a604..993131f932 100644 --- a/protocol/transcode/core.go +++ b/protocol/transcode/core.go @@ -91,6 +91,36 @@ func Transcode(mpToJSON bool, base32Encoding, strictJSON bool, in io.Reader, out } } +func isSliceOfBytes(a interface{}) bool { + switch v := a.(type) { + case []interface{}: + for _, e := range v { + _, ok := e.([]byte) + if !ok { + return false + } + } + return len(v) > 0 // No need to treat empty slice specially + default: + return false + } +} + +func isSliceOfString(a interface{}) bool { + switch v := a.(type) { + case []interface{}: + for _, e := range v { + _, ok := e.(string) + if !ok { + return false + } + } + return len(v) > 0 // No need to treat empty slice specially + default: + return false + } +} + func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} { switch v := a.(type) { case map[interface{}]interface{}: @@ -100,16 +130,23 @@ func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} { // a []byte, base64-encode the entry and append // ":b64" to the key (or, if the base32Encoding flag // is set, base32-encode and append ":b32"). - ks, ok1 := k.(string) - eb, ok2 := e.([]byte) + ks, keyIsString := k.(string) + eb, entryIsBytes := e.([]byte) - if ok1 && ok2 { + switch { + case keyIsString && entryIsBytes: if base32Encoding { r[fmt.Sprintf("%s:b32", ks)] = base32.StdEncoding.EncodeToString(eb) } else { r[fmt.Sprintf("%s:b64", ks)] = base64.StdEncoding.EncodeToString(eb) } - } else { + case keyIsString && isSliceOfBytes(e): + if base32Encoding { + r[fmt.Sprintf("%s:b32", ks)] = toJSON(e, base32Encoding, strictJSON) + } else { + r[fmt.Sprintf("%s:b64", ks)] = toJSON(e, base32Encoding, strictJSON) + } + default: if strictJSON { k = fmt.Sprintf("%v", k) } @@ -133,6 +170,28 @@ func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} { } } +func decodeSliceOfString(a interface{}, decodeFunc func(string) ([]byte, error)) ([][]byte, error) { + v, ok := a.([]interface{}) + if !ok { + return nil, fmt.Errorf("expected []interface{} for decodeSliceOfString") + } + + var all [][]byte + for _, e := range v { + es, entryIsString := e.(string) + if !entryIsString { + return nil, fmt.Errorf("expected string element in slice") + } + decoded, err := decodeFunc(es) + if err != nil { + return nil, err + } + all = append(all, decoded) + } + + return all, nil +} + func fromJSON(a interface{}) interface{} { switch v := a.(type) { case map[interface{}]interface{}: @@ -142,24 +201,35 @@ func fromJSON(a interface{}) interface{} { // ":b64", and entry is a string, then base64-decode // the entry and drop the ":b64" from the key. // Same for ":b32" and base32-decoding. - ks, ok1 := k.(string) - es, ok2 := e.(string) + ks, keyIsString := k.(string) + es, entryIsString := e.(string) - if ok1 && ok2 && strings.HasSuffix(ks, ":b64") { + switch { + case keyIsString && strings.HasSuffix(ks, ":b64") && entryIsString: eb, err := base64.StdEncoding.DecodeString(es) if err != nil { panic(err) } - r[ks[:len(ks)-4]] = eb - } else if ok1 && ok2 && strings.HasSuffix(ks, ":b32") { + case keyIsString && strings.HasSuffix(ks, ":b32") && entryIsString: eb, err := base32.StdEncoding.DecodeString(es) if err != nil { panic(err) } - r[ks[:len(ks)-4]] = eb - } else { + case keyIsString && strings.HasSuffix(ks, ":b64") && isSliceOfString(e): + eb, err := decodeSliceOfString(e, base64.StdEncoding.DecodeString) + if err != nil { + panic(err) + } + r[ks[:len(ks)-4]] = eb + case keyIsString && strings.HasSuffix(ks, ":b32") && isSliceOfString(e): + eb, err := decodeSliceOfString(e, base32.StdEncoding.DecodeString) + if err != nil { + panic(err) + } + r[ks[:len(ks)-4]] = eb + default: r[fromJSON(k)] = fromJSON(e) } } @@ -167,29 +237,40 @@ func fromJSON(a interface{}) interface{} { case map[string]interface{}: r := make(map[string]interface{}) - for k, e := range v { + for ks, e := range v { // Special case: if key ends in ":b64", and entry // is a string, then base64-decode the entry and // drop the ":b64" from the key. Same for ":b32" // and base32-decoding. - es, ok := e.(string) + es, entryIsString := e.(string) - if ok && strings.HasSuffix(k, ":b64") { + switch { + case strings.HasSuffix(ks, ":b64") && entryIsString: eb, err := base64.StdEncoding.DecodeString(es) if err != nil { panic(err) } - - r[k[:len(k)-4]] = eb - } else if ok && strings.HasSuffix(k, ":b32") { + r[ks[:len(ks)-4]] = eb + case strings.HasSuffix(ks, ":b32") && entryIsString: eb, err := base32.StdEncoding.DecodeString(es) if err != nil { panic(err) } - - r[k[:len(k)-4]] = eb - } else { - r[k] = fromJSON(e) + r[ks[:len(ks)-4]] = eb + case strings.HasSuffix(ks, ":b64") && isSliceOfString(e): + eb, err := decodeSliceOfString(e, base64.StdEncoding.DecodeString) + if err != nil { + panic(err) + } + r[ks[:len(ks)-4]] = eb + case strings.HasSuffix(ks, ":b32") && isSliceOfString(e): + eb, err := decodeSliceOfString(e, base32.StdEncoding.DecodeString) + if err != nil { + panic(err) + } + r[ks[:len(ks)-4]] = eb + default: + r[ks] = fromJSON(e) } } return r diff --git a/protocol/transcode/core_test.go b/protocol/transcode/core_test.go index 8dec7b1ba5..72ceee9c6f 100644 --- a/protocol/transcode/core_test.go +++ b/protocol/transcode/core_test.go @@ -18,6 +18,7 @@ package transcode import ( "encoding/base32" + "encoding/base64" "fmt" "io" "testing" @@ -58,7 +59,8 @@ func testIdempotentRoundtrip(t *testing.T, mpdata []byte) { res, err := io.ReadAll(p3out) require.NoError(t, err) - require.Equal(t, mpdata, res) + require.Equal(t, mpdata, res, + "%v != %v", base64.StdEncoding.EncodeToString(mpdata), base64.StdEncoding.EncodeToString(res)) } type objectType int @@ -117,27 +119,39 @@ func randomObjectOfType(randtype uint64, width int, depth int) interface{} { return base32.StdEncoding.EncodeToString(buf[:]) case objectArray: var arr [2]interface{} - for i := 0; i < len(arr); i++ { + if crypto.RandUint64()%2 == 0 { // half the time, make the slice a uniform type t := crypto.RandUint64() - if t%uint64(objectTypeMax) == uint64(objectBytes) { - // We cannot cleanly pass through an array of - // binary blobs. - t++ + for i := range arr { + arr[i] = randomObjectOfType(t, width, depth-1) + } + } else { + for i := range arr { + t := crypto.RandUint64() + if t%uint64(objectTypeMax) == uint64(objectBytes) { + // We cannot cleanly handle binary blobs unless the entire array is. + t++ + } + arr[i] = randomObjectOfType(t, width, depth-1) } - arr[i] = randomObjectOfType(t, width, depth-1) } return arr case objectSlice: slice := make([]interface{}, 0) sz := crypto.RandUint64() % uint64(width) - for i := 0; i < int(sz); i++ { + if crypto.RandUint64()%2 == 0 { // half the time, make the slice a uniform type t := crypto.RandUint64() - if t%uint64(objectTypeMax) == uint64(objectBytes) { - // We cannot cleanly pass through an array of - // binary blobs. - t++ + for range sz { + slice = append(slice, randomObjectOfType(t, width, depth-1)) + } + } else { + for range sz { + t := crypto.RandUint64() + if t%uint64(objectTypeMax) == uint64(objectBytes) { + // We cannot cleanly handle binary blobs unless the entire slice is. + t++ + } + slice = append(slice, randomObjectOfType(t, width, depth-1)) } - slice = append(slice, randomObjectOfType(t, width, depth-1)) } return slice case objectMap: @@ -172,7 +186,7 @@ func TestIdempotence(t *testing.T) { } for i := 0; i < niter; i++ { - o := randomMap(6, 3) + o := randomMap(i%7, i%3) testIdempotentRoundtrip(t, protocol.EncodeReflect(o)) } } @@ -189,7 +203,7 @@ func TestIdempotenceMultiobject(t *testing.T) { nobj := crypto.RandUint64() % 8 buf := []byte{} for j := 0; j < int(nobj); j++ { - buf = append(buf, protocol.EncodeReflect(randomMap(6, 3))...) + buf = append(buf, protocol.EncodeReflect(randomMap(i%7, i%3))...) } testIdempotentRoundtrip(t, buf) }