Skip to content

Commit

Permalink
fix(spanner): decode PROTO to custom type variant of base type (#11007)
Browse files Browse the repository at this point in the history
  • Loading branch information
harshachinta authored Oct 18, 2024
1 parent a273aab commit 5e363a3
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 4 deletions.
8 changes: 4 additions & 4 deletions spanner/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -2718,7 +2718,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb
result = &NullString{x, !isNull}
}
case spannerTypeByteArray:
if code != sppb.TypeCode_BYTES {
if code != sppb.TypeCode_BYTES && code != sppb.TypeCode_PROTO {
return errTypeMismatch(code, acode, ptr)
}
if isNull {
Expand All @@ -2735,7 +2735,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb
}
result = y
case spannerTypeNonNullInt64, spannerTypeNullInt64:
if code != sppb.TypeCode_INT64 {
if code != sppb.TypeCode_INT64 && code != sppb.TypeCode_ENUM {
return errTypeMismatch(code, acode, ptr)
}
if isNull {
Expand Down Expand Up @@ -2913,7 +2913,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb
}
result = y
case spannerTypeArrayOfByteArray:
if acode != sppb.TypeCode_BYTES {
if acode != sppb.TypeCode_BYTES && acode != sppb.TypeCode_PROTO {
return errTypeMismatch(code, acode, ptr)
}
if isNull {
Expand All @@ -2930,7 +2930,7 @@ func (dsc decodableSpannerType) decodeValueToCustomType(v *proto3.Value, t *sppb
}
result = y
case spannerTypeArrayOfNonNullInt64, spannerTypeArrayOfNullInt64:
if acode != sppb.TypeCode_INT64 {
if acode != sppb.TypeCode_INT64 && acode != sppb.TypeCode_ENUM {
return errTypeMismatch(code, acode, ptr)
}
if isNull {
Expand Down
72 changes: 72 additions & 0 deletions spanner/value_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package spanner

import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -3289,3 +3290,74 @@ func TestNullJson(t *testing.T) {
t.Fatalf("expected null, got %s", v)
}
}

// Test decode for PROTO type when custom type is a variant of a base type
func TestDecodeProtoUsingBaseVariant(t *testing.T) {
// nullBytes is custom type from []byte base type.
type nullBytes []byte

var b []byte
var nb nullBytes

gcv := &GenericColumnValue{
Type: &sppb.Type{
Code: sppb.TypeCode_PROTO,
ProtoTypeFqn: "examples.ProtoType",
},
Value: structpb.NewStringValue("Zm9vCg=="),
}
if err := gcv.Decode(&nb); err != nil {
t.Error(err)
}
if err := gcv.Decode(&b); err != nil {
t.Error(err)
}

// Convert []byte and nullBytes to base64 encoding and then compare the contents.
if !testutil.Equal(base64.StdEncoding.EncodeToString(b), base64.StdEncoding.EncodeToString(nb)) {
t.Errorf("%s: got %+v, want %+v", "Test PROTO decode to []byte custom type", nb, b)
}
}

// Test decode for PROTO type when custom type is a variant of a base type
func TestDecodeProtoArrayUsingBaseVariant(t *testing.T) {
// nullBytes is custom type from []byte base type.
type nullBytes [][]byte

var b [][]byte
var nb nullBytes

gcv := &GenericColumnValue{
Type: &sppb.Type{
Code: sppb.TypeCode_ARRAY,
ArrayElementType: &sppb.Type{
Code: sppb.TypeCode_PROTO,
ProtoTypeFqn: "examples.ProtoType",
},
},
Value: structpb.NewListValue(
&structpb.ListValue{
Values: []*structpb.Value{
structpb.NewStringValue("Zm9vCg=="),
},
}),
}
if err := gcv.Decode(&nb); err != nil {
t.Error(err)
}
if err := gcv.Decode(&b); err != nil {
t.Error(err)
}

if len(b) != 1 {
t.Errorf("Expected length to be 1")
}

if len(nb) != 1 {
t.Errorf("Expected length to be 1")
}
// Convert to base64 encoding and then compare the contents.
if !testutil.Equal(base64.StdEncoding.EncodeToString(b[0]), base64.StdEncoding.EncodeToString(nb[0])) {
t.Errorf("%s: got %+v, want %+v", "Test PROTO decode to [][]byte custom type", nb, b)
}
}

0 comments on commit 5e363a3

Please sign in to comment.