diff --git a/context.go b/context.go index 4027314..3bf3d65 100644 --- a/context.go +++ b/context.go @@ -6,6 +6,7 @@ import ( "time" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/rand" ) @@ -24,33 +25,29 @@ type Context interface { // the sleep and other Selectable operations. After(d time.Duration) After - // Service gets a Service accessor by name where service - // must be another service known by restate runtime - // Note: use the CallAs and SendAs helper functions to send and receive serialised values - Service(service, method string) CallClient[[]byte, []byte] + // Service gets a Service accessor by service and method name + // Note: use the CallAs helper function to deserialise return values + Service(service, method string, opts ...options.CallOption) CallClient - // Object gets a Object accessor by name where object - // must be another object known by restate runtime and - // key is any string representing the key for the object - // Note: use the CallAs and SendAs helper functions to send and receive serialised values - Object(object, key, method string) CallClient[[]byte, []byte] + // Object gets a Object accessor by name, key and method name + // Note: use the CallAs helper function to receive serialised values + Object(object, key, method string, opts ...options.CallOption) CallClient // Run runs the function (fn), storing final results (including terminal errors) // durably in the journal, or otherwise for transient errors stopping execution // so Restate can retry the invocation. Replays will produce the same value, so // all non-deterministic operations (eg, generating a unique ID) *must* happen // inside Run blocks. - // Note: use the RunAs helper function to serialise non-[]byte return values - Run(fn func(RunContext) ([]byte, error)) ([]byte, error) + // Note: use the RunAs helper function to get typed output values instead of providing an output pointer + Run(fn func(RunContext) (any, error), output any, opts ...options.RunOption) error // Awakeable returns a Restate awakeable; a 'promise' to a future // value or error, that can be resolved or rejected by other services. - // Note: use the AwakeableAs helper function to deserialise the []byte value - Awakeable() Awakeable[[]byte] + // Note: use the AwakeableAs helper function to avoid having to pass a output pointer to Awakeable.Result() + Awakeable(options ...options.AwakeableOption) Awakeable // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // resolved with a particular value. - // Note: use the ResolveAwakeableAs helper function to provide a value to be serialised - ResolveAwakeable(id string, value []byte) + ResolveAwakeable(id string, value any, options ...options.ResolveAwakeableOption) error // ResolveAwakeable allows an awakeable (not necessarily from this service) to be // rejected with a particular error. RejectAwakeable(id string, reason error) @@ -63,24 +60,38 @@ type Context interface { Select(futs ...futures.Selectable) Selector } -type CallClient[I any, O any] interface { +// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future +// value or error, that can be resolved or rejected by other services. +type Awakeable interface { + // Id returns the awakeable ID, which can be stored or sent to a another service + Id() string + // Result blocks on receiving the result of the awakeable, storing the value it was + // resolved with in output or otherwise returning the error it was rejected with. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + // Note: use the AwakeableAs helper function to avoid having to pass a output pointer + Result(output any) error + futures.Selectable +} + +type CallClient interface { // RequestFuture makes a call and returns a handle on a future response - RequestFuture(input I) (ResponseFuture[O], error) - // Request makes a call and blocks on getting the response - Request(input I) (O, error) - SendClient[I] + RequestFuture(input any) (ResponseFuture, error) + // Request makes a call and blocks on getting the response which is stored in output + Request(input any, output any) error + SendClient } -type SendClient[I any] interface { +type SendClient interface { // Send makes a one-way call which is executed in the background - Send(input I, delay time.Duration) error + Send(input any, delay time.Duration) error } -type ResponseFuture[O any] interface { - // Response blocks on the response to the call +type ResponseFuture interface { + // Response blocks on the response to the call and stores it in output, or returns the associated error // It is *not* safe to call this in a goroutine - use Context.Select if you // want to wait on multiple results at once. - Response() (O, error) + Response(output any) error futures.Selectable } @@ -119,32 +130,28 @@ type ObjectContext interface { Context KeyValueReader KeyValueWriter - // Key retrieves the key for this virtual object invocation. This is a no-op and is - // always safe to call. - Key() string } type ObjectSharedContext interface { Context KeyValueReader - // Key retrieves the key for this virtual object invocation. This is a no-op and is - // always safe to call. - Key() string } type KeyValueReader interface { - // Get gets value (bytes array) associated with key - // If key does not exist, this function return a nil bytes array - // Note: Use GetAs helper function to read serialised values - Get(key string) []byte + // Get gets value associated with key and stores it in value + // If key does not exist, this function returns ErrKeyNotFound + // Note: Use GetAs generic helper function to avoid passing in a value pointer + Get(key string, value any, options ...options.GetOption) error // Keys returns a list of all associated key Keys() []string + // Key retrieves the key for this virtual object invocation. This is a no-op and is + // always safe to call. + Key() string } type KeyValueWriter interface { - // Set sets a byte array against a key - // Note: Use SetAs helper function to store serialised values - Set(key string, value []byte) + // Set sets a value against a key, using the provided codec (defaults to JSON) + Set(key string, value any, options ...options.SetOption) error // Clear deletes a key Clear(key string) // ClearAll drops all stored state associated with key diff --git a/encoding/encoding.go b/encoding/encoding.go index bbe3549..12286f1 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -8,6 +8,14 @@ import ( "google.golang.org/protobuf/proto" ) +var ( + BinaryCodec PayloadCodec = binaryCodec{} + VoidCodec PayloadCodec = voidCodec{} + ProtoCodec PayloadCodec = protoCodec{} + JSONCodec PayloadCodec = jsonCodec{} + _ PayloadCodec = PairCodec{} +) + type Void struct{} type Codec interface { @@ -32,23 +40,21 @@ type OutputPayload struct { JsonSchema interface{} `json:"jsonSchema,omitempty"` } -type VoidCodec struct{} +type voidCodec struct{} -var _ PayloadCodec = VoidCodec{} - -func (j VoidCodec) InputPayload() *InputPayload { +func (j voidCodec) InputPayload() *InputPayload { return &InputPayload{} } -func (j VoidCodec) OutputPayload() *OutputPayload { +func (j voidCodec) OutputPayload() *OutputPayload { return &OutputPayload{} } -func (j VoidCodec) Unmarshal(data []byte, input any) (err error) { +func (j voidCodec) Unmarshal(data []byte, input any) (err error) { return nil } -func (j VoidCodec) Marshal(output any) ([]byte, error) { +func (j voidCodec) Marshal(output any) ([]byte, error) { return nil, nil } @@ -57,8 +63,6 @@ type PairCodec struct { Output PayloadCodec } -var _ PayloadCodec = PairCodec{} - func (w PairCodec) InputPayload() *InputPayload { return w.Input.InputPayload() } @@ -117,29 +121,27 @@ func PartialVoidCodec[I any, O any]() PayloadCodec { _, outputVoid := any(output).(Void) switch { case inputVoid && outputVoid: - return VoidCodec{} + return VoidCodec case inputVoid: - return PairCodec{Input: VoidCodec{}, Output: nil} + return PairCodec{Input: VoidCodec, Output: nil} case outputVoid: - return PairCodec{Input: nil, Output: VoidCodec{}} + return PairCodec{Input: nil, Output: VoidCodec} default: return nil } } -type BinaryCodec struct{} +type binaryCodec struct{} -var _ PayloadCodec = BinaryCodec{} - -func (j BinaryCodec) InputPayload() *InputPayload { +func (j binaryCodec) InputPayload() *InputPayload { return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")} } -func (j BinaryCodec) OutputPayload() *OutputPayload { +func (j binaryCodec) OutputPayload() *OutputPayload { return &OutputPayload{ContentType: proto.String("application/octet-stream")} } -func (j BinaryCodec) Unmarshal(data []byte, input any) (err error) { +func (j binaryCodec) Unmarshal(data []byte, input any) (err error) { switch input := input.(type) { case *[]byte: *input = data @@ -149,7 +151,7 @@ func (j BinaryCodec) Unmarshal(data []byte, input any) (err error) { } } -func (j BinaryCodec) Marshal(output any) ([]byte, error) { +func (j binaryCodec) Marshal(output any) ([]byte, error) { switch output := output.(type) { case []byte: return output, nil @@ -158,39 +160,35 @@ func (j BinaryCodec) Marshal(output any) ([]byte, error) { } } -type JSONCodec struct{} - -var _ PayloadCodec = JSONCodec{} +type jsonCodec struct{} -func (j JSONCodec) InputPayload() *InputPayload { +func (j jsonCodec) InputPayload() *InputPayload { return &InputPayload{Required: true, ContentType: proto.String("application/json")} } -func (j JSONCodec) OutputPayload() *OutputPayload { +func (j jsonCodec) OutputPayload() *OutputPayload { return &OutputPayload{ContentType: proto.String("application/json")} } -func (j JSONCodec) Unmarshal(data []byte, input any) (err error) { +func (j jsonCodec) Unmarshal(data []byte, input any) (err error) { return json.Unmarshal(data, &input) } -func (j JSONCodec) Marshal(output any) ([]byte, error) { +func (j jsonCodec) Marshal(output any) ([]byte, error) { return json.Marshal(output) } -type ProtoCodec struct{} - -var _ PayloadCodec = ProtoCodec{} +type protoCodec struct{} -func (p ProtoCodec) InputPayload() *InputPayload { +func (p protoCodec) InputPayload() *InputPayload { return &InputPayload{Required: true, ContentType: proto.String("application/proto")} } -func (p ProtoCodec) OutputPayload() *OutputPayload { +func (p protoCodec) OutputPayload() *OutputPayload { return &OutputPayload{ContentType: proto.String("application/proto")} } -func (p ProtoCodec) Unmarshal(data []byte, input any) (err error) { +func (p protoCodec) Unmarshal(data []byte, input any) (err error) { switch input := input.(type) { case proto.Message: // called with a *Message @@ -216,7 +214,7 @@ func (p ProtoCodec) Unmarshal(data []byte, input any) (err error) { } } -func (p ProtoCodec) Marshal(output any) (data []byte, err error) { +func (p protoCodec) Marshal(output any) (data []byte, err error) { switch output := output.(type) { case proto.Message: return proto.Marshal(output) diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go index 37026d1..d3bbbae 100644 --- a/encoding/encoding_test.go +++ b/encoding/encoding_test.go @@ -32,7 +32,7 @@ func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) { } func TestProto(t *testing.T) { - p := ProtoCodec{} + p := ProtoCodec _, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"}) if err == nil { diff --git a/error.go b/error.go index e6b1683..0d69679 100644 --- a/error.go +++ b/error.go @@ -6,6 +6,10 @@ import ( "github.com/restatedev/sdk-go/internal/errors" ) +var ( + ErrKeyNotFound = errors.ErrKeyNotFound +) + // WithErrorCode returns an error with specific func WithErrorCode(err error, code errors.Code) error { if err == nil { diff --git a/example/main.go b/example/main.go index d8cb761..f6422f3 100644 --- a/example/main.go +++ b/example/main.go @@ -16,27 +16,9 @@ func main() { Bind(restate.Object(&userSession{})). Bind(restate.Object(&ticketService{})). Bind(restate.Service(&checkout{})). - // Or registered explicitly - Bind(restate.NewServiceRouter("health").Handler("ping", restate.NewServiceHandler( - func(restate.Context, struct{}) (restate.Void, error) { - return restate.Void{}, nil - }))). - Bind(restate.NewObjectRouter("counter").Handler("add", restate.NewObjectHandler( - func(ctx restate.ObjectContext, delta int) (int, error) { - count, err := restate.GetAs[int](ctx, "counter") - if err != nil && err != restate.ErrKeyNotFound { - return 0, err - } - count += delta - if err := restate.SetAs(ctx, "counter", count); err != nil { - return 0, err - } - - return count, nil - })).Handler("get", restate.NewObjectSharedHandler( - func(ctx restate.ObjectSharedContext, input restate.Void) (int, error) { - return restate.GetAs[int](ctx, "counter") - }))) + // Or created and registered explicitly + Bind(health). + Bind(bigCounter) if err := server.Start(context.Background(), ":9080"); err != nil { slog.Error("application exited unexpectedly", "err", err.Error()) diff --git a/example/ticket_service.go b/example/ticket_service.go index 513ddde..1544157 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -27,7 +27,7 @@ func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool } if status == TicketAvailable { - return true, restate.SetAs(ctx, "status", TicketReserved) + return true, ctx.Set("status", TicketReserved) } return false, nil @@ -59,7 +59,7 @@ func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (v } if status == TicketReserved { - return void, restate.SetAs(ctx, "status", TicketSold) + return void, ctx.Set("status", TicketSold) } return void, nil diff --git a/example/user_session.go b/example/user_session.go index c7c48d7..6919b50 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -37,11 +37,11 @@ func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (boo tickets = append(tickets, ticketId) - if err := restate.SetAs(ctx, "tickets", tickets); err != nil { + if err := ctx.Set("tickets", tickets); err != nil { return false, err } - if err := restate.SendAs(ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket")).Send(ticketId, 15*time.Minute); err != nil { + if err := ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket").Send(ticketId, 15*time.Minute); err != nil { return false, err } @@ -66,7 +66,7 @@ func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) ( return void, nil } - if err := restate.SetAs(ctx, "tickets", tickets); err != nil { + if err := ctx.Set("tickets", tickets); err != nil { return void, err } diff --git a/example/utils.go b/example/utils.go new file mode 100644 index 0000000..6d05c7d --- /dev/null +++ b/example/utils.go @@ -0,0 +1,45 @@ +package main + +import ( + "fmt" + "math/big" + + restate "github.com/restatedev/sdk-go" +) + +var health = restate. + NewServiceRouter("health"). + Handler("ping", restate.NewServiceHandler( + func(restate.Context, struct{}) (restate.Void, error) { + return restate.Void{}, nil + })) + +var bigCounter = restate. + NewObjectRouter("bigCounter"). + Handler("add", restate.NewObjectHandler( + func(ctx restate.ObjectContext, deltaText string) (string, error) { + delta, ok := big.NewInt(0).SetString(deltaText, 10) + if !ok { + return "", restate.TerminalError(fmt.Errorf("input must be a valid integer string: %s", deltaText)) + } + + bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary) + if err != nil && err != restate.ErrKeyNotFound { + return "", err + } + newCount := big.NewInt(0).Add(big.NewInt(0).SetBytes(bytes), delta) + if err := ctx.Set("counter", newCount.Bytes(), restate.WithBinary); err != nil { + return "", err + } + + return newCount.String(), nil + })). + Handler("get", restate.NewObjectSharedHandler( + func(ctx restate.ObjectSharedContext, input restate.Void) (string, error) { + bytes, err := restate.GetAs[[]byte](ctx, "counter", restate.WithBinary) + if err != nil { + return "", err + } + + return big.NewInt(0).SetBytes(bytes).String(), err + })) diff --git a/facilitators.go b/facilitators.go index 2fba819..6af240c 100644 --- a/facilitators.go +++ b/facilitators.go @@ -1,274 +1,101 @@ package restate import ( - "fmt" - "time" - - "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" ) -type getOptions struct { - codec encoding.Codec -} - -type GetOption interface { - beforeGet(*getOptions) -} - -// GetAs helper function to get a key as specific type. Note that -// if there is no associated value with key, an error ErrKeyNotFound is -// returned -// it does encoding/decoding of bytes, defaulting to json codec -func GetAs[T any](ctx ObjectSharedContext, key string, options ...GetOption) (output T, err error) { - opts := getOptions{} - for _, opt := range options { - opt.beforeGet(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - - bytes := ctx.Get(key) - if bytes == nil { - // key does not exist. - return output, ErrKeyNotFound - } - - if err := opts.codec.Unmarshal(bytes, &output); err != nil { - return output, TerminalError(fmt.Errorf("failed to unmarshal Get state into T: %w", err)) - } - - return output, nil -} - -type setOptions struct { - codec encoding.Codec -} - -type SetOption interface { - beforeSet(*setOptions) -} - -// SetAs helper function to set a key value with a generic type T. -// it does encoding/decoding of bytes automatically, defaulting to json codec -func SetAs(ctx ObjectContext, key string, value any, options ...SetOption) error { - opts := setOptions{} - for _, opt := range options { - opt.beforeSet(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - - bytes, err := opts.codec.Marshal(value) - if err != nil { - return TerminalError(fmt.Errorf("failed to marshal Set value: %w", err)) - } - - ctx.Set(key, bytes) - return nil -} - -type runOptions struct { - codec encoding.Codec -} - -type RunOption interface { - beforeRun(*runOptions) +// GetAs helper function to get a key, returning a typed response instead of accepting a pointer. +// If there is no associated value with key, an error ErrKeyNotFound is returned +func GetAs[T any](ctx ObjectSharedContext, key string, options ...options.GetOption) (output T, err error) { + err = ctx.Get(key, &output, options...) + return } -// RunAs helper function runs a run function with specific concrete type as a result -// it does encoding/decoding of bytes automatically, defaulting to json codec -func RunAs[T any](ctx Context, fn func(RunContext) (T, error), options ...RunOption) (output T, err error) { - opts := runOptions{} - for _, opt := range options { - opt.beforeRun(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - - bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { - out, err := fn(ctx) - if err != nil { - return nil, err - } +// RunAs helper function runs a Run function, returning a typed response instead of accepting a pointer +func RunAs[T any](ctx Context, fn func(RunContext) (T, error), options ...options.RunOption) (output T, err error) { + err = ctx.Run(func(ctx RunContext) (any, error) { + return fn(ctx) + }, &output, options...) - bytes, err := opts.codec.Marshal(out) - if err != nil { - return nil, TerminalError(fmt.Errorf("failed to marshal Run output: %w", err)) - } - return bytes, nil - }) - - if err != nil { - return output, err - } - - if err := opts.codec.Unmarshal(bytes, &output); err != nil { - return output, TerminalError(fmt.Errorf("failed to unmarshal Run output into T: %w", err)) - } - - return output, nil + return } -// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future -// value or error, that can be resolved or rejected by other services. -type Awakeable[T any] interface { +// TypedAwakeable is an extension of Awakeable which returns typed responses instead of accepting a pointer +type TypedAwakeable[T any] interface { // Id returns the awakeable ID, which can be stored or sent to a another service Id() string - // Result blocks on receiving the result of the awakeable, returning the value it was - // resolved with or the error it was rejected with. + // Result blocks on receiving the result of the awakeable, storing the value it was + // resolved with in output or otherwise returning the error it was rejected with. // It is *not* safe to call this in a goroutine - use Context.Select if you // want to wait on multiple results at once. Result() (T, error) futures.Selectable } -type decodingAwakeable[T any] struct { - Awakeable[[]byte] - opts awakeableOptions +type typedAwakeable[T any] struct { + Awakeable } -func (d decodingAwakeable[T]) Id() string { return d.Awakeable.Id() } -func (d decodingAwakeable[T]) Result() (out T, err error) { - bytes, err := d.Awakeable.Result() - if err != nil { - return out, err - } - if err := d.opts.codec.Unmarshal(bytes, &out); err != nil { - return out, TerminalError(fmt.Errorf("failed to unmarshal Awakeable result into T: %w", err)) - } +func (t typedAwakeable[T]) Result() (output T, err error) { + err = t.Awakeable.Result(&output) return } -type awakeableOptions struct { - codec encoding.Codec -} - -type AwakeableOption interface { - beforeAwakeable(*awakeableOptions) -} - -// AwakeableAs helper function to treat awakeable values as a particular type. -// Bytes are deserialised using JSON by default -func AwakeableAs[T any](ctx Context, options ...AwakeableOption) Awakeable[T] { - opts := awakeableOptions{} - for _, opt := range options { - opt.beforeAwakeable(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - return decodingAwakeable[T]{ctx.Awakeable(), opts} -} - -type resolveAwakeableOptions struct { - codec encoding.Codec -} - -type ResolveAwakeableOption interface { - beforeResolveAwakeable(*resolveAwakeableOptions) -} - -// ResolveAwakeableAs helper function to resolve an awakeable with a particular type -// The type will be serialised to bytes, defaulting to JSON -func ResolveAwakeableAs(ctx Context, id string, value any, options ...ResolveAwakeableOption) error { - opts := resolveAwakeableOptions{} - for _, opt := range options { - opt.beforeResolveAwakeable(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - bytes, err := opts.codec.Marshal(value) - if err != nil { - return TerminalError(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err)) - } - ctx.ResolveAwakeable(id, bytes) - return nil -} - -type callOptions struct { - codec encoding.Codec -} - -type CallOption interface { - beforeCall(*callOptions) +// AwakeableAs helper function to treat awakeable results as a particular type. +func AwakeableAs[T any](ctx Context, options ...options.AwakeableOption) TypedAwakeable[T] { + return typedAwakeable[T]{ctx.Awakeable(options...)} } -type codecCallClient[O any] struct { - client CallClient[[]byte, []byte] - options callOptions +// TypedCallClient is an extension of CallClient which returns typed responses instead of accepting a pointer +type TypedCallClient[O any] interface { + // RequestFuture makes a call and returns a handle on a future response + RequestFuture(input any) (TypedResponseFuture[O], error) + // Request makes a call and blocks on getting the response + Request(input any) (O, error) + SendClient } -func (c codecCallClient[O]) RequestFuture(input any) (ResponseFuture[O], error) { - bytes, err := c.options.codec.Marshal(input) - if err != nil { - return nil, TerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) - } - fut, err := c.client.RequestFuture(bytes) - if err != nil { - return nil, err - } - return decodingResponseFuture[O]{fut, c.options}, nil +type typedCallClient[O any] struct { + CallClient } -func (c codecCallClient[O]) Request(input any) (output O, err error) { - fut, err := c.RequestFuture(input) +func (t typedCallClient[O]) Request(input any) (output O, err error) { + fut, err := t.CallClient.RequestFuture(input) if err != nil { return output, err } - return fut.Response() + err = fut.Response(&output) + return } -func (c codecCallClient[O]) Send(input any, delay time.Duration) error { - bytes, err := c.options.codec.Marshal(input) +func (t typedCallClient[O]) RequestFuture(input any) (TypedResponseFuture[O], error) { + fut, err := t.CallClient.RequestFuture(input) if err != nil { - return TerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) + return nil, err } - return c.client.Send(bytes, delay) + return typedResponseFuture[O]{fut}, nil } -// CallAs helper function to use a codec for encoding and decoding, defaulting to JSON -func CallAs[O any](client CallClient[[]byte, []byte], options ...CallOption) CallClient[any, O] { - opts := callOptions{} - for _, opt := range options { - opt.beforeCall(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - return codecCallClient[O]{client, opts} +// TypedResponseFuture is an extension of ResponseFuture which returns typed responses instead of accepting a pointer +type TypedResponseFuture[O any] interface { + // Response blocks on the response to the call and returns it or the associated error + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Response() (O, error) + futures.Selectable } -// SendAs helper function to use a codec for encoding .Send request parameters, defaulting to JSON -func SendAs(client CallClient[[]byte, []byte], options ...CallOption) SendClient[any] { - opts := callOptions{} - for _, opt := range options { - opt.beforeCall(&opts) - } - if opts.codec == nil { - opts.codec = encoding.JSONCodec{} - } - return codecCallClient[struct{}]{client, opts} +type typedResponseFuture[O any] struct { + ResponseFuture } -type decodingResponseFuture[O any] struct { - ResponseFuture[[]byte] - options callOptions +func (t typedResponseFuture[O]) Response() (output O, err error) { + err = t.ResponseFuture.Response(&output) + return } -func (d decodingResponseFuture[O]) Response() (output O, err error) { - bytes, err := d.ResponseFuture.Response() - if err != nil { - return output, err - } - - if err := d.options.codec.Unmarshal(bytes, &output); err != nil { - return output, TerminalError(fmt.Errorf("failed to unmarshal Call response into O: %w", err)) - } - - return output, nil +// CallAs helper function to get typed responses instead of passing in a pointer +func CallAs[O any](client CallClient) TypedCallClient[O] { + return typedCallClient[O]{client} } diff --git a/internal/errors/error.go b/internal/errors/error.go index 1cd0ca5..f0ef08f 100644 --- a/internal/errors/error.go +++ b/internal/errors/error.go @@ -13,6 +13,10 @@ const ( ErrProtocolViolation Code = 571 ) +var ( + ErrKeyNotFound = NewTerminalError(fmt.Errorf("key not found"), 404) +) + type CodeError struct { Code Code Inner error diff --git a/internal/options/options.go b/internal/options/options.go new file mode 100644 index 0000000..f3d2299 --- /dev/null +++ b/internal/options/options.go @@ -0,0 +1,51 @@ +package options + +import "github.com/restatedev/sdk-go/encoding" + +type AwakeableOptions struct { + Codec encoding.Codec +} + +type AwakeableOption interface { + BeforeAwakeable(*AwakeableOptions) +} + +type ResolveAwakeableOptions struct { + Codec encoding.Codec +} + +type ResolveAwakeableOption interface { + BeforeResolveAwakeable(*ResolveAwakeableOptions) +} + +type GetOptions struct { + Codec encoding.Codec +} + +type GetOption interface { + BeforeGet(*GetOptions) +} + +type SetOptions struct { + Codec encoding.Codec +} + +type SetOption interface { + BeforeSet(*SetOptions) +} + +type CallOptions struct { + Codec encoding.Codec +} + +type CallOption interface { + BeforeCall(*CallOptions) +} + +type RunOptions struct { + Codec encoding.Codec +} + +type RunOption interface { + BeforeRun(*RunOptions) +} diff --git a/internal/state/awakeable.go b/internal/state/awakeable.go index 05b575d..85f6671 100644 --- a/internal/state/awakeable.go +++ b/internal/state/awakeable.go @@ -9,7 +9,7 @@ import ( "github.com/restatedev/sdk-go/internal/wire" ) -func (c *Machine) awakeable() restate.Awakeable[[]byte] { +func (c *Machine) awakeable() *futures.Awakeable { entry, entryIndex := replayOrNew( c, func(entry *wire.AwakeableEntryMessage) *wire.AwakeableEntryMessage { diff --git a/internal/state/call.go b/internal/state/call.go index 690f681..8c9e5cf 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -2,15 +2,19 @@ package state import ( "bytes" + "fmt" "time" restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/generated/proto/protocol" + "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/wire" ) type serviceCall struct { + options options.CallOptions machine *Machine service string key string @@ -18,24 +22,53 @@ type serviceCall struct { } // RequestFuture makes a call and returns a handle on the response -func (c *serviceCall) RequestFuture(input []byte) (restate.ResponseFuture[[]byte], error) { - entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, input) +func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) { + bytes, err := c.options.Codec.Marshal(input) + if err != nil { + return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) + } + entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, bytes) + + return decodingResponseFuture{ + futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), + c.options, + }, nil +} + +type decodingResponseFuture struct { + *futures.ResponseFuture + options options.CallOptions +} + +func (d decodingResponseFuture) Response(output any) (err error) { + bytes, err := d.ResponseFuture.Response() + if err != nil { + return err + } + + if err := d.options.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Call response into O: %w", err)) + } - return futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), nil + return nil } // Request makes a call and blocks on the response -func (c *serviceCall) Request(input []byte) ([]byte, error) { +func (c *serviceCall) Request(input any, output any) error { fut, err := c.RequestFuture(input) if err != nil { - return nil, err + return err } - return fut.Response() + return fut.Response(output) } // Send runs a call in the background after delay duration -func (c *serviceCall) Send(input []byte, delay time.Duration) error { - c.machine.sendCall(c.service, c.key, c.method, input, delay) +func (c *serviceCall) Send(input any, delay time.Duration) error { + bytes, err := c.options.Codec.Marshal(input) + if err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) + } + c.machine.sendCall(c.service, c.key, c.method, bytes, delay) return nil } diff --git a/internal/state/state.go b/internal/state/state.go index 3bc05ad..8d6f3cf 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -13,10 +13,12 @@ import ( "time" restate "github.com/restatedev/sdk-go" + "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/generated/proto/protocol" "github.com/restatedev/sdk-go/internal/errors" "github.com/restatedev/sdk-go/internal/futures" "github.com/restatedev/sdk-go/internal/log" + "github.com/restatedev/sdk-go/internal/options" "github.com/restatedev/sdk-go/internal/rand" "github.com/restatedev/sdk-go/internal/wire" "github.com/restatedev/sdk-go/rcontext" @@ -30,10 +32,6 @@ var ( ErrInvalidVersion = fmt.Errorf("invalid version number") ) -var ( - _ restate.Context = (*Context)(nil) -) - type Context struct { context.Context userLogger *slog.Logger @@ -41,7 +39,9 @@ type Context struct { } var _ restate.ObjectContext = &Context{} +var _ restate.ObjectSharedContext = &Context{} var _ restate.Context = &Context{} +var _ restate.RunContext = &Context{} func (c *Context) Log() *slog.Logger { return c.machine.userLog @@ -51,8 +51,22 @@ func (c *Context) Rand() *rand.Rand { return c.machine.rand } -func (c *Context) Set(key string, value []byte) { - c.machine.set(key, value) +func (c *Context) Set(key string, value any, opts ...options.SetOption) error { + o := options.SetOptions{} + for _, opt := range opts { + opt.BeforeSet(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes, err := o.Codec.Marshal(value) + if err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to marshal Set value: %w", err)) + } + + c.machine.set(key, bytes) + return nil } func (c *Context) Clear(key string) { @@ -66,8 +80,25 @@ func (c *Context) ClearAll() { } -func (c *Context) Get(key string) []byte { - return c.machine.get(key) +func (c *Context) Get(key string, output any, opts ...options.GetOption) error { + o := options.GetOptions{} + for _, opt := range opts { + opt.BeforeGet(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes := c.machine.get(key) + if len(bytes) == 0 { + return errors.ErrKeyNotFound + } + + if err := o.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Get state into output: %w", err)) + } + + return nil } func (c *Context) Keys() []string { @@ -82,16 +113,34 @@ func (c *Context) After(d time.Duration) restate.After { return c.machine.after(d) } -func (c *Context) Service(service, method string) restate.CallClient[[]byte, []byte] { +func (c *Context) Service(service, method string, opts ...options.CallOption) restate.CallClient { + o := options.CallOptions{} + for _, opt := range opts { + opt.BeforeCall(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + return &serviceCall{ + options: o, machine: c.machine, service: service, method: method, } } -func (c *Context) Object(service, key, method string) restate.CallClient[[]byte, []byte] { +func (c *Context) Object(service, key, method string, opts ...options.CallOption) restate.CallClient { + o := options.CallOptions{} + for _, opt := range opts { + opt.BeforeCall(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + return &serviceCall{ + options: o, machine: c.machine, service: service, key: key, @@ -99,16 +148,89 @@ func (c *Context) Object(service, key, method string) restate.CallClient[[]byte, } } -func (c *Context) Run(fn func(ctx restate.RunContext) ([]byte, error)) ([]byte, error) { - return c.machine.run(fn) +func (c *Context) Run(fn func(ctx restate.RunContext) (any, error), output any, opts ...options.RunOption) error { + o := options.RunOptions{} + for _, opt := range opts { + opt.BeforeRun(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + + bytes, err := c.machine.run(func(ctx restate.RunContext) ([]byte, error) { + output, err := fn(ctx) + if err != nil { + return nil, err + } + + bytes, err := o.Codec.Marshal(output) + if err != nil { + return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal Run output: %w", err)) + } + + return bytes, nil + }) + if err != nil { + return err + } + + if err := o.Codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Run output: %w", err)) + } + + return nil +} + +type awakeableOptions struct { + codec encoding.Codec } -func (c *Context) Awakeable() restate.Awakeable[[]byte] { - return c.machine.awakeable() +type AwakeableOption interface { + beforeAwakeable(*awakeableOptions) } -func (c *Context) ResolveAwakeable(id string, value []byte) { - c.machine.resolveAwakeable(id, value) +func (c *Context) Awakeable(opts ...options.AwakeableOption) restate.Awakeable { + o := options.AwakeableOptions{} + for _, opt := range opts { + opt.BeforeAwakeable(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + return decodingAwakeable{c.machine.awakeable(), o.Codec} +} + +type decodingAwakeable struct { + *futures.Awakeable + codec encoding.Codec +} + +func (d decodingAwakeable) Id() string { return d.Awakeable.Id() } +func (d decodingAwakeable) Result(output any) (err error) { + bytes, err := d.Awakeable.Result() + if err != nil { + return err + } + if err := d.codec.Unmarshal(bytes, output); err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to unmarshal Awakeable result into output: %w", err)) + } + return +} + +func (c *Context) ResolveAwakeable(id string, value any, opts ...options.ResolveAwakeableOption) error { + o := options.ResolveAwakeableOptions{} + for _, opt := range opts { + opt.BeforeResolveAwakeable(&o) + } + if o.Codec == nil { + o.Codec = encoding.JSONCodec + } + bytes, err := o.Codec.Marshal(value) + if err != nil { + return errors.NewTerminalError(fmt.Errorf("failed to marshal ResolveAwakeable value: %w", err)) + } + c.machine.resolveAwakeable(id, bytes) + return nil } func (c *Context) RejectAwakeable(id string, reason error) { diff --git a/options.go b/options.go index dcdd34d..58c2ae6 100644 --- a/options.go +++ b/options.go @@ -1,24 +1,29 @@ package restate -import "github.com/restatedev/sdk-go/encoding" +import ( + "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal/options" +) type withCodec struct { codec encoding.Codec } -var _ GetOption = withCodec{} -var _ SetOption = withCodec{} -var _ RunOption = withCodec{} -var _ AwakeableOption = withCodec{} -var _ ResolveAwakeableOption = withCodec{} -var _ CallOption = withCodec{} - -func (w withCodec) beforeGet(opts *getOptions) { opts.codec = w.codec } -func (w withCodec) beforeSet(opts *setOptions) { opts.codec = w.codec } -func (w withCodec) beforeRun(opts *runOptions) { opts.codec = w.codec } -func (w withCodec) beforeAwakeable(opts *awakeableOptions) { opts.codec = w.codec } -func (w withCodec) beforeResolveAwakeable(opts *resolveAwakeableOptions) { opts.codec = w.codec } -func (w withCodec) beforeCall(opts *callOptions) { opts.codec = w.codec } +var _ options.GetOption = withCodec{} +var _ options.SetOption = withCodec{} +var _ options.RunOption = withCodec{} +var _ options.AwakeableOption = withCodec{} +var _ options.ResolveAwakeableOption = withCodec{} +var _ options.CallOption = withCodec{} + +func (w withCodec) BeforeGet(opts *options.GetOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeSet(opts *options.SetOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeRun(opts *options.RunOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeAwakeable(opts *options.AwakeableOptions) { opts.Codec = w.codec } +func (w withCodec) BeforeResolveAwakeable(opts *options.ResolveAwakeableOptions) { + opts.Codec = w.codec +} +func (w withCodec) BeforeCall(opts *options.CallOptions) { opts.Codec = w.codec } func WithCodec(codec encoding.Codec) withCodec { return withCodec{codec} @@ -45,4 +50,6 @@ func WithPayloadCodec(codec encoding.PayloadCodec) withPayloadCodec { return withPayloadCodec{withCodec{codec}, codec} } -var WithProto = WithPayloadCodec(encoding.ProtoCodec{}) +var WithProto = WithPayloadCodec(encoding.ProtoCodec) +var WithBinary = WithPayloadCodec(encoding.BinaryCodec) +var WithJSON = WithPayloadCodec(encoding.JSONCodec) diff --git a/reflect.go b/reflect.go index f83e251..24c5b4e 100644 --- a/reflect.go +++ b/reflect.go @@ -81,11 +81,11 @@ func Object(object any, options ...ObjectRouterOption) *ObjectRouter { var codec encoding.PayloadCodec switch { case input == typeOfVoid && output == typeOfVoid: - codec = encoding.VoidCodec{} + codec = encoding.VoidCodec case input == typeOfVoid: - codec = encoding.PairCodec{Input: encoding.VoidCodec{}, Output: nil} + codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} case output == typeOfVoid: - codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec{}} + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} default: codec = nil } @@ -159,11 +159,11 @@ func Service(service any, options ...ServiceRouterOption) *ServiceRouter { var codec encoding.PayloadCodec switch { case input == typeOfVoid && output == typeOfVoid: - codec = encoding.VoidCodec{} + codec = encoding.VoidCodec case input == typeOfVoid: - codec = encoding.PairCodec{Input: encoding.VoidCodec{}, Output: nil} + codec = encoding.PairCodec{Input: encoding.VoidCodec, Output: nil} case output == typeOfVoid: - codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec{}} + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec} default: codec = nil } diff --git a/router.go b/router.go index 3659ca2..9cd52b4 100644 --- a/router.go +++ b/router.go @@ -1,17 +1,10 @@ package restate import ( - "fmt" - "net/http" - "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" ) -var ( - ErrKeyNotFound = TerminalError(fmt.Errorf("key not found"), http.StatusNotFound) -) - // Router interface type Router interface { Name() string @@ -44,7 +37,7 @@ func NewServiceRouter(name string, options ...ServiceRouterOption) *ServiceRoute opt.beforeServiceRouter(&opts) } if opts.defaultCodec == nil { - opts.defaultCodec = encoding.JSONCodec{} + opts.defaultCodec = encoding.JSONCodec } return &ServiceRouter{ name: name, @@ -95,7 +88,7 @@ func NewObjectRouter(name string, options ...ObjectRouterOption) *ObjectRouter { opt.beforeObjectRouter(&opts) } if opts.defaultCodec == nil { - opts.defaultCodec = encoding.JSONCodec{} + opts.defaultCodec = encoding.JSONCodec } return &ObjectRouter{ name: name,