Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verif fixes, simplify void #12

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 86 additions & 82 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,101 @@ import (
"google.golang.org/protobuf/proto"
)

// Void is a placeholder to signify 'no value' where a type is otherwise needed
type Void struct{}

var (
// BinaryCodec marshals []byte and unmarshals into *[]byte
// In handlers, it uses a content type of application/octet-stream
BinaryCodec PayloadCodec = binaryCodec{}
VoidCodec PayloadCodec = voidCodec{}
ProtoCodec PayloadCodec = protoCodec{}
JSONCodec PayloadCodec = jsonCodec{}
_ PayloadCodec = PairCodec{}
// VoidCodec marshals anything into []byte(nil) and skips unmarshaling
// In handlers, it requires that there is no input content-type and does not set an output content-type
VoidCodec PayloadCodec = voidCodec{}
// ProtoCodec marshals proto.Message and unmarshals into proto.Message or pointers to types that implement proto.Message
// In handlers, it uses a content-type of application/proto
ProtoCodec PayloadCodec = protoCodec{}
// JSONCodec marshals any json.Marshallable type and unmarshals into any json.Unmarshallable type
// In handlers, it uses a content-type of application/json
JSONCodec PayloadCodec = jsonCodec{}
_ RestateMarshaler = Void{}
_ RestateUnmarshaler = Void{}
_ RestateUnmarshaler = &Void{}
)

type Void struct{}
func (v Void) RestateUnmarshal(codec Codec, data []byte) error {
return nil
}

func (v Void) InputPayload(codec Codec) *InputPayload {
return &InputPayload{}
}

func (v Void) RestateMarshal(codec Codec) ([]byte, error) {
return nil, nil
}

func (v Void) OutputPayload(codec Codec) *OutputPayload {
return &OutputPayload{}
}

type RestateUnmarshaler interface {
RestateUnmarshal(codec Codec, data []byte) error
InputPayload(codec Codec) *InputPayload
}

func InputPayloadFor(codec PayloadCodec, i any) *InputPayload {
ru, ok := i.(RestateUnmarshaler)
if ok {
return ru.InputPayload(codec)
}
return codec.InputPayload()
}

func OutputPayloadFor(codec PayloadCodec, o any) *OutputPayload {
ru, ok := o.(RestateMarshaler)
if ok {
return ru.OutputPayload(codec)
}
return codec.OutputPayload()
}

func RestateMarshalerFor[O any]() (RestateMarshaler, bool) {
var o O
ru, ok := any(o).(RestateMarshaler)
return ru, ok
}

// RestateMarshaler can be implemented by types that want to control their own marshaling
type RestateMarshaler interface {
RestateMarshal(codec Codec) ([]byte, error)
OutputPayload(codec Codec) *OutputPayload
}

type Codec interface {
Marshal(v any) ([]byte, error)
Unmarshal(data []byte, v any) error
}

func Marshal(codec Codec, v any) ([]byte, error) {
if marshaler, ok := v.(RestateMarshaler); ok {
return marshaler.RestateMarshal(codec)
}
return codec.Marshal(v)
}

func Unmarshal(codec Codec, data []byte, v any) error {
if marshaler, ok := v.(RestateUnmarshaler); ok {
return marshaler.RestateUnmarshal(codec, data)
}
return codec.Unmarshal(data, v)
}

type PayloadCodec interface {
Codec
InputPayload() *InputPayload
OutputPayload() *OutputPayload
Codec
}

