Skip to content

Commit

Permalink
Add protoc-gen-go-restate (#18)
Browse files Browse the repository at this point in the history
* Add protoc-gen-go-restate

* Split out the selector proto

* Rename handler and service type options

* Add to readme re options
  • Loading branch information
jackkleeman authored Aug 12, 2024
1 parent adae520 commit 8eb988a
Show file tree
Hide file tree
Showing 45 changed files with 2,476 additions and 1,282 deletions.
12 changes: 8 additions & 4 deletions buf.gen.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
version: v1
version: v2
managed:
enabled: true
go_package_prefix:
default: github.com/restatedev/sdk-go/generated
override:
- file_option: go_package_prefix
value: github.com/restatedev/sdk-go/generated
plugins:
- plugin: go
- remote: buf.build/protocolbuffers/go:v1.34.2
out: generated
opt: paths=source_relative
inputs:
- module: buf.build/restatedev/service-protocol
- directory: proto
8 changes: 1 addition & 7 deletions buf.lock
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
# Generated by buf. DO NOT EDIT.
version: v1
deps:
- remote: buf.build
owner: restatedev
repository: proto
commit: 6ea2d15aed8f408590a1465844df5a8e
digest: shake256:e6599809ff13490a631f87d1a4b13ef1886d1bd1c0aa001ccb92806c0acc373d047a6ead761f8a21dfbd57a4fd9acd5915a52e47bd5b4e4a02dd1766f78511b3
version: v2
7 changes: 6 additions & 1 deletion buf.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
version: v1
version: v2
modules:
- path: proto
name: buf.build/restatedev/sdk-go
excludes:
- proto/dev/restate/sdk/go
breaking:
use:
- FILE
Expand Down
5 changes: 0 additions & 5 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ type CallClient interface {
RequestFuture(input any) (ResponseFuture, error)
// Request makes a call and blocks on getting the response which is stored in output
Request(input any, output any) error
SendClient
}

// SendClient allows for one-way invocations to a particular service/key/method tuple.
type SendClient interface {
// Send makes a one-way call which is executed in the background
Send(input any, delay time.Duration) error
}
Expand Down
79 changes: 63 additions & 16 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"reflect"

"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
)

Expand All @@ -20,6 +21,10 @@ var (
// ProtoCodec marshals proto.Message and unmarshals into proto.Message or pointers to types that implement proto.Message
// In handlers, it uses a content-type of application/proto
ProtoCodec PayloadCodec = protoCodec{}
// ProtoJSONCodec marshals proto.Message and unmarshals into proto.Message or pointers to types that implement proto.Message
// It uses the protojson package to marshal and unmarshal
// In handlers, it uses a content-type of application/json
ProtoJSONCodec PayloadCodec = protoJSONCodec{}
// JSONCodec marshals any json.Marshallable type and unmarshals into any json.Unmarshallable type
// In handlers, it uses a content-type of application/json
JSONCodec PayloadCodec = jsonCodec{}
Expand Down Expand Up @@ -188,23 +193,11 @@ func (p protoCodec) Unmarshal(data []byte, input any) (err error) {
// called with a *Message
return proto.Unmarshal(data, input)
default:
// we must support being called with a **Message where *Message is nil because this is the result of new(I) where I is a proto.Message
// and calling with new(I) is really the only generic approach.
value := reflect.ValueOf(input)
if value.Kind() != reflect.Pointer || value.IsNil() || value.Elem().Kind() != reflect.Pointer {
return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.")
}
elem := value.Elem() // hopefully a *Message
if elem.IsNil() {
// allocate a &Message and swap this in
elem.Set(reflect.New(elem.Type().Elem()))
}
switch elemI := elem.Interface().(type) {
case proto.Message:
return proto.Unmarshal(data, elemI)
default:
return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.")
msg, err := allocateProtoMessage("ProtoCodec", input)
if err != nil {
return err
}
return proto.Unmarshal(data, msg)
}
}

Expand All @@ -216,3 +209,57 @@ func (p protoCodec) Marshal(output any) (data []byte, err error) {
return nil, fmt.Errorf("ProtoCodec.Marshal called with a type that is not a proto.Message")
}
}

type protoJSONCodec struct{}

func (j protoJSONCodec) InputPayload(_ any) *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/json")}
}

func (j protoJSONCodec) OutputPayload(_ any) *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/json")}
}

func (j protoJSONCodec) Unmarshal(data []byte, input any) (err error) {
switch input := input.(type) {
case proto.Message:
// called with a *Message
return protojson.Unmarshal(data, input)
default:
msg, err := allocateProtoMessage("ProtoJSONCodec", input)
if err != nil {
return err
}
return protojson.Unmarshal(data, msg)
}
}

func (j protoJSONCodec) Marshal(output any) ([]byte, error) {
switch output := output.(type) {
case proto.Message:
return protojson.Marshal(output)
default:
return nil, fmt.Errorf("ProtoJSONCodec.Marshal called with a type that is not a proto.Message")
}
}

// we must support being called with a **Message where *Message is nil because this is the result of new(I) where I is a proto.Message
// new(I) is really the only generic approach for allocating. Hitting this code path is meaningfully slower
// for protobuf decoding, but the effect is minimal for protojson
func allocateProtoMessage(codecName string, input any) (proto.Message, error) {
value := reflect.ValueOf(input)
if value.Kind() != reflect.Pointer || value.IsNil() || value.Elem().Kind() != reflect.Pointer {
return nil, fmt.Errorf("%s.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.", codecName)
}
elem := value.Elem() // hopefully a *Message
if elem.IsNil() {
// allocate a &Message and swap this in
elem.Set(reflect.New(elem.Type().Elem()))
}
switch elemI := elem.Interface().(type) {
case proto.Message:
return elemI, nil
default:
return nil, fmt.Errorf("%s.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.", codecName)
}
}
59 changes: 55 additions & 4 deletions encoding/encoding_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package encoding

