Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 102 additions & 21 deletions protocol/transcode/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,36 @@ func Transcode(mpToJSON bool, base32Encoding, strictJSON bool, in io.Reader, out
}
}

func isSliceOfBytes(a interface{}) bool {
switch v := a.(type) {
case []interface{}:
for _, e := range v {
_, ok := e.([]byte)
if !ok {
return false
}
}
return len(v) > 0 // No need to treat empty slice specially
default:
return false
}
}

func isSliceOfString(a interface{}) bool {
switch v := a.(type) {
case []interface{}:
for _, e := range v {
_, ok := e.(string)
if !ok {
return false
}
}
return len(v) > 0 // No need to treat empty slice specially
default:
return false
}
}

func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} {
switch v := a.(type) {
case map[interface{}]interface{}:
Expand All @@ -100,16 +130,23 @@ func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} {
// a []byte, base64-encode the entry and append
// ":b64" to the key (or, if the base32Encoding flag
// is set, base32-encode and append ":b32").
ks, ok1 := k.(string)
eb, ok2 := e.([]byte)
ks, keyIsString := k.(string)
eb, entryIsBytes := e.([]byte)

if ok1 && ok2 {
switch {
case keyIsString && entryIsBytes:
if base32Encoding {
r[fmt.Sprintf("%s:b32", ks)] = base32.StdEncoding.EncodeToString(eb)
} else {
r[fmt.Sprintf("%s:b64", ks)] = base64.StdEncoding.EncodeToString(eb)
}
} else {
case keyIsString && isSliceOfBytes(e):
if base32Encoding {
r[fmt.Sprintf("%s:b32", ks)] = toJSON(e, base32Encoding, strictJSON)
} else {
r[fmt.Sprintf("%s:b64", ks)] = toJSON(e, base32Encoding, strictJSON)
}
default:
if strictJSON {
k = fmt.Sprintf("%v", k)
}
Expand All @@ -133,6 +170,28 @@ func toJSON(a interface{}, base32Encoding, strictJSON bool) interface{} {
}
}

func decodeSliceOfString(a interface{}, decodeFunc func(string) ([]byte, error)) ([][]byte, error) {
v, ok := a.([]interface{})
if !ok {
return nil, fmt.Errorf("expected []interface{} for decodeSliceOfString")
}

var all [][]byte
for _, e := range v {
es, entryIsString := e.(string)
if !entryIsString {
return nil, fmt.Errorf("expected string element in slice")
}
decoded, err := decodeFunc(es)
if err != nil {
return nil, err
}
all = append(all, decoded)
}

return all, nil
}

func fromJSON(a interface{}) interface{} {
switch v := a.(type) {
case map[interface{}]interface{}:
Expand All @@ -142,54 +201,76 @@ func fromJSON(a interface{}) interface{} {
// ":b64", and entry is a string, then base64-decode
// the entry and drop the ":b64" from the key.
// Same for ":b32" and base32-decoding.
ks, ok1 := k.(string)
es, ok2 := e.(string)
ks, keyIsString := k.(string)
es, entryIsString := e.(string)

if ok1 && ok2 && strings.HasSuffix(ks, ":b64") {
switch {
case keyIsString && strings.HasSuffix(ks, ":b64") && entryIsString:
eb, err := base64.StdEncoding.DecodeString(es)
if err != nil {
panic(err)
}

r[ks[:len(ks)-4]] = eb
} else if ok1 && ok2 && strings.HasSuffix(ks, ":b32") {
case keyIsString && strings.HasSuffix(ks, ":b32") && entryIsString:
eb, err := base32.StdEncoding.DecodeString(es)
if err != nil {
panic(err)
}

r[ks[:len(ks)-4]] = eb
} else {
case keyIsString && strings.HasSuffix(ks, ":b64") && isSliceOfString(e):
eb, err := decodeSliceOfString(e, base64.StdEncoding.DecodeString)
if err != nil {
panic(err)
}
r[ks[:len(ks)-4]] = eb
case keyIsString && strings.HasSuffix(ks, ":b32") && isSliceOfString(e):
eb, err := decodeSliceOfString(e, base32.StdEncoding.DecodeString)
if err != nil {
panic(err)
}
r[ks[:len(ks)-4]] = eb
default:
r[fromJSON(k)] = fromJSON(e)
}
}
return r

case map[string]interface{}:
r := make(map[string]interface{})
for k, e := range v {
for ks, e := range v {
// Special case: if key ends in ":b64", and entry
// is a string, then base64-decode the entry and
// drop the ":b64" from the key. Same for ":b32"
// and base32-decoding.
es, ok := e.(string)
es, entryIsString := e.(string)

if ok && strings.HasSuffix(k, ":b64") {
switch {
case strings.HasSuffix(ks, ":b64") && entryIsString:
eb, err := base64.StdEncoding.DecodeString(es)
if err != nil {
panic(err)
}

r[k[:len(k)-4]] = eb
} else if ok && strings.HasSuffix(k, ":b32") {
r[ks[:len(ks)-4]] = eb
case strings.HasSuffix(ks, ":b32") && entryIsString:
eb, err := base32.StdEncoding.DecodeString(es)
if err != nil {
panic(err)
}

r[k[:len(k)-4]] = eb
} else {
r[k] = fromJSON(e)
r[ks[:len(ks)-4]] = eb
case strings.HasSuffix(ks, ":b64") && isSliceOfString(e):
eb, err := decodeSliceOfString(e, base64.StdEncoding.DecodeString)
if err != nil {
panic(err)
}
r[ks[:len(ks)-4]] = eb
case strings.HasSuffix(ks, ":b32") && isSliceOfString(e):
eb, err := decodeSliceOfString(e, base32.StdEncoding.DecodeString)
if err != nil {
panic(err)
}
r[ks[:len(ks)-4]] = eb
default:
r[ks] = fromJSON(e)
}
}
return r
Expand Down
44 changes: 29 additions & 15 deletions protocol/transcode/core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package transcode

import (
"encoding/base32"
"encoding/base64"
"fmt"
"io"
"testing"
Expand Down Expand Up @@ -58,7 +59,8 @@ func testIdempotentRoundtrip(t *testing.T, mpdata []byte) {
res, err := io.ReadAll(p3out)

require.NoError(t, err)
require.Equal(t, mpdata, res)
require.Equal(t, mpdata, res,
"%v != %v", base64.StdEncoding.EncodeToString(mpdata), base64.StdEncoding.EncodeToString(res))
}

type objectType int
Expand Down Expand Up @@ -117,27 +119,39 @@ func randomObjectOfType(randtype uint64, width int, depth int) interface{} {
return base32.StdEncoding.EncodeToString(buf[:])
case objectArray:
var arr [2]interface{}
for i := 0; i < len(arr); i++ {
if crypto.RandUint64()%2 == 0 { // half the time, make the slice a uniform type
t := crypto.RandUint64()
if t%uint64(objectTypeMax) == uint64(objectBytes) {
// We cannot cleanly pass through an array of
// binary blobs.
t++
for i := range arr {
arr[i] = randomObjectOfType(t, width, depth-1)
}
} else {
for i := range arr {
t := crypto.RandUint64()
if t%uint64(objectTypeMax) == uint64(objectBytes) {
// We cannot cleanly handle binary blobs unless the entire array is.
t++
}
arr[i] = randomObjectOfType(t, width, depth-1)
}
arr[i] = randomObjectOfType(t, width, depth-1)
}
return arr
case objectSlice:
slice := make([]interface{}, 0)
sz := crypto.RandUint64() % uint64(width)
for i := 0; i < int(sz); i++ {
if crypto.RandUint64()%2 == 0 { // half the time, make the slice a uniform type
t := crypto.RandUint64()
if t%uint64(objectTypeMax) == uint64(objectBytes) {
// We cannot cleanly pass through an array of
// binary blobs.
t++
for range sz {
slice = append(slice, randomObjectOfType(t, width, depth-1))
}
} else {
for range sz {
t := crypto.RandUint64()
if t%uint64(objectTypeMax) == uint64(objectBytes) {
// We cannot cleanly handle binary blobs unless the entire slice is.
t++
}
slice = append(slice, randomObjectOfType(t, width, depth-1))
}
slice = append(slice, randomObjectOfType(t, width, depth-1))
}
return slice
case objectMap:
Expand Down Expand Up @@ -172,7 +186,7 @@ func TestIdempotence(t *testing.T) {
}

for i := 0; i < niter; i++ {
o := randomMap(6, 3)
o := randomMap(i%7, i%3)
testIdempotentRoundtrip(t, protocol.EncodeReflect(o))
}
}
Expand All @@ -189,7 +203,7 @@ func TestIdempotenceMultiobject(t *testing.T) {
nobj := crypto.RandUint64() % 8
buf := []byte{}
for j := 0; j < int(nobj); j++ {
buf = append(buf, protocol.EncodeReflect(randomMap(6, 3))...)
buf = append(buf, protocol.EncodeReflect(randomMap(i%7, i%3))...)
}
testIdempotentRoundtrip(t, buf)
}
Expand Down
Loading