diff --git a/types.go b/types.go index 78817b7..4665f23 100644 --- a/types.go +++ b/types.go @@ -104,21 +104,27 @@ func Register(v interface{}, args ...string) { } // TypeURL returns the type url for a registered type. -func TypeURL(v interface{}) (string, error) { +func TypeURL(v interface{}) (u string, err error) { + u, _, err = typeURL(v) + return +} + +// typeURL returns the type url and whether it was registered for json +func typeURL(v interface{}) (string, bool, error) { mu.RLock() u, ok := registry[tryDereference(v)] mu.RUnlock() if !ok { switch t := v.(type) { case proto.Message: - return string(t.ProtoReflect().Descriptor().FullName()), nil + return string(t.ProtoReflect().Descriptor().FullName()), ok, nil case gogoproto.Message: - return gogoproto.MessageName(t), nil + return gogoproto.MessageName(t), ok, nil default: - return "", fmt.Errorf("type %s: %w", reflect.TypeOf(v), ErrNotFound) + return "", ok, fmt.Errorf("type %s: %w", reflect.TypeOf(v), ErrNotFound) } } - return u, nil + return u, ok, nil } // Is returns true if the type of the Any is the same as v. @@ -140,36 +146,35 @@ func Is(any Any, v interface{}) bool { // returned verbatim. If it is of type proto.Message, it will be marshaled as a // protocol buffer. Otherwise, the object will be marshaled to json. func MarshalAny(v interface{}) (Any, error) { - var marshal func(v interface{}) ([]byte, error) - switch t := v.(type) { - case Any: - // avoid reserializing the type if we have an any. + if t, ok := v.(Any); ok { return t, nil - case proto.Message: - marshal = func(v interface{}) ([]byte, error) { - return proto.Marshal(t) - } - case gogoproto.Message: - marshal = func(v interface{}) ([]byte, error) { - return gogoproto.Marshal(t) - } - default: - marshal = json.Marshal } - url, err := TypeURL(v) + url, isJSON, err := typeURL(v) if err != nil { return nil, err } - - data, err := marshal(v) + a := anyType{ + typeURL: url, + } + if !isJSON { + switch t := v.(type) { + case proto.Message: + a.value, err = proto.Marshal(t) + case gogoproto.Message: + a.value, err = gogoproto.Marshal(t) + default: + isJSON = true + } + } + if isJSON { + a.value, err = json.Marshal(v) + } if err != nil { return nil, err } - return &anyType{ - typeURL: url, - value: data, - }, nil + + return &a, nil } // UnmarshalAny unmarshals the any type into a concrete type. diff --git a/types_test.go b/types_test.go index f6c001a..17a5d49 100644 --- a/types_test.go +++ b/types_test.go @@ -232,3 +232,29 @@ func TestProtoFallback(t *testing.T) { t.Fatalf("expected %+v but got %+v", expected, ts.AsTime()) } } + +func TestManualProtoRegistration(t *testing.T) { + ts := ×tamppb.Timestamp{} + defer func() { + mu.Lock() + delete(registry, tryDereference(ts)) + mu.Unlock() + }() + Register(ts, t.Name()) + + expected := time.Now() + a, err := MarshalAny(timestamppb.New(expected)) + if err != nil { + t.Fatal(err) + } + + i, err := UnmarshalAny(a) + if err != nil { + t.Fatal(err) + } + + actual := i.(*timestamppb.Timestamp).AsTime() + if !actual.Equal(expected) { + t.Fatalf("unexpected time %s, expected %s", actual, expected) + } +}