Skip to content

Commit

Permalink
Refactor serialisation into the main context methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman committed Jul 16, 2024
1 parent d3251bc commit 0da604b
Show file tree
Hide file tree
Showing 17 changed files with 454 additions and 381 deletions.
83 changes: 45 additions & 38 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/restatedev/sdk-go/internal/futures"
"github.com/restatedev/sdk-go/internal/options"
"github.com/restatedev/sdk-go/internal/rand"
)

Expand All @@ -24,33 +25,29 @@ type Context interface {
// the sleep and other Selectable operations.
After(d time.Duration) After

// Service gets a Service accessor by name where service
// must be another service known by restate runtime
// Note: use the CallAs and SendAs helper functions to send and receive serialised values
Service(service, method string) CallClient[[]byte, []byte]
// Service gets a Service accessor by service and method name
// Note: use the CallAs helper function to deserialise return values
Service(service, method string, opts ...options.CallOption) CallClient

// Object gets a Object accessor by name where object
// must be another object known by restate runtime and
// key is any string representing the key for the object
// Note: use the CallAs and SendAs helper functions to send and receive serialised values
Object(object, key, method string) CallClient[[]byte, []byte]
// Object gets a Object accessor by name, key and method name
// Note: use the CallAs helper function to receive serialised values
Object(object, key, method string, opts ...options.CallOption) CallClient

// Run runs the function (fn), storing final results (including terminal errors)
// durably in the journal, or otherwise for transient errors stopping execution
// so Restate can retry the invocation. Replays will produce the same value, so
// all non-deterministic operations (eg, generating a unique ID) *must* happen
// inside Run blocks.
// Note: use the RunAs helper function to serialise non-[]byte return values
Run(fn func(RunContext) ([]byte, error)) ([]byte, error)
// Note: use the RunAs helper function to get typed output values instead of providing an output pointer
Run(fn func(RunContext) (any, error), output any, opts ...options.RunOption) error

// Awakeable returns a Restate awakeable; a 'promise' to a future
// value or error, that can be resolved or rejected by other services.
// Note: use the AwakeableAs helper function to deserialise the []byte value
Awakeable() Awakeable[[]byte]
// Note: use the AwakeableAs helper function to avoid having to pass a output pointer to Awakeable.Result()
Awakeable(options ...options.AwakeableOption) Awakeable
// ResolveAwakeable allows an awakeable (not necessarily from this service) to be
// resolved with a particular value.
// Note: use the ResolveAwakeableAs helper function to provide a value to be serialised
ResolveAwakeable(id string, value []byte)
ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) error
// ResolveAwakeable allows an awakeable (not necessarily from this service) to be
// rejected with a particular error.
RejectAwakeable(id string, reason error)
Expand All @@ -63,24 +60,38 @@ type Context interface {
Select(futs ...futures.Selectable) Selector
}

