diff --git a/context.go b/context.go new file mode 100644 index 0000000..28bfe54 --- /dev/null +++ b/context.go @@ -0,0 +1,140 @@ +package restate + +import ( + "context" + "log/slog" + "time" + + "github.com/restatedev/sdk-go/internal/futures" + "github.com/restatedev/sdk-go/internal/rand" +) + +type Context interface { + RunContext + + // Rand returns a random source which will give deterministic results for a given invocation + // The source wraps the stdlib rand.Rand but with some extra helper methods + // This source is not safe for use inside .Run() + Rand() *rand.Rand + + // Sleep for the duration d + Sleep(d time.Duration) + // After is an alternative to Context.Sleep which allows you to complete other tasks concurrently + // with the sleep. This is particularly useful when combined with Context.Select to race between + // the sleep and other Selectable operations. + After(d time.Duration) After + + // Service gets a Service accessor by name where service + // must be another service known by restate runtime + Service(service, method string) CallClient[[]byte, []byte] + + // Object gets a Object accessor by name where object + // must be another object known by restate runtime and + // key is any string representing the key for the object + Object(object, key, method string) CallClient[[]byte, []byte] + + // Run runs the function (fn), storing final results (including terminal errors) + // durably in the journal, or otherwise for transient errors stopping execution + // so Restate can retry the invocation. Replays will produce the same value, so + // all non-deterministic operations (eg, generating a unique ID) *must* happen + // inside Run blocks. + // Note: use the RunAs helper function to serialise non-[]byte return values + Run(fn func(RunContext) ([]byte, error)) ([]byte, error) + + // Awakeable returns a Restate awakeable; a 'promise' to a future + // value or error, that can be resolved or rejected by other services. + // Note: use the AwakeableAs helper function to deserialise the []byte value + Awakeable() Awakeable[[]byte] + // ResolveAwakeable allows an awakeable (not necessarily from this service) to be + // resolved with a particular value. + // Note: use the ResolveAwakeableAs helper function to provide a value to be serialised + ResolveAwakeable(id string, value []byte) + // ResolveAwakeable allows an awakeable (not necessarily from this service) to be + // rejected with a particular error. + RejectAwakeable(id string, reason error) + + // Select returns an iterator over blocking Restate operations (sleep, call, awakeable) + // which allows you to safely run them in parallel. The Selector will store the order + // that things complete in durably inside Restate, so that on replay the same order + // can be used. This avoids non-determinism. It is *not* safe to use goroutines or channels + // outside of Context.Run functions, as they do not behave deterministically. + Select(futs ...futures.Selectable) Selector +} + +type CallClient[I any, O any] interface { + // RequestFuture makes a call and returns a handle on a future response + RequestFuture(input I) (ResponseFuture[O], error) + // Request makes a call and blocks on getting the response + Request(input I) (O, error) + SendClient[I] +} + +type SendClient[I any] interface { + // Send makes a one-way call which is executed in the background + Send(input I, delay time.Duration) error +} + +type ResponseFuture[O any] interface { + // Response blocks on the response to the call + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Response() (O, error) + futures.Selectable +} + +// Selector is an iterator over a list of blocking Restate operations that are running +// in the background. +type Selector interface { + // Remaining returns whether there are still operations that haven't been returned by Select(). + // There will always be exactly the same number of results as there were operations + // given to Context.Select + Remaining() bool + // Select blocks on the next completed operation + Select() futures.Selectable +} + +// RunContext methods are the only methods safe to call from inside a .Run() +type RunContext interface { + context.Context + + // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) + // By default, this logger will not output messages if the invocation is currently replaying + // The log handler can be set with `.WithLogger()` on the server object + Log() *slog.Logger +} + +// After is a handle on a Sleep operation which allows you to do other work concurrently +// with the sleep. +type After interface { + // Done blocks waiting on the remaining duration of the sleep. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Done() + futures.Selectable +} + +type ObjectContext interface { + Context + KeyValueStore + // Key retrieves the key for this virtual object invocation. This is a no-op and is + // always safe to call. + Key() string +} + +type KeyValueStore interface { + // Set sets a byte array against a key + // Note: Use SetAs helper function to seamlessly store + // a value of specific type. + Set(key string, value []byte) + // Get gets value (bytes array) associated with key + // If key does not exist, this function return a nil bytes array + // Note: Use GetAs helper function to seamlessly get value + // as specific type. + Get(key string) []byte + // Clear deletes a key + Clear(key string) + // ClearAll drops all stored state associated with key + ClearAll() + // Keys returns a list of all associated key + Keys() []string +} diff --git a/encoding/encoding.go b/encoding/encoding.go index c7dd5eb..bbe3549 100644 --- a/encoding/encoding.go +++ b/encoding/encoding.go @@ -2,10 +2,24 @@ package encoding import ( "encoding/json" + "fmt" + "reflect" "google.golang.org/protobuf/proto" ) +type Void struct{} + +type Codec interface { + Marshal(v any) ([]byte, error) + Unmarshal(data []byte, v any) error +} + +type PayloadCodec interface { + Codec + InputPayload() *InputPayload + OutputPayload() *OutputPayload +} type InputPayload struct { Required bool `json:"required"` ContentType *string `json:"contentType,omitempty"` @@ -18,53 +32,195 @@ type OutputPayload struct { JsonSchema interface{} `json:"jsonSchema,omitempty"` } -type JSONDecoder[I any] struct{} +type VoidCodec struct{} -func (j JSONDecoder[I]) InputPayload() *InputPayload { - return &InputPayload{Required: true, ContentType: proto.String("application/json")} +var _ PayloadCodec = VoidCodec{} + +func (j VoidCodec) InputPayload() *InputPayload { + return &InputPayload{} } -func (j JSONDecoder[I]) Decode(data []byte) (input I, err error) { - err = json.Unmarshal(data, &input) - return +func (j VoidCodec) OutputPayload() *OutputPayload { + return &OutputPayload{} } -type JSONEncoder[O any] struct{} +func (j VoidCodec) Unmarshal(data []byte, input any) (err error) { + return nil +} -func (j JSONEncoder[O]) OutputPayload() *OutputPayload { - return &OutputPayload{ContentType: proto.String("application/json")} +func (j VoidCodec) Marshal(output any) ([]byte, error) { + return nil, nil } -func (j JSONEncoder[O]) Encode(output O) ([]byte, error) { - return json.Marshal(output) +type PairCodec struct { + Input PayloadCodec + Output PayloadCodec } -type MessagePointer[I any] interface { - proto.Message - *I +var _ PayloadCodec = PairCodec{} + +func (w PairCodec) InputPayload() *InputPayload { + return w.Input.InputPayload() } -type ProtoDecoder[I any, IP MessagePointer[I]] struct{} +func (w PairCodec) OutputPayload() *OutputPayload { + return w.Output.OutputPayload() +} -func (p ProtoDecoder[I, IP]) InputPayload() *InputPayload { - return &InputPayload{Required: true, ContentType: proto.String("application/proto")} +func (w PairCodec) Unmarshal(data []byte, v any) error { + return w.Input.Unmarshal(data, v) } -func (p ProtoDecoder[I, IP]) Decode(data []byte) (input IP, err error) { - // Unmarshal expects a non-nil pointer to a proto.Message implementing struct - // hence we must have a type parameter for the struct itself (I) and here we allocate - // a non-nil pointer of type IP - input = IP(new(I)) - err = proto.Unmarshal(data, input) - return +func (w PairCodec) Marshal(v any) ([]byte, error) { + return w.Output.Marshal(v) } -type ProtoEncoder[O proto.Message] struct{} +func MergeCodec(base, overlay PayloadCodec) PayloadCodec { + switch { + case base == nil && overlay == nil: + return nil + case base == nil: + return overlay + case overlay == nil: + return base + } + + basePair, baseOk := base.(PairCodec) + overlayPair, overlayOk := overlay.(PairCodec) + + switch { + case baseOk && overlayOk: + return PairCodec{ + Input: MergeCodec(basePair.Input, overlayPair.Input), + Output: MergeCodec(basePair.Output, overlayPair.Output), + } + case baseOk: + return PairCodec{ + Input: MergeCodec(basePair.Input, overlay), + Output: MergeCodec(basePair.Output, overlay), + } + case overlayOk: + return PairCodec{ + Input: MergeCodec(base, overlayPair.Input), + Output: MergeCodec(base, overlayPair.Output), + } + default: + // just two non-pairs; keep base + return base + } +} + +func PartialVoidCodec[I any, O any]() PayloadCodec { + var input I + var output O + _, inputVoid := any(input).(Void) + _, outputVoid := any(output).(Void) + switch { + case inputVoid && outputVoid: + return VoidCodec{} + case inputVoid: + return PairCodec{Input: VoidCodec{}, Output: nil} + case outputVoid: + return PairCodec{Input: nil, Output: VoidCodec{}} + default: + return nil + } +} + +type BinaryCodec struct{} + +var _ PayloadCodec = BinaryCodec{} + +func (j BinaryCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/octet-stream")} +} + +func (j BinaryCodec) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/octet-stream")} +} + +func (j BinaryCodec) Unmarshal(data []byte, input any) (err error) { + switch input := input.(type) { + case *[]byte: + *input = data + return nil + default: + return fmt.Errorf("BinaryCodec.Unmarshal called with a type that is not *[]byte") + } +} -func (p ProtoEncoder[O]) OutputPayload() *OutputPayload { +func (j BinaryCodec) Marshal(output any) ([]byte, error) { + switch output := output.(type) { + case []byte: + return output, nil + default: + return nil, fmt.Errorf("BinaryCodec.Marshal called with a type that is not []byte") + } +} + +type JSONCodec struct{} + +var _ PayloadCodec = JSONCodec{} + +func (j JSONCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/json")} +} + +func (j JSONCodec) OutputPayload() *OutputPayload { + return &OutputPayload{ContentType: proto.String("application/json")} +} + +func (j JSONCodec) Unmarshal(data []byte, input any) (err error) { + return json.Unmarshal(data, &input) +} + +func (j JSONCodec) Marshal(output any) ([]byte, error) { + return json.Marshal(output) +} + +type ProtoCodec struct{} + +var _ PayloadCodec = ProtoCodec{} + +func (p ProtoCodec) InputPayload() *InputPayload { + return &InputPayload{Required: true, ContentType: proto.String("application/proto")} +} + +func (p ProtoCodec) OutputPayload() *OutputPayload { return &OutputPayload{ContentType: proto.String("application/proto")} } -func (p ProtoEncoder[O]) Encode(output O) ([]byte, error) { - return proto.Marshal(output) +func (p ProtoCodec) Unmarshal(data []byte, input any) (err error) { + switch input := input.(type) { + case proto.Message: + // called with a *Message + return proto.Unmarshal(data, input) + default: + // we must support being called with a **Message where *Message is nil because this is the result of new(I) where I is a proto.Message + // and calling with new(I) is really the only generic approach. + value := reflect.ValueOf(input) + if value.Kind() != reflect.Pointer || value.IsNil() || value.Elem().Kind() != reflect.Pointer { + return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.") + } + elem := value.Elem() // hopefully a *Message + if elem.IsNil() { + // allocate a &Message and swap this in + elem.Set(reflect.New(elem.Type().Elem())) + } + switch elemI := elem.Interface().(type) { + case proto.Message: + return proto.Unmarshal(data, elemI) + default: + return fmt.Errorf("ProtoCodec.Unmarshal called with neither a proto.Message nor a non-nil pointer to a type that implements proto.Message.") + } + } +} + +func (p ProtoCodec) Marshal(output any) (data []byte, err error) { + switch output := output.(type) { + case proto.Message: + return proto.Marshal(output) + default: + return nil, fmt.Errorf("ProtoCodec.Marshal called with a type that is not a proto.Message") + } } diff --git a/encoding/encoding_test.go b/encoding/encoding_test.go new file mode 100644 index 0000000..37026d1 --- /dev/null +++ b/encoding/encoding_test.go @@ -0,0 +1,73 @@ +package encoding + +import ( + "testing" + + "github.com/restatedev/sdk-go/generated/proto/protocol" +) + +func willPanic(t *testing.T, do func()) { + defer func() { + switch recover() { + case nil: + t.Fatalf("expected panic but didn't find one") + default: + return + } + }() + + do() +} + +func willSucceed(t *testing.T, err error) { + if err != nil { + t.Fatal(err) + } +} + +func checkMessage(t *testing.T, msg *protocol.AwakeableEntryMessage) { + if msg.Name != "foobar" { + t.Fatalf("unexpected msg.Name: %s", msg.Name) + } +} + +func TestProto(t *testing.T) { + p := ProtoCodec{} + + _, err := p.Marshal(protocol.AwakeableEntryMessage{Name: "foobar"}) + if err == nil { + t.Fatalf("expected error when marshaling non-pointer proto Message") + } + + bytes, err := p.Marshal(&protocol.AwakeableEntryMessage{Name: "foobar"}) + if err != nil { + t.Fatal(err) + } + + { + msg := &protocol.AwakeableEntryMessage{} + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, msg) + } + + { + inner := &protocol.AwakeableEntryMessage{} + msg := &inner + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, *msg) + } + + { + msg := new(*protocol.AwakeableEntryMessage) + willSucceed(t, p.Unmarshal(bytes, msg)) + checkMessage(t, *msg) + } + + { + var msg *protocol.AwakeableEntryMessage + willPanic(t, func() { + p.Unmarshal(bytes, msg) + }) + } + +} diff --git a/example/checkout.go b/example/checkout.go index d6cd8a4..f222540 100644 --- a/example/checkout.go +++ b/example/checkout.go @@ -19,7 +19,7 @@ type PaymentResponse struct { type checkout struct{} -func (c *checkout) Name() string { +func (c *checkout) ServiceName() string { return CheckoutServiceName } diff --git a/example/main.go b/example/main.go index 12574af..d1c9127 100644 --- a/example/main.go +++ b/example/main.go @@ -14,7 +14,10 @@ func main() { server := server.NewRestate(). Bind(restate.Object(&userSession{})). Bind(restate.Object(&ticketService{})). - Bind(restate.Service(&checkout{})) + Bind(restate.Service(&checkout{})). + Bind(restate.NewServiceRouter("health").Handler("ping", restate.NewServiceHandler(func(restate.Context, struct{}) (restate.Void, error) { + return restate.Void{}, nil + }))) if err := server.Start(context.Background(), ":9080"); err != nil { slog.Error("application exited unexpectedly", "err", err.Error()) diff --git a/example/ticket_service.go b/example/ticket_service.go index aa98ae8..de9a0aa 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -18,7 +18,7 @@ const TicketServiceName = "TicketService" type ticketService struct{} -func (t *ticketService) Name() string { return TicketServiceName } +func (t *ticketService) ServiceName() string { return TicketServiceName } func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) { status, err := restate.GetAs[TicketStatus](ctx, "status") diff --git a/example/user_session.go b/example/user_session.go index 54b2a86..cb82dbf 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -12,15 +12,15 @@ const UserSessionServiceName = "UserSession" type userSession struct{} -func (u *userSession) Name() string { +func (u *userSession) ServiceName() string { return UserSessionServiceName } func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (bool, error) { userId := ctx.Key() - var success bool - if err := ctx.Object(TicketServiceName, ticketId).Method("Reserve").Request(userId).Response(&success); err != nil { + success, err := restate.CallAs[bool](ctx.Object(TicketServiceName, ticketId, "Reserve")).Request(userId) + if err != nil { return false, err } @@ -41,7 +41,7 @@ func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (boo return false, err } - if err := ctx.ObjectSend(UserSessionServiceName, ticketId, 15*time.Minute).Method("ExpireTicket").Request(ticketId); err != nil { + if err := restate.SendAs(ctx.Object(UserSessionServiceName, ticketId, "ExpireTicket")).Send(ticketId, 15*time.Minute); err != nil { return false, err } @@ -70,7 +70,7 @@ func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) ( return void, err } - return void, ctx.ObjectSend(TicketServiceName, ticketId, 0).Method("Unreserve").Request(nil) + return void, ctx.Object(TicketServiceName, ticketId, "Unreserve").Send(nil, 0) } func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { @@ -86,19 +86,34 @@ func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, return false, nil } - var response PaymentResponse - if err := ctx.Object(CheckoutServiceName, ""). - Method("Payment"). - Request(PaymentRequest{UserID: userId, Tickets: tickets}). - Response(&response); err != nil { + timeout := ctx.After(time.Minute) + + request, err := restate.CallAs[PaymentResponse](ctx.Object(CheckoutServiceName, "Payment", "")). + RequestFuture(PaymentRequest{UserID: userId, Tickets: tickets}) + if err != nil { + return false, err + } + + // race between the request and the timeout + switch ctx.Select(timeout, request).Select() { + case request: + // happy path + case timeout: + // we could choose to fail here with terminal error, but we'd also have to refund the payment! + ctx.Log().Warn("slow payment") + } + + // block on the eventual response + response, err := request.Response() + if err != nil { return false, err } ctx.Log().Info("payment details", "id", response.ID, "price", response.Price) for _, ticket := range tickets { - call := ctx.ObjectSend(TicketServiceName, ticket, 0).Method("MarkAsSold") - if err := call.Request(nil); err != nil { + call := ctx.Object(TicketServiceName, ticket, "MarkAsSold") + if err := call.Send(nil, 0); err != nil { return false, err } } diff --git a/facilitators.go b/facilitators.go new file mode 100644 index 0000000..5d8cecb --- /dev/null +++ b/facilitators.go @@ -0,0 +1,312 @@ +package restate + +import ( + "time" + + "github.com/restatedev/sdk-go/encoding" + "github.com/restatedev/sdk-go/internal/futures" +) + +type getOptions struct { + codec encoding.Codec +} + +type GetOption interface { + beforeGet(*getOptions) +} + +// GetAs helper function to get a key as specific type. Note that +// if there is no associated value with key, an error ErrKeyNotFound is +// returned +// it does encoding/decoding of bytes, defaulting to json codec +func GetAs[T any](ctx ObjectContext, key string, options ...GetOption) (output T, err error) { + opts := getOptions{} + for _, opt := range options { + opt.beforeGet(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + + bytes := ctx.Get(key) + if bytes == nil { + // key does not exist. + return output, ErrKeyNotFound + } + + return output, opts.codec.Unmarshal(bytes, &output) +} + +type setOptions struct { + codec encoding.Codec +} + +type SetOption interface { + beforeSet(*setOptions) +} + +// SetAs helper function to set a key value with a generic type T. +// it does encoding/decoding of bytes automatically, defaulting to json codec +func SetAs(ctx ObjectContext, key string, value any, options ...SetOption) error { + opts := setOptions{} + for _, opt := range options { + opt.beforeSet(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + + bytes, err := opts.codec.Marshal(value) + if err != nil { + return err + } + + ctx.Set(key, bytes) + return nil +} + +type runOptions struct { + codec encoding.Codec +} + +type RunOption interface { + beforeRun(*runOptions) +} + +// RunAs helper function runs a run function with specific concrete type as a result +// it does encoding/decoding of bytes automatically, defaulting to json codec +func RunAs[T any](ctx Context, fn func(RunContext) (T, error), options ...RunOption) (output T, err error) { + opts := runOptions{} + for _, opt := range options { + opt.beforeRun(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + + bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { + out, err := fn(ctx) + if err != nil { + return nil, err + } + + bytes, err := opts.codec.Marshal(out) + // todo: should this be terminal + return bytes, TerminalError(err) + }) + + if err != nil { + return output, err + } + + return output, TerminalError(opts.codec.Unmarshal(bytes, &output)) +} + +// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future +// value or error, that can be resolved or rejected by other services. +type Awakeable[T any] interface { + // Id returns the awakeable ID, which can be stored or sent to a another service + Id() string + // Result blocks on receiving the result of the awakeable, returning the value it was + // resolved with or the error it was rejected with. + // It is *not* safe to call this in a goroutine - use Context.Select if you + // want to wait on multiple results at once. + Result() (T, error) + futures.Selectable +} + +type decodingAwakeable[T any] struct { + Awakeable[[]byte] + opts awakeableOptions +} + +func (d decodingAwakeable[T]) Id() string { return d.Awakeable.Id() } +func (d decodingAwakeable[T]) Result() (out T, err error) { + bytes, err := d.Awakeable.Result() + if err != nil { + return out, err + } + if err := d.opts.codec.Unmarshal(bytes, &out); err != nil { + return out, err + } + return +} + +type awakeableOptions struct { + codec encoding.Codec +} + +type AwakeableOption interface { + beforeAwakeable(*awakeableOptions) +} + +// AwakeableAs helper function to treat awakeable values as a particular type. +// Bytes are deserialised using JSON by default +func AwakeableAs[T any](ctx Context, options ...AwakeableOption) Awakeable[T] { + opts := awakeableOptions{} + for _, opt := range options { + opt.beforeAwakeable(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + return decodingAwakeable[T]{ctx.Awakeable(), opts} +} + +type resolveAwakeableOptions struct { + codec encoding.Codec +} + +type ResolveAwakeableOption interface { + beforeResolveAwakeable(*resolveAwakeableOptions) +} + +// ResolveAwakeableAs helper function to resolve an awakeable with a particular type +// The type will be serialised to bytes, defaulting to JSON +func ResolveAwakeableAs[T any](ctx Context, id string, value T, options ...ResolveAwakeableOption) error { + opts := resolveAwakeableOptions{} + for _, opt := range options { + opt.beforeResolveAwakeable(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + bytes, err := opts.codec.Marshal(value) + if err != nil { + return TerminalError(err) + } + ctx.ResolveAwakeable(id, bytes) + return nil +} + +type callOptions struct { + codec encoding.Codec +} + +type CallOption interface { + beforeCall(*callOptions) +} + +type codecCallClient[O any] struct { + client CallClient[[]byte, []byte] + options callOptions +} + +func (c codecCallClient[O]) RequestFuture(input any) (ResponseFuture[O], error) { + bytes, err := c.options.codec.Marshal(input) + if err != nil { + return nil, TerminalError(err) + } + fut, err := c.client.RequestFuture(bytes) + if err != nil { + return nil, err + } + return decodingResponseFuture[O]{fut, c.options}, nil +} + +func (c codecCallClient[O]) Request(input any) (output O, err error) { + fut, err := c.RequestFuture(input) + if err != nil { + return output, err + } + return fut.Response() +} + +func (c codecCallClient[O]) Send(input any, delay time.Duration) error { + bytes, err := c.options.codec.Marshal(input) + if err != nil { + return TerminalError(err) + } + return c.client.Send(bytes, delay) +} + +// CallClientAs helper function to use a codec for encoding and decoding, defaulting to JSON +func CallAs[O any](client CallClient[[]byte, []byte], options ...CallOption) CallClient[any, O] { + opts := callOptions{} + for _, opt := range options { + opt.beforeCall(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + return codecCallClient[O]{client, opts} +} + +func SendAs(client CallClient[[]byte, []byte], options ...CallOption) SendClient[any] { + opts := callOptions{} + for _, opt := range options { + opt.beforeCall(&opts) + } + if opts.codec == nil { + opts.codec = encoding.JSONCodec{} + } + return codecCallClient[struct{}]{client, opts} +} + +// // ResponseFutureAs helper function to receive JSON without immediately blocking +// func ResponseFutureAs[O any](responseFuture ResponseFuture[[]byte], options ...CallOption) ResponseFuture[O] { +// opts := callOptions{} +// for _, opt := range options { +// opt.beforeCall(&opts) +// } +// if opts.codec == nil { +// opts.codec = encoding.JSONCodec{} +// } +// return decodingResponseFuture[O]{responseFuture, opts} +// } + +type decodingResponseFuture[O any] struct { + ResponseFuture[[]byte] + options callOptions +} + +func (d decodingResponseFuture[O]) Response() (output O, err error) { + bytes, err := d.ResponseFuture.Response() + if err != nil { + return output, err + } + + return output, d.options.codec.Unmarshal(bytes, &output) +} + +// // CallAsFuture helper function to send JSON and allow receiving JSON later +// func CallAsFuture[O any, I any](client CallClient[[]byte, []byte], input I) (ResponseFuture[O], error) { +// var bytes []byte +// switch any(input).(type) { +// case Void: +// default: +// var err error +// bytes, err = json.Marshal(input) +// if err != nil { +// return nil, err +// } +// } + +// return ResponseFutureAs[O](client.Request(bytes)), nil +// } + +// type codecSendClient struct { +// client SendClient[[]byte] +// options callOptions +// } + +// func (c codecSendClient) Request(input any) error { +// bytes, err := c.options.codec.Marshal(input) +// if err != nil { +// return TerminalError(err) +// } +// return c.client.Request(bytes) +// } + +// // CallClientAs helper function to use a codec for encoding, defaulting to JSON +// func SendClientAs(client SendClient[[]byte], options ...CallOption) SendClient[any] { +// opts := callOptions{} +// for _, opt := range options { +// opt.beforeCall(&opts) +// } +// if opts.codec == nil { +// opts.codec = encoding.JSONCodec{} +// } + +// return codecSendClient{client, opts} +// } diff --git a/handler.go b/handler.go index af66b28..66a28e9 100644 --- a/handler.go +++ b/handler.go @@ -1,78 +1,75 @@ package restate import ( - "encoding/json" "fmt" + "net/http" "github.com/restatedev/sdk-go/encoding" ) -// Void is a placeholder used usually for functions that their signature require that -// you accept an input or return an output but the function implementation does not -// require them -type Void struct{} +// Void is a placeholder to signify 'no value' where a type is otherwise needed. It can be used in several contexts: +// 1. Input types for handlers - the request payload codec will default to a encoding.VoidCodec which will reject input at the ingress +// 2. Output types for handlers - the response payload codec will default to a encoding.VoidCodec which will send no bytes and set no content-type +type Void = encoding.Void -type VoidDecoder struct{} +type ObjectHandler interface { + Call(ctx ObjectContext, request []byte) (output []byte, err error) + getOptions() *objectHandlerOptions + Handler +} -func (v VoidDecoder) InputPayload() *encoding.InputPayload { - return &encoding.InputPayload{} +type ServiceHandler interface { + Call(ctx Context, request []byte) (output []byte, err error) + getOptions() *serviceHandlerOptions + Handler } -func (v VoidDecoder) Decode(data []byte) (input Void, err error) { - if len(data) > 0 { - err = fmt.Errorf("restate.Void decoder expects no request data") - } - return +type Handler interface { + sealed() + InputPayload() *encoding.InputPayload + OutputPayload() *encoding.OutputPayload } -type VoidEncoder struct{} +// ServiceHandlerFn signature of service (unkeyed) handler function +type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (output O, err error) -func (v VoidEncoder) OutputPayload() *encoding.OutputPayload { - return &encoding.OutputPayload{} -} +// ObjectHandlerFn signature for object (keyed) handler function +type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (output O, err error) -func (v VoidEncoder) Encode(output Void) ([]byte, error) { - return nil, nil +type serviceHandlerOptions struct { + codec encoding.PayloadCodec } type serviceHandler[I any, O any] struct { fn ServiceHandlerFn[I, O] - decoder Decoder[I] - encoder Encoder[O] + options serviceHandlerOptions } -// NewJSONServiceHandler create a new handler for a service using JSON encoding -func NewJSONServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *serviceHandler[I, O] { - return &serviceHandler[I, O]{ - fn: fn, - decoder: encoding.JSONDecoder[I]{}, - encoder: encoding.JSONEncoder[O]{}, - } -} +var _ ServiceHandler = (*serviceHandler[struct{}, struct{}])(nil) -// NewProtoServiceHandler create a new handler for a service using protobuf encoding -// Input and output type must both be pointers that satisfy proto.Message -func NewProtoServiceHandler[I any, O any, IP encoding.MessagePointer[I], OP encoding.MessagePointer[O]](fn ServiceHandlerFn[IP, OP]) *serviceHandler[IP, OP] { - return &serviceHandler[IP, OP]{ - fn: fn, - decoder: encoding.ProtoDecoder[I, IP]{}, - encoder: encoding.ProtoEncoder[OP]{}, - } +type ServiceHandlerOption interface { + beforeServiceHandler(*serviceHandlerOptions) } -// NewServiceHandlerWithEncoders create a new handler for a service using a custom encoder/decoder implementation -func NewServiceHandlerWithEncoders[I any, O any](fn ServiceHandlerFn[I, O], decoder Decoder[I], encoder Encoder[O]) *serviceHandler[I, O] { +// NewServiceHandler create a new handler for a service, defaulting to JSON encoding +func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O], options ...ServiceHandlerOption) *serviceHandler[I, O] { + opts := serviceHandlerOptions{} + for _, opt := range options { + opt.beforeServiceHandler(&opts) + } + if opts.codec == nil { + opts.codec = encoding.PartialVoidCodec[I, O]() + } return &serviceHandler[I, O]{ fn: fn, - decoder: decoder, - encoder: encoder, + options: opts, } } func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { - input, err := h.decoder.Decode(bytes) - if err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) + var input I + if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } output, err := h.fn( @@ -83,7 +80,7 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { return nil, err } - bytes, err = h.encoder.Encode(output) + bytes, err = h.options.codec.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -92,44 +89,62 @@ func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { } func (h *serviceHandler[I, O]) InputPayload() *encoding.InputPayload { - return h.decoder.InputPayload() + return h.options.codec.InputPayload() } func (h *serviceHandler[I, O]) OutputPayload() *encoding.OutputPayload { - return h.encoder.OutputPayload() + return h.options.codec.OutputPayload() +} + +func (h *serviceHandler[I, O]) getOptions() *serviceHandlerOptions { + return &h.options } func (h *serviceHandler[I, O]) sealed() {} +type objectHandlerOptions struct { + codec encoding.PayloadCodec +} + type objectHandler[I any, O any] struct { - fn ObjectHandlerFn[I, O] + fn ObjectHandlerFn[I, O] + options objectHandlerOptions +} + +var _ ObjectHandler = (*objectHandler[struct{}, struct{}])(nil) + +type ObjectHandlerOption interface { + beforeObjectHandler(*objectHandlerOptions) } -func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O]) *objectHandler[I, O] { +func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O], options ...ObjectHandlerOption) *objectHandler[I, O] { + opts := objectHandlerOptions{} + for _, opt := range options { + opt.beforeObjectHandler(&opts) + } + if opts.codec == nil { + opts.codec = encoding.PartialVoidCodec[I, O]() + } return &objectHandler[I, O]{ fn: fn, } } func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { - input := new(I) - - if len(bytes) > 0 { - // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) - } + var input I + if err := h.options.codec.Unmarshal(bytes, &input); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } output, err := h.fn( ctx, - *input, + input, ) if err != nil { return nil, err } - bytes, err = json.Marshal(output) + bytes, err = h.options.codec.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -137,4 +152,16 @@ func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, err return bytes, nil } +func (h *objectHandler[I, O]) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *objectHandler[I, O]) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} + +func (h *objectHandler[I, O]) getOptions() *objectHandlerOptions { + return &h.options +} + func (h *objectHandler[I, O]) sealed() {} diff --git a/internal/futures/futures.go b/internal/futures/futures.go index 148f078..62eeea3 100644 --- a/internal/futures/futures.go +++ b/internal/futures/futures.go @@ -4,7 +4,6 @@ import ( "context" "encoding/base64" "encoding/binary" - "encoding/json" "fmt" "github.com/restatedev/sdk-go/generated/proto/protocol" @@ -32,8 +31,8 @@ func (a *After) Done() { a.entry.Await(a.suspensionCtx, a.entryIndex) } -func (a *After) getEntry() (wire.CompleteableMessage, uint32, error) { - return a.entry, a.entryIndex, nil +func (a *After) getEntry() (wire.CompleteableMessage, uint32) { + return a.entry, a.entryIndex } const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1" @@ -62,8 +61,8 @@ func (c *Awakeable) Result() ([]byte, error) { return nil, fmt.Errorf("unexpected result in completed awakeable entry: %v", c.entry.Result) } } -func (c *Awakeable) getEntry() (wire.CompleteableMessage, uint32, error) { - return c.entry, c.entryIndex, nil +func (c *Awakeable) getEntry() (wire.CompleteableMessage, uint32) { + return c.entry, c.entryIndex } func awakeableID(invocationID []byte, entryIndex uint32) string { @@ -74,46 +73,30 @@ func awakeableID(invocationID []byte, entryIndex uint32) string { } type ResponseFuture struct { - suspensionCtx context.Context - err error - entry *wire.CallEntryMessage - entryIndex uint32 -} - -func NewResponseFuture(suspensionCtx context.Context, entry *wire.CallEntryMessage, entryIndex uint32) *ResponseFuture { - return &ResponseFuture{suspensionCtx, nil, entry, entryIndex} + suspensionCtx context.Context + entry *wire.CallEntryMessage + entryIndex uint32 + newProtocolViolation func(error) any } -func NewFailedResponseFuture(err error) *ResponseFuture { - return &ResponseFuture{nil, err, nil, 0} +func NewResponseFuture(suspensionCtx context.Context, entry *wire.CallEntryMessage, entryIndex uint32, newProtocolViolation func(error) any) *ResponseFuture { + return &ResponseFuture{suspensionCtx, entry, entryIndex, newProtocolViolation} } -func (r *ResponseFuture) Response(output any) error { - if r.err != nil { - return r.err - } - +func (r *ResponseFuture) Response() ([]byte, error) { r.entry.Await(r.suspensionCtx, r.entryIndex) var bytes []byte switch result := r.entry.Result.(type) { case *protocol.CallEntryMessage_Failure: - return errors.ErrorFromFailure(result.Failure) + return nil, errors.ErrorFromFailure(result.Failure) case *protocol.CallEntryMessage_Value: - bytes = result.Value + return bytes, nil default: - return errors.NewTerminalError(fmt.Errorf("sync call had invalid result: %v", r.entry.Result), 571) - + panic(r.newProtocolViolation(fmt.Errorf("call entry had invalid result: %v", r.entry.Result))) } - - if err := json.Unmarshal(bytes, output); err != nil { - // TODO: is this should be a terminal error or not? - return errors.NewTerminalError(fmt.Errorf("failed to decode response (%s): %w", string(bytes), err)) - } - - return nil } -func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32, error) { - return r.entry, r.entryIndex, r.err +func (r *ResponseFuture) getEntry() (wire.CompleteableMessage, uint32) { + return r.entry, r.entryIndex } diff --git a/internal/futures/select.go b/internal/futures/select.go index fc3674d..314cb45 100644 --- a/internal/futures/select.go +++ b/internal/futures/select.go @@ -9,7 +9,7 @@ import ( ) type Selectable interface { - getEntry() (wire.CompleteableMessage, uint32, error) + getEntry() (wire.CompleteableMessage, uint32) } type Selector struct { @@ -56,10 +56,7 @@ func (s *Selector) Take(winningEntryIndex uint32) Selectable { if selectable == nil { return nil } - entry, _, err := selectable.getEntry() - if err != nil { - return nil - } + entry, _ := selectable.getEntry() if !entry.Completed() { return nil } @@ -81,19 +78,16 @@ func (s *Selector) Indexes() []uint32 { return indexes } -func Select(suspensionCtx context.Context, futs ...Selectable) (*Selector, error) { +func Select(suspensionCtx context.Context, futs ...Selectable) *Selector { s := &Selector{ suspensionCtx: suspensionCtx, indexedFuts: make(map[uint32]Selectable, len(futs)), indexedChans: make(map[uint32]<-chan struct{}, len(futs)), } for i := range futs { - entry, entryIndex, err := futs[i].getEntry() - if err != nil { - return nil, err - } + entry, entryIndex := futs[i].getEntry() s.indexedFuts[entryIndex] = futs[i] s.indexedChans[entryIndex] = entry.Done() } - return s, nil + return s } diff --git a/internal/state/call.go b/internal/state/call.go index a0e6281..690f681 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -2,7 +2,6 @@ package state import ( "bytes" - "encoding/json" "time" restate "github.com/restatedev/sdk-go" @@ -11,45 +10,6 @@ import ( "github.com/restatedev/sdk-go/internal/wire" ) -var ( - _ restate.ServiceClient = (*serviceProxy)(nil) - _ restate.ServiceSendClient = (*serviceSendProxy)(nil) - _ restate.CallClient = (*serviceCall)(nil) - _ restate.SendClient = (*serviceSend)(nil) -) - -type serviceProxy struct { - machine *Machine - service string - key string -} - -func (c *serviceProxy) Method(fn string) restate.CallClient { - return &serviceCall{ - machine: c.machine, - service: c.service, - key: c.key, - method: fn, - } -} - -type serviceSendProxy struct { - machine *Machine - service string - key string - delay time.Duration -} - -func (c *serviceSendProxy) Method(fn string) restate.SendClient { - return &serviceSend{ - machine: c.machine, - service: c.service, - key: c.key, - method: fn, - delay: c.delay, - } -} - type serviceCall struct { machine *Machine service string @@ -57,37 +17,26 @@ type serviceCall struct { method string } -// Do makes a call and wait for the response -func (c *serviceCall) Request(input any) restate.ResponseFuture { - if entry, entryIndex, err := c.machine.doDynCall(c.service, c.key, c.method, input); err != nil { - return futures.NewFailedResponseFuture(err) - } else { - return futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex) - } -} - -type serviceSend struct { - machine *Machine - service string - key string - method string - - delay time.Duration -} +// RequestFuture makes a call and returns a handle on the response +func (c *serviceCall) RequestFuture(input []byte) (restate.ResponseFuture[[]byte], error) { + entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, input) -// Send runs a call in the background after delay duration -func (c *serviceSend) Request(input any) error { - return c.machine.sendCall(c.service, c.key, c.method, input, c.delay) + return futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), nil } -func (m *Machine) doDynCall(service, key, method string, input any) (*wire.CallEntryMessage, uint32, error) { - params, err := json.Marshal(input) +// Request makes a call and blocks on the response +func (c *serviceCall) Request(input []byte) ([]byte, error) { + fut, err := c.RequestFuture(input) if err != nil { - return nil, 0, err + return nil, err } + return fut.Response() +} - entry, entryIndex := m.doCall(service, key, method, params) - return entry, entryIndex, nil +// Send runs a call in the background after delay duration +func (c *serviceCall) Send(input []byte, delay time.Duration) error { + c.machine.sendCall(c.service, c.key, c.method, input, delay) + return nil } func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) { @@ -129,24 +78,19 @@ func (m *Machine) _doCall(service, key, method string, params []byte) *wire.Call return msg } -func (m *Machine) sendCall(service, key, method string, body any, delay time.Duration) error { - params, err := json.Marshal(body) - if err != nil { - return err - } - +func (m *Machine) sendCall(service, key, method string, body []byte, delay time.Duration) { _, _ = replayOrNew( m, func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || - !bytes.Equal(entry.Parameter, params) { + !bytes.Equal(entry.Parameter, body) { panic(m.newEntryMismatch(&wire.OneWayCallEntryMessage{ OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ ServiceName: service, HandlerName: method, - Parameter: params, + Parameter: body, Key: key, }, }, entry)) @@ -155,12 +99,10 @@ func (m *Machine) sendCall(service, key, method string, body any, delay time.Dur return restate.Void{} }, func() restate.Void { - m._sendCall(service, key, method, params, delay) + m._sendCall(service, key, method, body, delay) return restate.Void{} }, ) - - return nil } func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) { diff --git a/internal/state/select.go b/internal/state/select.go index 1ea21e1..1be3f4c 100644 --- a/internal/state/select.go +++ b/internal/state/select.go @@ -13,12 +13,9 @@ type selector struct { inner *futures.Selector } -func (m *Machine) selector(futs ...futures.Selectable) (*selector, error) { - inner, err := futures.Select(m.suspensionCtx, futs...) - if err != nil { - return nil, err - } - return &selector{m, inner}, nil +func (m *Machine) selector(futs ...futures.Selectable) *selector { + inner := futures.Select(m.suspensionCtx, futs...) + return &selector{m, inner} } func (s *selector) Select() futures.Selectable { diff --git a/internal/state/state.go b/internal/state/state.go index b270988..3970c2e 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -66,11 +66,11 @@ func (c *Context) ClearAll() { } -func (c *Context) Get(key string) ([]byte, error) { +func (c *Context) Get(key string) []byte { return c.machine.get(key) } -func (c *Context) Keys() ([]string, error) { +func (c *Context) Keys() []string { return c.machine.keys() } @@ -82,37 +82,41 @@ func (c *Context) After(d time.Duration) restate.After { return c.machine.after(d) } -func (c *Context) Service(service string) restate.ServiceClient { - return &serviceProxy{ +func (c *Context) Service(service, method string) restate.CallClient[[]byte, []byte] { + return &serviceCall{ machine: c.machine, service: service, + method: method, } } -func (c *Context) ServiceSend(service string, delay time.Duration) restate.ServiceSendClient { - return &serviceSendProxy{ - machine: c.machine, - service: service, - delay: delay, - } -} - -func (c *Context) Object(service, key string) restate.ServiceClient { - return &serviceProxy{ +// func (c *Context) ServiceSend(service, method string, delay time.Duration) restate.SendClient[[]byte] { +// return &serviceSend{ +// machine: c.machine, +// service: service, +// method: method, +// delay: delay, +// } +// } + +func (c *Context) Object(service, key, method string) restate.CallClient[[]byte, []byte] { + return &serviceCall{ machine: c.machine, service: service, key: key, + method: method, } } -func (c *Context) ObjectSend(service, key string, delay time.Duration) restate.ServiceSendClient { - return &serviceSendProxy{ - machine: c.machine, - service: service, - key: key, - delay: delay, - } -} +// func (c *Context) ObjectSend(service, key, method string, delay time.Duration) restate.SendClient[[]byte] { +// return &serviceSend{ +// machine: c.machine, +// service: service, +// method: method, +// key: key, +// delay: delay, +// } +// } func (c *Context) Run(fn func(ctx restate.RunContext) ([]byte, error)) ([]byte, error) { return c.machine.run(fn) @@ -130,7 +134,7 @@ func (c *Context) RejectAwakeable(id string, reason error) { c.machine.rejectAwakeable(id, reason) } -func (c *Context) Selector(futs ...futures.Selectable) (restate.Selector, error) { +func (c *Context) Select(futs ...futures.Selectable) restate.Selector { return c.machine.selector(futs...) } @@ -239,6 +243,19 @@ func (m *Machine) invoke(ctx *Context, input []byte, outputSeen bool) error { case nil: // nothing to do, just exit return + case *protocolViolation: + m.log.LogAttrs(m.ctx, slog.LevelError, "Protocol violation", log.Error(typ.err)) + + if err := m.protocol.Write(wire.ErrorMessageType, &wire.ErrorMessage{ + ErrorMessage: protocol.ErrorMessage{ + Code: uint32(errors.ErrProtocolViolation), + Message: fmt.Sprintf("Protocol violation: %v", typ.err), + RelatedEntryIndex: &typ.entryIndex, + RelatedEntryType: wire.MessageType(typ.entry).UInt32(), + }, + }); err != nil { + m.log.LogAttrs(m.ctx, slog.LevelError, "Error sending failure message", log.Error(err)) + } case *entryMismatch: expected, _ := json.Marshal(typ.expectedEntry) actual, _ := json.Marshal(typ.actualEntry) diff --git a/internal/state/sys.go b/internal/state/sys.go index f5f33a1..0409049 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -29,6 +29,18 @@ func (m *Machine) newEntryMismatch(expectedEntry wire.Message, actualEntry wire. return e } +type protocolViolation struct { + entryIndex uint32 + entry wire.Message + err error +} + +func (m *Machine) newProtocolViolation(entry wire.Message, err error) *protocolViolation { + e := &protocolViolation{m.entryIndex, entry, err} + m.failure = e + return e +} + func (m *Machine) set(key string, value []byte) { _, _ = replayOrNew( m, @@ -113,7 +125,7 @@ func (m *Machine) _clearAll() { ) } -func (m *Machine) get(key string) ([]byte, error) { +func (m *Machine) get(key string) []byte { entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateEntryMessage) *wire.GetStateEntryMessage { @@ -133,18 +145,13 @@ func (m *Machine) get(key string) ([]byte, error) { switch value := entry.Result.(type) { case *protocol.GetStateEntryMessage_Empty: - return nil, nil - case *protocol.GetStateEntryMessage_Failure: - // the get state entry message is not failable so this should - // never happen - // TODO terminal? - return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) + return nil case *protocol.GetStateEntryMessage_Value: m.current[key] = value.Value - return value.Value, nil + return value.Value + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("get state entry had invalid result: %v", entry.Result))) } - - return nil, restate.TerminalError(fmt.Errorf("get state had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } func (m *Machine) _get(key string) *wire.GetStateEntryMessage { @@ -184,7 +191,7 @@ func (m *Machine) _get(key string) *wire.GetStateEntryMessage { return msg } -func (m *Machine) keys() ([]string, error) { +func (m *Machine) keys() []string { entry, entryIndex := replayOrNew( m, func(entry *wire.GetStateKeysEntryMessage) *wire.GetStateKeysEntryMessage { @@ -196,20 +203,16 @@ func (m *Machine) keys() ([]string, error) { entry.Await(m.suspensionCtx, entryIndex) switch value := entry.Result.(type) { - case *protocol.GetStateKeysEntryMessage_Failure: - // the get state entry message is not failable so this should - // never happen - return nil, fmt.Errorf("[%d] %s", value.Failure.Code, value.Failure.Message) case *protocol.GetStateKeysEntryMessage_Value: values := make([]string, 0, len(value.Value.Keys)) for _, key := range value.Value.Keys { values = append(values, string(key)) } - return values, nil + return values + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("get state keys entry had invalid result: %v", entry.Result))) } - - return nil, nil } func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { @@ -297,9 +300,9 @@ func (m *Machine) run(fn func(restate.RunContext) ([]byte, error)) ([]byte, erro case nil: // Empty result is valid return nil, nil + default: + panic(m.newProtocolViolation(entry, fmt.Errorf("run entry had invalid result: %v", entry.Result))) } - - return nil, restate.TerminalError(fmt.Errorf("run entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } type runContext struct { diff --git a/options.go b/options.go new file mode 100644 index 0000000..dcdd34d --- /dev/null +++ b/options.go @@ -0,0 +1,48 @@ +package restate + +import "github.com/restatedev/sdk-go/encoding" + +type withCodec struct { + codec encoding.Codec +} + +var _ GetOption = withCodec{} +var _ SetOption = withCodec{} +var _ RunOption = withCodec{} +var _ AwakeableOption = withCodec{} +var _ ResolveAwakeableOption = withCodec{} +var _ CallOption = withCodec{} + +func (w withCodec) beforeGet(opts *getOptions) { opts.codec = w.codec } +func (w withCodec) beforeSet(opts *setOptions) { opts.codec = w.codec } +func (w withCodec) beforeRun(opts *runOptions) { opts.codec = w.codec } +func (w withCodec) beforeAwakeable(opts *awakeableOptions) { opts.codec = w.codec } +func (w withCodec) beforeResolveAwakeable(opts *resolveAwakeableOptions) { opts.codec = w.codec } +func (w withCodec) beforeCall(opts *callOptions) { opts.codec = w.codec } + +func WithCodec(codec encoding.Codec) withCodec { + return withCodec{codec} +} + +type withPayloadCodec struct { + withCodec + codec encoding.PayloadCodec +} + +var _ ServiceHandlerOption = withPayloadCodec{} +var _ ServiceRouterOption = withPayloadCodec{} +var _ ObjectHandlerOption = withPayloadCodec{} +var _ ObjectRouterOption = withPayloadCodec{} + +func (w withPayloadCodec) beforeServiceHandler(opts *serviceHandlerOptions) { opts.codec = w.codec } +func (w withPayloadCodec) beforeObjectHandler(opts *objectHandlerOptions) { opts.codec = w.codec } +func (w withPayloadCodec) beforeServiceRouter(opts *serviceRouterOptions) { + opts.defaultCodec = w.codec +} +func (w withPayloadCodec) beforeObjectRouter(opts *objectRouterOptions) { opts.defaultCodec = w.codec } + +func WithPayloadCodec(codec encoding.PayloadCodec) withPayloadCodec { + return withPayloadCodec{withCodec{codec}, codec} +} + +var WithProto = WithPayloadCodec(encoding.ProtoCodec{}) diff --git a/reflect.go b/reflect.go index 6469f0a..8dd146c 100644 --- a/reflect.go +++ b/reflect.go @@ -1,12 +1,11 @@ package restate import ( - "encoding/json" "fmt" + "net/http" "reflect" "github.com/restatedev/sdk-go/encoding" - "google.golang.org/protobuf/proto" ) type serviceNamer interface { @@ -16,19 +15,19 @@ type serviceNamer interface { var ( typeOfContext = reflect.TypeOf((*Context)(nil)).Elem() typeOfObjectContext = reflect.TypeOf((*ObjectContext)(nil)).Elem() - typeOfVoid = reflect.TypeOf((*Void)(nil)) - typeOfError = reflect.TypeOf((*error)(nil)) + typeOfVoid = reflect.TypeOf((*Void)(nil)).Elem() + typeOfError = reflect.TypeOf((*error)(nil)).Elem() ) // Object converts a struct with methods into a Virtual Object where each correctly-typed // and exported method of the struct will become a handler on the Object. The Object name defaults // to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. // The handler name is the name of the method. Handler methods should be of the type `ObjectHandlerFn[I, O]`. -// Input types I will be deserialised from JSON except when they are restate.Void, -// in which case no input bytes or content type may be sent. Output types O will be serialised -// to JSON except when they are restate.Void, in which case no data will be sent and no content type -// set. -func Object(object any) *ObjectRouter { +// Input types I will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no input bytes or content type may be sent. +// Output types O will be serialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no data will be sent and no content type set. +func Object(object any, options ...ObjectRouterOption) *ObjectRouter { typ := reflect.TypeOf(object) val := reflect.ValueOf(object) var name string @@ -37,7 +36,7 @@ func Object(object any) *ObjectRouter { } else { name = reflect.Indirect(val).Type().Name() } - router := NewObjectRouter(name) + router := NewObjectRouter(name, options...) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -66,12 +65,28 @@ func Object(object any) *ObjectRouter { continue } + input := mtype.In(2) + output := mtype.Out(0) + + var codec encoding.PayloadCodec + switch { + case input == typeOfVoid && output == typeOfVoid: + codec = encoding.VoidCodec{} + case input == typeOfVoid: + codec = encoding.PairCodec{Input: encoding.VoidCodec{}, Output: nil} + case output == typeOfVoid: + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec{}} + default: + codec = nil + } + router.Handler(mname, &objectReflectHandler{ + objectHandlerOptions{codec}, reflectHandler{ fn: method.Func, receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + input: input, + output: output, }, }) } @@ -83,11 +98,11 @@ func Object(object any) *ObjectRouter { // and exported method of the struct will become a handler on the Service. The Service name defaults // to the name of the struct, but this can be overidden by providing a `ServiceName() string` method. // The handler name is the name of the method. Handler methods should be of the type `ServiceHandlerFn[I, O]`. -// Input types I will be deserialised from JSON except when they are restate.Void, -// in which case no input bytes or content type may be sent. Output types O will be serialised -// to JSON except when they are restate.Void, in which case no data will be sent and no content type -// set. -func Service(service any) *ServiceRouter { +// Input types I will be deserialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no input bytes or content type may be sent. +// Output types O will be serialised with the provided codec (defaults to JSON) except when they are restate.Void, +// in which case no data will be sent and no content type set. +func Service(service any, options ...ServiceRouterOption) *ServiceRouter { typ := reflect.TypeOf(service) val := reflect.ValueOf(service) var name string @@ -96,7 +111,7 @@ func Service(service any) *ServiceRouter { } else { name = reflect.Indirect(val).Type().Name() } - router := NewServiceRouter(name) + router := NewServiceRouter(name, options...) for m := 0; m < typ.NumMethod(); m++ { method := typ.Method(m) @@ -127,12 +142,28 @@ func Service(service any) *ServiceRouter { continue } + input := mtype.In(2) + output := mtype.Out(0) + + var codec encoding.PayloadCodec + switch { + case input == typeOfVoid && output == typeOfVoid: + codec = encoding.VoidCodec{} + case input == typeOfVoid: + codec = encoding.PairCodec{Input: encoding.VoidCodec{}, Output: nil} + case output == typeOfVoid: + codec = encoding.PairCodec{Input: nil, Output: encoding.VoidCodec{}} + default: + codec = nil + } + router.Handler(mname, &serviceReflectHandler{ + serviceHandlerOptions{codec: codec}, reflectHandler{ fn: method.Func, receiver: val, - input: mtype.In(2), - output: mtype.Out(0), + input: input, + output: output, }, }) } @@ -147,42 +178,20 @@ type reflectHandler struct { output reflect.Type } -func (h *reflectHandler) InputPayload() *encoding.InputPayload { - if h.input == typeOfVoid { - return &encoding.InputPayload{} - } else { - return &encoding.InputPayload{ - Required: true, - ContentType: proto.String("application/json"), - } - } -} - -func (h *reflectHandler) OutputPayload() *encoding.OutputPayload { - if h.output == typeOfVoid { - return &encoding.OutputPayload{} - } else { - return &encoding.OutputPayload{ - ContentType: proto.String("application/json"), - } - } -} - func (h *reflectHandler) sealed() {} type objectReflectHandler struct { + options objectHandlerOptions reflectHandler } -var _ Handler = (*objectReflectHandler)(nil) +var _ ObjectHandler = (*objectReflectHandler)(nil) func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if h.input != typeOfVoid { - if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err)) - } + if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } // we are sure about the fn signature so it's safe to do this @@ -198,11 +207,7 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return nil, errI.(error) } - if h.output == typeOfVoid { - return nil, nil - } - - bytes, err := json.Marshal(outI) + bytes, err := h.options.codec.Marshal(outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -210,17 +215,30 @@ func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, er return bytes, nil } +func (h *objectReflectHandler) getOptions() *objectHandlerOptions { + return &h.options +} + +func (h *objectReflectHandler) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *objectReflectHandler) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} + type serviceReflectHandler struct { + options serviceHandlerOptions reflectHandler } -var _ Handler = (*serviceReflectHandler)(nil) +var _ ServiceHandler = (*serviceReflectHandler)(nil) func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { input := reflect.New(h.input) - if err := json.Unmarshal(bytes, input.Interface()); err != nil { - return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) + if err := h.options.codec.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request could not be decoded into handler input type: %w", err), http.StatusBadRequest) } // we are sure about the fn signature so it's safe to do this @@ -236,10 +254,22 @@ func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) return nil, errI.(error) } - bytes, err := json.Marshal(outI) + bytes, err := h.options.codec.Marshal(outI) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } return bytes, nil } + +func (h *serviceReflectHandler) getOptions() *serviceHandlerOptions { + return &h.options +} + +func (h *serviceReflectHandler) InputPayload() *encoding.InputPayload { + return h.options.codec.InputPayload() +} + +func (h *serviceReflectHandler) OutputPayload() *encoding.OutputPayload { + return h.options.codec.OutputPayload() +} diff --git a/router.go b/router.go index 436f44d..1663fe0 100644 --- a/router.go +++ b/router.go @@ -1,126 +1,16 @@ package restate import ( - "context" - "encoding/json" "fmt" - "log/slog" - "time" "github.com/restatedev/sdk-go/encoding" "github.com/restatedev/sdk-go/internal" - "github.com/restatedev/sdk-go/internal/futures" - "github.com/restatedev/sdk-go/internal/rand" ) var ( ErrKeyNotFound = fmt.Errorf("key not found") ) -type CallClient interface { - // Request makes a call and returns a handle on a future response - Request(input any) ResponseFuture -} - -type SendClient interface { - // Send makes a call in the background (doesn't wait for response) - Request(input any) error -} - -type ResponseFuture interface { - // Response blocks on the response to the call and unmarshals it into output - // It is *not* safe to call this in a goroutine - use Context.Selector if you - // want to wait on multiple results at once. - Response(output any) error - futures.Selectable -} - -type ServiceClient interface { - // Method creates a call to method with name - Method(method string) CallClient -} - -type ServiceSendClient interface { - // Method creates a call to method with name - Method(method string) SendClient -} - -type Selector interface { - Remaining() bool - Select() futures.Selectable -} - -type Context interface { - RunContext - - // Rand returns a random source which will give deterministic results for a given invocation - // The source wraps the stdlib rand.Rand but with some extra helper methods - // This source is not safe for use inside .Run() - Rand() *rand.Rand - - // Sleep for the duration d - Sleep(d time.Duration) - // After is an alternative to Context.Sleep which allows you to complete other tasks concurrently - // with the sleep. This is particularly useful when combined with Context.Selector to race between - // the sleep and other Selectable operations. - After(d time.Duration) After - - // Service gets a Service accessor by name where service - // must be another service known by restate runtime - Service(service string) ServiceClient - // Service gets a Service send accessor by name where service - // must be another service known by restate runtime - // and delay is the duration with which to delay requests - ServiceSend(service string, delay time.Duration) ServiceSendClient - - // Object gets a Object accessor by name where object - // must be another object known by restate runtime and - // key is any string representing the key for the object - Object(object, key string) ServiceClient - // Object gets a Object accessor by name where object - // must be another object known by restate runtime, - // key is any string representing the key for the object, - // and delay is the duration with which to delay requests - ObjectSend(object, key string, delay time.Duration) ServiceSendClient - - // Run runs the function (fn), storing final results (including terminal errors) - // durably in the journal, or otherwise for transient errors stopping execution - // so Restate can retry the invocation. Replays will produce the same value, so - // all non-deterministic operations (eg, generating a unique ID) *must* happen - // inside Run blocks. - // Note: use the RunAs helper function to serialise non-[]byte return values - Run(fn func(RunContext) ([]byte, error)) ([]byte, error) - - // Awakeable returns a Restate awakeable; a 'promise' to a future - // value or error, that can be resolved or rejected by other services. - // Note: use the AwakeableAs helper function to deserialise the []byte value - Awakeable() Awakeable[[]byte] - // ResolveAwakeable allows an awakeable (not necessarily from this service) to be - // resolved with a particular value. - // Note: use the ResolveAwakeableAs helper function to provide a value to be serialised - ResolveAwakeable(id string, value []byte) - // ResolveAwakeable allows an awakeable (not necessarily from this service) to be - // rejected with a particular error. - RejectAwakeable(id string, reason error) - - // Selector returns an iterator over blocking Restate operations (sleep, call, awakeable) - // which allows you to safely run them in parallel. The Selector will store the order - // that things complete in durably inside Restate, so that on replay the same order - // can be used. This avoids non-determinism. It is *not* safe to use goroutines or channels - // outside of Context.Run functions, as they do not behave deterministically. - Selector(futs ...futures.Selectable) (Selector, error) -} - -// RunContext methods are the only methods safe to call from inside a .Run() -type RunContext interface { - context.Context - - // Log obtains a handle on a slog.Logger which already has some useful fields (invocationID and method) - // By default, this logger will not output messages if the invocation is currently replaying - // The log handler can be set with `.WithLogger()` on the server object - Log() *slog.Logger -} - // Router interface type Router interface { Name() string @@ -129,85 +19,36 @@ type Router interface { Handlers() map[string]Handler } -type ObjectHandler interface { - Call(ctx ObjectContext, request []byte) (output []byte, err error) - Handler -} - -type ServiceHandler interface { - Call(ctx Context, request []byte) (output []byte, err error) - Handler -} - -type Handler interface { - sealed() - InputPayload() *encoding.InputPayload - OutputPayload() *encoding.OutputPayload -} - -type ServiceType string - -const ( - ServiceType_VIRTUAL_OBJECT ServiceType = "VIRTUAL_OBJECT" - ServiceType_SERVICE ServiceType = "SERVICE" -) - -type KeyValueStore interface { - // Set sets key value to bytes array. You can - // Note: Use SetAs helper function to seamlessly store - // a value of specific type. - Set(key string, value []byte) - // Get gets value (bytes array) associated with key - // If key does not exist, this function return a nil bytes array - // and a nil error - // Note: Use GetAs helper function to seamlessly get value - // as specific type. - Get(key string) ([]byte, error) - // Clear deletes a key - Clear(key string) - // ClearAll drops all stored state associated with key - ClearAll() - // Keys returns a list of all associated key - Keys() ([]string, error) -} - -type ObjectContext interface { - Context - KeyValueStore - // Key retrieves the key for this virtual object invocation. This is a no-op and is - // always safe to call. - Key() string +type serviceRouterOptions struct { + defaultCodec encoding.PayloadCodec } -// ServiceHandlerFn signature of service (unkeyed) handler function -type ServiceHandlerFn[I any, O any] func(ctx Context, input I) (output O, err error) - -// ObjectHandlerFn signature for object (keyed) handler function -type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (output O, err error) - -type Decoder[I any] interface { - InputPayload() *encoding.InputPayload - Decode(data []byte) (input I, err error) -} - -type Encoder[O any] interface { - OutputPayload() *encoding.OutputPayload - Encode(output O) ([]byte, error) +type ServiceRouterOption interface { + beforeServiceRouter(*serviceRouterOptions) } // ServiceRouter implements Router type ServiceRouter struct { name string handlers map[string]Handler + options serviceRouterOptions } var _ Router = &ServiceRouter{} // NewServiceRouter creates a new ServiceRouter -func NewServiceRouter(name string) *ServiceRouter { +func NewServiceRouter(name string, options ...ServiceRouterOption) *ServiceRouter { + opts := serviceRouterOptions{} + for _, opt := range options { + opt.beforeServiceRouter(&opts) + } + if opts.defaultCodec == nil { + opts.defaultCodec = encoding.JSONCodec{} + } return &ServiceRouter{ name: name, handlers: make(map[string]Handler), + options: opts, } } @@ -217,6 +58,7 @@ func (r *ServiceRouter) Name() string { // Handler registers a new handler by name func (r *ServiceRouter) Handler(name string, handler ServiceHandler) *ServiceRouter { + handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) r.handlers[name] = handler return r } @@ -229,18 +71,35 @@ func (r *ServiceRouter) Type() internal.ServiceType { return internal.ServiceType_SERVICE } +type objectRouterOptions struct { + defaultCodec encoding.PayloadCodec +} + +type ObjectRouterOption interface { + beforeObjectRouter(*objectRouterOptions) +} + // ObjectRouter type ObjectRouter struct { name string handlers map[string]Handler + options objectRouterOptions } var _ Router = &ObjectRouter{} -func NewObjectRouter(name string) *ObjectRouter { +func NewObjectRouter(name string, options ...ObjectRouterOption) *ObjectRouter { + opts := objectRouterOptions{} + for _, opt := range options { + opt.beforeObjectRouter(&opts) + } + if opts.defaultCodec == nil { + opts.defaultCodec = encoding.JSONCodec{} + } return &ObjectRouter{ name: name, handlers: make(map[string]Handler), + options: opts, } } @@ -249,6 +108,7 @@ func (r *ObjectRouter) Name() string { } func (r *ObjectRouter) Handler(name string, handler ObjectHandler) *ObjectRouter { + handler.getOptions().codec = encoding.MergeCodec(handler.getOptions().codec, r.options.defaultCodec) r.handlers[name] = handler return r } @@ -260,113 +120,3 @@ func (r *ObjectRouter) Handlers() map[string]Handler { func (r *ObjectRouter) Type() internal.ServiceType { return internal.ServiceType_VIRTUAL_OBJECT } - -// GetAs helper function to get a key as specific type. Note that -// if there is no associated value with key, an error ErrKeyNotFound is -// returned -// it does encoding/decoding of bytes automatically using json -func GetAs[T any](ctx ObjectContext, key string) (output T, err error) { - bytes, err := ctx.Get(key) - if err != nil { - return output, err - } - - if bytes == nil { - // key does not exit. - return output, ErrKeyNotFound - } - - err = json.Unmarshal(bytes, &output) - - return -} - -// SetAs helper function to set a key value with a generic type T. -// it does encoding/decoding of bytes automatically using json -func SetAs[T any](ctx ObjectContext, key string, value T) error { - bytes, err := json.Marshal(value) - if err != nil { - return err - } - - ctx.Set(key, bytes) - return nil -} - -// RunAs helper function runs a run function with specific concrete type as a result -// it does encoding/decoding of bytes automatically using json -func RunAs[T any](ctx Context, fn func(RunContext) (T, error)) (output T, err error) { - bytes, err := ctx.Run(func(ctx RunContext) ([]byte, error) { - out, err := fn(ctx) - if err != nil { - return nil, err - } - - bytes, err := json.Marshal(out) - return bytes, TerminalError(err) - }) - - if err != nil { - return output, err - } - - err = json.Unmarshal(bytes, &output) - - return output, TerminalError(err) -} - -// Awakeable is the Go representation of a Restate awakeable; a 'promise' to a future -// value or error, that can be resolved or rejected by other services. -type Awakeable[T any] interface { - // Id returns the awakeable ID, which can be stored or sent to a another service - Id() string - // Result blocks on receiving the result of the awakeable, returning the value it was - // resolved with or the error it was rejected with. - // It is *not* safe to call this in a goroutine - use Context.Selector if you - // want to wait on multiple results at once. - Result() (T, error) - futures.Selectable -} - -type decodingAwakeable[T any] struct { - Awakeable[[]byte] -} - -func (d decodingAwakeable[T]) Id() string { return d.Awakeable.Id() } -func (d decodingAwakeable[T]) Result() (out T, err error) { - bytes, err := d.Awakeable.Result() - if err != nil { - return out, err - } - if err := json.Unmarshal(bytes, &out); err != nil { - return out, err - } - return -} - -// AwakeableAs helper function to treat awakeable values as a particular type. -// Bytes are deserialised as JSON -func AwakeableAs[T any](ctx Context) Awakeable[T] { - return decodingAwakeable[T]{Awakeable: ctx.Awakeable()} -} - -// ResolveAwakeableAs helper function to resolve an awakeable with a particular type -// The type will be serialised to bytes using JSON -func ResolveAwakeableAs[T any](ctx Context, id string, value T) error { - bytes, err := json.Marshal(value) - if err != nil { - return TerminalError(err) - } - ctx.ResolveAwakeable(id, bytes) - return nil -} - -// After is a handle on a Sleep operation which allows you to do other work concurrently -// with the sleep. -type After interface { - // Done blocks waiting on the remaining duration of the sleep. - // It is *not* safe to call this in a goroutine - use Context.Selector if you - // want to wait on multiple results at once. - Done() - futures.Selectable -}