type InputPayload struct {
Required bool `json:"required"`
ContentType *string `json:"contentType,omitempty"`
Expand Down Expand Up @@ -58,87 +133,16 @@ func (j voidCodec) Marshal(output any) ([]byte, error) {
return nil, nil
}

type PairCodec struct {
Input PayloadCodec
Output PayloadCodec
}

func (w PairCodec) InputPayload() *InputPayload {
return w.Input.InputPayload()
}

func (w PairCodec) OutputPayload() *OutputPayload {
return w.Output.OutputPayload()
}

func (w PairCodec) Unmarshal(data []byte, v any) error {
return w.Input.Unmarshal(data, v)
}

func (w PairCodec) Marshal(v any) ([]byte, error) {
return w.Output.Marshal(v)
}

func MergeCodec(base, overlay PayloadCodec) PayloadCodec {
switch {
case base == nil && overlay == nil:
return nil
case base == nil:
return overlay
case overlay == nil:
return base
}

basePair, baseOk := base.(PairCodec)
overlayPair, overlayOk := overlay.(PairCodec)

switch {
case baseOk && overlayOk:
return PairCodec{
Input: MergeCodec(basePair.Input, overlayPair.Input),
Output: MergeCodec(basePair.Output, overlayPair.Output),
}
case baseOk:
return PairCodec{
Input: MergeCodec(basePair.Input, overlay),
Output: MergeCodec(basePair.Output, overlay),
}
case overlayOk:
return PairCodec{
Input: MergeCodec(base, overlayPair.Input),
Output: MergeCodec(base, overlayPair.Output),
}
default:
// just two non-pairs; keep base
return base
}
}

func PartialVoidCodec[I any, O any]() PayloadCodec {
var input I
var output O
_, inputVoid := any(input).(Void)
_, outputVoid := any(output).(Void)
switch {
case inputVoid && outputVoid:
return VoidCodec
case inputVoid:
return PairCodec{Input: VoidCodec, Output: nil}
case outputVoid:
return PairCodec{Input: nil, Output: VoidCodec}
default:
return nil
}
}

type binaryCodec struct{}

func (j binaryCodec) InputPayload() *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")}
// Required false because 0 bytes is a valid input
return &InputPayload{Required: false, ContentType: proto.String("application/octet-stream")}
}

func (j binaryCodec) OutputPayload() *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/octet-stream")}
// SetContentTypeIfEmpty true because 0 bytes is a valid output
return &OutputPayload{ContentType: proto.String("application/octet-stream"), SetContentTypeIfEmpty: true}
}

func (j binaryCodec) Unmarshal(data []byte, input any) (err error) {
Expand Down Expand Up @@ -171,7 +175,7 @@ func (j jsonCodec) OutputPayload() *OutputPayload {
}

func (j jsonCodec) Unmarshal(data []byte, input any) (err error) {
return json.Unmarshal(data, &input)
return json.Unmarshal(data, input)
}

func (j jsonCodec) Marshal(output any) ([]byte, error) {
Expand Down
39 changes: 33 additions & 6 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,67 @@ func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) {
func TestProto(t *testing.T) {
p := ProtoCodec

_, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"})
_, err := Marshal(p, protocol.AwakeableEntryMessage{Name: "foobar"})
if err == nil {
t.Fatalf("expected error when marshaling non-pointer proto Message")
}

bytes, err := p.Marshal(&protocol.AwakeableEntryMessage{Name: "foobar"})
bytes, err := Marshal(p, &protocol.AwakeableEntryMessage{Name: "foobar"})
if err != nil {
t.Fatal(err)
}

{
msg := &protocol.AwakeableEntryMessage{}
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, msg)
}

{
inner := &protocol.AwakeableEntryMessage{}
msg := &inner
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, *msg)
}

{
msg := new(*protocol.AwakeableEntryMessage)
willSucceed(t, p.Unmarshal(bytes, msg))
willSucceed(t, Unmarshal(p, bytes, msg))
checkMessage(t, *msg)
}

{
var msg *protocol.AwakeableEntryMessage
willPanic(t, func() {
p.Unmarshal(bytes, msg)
Unmarshal(p, bytes, msg)
})
}
}

func TestVoid(t *testing.T) {
codecs := map[string]Codec{
"json": JSONCodec,
"proto": ProtoCodec,
"binary": BinaryCodec,
}
for name, codec := range codecs {
t.Run(name, func(t *testing.T) {
bytes, err := Marshal(codec, Void{})
if err != nil {
t.Fatal(err)
}

if bytes != nil {
t.Fatalf("expected bytes to be nil, found %v", bytes)
}

if err := Unmarshal(codec, []byte{1, 2, 3}, &Void{}); err != nil {
t.Fatal(err)
}

if err := Unmarshal(codec, []byte{1, 2, 3}, Void{}); err != nil {
t.Fatal(err)
}
})
}
}
3 changes: 2 additions & 1 deletion example/utils.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"errors"
"fmt"
"math/big"

Expand All @@ -24,7 +25,7 @@ var bigCounter = restate.
}

bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary)
if err != nil && err != restate.ErrKeyNotFound {
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return "", err
}
newCount := big.NewInt(0).Add(big.NewInt(0).SetBytes(bytes), delta)
Expand Down
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module github.com/restatedev/sdk-go

go 1.21.0

toolchain go1.21.12

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/google/uuid v1.6.0
Expand Down
Loading
Loading