type CallClient[I any, O any] interface {
// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future
// value or error, that can be resolved or rejected by other services.
type Awakeable interface {
// Id returns the awakeable ID, which can be stored or sent to a another service
Id() string
// Result blocks on receiving the result of the awakeable, storing the value it was
// resolved with in output or otherwise returning the error it was rejected with.
// It is *not* safe to call this in a goroutine - use Context.Select if you
// want to wait on multiple results at once.
// Note: use the AwakeableAs helper function to avoid having to pass a output pointer
Result(output any) error
futures.Selectable
}

type CallClient interface {
// RequestFuture makes a call and returns a handle on a future response
RequestFuture(input I) (ResponseFuture[O], error)
// Request makes a call and blocks on getting the response
Request(input I) (O, error)
SendClient[I]
RequestFuture(input any) (ResponseFuture, error)
// Request makes a call and blocks on getting the response which is stored in output
Request(input any, output any) error
SendClient
}

type SendClient[I any] interface {
type SendClient interface {
// Send makes a one-way call which is executed in the background
Send(input I, delay time.Duration) error
Send(input any, delay time.Duration) error
}

type ResponseFuture[O any] interface {
// Response blocks on the response to the call
type ResponseFuture interface {
// Response blocks on the response to the call and stores it in output, or returns the associated error
// It is *not* safe to call this in a goroutine - use Context.Select if you
// want to wait on multiple results at once.
Response() (O, error)
Response(output any) error
futures.Selectable
}

Expand Down Expand Up @@ -119,32 +130,28 @@ type ObjectContext interface {
Context
KeyValueReader
KeyValueWriter
// Key retrieves the key for this virtual object invocation. This is a no-op and is
// always safe to call.
Key() string
}

type ObjectSharedContext interface {
Context
KeyValueReader
// Key retrieves the key for this virtual object invocation. This is a no-op and is
// always safe to call.
Key() string
}

type KeyValueReader interface {
// Get gets value (bytes array) associated with key
// If key does not exist, this function return a nil bytes array
// Note: Use GetAs helper function to read serialised values
Get(key string) []byte
// Get gets value associated with key and stores it in value
// If key does not exist, this function returns ErrKeyNotFound
// Note: Use GetAs generic helper function to avoid passing in a value pointer
Get(key string, value any, options ...options.GetOption) error
// Keys returns a list of all associated key
Keys() []string
// Key retrieves the key for this virtual object invocation. This is a no-op and is
// always safe to call.
Key() string
}

type KeyValueWriter interface {
// Set sets a byte array against a key
// Note: Use SetAs helper function to store serialised values
Set(key string, value []byte)
// Set sets a value against a key, using the provided codec (defaults to JSON)
Set(key string, value any, options ...options.SetOption) error
// Clear deletes a key
Clear(key string)
// ClearAll drops all stored state associated with key
Expand Down
64 changes: 31 additions & 33 deletions encoding/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ import (
"google.golang.org/protobuf/proto"
)

var (
BinaryCodec PayloadCodec = binaryCodec{}
VoidCodec PayloadCodec = voidCodec{}
ProtoCodec PayloadCodec = protoCodec{}
JSONCodec PayloadCodec = jsonCodec{}
_ PayloadCodec = PairCodec{}
)

type Void struct{}

type Codec interface {
Expand All @@ -32,23 +40,21 @@ type OutputPayload struct {
JsonSchema interface{} `json:"jsonSchema,omitempty"`
}

type VoidCodec struct{}
type voidCodec struct{}

var _ PayloadCodec = VoidCodec{}

func (j VoidCodec) InputPayload() *InputPayload {
func (j voidCodec) InputPayload() *InputPayload {
return &InputPayload{}
}

func (j VoidCodec) OutputPayload() *OutputPayload {
func (j voidCodec) OutputPayload() *OutputPayload {
return &OutputPayload{}
}

func (j VoidCodec) Unmarshal(data []byte, input any) (err error) {
func (j voidCodec) Unmarshal(data []byte, input any) (err error) {
return nil
}

func (j VoidCodec) Marshal(output any) ([]byte, error) {
func (j voidCodec) Marshal(output any) ([]byte, error) {
return nil, nil
}

Expand All @@ -57,8 +63,6 @@ type PairCodec struct {
Output PayloadCodec
}

var _ PayloadCodec = PairCodec{}

func (w PairCodec) InputPayload() *InputPayload {
return w.Input.InputPayload()
}
Expand Down Expand Up @@ -117,29 +121,27 @@ func PartialVoidCodec[I any, O any]() PayloadCodec {
_, outputVoid := any(output).(Void)
switch {
case inputVoid && outputVoid:
return VoidCodec{}
return VoidCodec
case inputVoid:
return PairCodec{Input: VoidCodec{}, Output: nil}
return PairCodec{Input: VoidCodec, Output: nil}
case outputVoid:
return PairCodec{Input: nil, Output: VoidCodec{}}
return PairCodec{Input: nil, Output: VoidCodec}
default:
return nil
}
}

type BinaryCodec struct{}
type binaryCodec struct{}

var _ PayloadCodec = BinaryCodec{}

func (j BinaryCodec) InputPayload() *InputPayload {
func (j binaryCodec) InputPayload() *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")}
}

func (j BinaryCodec) OutputPayload() *OutputPayload {
func (j binaryCodec) OutputPayload() *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/octet-stream")}
}

func (j BinaryCodec) Unmarshal(data []byte, input any) (err error) {
func (j binaryCodec) Unmarshal(data []byte, input any) (err error) {
switch input := input.(type) {
case *[]byte:
*input = data
Expand All @@ -149,7 +151,7 @@ func (j BinaryCodec) Unmarshal(data []byte, input any) (err error) {
}
}

func (j BinaryCodec) Marshal(output any) ([]byte, error) {
func (j binaryCodec) Marshal(output any) ([]byte, error) {
switch output := output.(type) {
case []byte:
return output, nil
Expand All @@ -158,39 +160,35 @@ func (j BinaryCodec) Marshal(output any) ([]byte, error) {
}
}

type JSONCodec struct{}

var _ PayloadCodec = JSONCodec{}
type jsonCodec struct{}

func (j JSONCodec) InputPayload() *InputPayload {
func (j jsonCodec) InputPayload() *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/json")}
}

func (j JSONCodec) OutputPayload() *OutputPayload {
func (j jsonCodec) OutputPayload() *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/json")}
}

func (j JSONCodec) Unmarshal(data []byte, input any) (err error) {
func (j jsonCodec) Unmarshal(data []byte, input any) (err error) {
return json.Unmarshal(data, &input)
}

func (j JSONCodec) Marshal(output any) ([]byte, error) {
func (j jsonCodec) Marshal(output any) ([]byte, error) {
return json.Marshal(output)
}

type ProtoCodec struct{}

var _ PayloadCodec = ProtoCodec{}
type protoCodec struct{}

func (p ProtoCodec) InputPayload() *InputPayload {
func (p protoCodec) InputPayload() *InputPayload {
return &InputPayload{Required: true, ContentType: proto.String("application/proto")}
}

func (p ProtoCodec) OutputPayload() *OutputPayload {
func (p protoCodec) OutputPayload() *OutputPayload {
return &OutputPayload{ContentType: proto.String("application/proto")}
}

func (p ProtoCodec) Unmarshal(data []byte, input any) (err error) {
func (p protoCodec) Unmarshal(data []byte, input any) (err error) {
switch input := input.(type) {
case proto.Message:
// called with a *Message
Expand All @@ -216,7 +214,7 @@ func (p ProtoCodec) Unmarshal(data []byte, input any) (err error) {
}
}

func (p ProtoCodec) Marshal(output any) (data []byte, err error) {
func (p protoCodec) Marshal(output any) (data []byte, err error) {
switch output := output.(type) {
case proto.Message:
return proto.Marshal(output)
Expand Down
2 changes: 1 addition & 1 deletion encoding/encoding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) {
}

func TestProto(t *testing.T) {
p := ProtoCodec{}
p := ProtoCodec

_, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"})
if err == nil {
Expand Down
4 changes: 4 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ import (
"github.com/restatedev/sdk-go/internal/errors"
)

var (
ErrKeyNotFound = errors.ErrKeyNotFound
)

// WithErrorCode returns an error with specific
func WithErrorCode(err error, code errors.Code) error {
if err == nil {
Expand Down
24 changes: 3 additions & 21 deletions example/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,9 @@ func main() {
Bind(restate.Object(&userSession{})).
Bind(restate.Object(&ticketService{})).
Bind(restate.Service(&checkout{})).
// Or registered explicitly
Bind(restate.NewServiceRouter("health").Handler("ping", restate.NewServiceHandler(
func(restate.Context, struct{}) (restate.Void, error) {
return restate.Void{}, nil
}))).
Bind(restate.NewObjectRouter("counter").Handler("add", restate.NewObjectHandler(
func(ctx restate.ObjectContext, delta int) (int, error) {
count, err := restate.GetAs[int](ctx, "counter")
if err != nil && err != restate.ErrKeyNotFound {
return 0, err
}
count += delta
if err := restate.SetAs(ctx, "counter", count); err != nil {
return 0, err
}

return count, nil
})).Handler("get", restate.NewObjectSharedHandler(
func(ctx restate.ObjectSharedContext, input restate.Void) (int, error) {
return restate.GetAs[int](ctx, "counter")
})))
// Or created and registered explicitly
Bind(health).
Bind(bigCounter)

if err := server.Start(context.Background(), ":9080"); err != nil {
slog.Error("application exited unexpectedly", "err", err.Error())
Expand Down
4 changes: 2 additions & 2 deletions example/ticket_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool
}

if status == TicketAvailable {
return true, restate.SetAs(ctx, "status", TicketReserved)
return true, ctx.Set("status", TicketReserved)
}

return false, nil
Expand Down Expand Up @@ -59,7 +59,7 @@ func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (v
}

if status == TicketReserved {
return void, restate.SetAs(ctx, "status", TicketSold)
return void, ctx.Set("status", TicketSold)
}

return void, nil
Expand Down
6 changes: 3 additions & 3 deletions example/user_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (boo

tickets = append(tickets, ticketId)

if err := restate.SetAs(ctx, "tickets", tickets); err != nil {
if err := ctx.Set("tickets", tickets); err != nil {
return false, err
}

if err := restate.SendAs(ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket")).Send(ticketId, 15*time.Minute); err != nil {
if err := ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket").Send(ticketId, 15*time.Minute); err != nil {
return false, err
}

Expand All @@ -66,7 +66,7 @@ func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) (
return void, nil
}

if err := restate.SetAs(ctx, "tickets", tickets); err != nil {
if err := ctx.Set("tickets", tickets); err != nil {
return void, err
}

Expand Down
Loading

0 comments on commit 0da604b

Please sign in to comment.