import (
"encoding/base64"
"testing"

"github.com/restatedev/sdk-go/generated/proto/protocol"
protocol "github.com/restatedev/sdk-go/generated/dev/restate/service"
)

func willPanic(t *testing.T, do func()) {
Expand Down Expand Up @@ -73,9 +74,10 @@ func TestProto(t *testing.T) {

func TestVoid(t *testing.T) {
codecs := map[string]Codec{
"json": JSONCodec,
"proto": ProtoCodec,
"binary": BinaryCodec,
"json": JSONCodec,
"proto": ProtoCodec,
"protojson": ProtoJSONCodec,
"binary": BinaryCodec,
}
for name, codec := range codecs {
t.Run(name, func(t *testing.T) {
Expand All @@ -98,3 +100,52 @@ func TestVoid(t *testing.T) {
})
}
}

func BenchmarkProto(b *testing.B) {
// protoscope -s <(echo '1: {4 5 6 7}') | base64
data, err := base64.StdEncoding.DecodeString("CgQEBQYH")
if err != nil {
b.Fatal(err)
}
benchmarkProto(b, ProtoCodec, data)
}

func BenchmarkProtoJSON(b *testing.B) {
benchmarkProto(b, ProtoJSONCodec, []byte(`{"entryIndexes": [1,2,3]}`))
}

func benchmarkProto(b *testing.B, codec Codec, data []byte) {
b.Run("non-nil proto.Message", func(b *testing.B) {
for n := 0; n < b.N; n++ {
a := new(protocol.SuspensionMessage)
if err := codec.Unmarshal(data, a); err != nil {
b.Fatal(err)
}
}
})

b.Run("non-nil pointer to non-nil proto.Message", func(b *testing.B) {
for n := 0; n < b.N; n++ {
a := new(protocol.SuspensionMessage)
if err := codec.Unmarshal(data, &a); err != nil {
b.Fatal(err)
}
}
})

b.Run("non-nil pointer to nil proto.Message", func(b *testing.B) {
for n := 0; n < b.N; n++ {
var a *protocol.SuspensionMessage
if err := codec.Unmarshal(data, &a); err != nil {
b.Fatal(err)
}
}
})
}

func BenchmarkAllocateProtoMessage(b *testing.B) {
for n := 0; n < b.N; n++ {
var a *protocol.SuspensionMessage
allocateProtoMessage("", &a)
}
}
46 changes: 0 additions & 46 deletions example/utils.go

This file was deleted.

17 changes: 17 additions & 0 deletions examples/codegen/buf.gen.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
version: v2
managed:
enabled: true
plugins:
- remote: buf.build/protocolbuffers/go:v1.34.2
out: .
opt: paths=source_relative
- local:
- docker
- run
- --pull=always
- -i
- ghcr.io/restatedev/protoc-gen-go-restate:latest
out: .
opt: paths=source_relative
inputs:
- directory: .
6 changes: 6 additions & 0 deletions examples/codegen/buf.lock
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Generated by buf. DO NOT EDIT.
version: v2
deps:
- name: buf.build/restatedev/sdk-go
commit: 9ea0b54286dd4f35b0cb96ecdf09b402
digest: b5:822b9362e943c827c36e44b0db519542259439382f94817989349d0ee590617ba70e35975840c5d96ceff278254806435e7d570db81548f9703c00b01eec398e
9 changes: 9 additions & 0 deletions examples/codegen/buf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
version: v2
lint:
use:
- DEFAULT
breaking:
use:
- FILE
deps:
- buf.build/restatedev/sdk-go
67 changes: 67 additions & 0 deletions examples/codegen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package main

import (
"context"
"errors"
"fmt"
"log/slog"
"os"

restate "github.com/restatedev/sdk-go"
helloworld "github.com/restatedev/sdk-go/examples/codegen/proto"
"github.com/restatedev/sdk-go/server"
)

type greeter struct {
helloworld.UnimplementedGreeterServer
}

func (greeter) SayHello(ctx restate.Context, req *helloworld.HelloRequest) (*helloworld.HelloResponse, error) {
counter := helloworld.NewCounterClient(ctx, req.Name)
count, err := counter.Add().
Request(&helloworld.AddRequest{Delta: 1})
if err != nil {
return nil, err
}
return &helloworld.HelloResponse{
Message: fmt.Sprintf("Hello, %s! Call number: %d", req.Name, count.Value),
}, nil
}

type counter struct {
helloworld.UnimplementedCounterServer
}

func (c counter) Add(ctx restate.ObjectContext, req *helloworld.AddRequest) (*helloworld.GetResponse, error) {
count, err := restate.GetAs[int64](ctx, "counter")
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return nil, err
}

count += 1
if err := ctx.Set("counter", count); err != nil {
return nil, err
}

return &helloworld.GetResponse{Value: count}, nil
}

func (c counter) Get(ctx restate.ObjectSharedContext, _ *helloworld.GetRequest) (*helloworld.GetResponse, error) {
count, err := restate.GetAs[int64](ctx, "counter")
if err != nil && !errors.Is(err, restate.ErrKeyNotFound) {
return nil, err
}

return &helloworld.GetResponse{Value: count}, nil
}

func main() {
server := server.NewRestate().
Bind(helloworld.NewGreeterServer(greeter{})).
Bind(helloworld.NewCounterServer(counter{}))

if err := server.Start(context.Background(), ":9080"); err != nil {
slog.Error("application exited unexpectedly", "err", err.Error())
os.Exit(1)
}
}
Loading

0 comments on commit 8eb988a

Please sign in to comment.