diff --git a/encoding/proto/proto.go b/encoding/proto/proto.go index 7cc1f5211b98..6fff15a715e8 100644 --- a/encoding/proto/proto.go +++ b/encoding/proto/proto.go @@ -39,13 +39,8 @@ func init() { type codec struct{} func (codec) Marshal(v any) ([]byte, error) { - var vv proto.Message - switch v := v.(type) { - case protoadapt.MessageV1: - vv = protoadapt.MessageV2Of(v) - case protoadapt.MessageV2: - vv = v - default: + vv := messageV2Of(v) + if vv == nil { return nil, fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) } @@ -53,17 +48,23 @@ func (codec) Marshal(v any) ([]byte, error) { } func (codec) Unmarshal(data []byte, v any) error { - var vv proto.Message + vv := messageV2Of(v) + if vv == nil { + return fmt.Errorf("failed to marshal, message is %T, want proto.Message", v) + } + + return proto.Unmarshal(data, vv) +} + +func messageV2Of(v any) proto.Message { switch v := v.(type) { case protoadapt.MessageV1: - vv = protoadapt.MessageV2Of(v) + return protoadapt.MessageV2Of(v) case protoadapt.MessageV2: - vv = v - default: - return fmt.Errorf("failed to unmarshal, message is %T, want proto.Message", v) + return v } - return proto.Unmarshal(data, vv) + return nil } func (codec) Name() string {