diff --git a/README.md b/README.md index 2e0c1dc..45b2f78 100644 --- a/README.md +++ b/README.md @@ -5,16 +5,6 @@ [Restate](https://restate.dev/) is a system for easily building resilient applications using *distributed durable async/await*. This repository contains the Restate SDK for writing services in **Golang**. -This SDK is an individual effort to build a golang SDK for restate runtime. The implementation is based on the service protocol documentation found [here](https://github.com/restatedev/service-protocol/blob/main/service-invocation-protocol.md) and a lot of experimentation with the protocol. - -This means that it's not granted that this SDK matches exactly what `restate` has intended but it's a best effort interpretation of the docs - -Since **service discovery** was not documented (or at least I could not find any documentation for it), the implementation is based on reverse engineering the TypeScript SDK. - -This implementation of the SDK **ONLY** supports `dynrpc`. There is noway yet that you can define your service interface with `gRPC` - -Calling other services right now is done completely by name, hence it's not very safe since you can miss up arguments list/type for example but hopefully later on we can generate stubs or use `gRPC` interfaces to define services. - ## Features implemented - [x] Log replay (resume of execution on failure) @@ -22,9 +12,8 @@ Calling other services right now is done completely by name, hence it's not very - [x] Remote service call over restate runtime - [X] Delayed execution of remote services - [X] Sleep -- [x] Side effects - - Implementation might differ from as intended by restate since it's not documented and based on reverse engineering of the TypeScript SDK -- [ ] Awakeable +- [x] Run +- [x] Awakeable ## Basic usage @@ -58,14 +47,14 @@ In yet a third terminal do the following steps - Add tickets to basket ```bash -curl -v localhost:8080/UserSession/addTicket \ +curl -v localhost:8080/UserSession/azmy/AddTicket \ -H 'content-type: application/json' \ - -d '{"key": "azmy", "request": "ticket-1"}' + -d '"ticket-1"' # {"response":true} -curl -v localhost:8080/UserSession/addTicket \ +curl -v localhost:8080/UserSession/azmy/AddTicket \ -H 'content-type: application/json' \ - -d '{"key": "azmy", "request": "ticket-2"}' + -d '"ticket-2"' # {"response":true} ``` @@ -74,8 +63,8 @@ Trying adding the same tickets again should return `false` since they are alread Finally checkout ```bash -curl localhost:8080/UserSession/checkout \ +curl localhost:8080/UserSession/azmy/Checkout \ -H 'content-type: application/json' \ - -d '{"key": "azmy", "request": null}' + -d 'null' #{"response":true} ``` diff --git a/example/checkout.go b/example/checkout.go index 5d18d5d..4ed1cde 100644 --- a/example/checkout.go +++ b/example/checkout.go @@ -1,7 +1,9 @@ package main import ( + "context" "fmt" + "math/rand" "github.com/google/uuid" restate "github.com/restatedev/sdk-go" @@ -18,8 +20,16 @@ type PaymentResponse struct { Price int `json:"price"` } -func payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) { - uuid, err := restate.SideEffectAs(ctx, func() (string, error) { +type checkout struct{} + +func (c *checkout) Name() string { + return CheckoutServiceName +} + +const CheckoutServiceName = "Checkout" + +func (c *checkout) Payment(ctx restate.Context, request PaymentRequest) (response PaymentResponse, err error) { + uuid, err := restate.RunAs(ctx, func(ctx context.Context) (string, error) { uuid := uuid.New() return uuid.String(), nil }) @@ -35,17 +45,15 @@ func payment(ctx restate.Context, request PaymentRequest) (response PaymentRespo price := len(request.Tickets) * 30 response.Price = price - i := 0 - _, err = restate.SideEffectAs(ctx, func() (bool, error) { + _, err = restate.RunAs(ctx, func(ctx context.Context) (bool, error) { log := log.With().Str("uuid", uuid).Int("price", price).Logger() - if i > 2 { + if rand.Float64() < 0.5 { log.Info().Msg("payment succeeded") return true, nil + } else { + log.Error().Msg("payment failed") + return false, fmt.Errorf("failed to pay") } - - log.Error().Msg("payment failed") - i += 1 - return false, fmt.Errorf("failed to pay") }) if err != nil { @@ -56,8 +64,3 @@ func payment(ctx restate.Context, request PaymentRequest) (response PaymentRespo return response, nil } - -var ( - Checkout = restate.NewServiceRouter(). - Handler("checkout", restate.NewServiceHandler(payment)) -) diff --git a/example/main.go b/example/main.go index 6d09794..c8a5e13 100644 --- a/example/main.go +++ b/example/main.go @@ -4,26 +4,21 @@ import ( "context" "os" + restate "github.com/restatedev/sdk-go" "github.com/restatedev/sdk-go/server" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) -const ( - UserSessionServiceName = "UserSession" - TicketServiceName = "TicketService" - CheckoutServiceName = "Checkout" -) - func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) zerolog.SetGlobalLevel(zerolog.InfoLevel) server := server.NewRestate(). - Bind(UserSessionServiceName, UserSession). - Bind(TicketServiceName, TicketService). - Bind(CheckoutServiceName, Checkout) + Bind(restate.Object(&userSession{})). + Bind(restate.Object(&ticketService{})). + Bind(restate.Service(&checkout{})) if err := server.Start(context.Background(), ":9080"); err != nil { log.Error().Err(err).Msg("application exited unexpectedly") diff --git a/example/ticket_service.go b/example/ticket_service.go index 98c6cf9..8280e88 100644 --- a/example/ticket_service.go +++ b/example/ticket_service.go @@ -15,7 +15,13 @@ const ( TicketSold TicketStatus = 2 ) -func reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) { +const TicketServiceName = "TicketService" + +type ticketService struct{} + +func (t *ticketService) Name() string { return TicketServiceName } + +func (t *ticketService) Reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) { status, err := restate.GetAs[TicketStatus](ctx, "status") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return false, err @@ -28,7 +34,7 @@ func reserve(ctx restate.ObjectContext, _ restate.Void) (bool, error) { return false, nil } -func unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { +func (t *ticketService) Unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { ticketId := ctx.Key() log.Info().Str("ticket", ticketId).Msg("un-reserving ticket") status, err := restate.GetAs[TicketStatus](ctx, "status") @@ -44,7 +50,7 @@ func unreserve(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, er return void, nil } -func markAsSold(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { +func (t *ticketService) MarkAsSold(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, err error) { ticketId := ctx.Key() log.Info().Str("ticket", ticketId).Msg("mark ticket as sold") @@ -59,10 +65,3 @@ func markAsSold(ctx restate.ObjectContext, _ restate.Void) (void restate.Void, e return void, nil } - -var ( - TicketService = restate.NewObjectRouter(). - Handler("reserve", restate.NewObjectHandler(reserve)). - Handler("unreserve", restate.NewObjectHandler(unreserve)). - Handler("markAsSold", restate.NewObjectHandler(markAsSold)) -) diff --git a/example/user_session.go b/example/user_session.go index 02e2ae6..d8288d9 100644 --- a/example/user_session.go +++ b/example/user_session.go @@ -9,11 +9,19 @@ import ( "github.com/rs/zerolog/log" ) -func addTicket(ctx restate.ObjectContext, ticketId string) (bool, error) { +const UserSessionServiceName = "UserSession" + +type userSession struct{} + +func (u *userSession) Name() string { + return UserSessionServiceName +} + +func (u *userSession) AddTicket(ctx restate.ObjectContext, ticketId string) (bool, error) { userId := ctx.Key() var success bool - if err := ctx.Object(TicketServiceName, ticketId).Method("reserve").Request(userId).Response(&success); err != nil { + if err := ctx.Object(TicketServiceName, ticketId).Method("Reserve").Request(userId).Response(&success); err != nil { return false, err } @@ -34,14 +42,14 @@ func addTicket(ctx restate.ObjectContext, ticketId string) (bool, error) { return false, err } - if err := ctx.ObjectSend(UserSessionServiceName, ticketId, 15*time.Minute).Method("expireTicket").Request(ticketId); err != nil { + if err := ctx.ObjectSend(UserSessionServiceName, ticketId, 15*time.Minute).Method("ExpireTicket").Request(ticketId); err != nil { return false, err } return true, nil } -func expireTicket(ctx restate.ObjectContext, ticketId string) (void restate.Void, err error) { +func (u *userSession) ExpireTicket(ctx restate.ObjectContext, ticketId string) (void restate.Void, err error) { tickets, err := restate.GetAs[[]string](ctx, "tickets") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { return void, err @@ -63,10 +71,10 @@ func expireTicket(ctx restate.ObjectContext, ticketId string) (void restate.Void return void, err } - return void, ctx.ObjectSend(TicketServiceName, ticketId, 0).Method("unreserve").Request(nil) + return void, ctx.ObjectSend(TicketServiceName, ticketId, 0).Method("Unreserve").Request(nil) } -func checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { +func (u *userSession) Checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { userId := ctx.Key() tickets, err := restate.GetAs[[]string](ctx, "tickets") if err != nil && !errors.Is(err, restate.ErrKeyNotFound) { @@ -81,7 +89,7 @@ func checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { var response PaymentResponse if err := ctx.Object(CheckoutServiceName, ""). - Method("checkout"). + Method("Payment"). Request(PaymentRequest{UserID: userId, Tickets: tickets}). Response(&response); err != nil { return false, err @@ -90,7 +98,7 @@ func checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { log.Info().Str("id", response.ID).Int("price", response.Price).Msg("payment details") for _, ticket := range tickets { - call := ctx.ObjectSend(TicketServiceName, ticket, 0).Method("markAsSold") + call := ctx.ObjectSend(TicketServiceName, ticket, 0).Method("MarkAsSold") if err := call.Request(nil); err != nil { return false, err } @@ -99,10 +107,3 @@ func checkout(ctx restate.ObjectContext, _ restate.Void) (bool, error) { ctx.Clear("tickets") return true, nil } - -var ( - UserSession = restate.NewObjectRouter(). - Handler("addTicket", restate.NewObjectHandler(addTicket)). - Handler("expireTicket", restate.NewObjectHandler(expireTicket)). - Handler("checkout", restate.NewObjectHandler(checkout)) -) diff --git a/handler.go b/handler.go index d6b0220..d27c4a3 100644 --- a/handler.go +++ b/handler.go @@ -3,7 +3,6 @@ package restate import ( "encoding/json" "fmt" - "reflect" ) // Void is a placeholder used usually for functions that their signature require that @@ -19,44 +18,37 @@ func (v *Void) UnmarshalJSON(_ []byte) error { return nil } -type ServiceHandler struct { - fn reflect.Value - input reflect.Type - output reflect.Type +type serviceHandler[I any, O any] struct { + fn ServiceHandlerFn[I, O] } // NewServiceHandler create a new handler for a service -func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *ServiceHandler { - return &ServiceHandler{ - fn: reflect.ValueOf(fn), - input: reflect.TypeFor[I](), - output: reflect.TypeFor[O](), +func NewServiceHandler[I any, O any](fn ServiceHandlerFn[I, O]) *serviceHandler[I, O] { + return &serviceHandler[I, O]{ + fn: fn, } } -func (h *ServiceHandler) Call(ctx Context, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) +func (h *serviceHandler[I, O]) Call(ctx Context, bytes []byte) ([]byte, error) { + input := new(I) if len(bytes) > 0 { // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input.Interface()); err != nil { + if err := json.Unmarshal(bytes, input); err != nil { return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) } } // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - reflect.ValueOf(ctx), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) + output, err := h.fn( + ctx, + *input, + ) + if err != nil { + return nil, err } - bytes, err := json.Marshal(outI) + bytes, err = json.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -64,45 +56,37 @@ func (h *ServiceHandler) Call(ctx Context, bytes []byte) ([]byte, error) { return bytes, nil } -func (h *ServiceHandler) sealed() {} +func (h *serviceHandler[I, O]) sealed() {} -type ObjectHandler struct { - fn reflect.Value - input reflect.Type - output reflect.Type +type objectHandler[I any, O any] struct { + fn ObjectHandlerFn[I, O] } -func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O]) *ObjectHandler { - return &ObjectHandler{ - fn: reflect.ValueOf(fn), - input: reflect.TypeFor[I](), - output: reflect.TypeFor[O](), +func NewObjectHandler[I any, O any](fn ObjectHandlerFn[I, O]) *objectHandler[I, O] { + return &objectHandler[I, O]{ + fn: fn, } } -func (h *ObjectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { - input := reflect.New(h.input) +func (h *objectHandler[I, O]) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { + input := new(I) if len(bytes) > 0 { // use the zero value if there is no input data at all - if err := json.Unmarshal(bytes, input.Interface()); err != nil { + if err := json.Unmarshal(bytes, input); err != nil { return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) } } - // we are sure about the fn signature so it's safe to do this - output := h.fn.Call([]reflect.Value{ - reflect.ValueOf(ctx), - input.Elem(), - }) - - outI := output[0].Interface() - errI := output[1].Interface() - if errI != nil { - return nil, errI.(error) + output, err := h.fn( + ctx, + *input, + ) + if err != nil { + return nil, err } - bytes, err := json.Marshal(outI) + bytes, err = json.Marshal(output) if err != nil { return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) } @@ -110,4 +94,4 @@ func (h *ObjectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { return bytes, nil } -func (h *ObjectHandler) sealed() {} +func (h *objectHandler[I, O]) sealed() {} diff --git a/internal/state/state.go b/internal/state/state.go index 929f829..4cf43c7 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -103,8 +103,8 @@ func (c *Context) ObjectSend(service, key string, delay time.Duration) restate.S } } -func (c *Context) SideEffect(fn func() ([]byte, error)) ([]byte, error) { - return c.machine.sideEffect(fn) +func (c *Context) Run(fn func(ctx context.Context) ([]byte, error)) ([]byte, error) { + return c.machine.run(fn) } func (c *Context) Awakeable() restate.Awakeable[[]byte] { @@ -266,8 +266,8 @@ The journal entry at position %d was: }) return - case *sideEffectFailure: - m.log.Error().Err(typ.err).Msg("Side effect returned a failure, returning error to Restate") + case *runFailure: + m.log.Error().Err(typ.err).Msg("Run returned a failure, returning error to Restate") if err := m.protocol.Write(&wire.ErrorMessage{ ErrorMessage: protocol.ErrorMessage{ @@ -332,7 +332,15 @@ The journal entry at position %d was: return m.protocol.Write(&wire.EndMessage{}) } - bytes, err := m.handler.Call(ctx, input) + var bytes []byte + var err error + switch handler := m.handler.(type) { + case restate.ObjectHandler: + bytes, err = handler.Call(ctx, input) + case restate.ServiceHandler: + bytes, err = handler.Call(ctx, input) + } + if err != nil { m.log.Error().Err(err).Msg("failure") } diff --git a/internal/state/sys.go b/internal/state/sys.go index 6564eaf..45db016 100644 --- a/internal/state/sys.go +++ b/internal/state/sys.go @@ -2,6 +2,7 @@ package state import ( "bytes" + "context" "fmt" "sort" "time" @@ -243,7 +244,7 @@ func (m *Machine) _keys() *wire.GetStateKeysEntryMessage { return msg } -func (m *Machine) after(d time.Duration) restate.After { +func (m *Machine) after(d time.Duration) *futures.After { entry, entryIndex := replayOrNew( m, func(entry *wire.SleepEntryMessage) *wire.SleepEntryMessage { @@ -273,18 +274,18 @@ func (m *Machine) _sleep(d time.Duration) *wire.SleepEntryMessage { return msg } -func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { +func (m *Machine) run(fn func(context.Context) ([]byte, error)) ([]byte, error) { entry, entryIndex := replayOrNew( m, func(entry *wire.RunEntryMessage) *wire.RunEntryMessage { return entry }, func() *wire.RunEntryMessage { - return m._sideEffect(fn) + return m._run(fn) }, ) - // side effect must be acknowledged before proceeding + // run entry must be acknowledged before proceeding entry.Await(m.suspensionCtx, entryIndex) switch result := entry.Result.(type) { @@ -297,11 +298,11 @@ func (m *Machine) sideEffect(fn func() ([]byte, error)) ([]byte, error) { return nil, nil } - return nil, restate.TerminalError(fmt.Errorf("side effect entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) + return nil, restate.TerminalError(fmt.Errorf("run entry had invalid result: %v", entry.Result), errors.ErrProtocolViolation) } -func (m *Machine) _sideEffect(fn func() ([]byte, error)) *wire.RunEntryMessage { - bytes, err := fn() +func (m *Machine) _run(fn func(context.Context) ([]byte, error)) *wire.RunEntryMessage { + bytes, err := fn(m.ctx) if err != nil { if restate.IsTerminalError(err) { @@ -319,7 +320,7 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) *wire.RunEntryMessage { return msg } else { - panic(m.newSideEffectFailure(err)) + panic(m.newRunFailure(err)) } } else { msg := &wire.RunEntryMessage{ @@ -335,13 +336,13 @@ func (m *Machine) _sideEffect(fn func() ([]byte, error)) *wire.RunEntryMessage { } } -type sideEffectFailure struct { +type runFailure struct { entryIndex uint32 err error } -func (m *Machine) newSideEffectFailure(err error) *sideEffectFailure { - s := &sideEffectFailure{m.entryIndex, err} +func (m *Machine) newRunFailure(err error) *runFailure { + s := &runFailure{m.entryIndex, err} m.failure = s return s } diff --git a/internal/wire/wire.go b/internal/wire/wire.go index 0eb048e..85967b8 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -374,7 +374,7 @@ var ( RunEntryMessageType: func(header Header, bytes []byte) (Message, error) { msg := &RunEntryMessage{} - // replayed side effects are inherently acked + // replayed run entries are inherently acked msg.Ack() return msg, proto.Unmarshal(bytes, msg) diff --git a/reflect.go b/reflect.go new file mode 100644 index 0000000..93e1e70 --- /dev/null +++ b/reflect.go @@ -0,0 +1,197 @@ +package restate + +import ( + "encoding/json" + "fmt" + "reflect" +) + +type serviceNamer interface { + Name() string +} + +var ( + typeOfContext = reflect.TypeFor[Context]() + typeOfObjectContext = reflect.TypeFor[ObjectContext]() + typeOfError = reflect.TypeFor[error]() +) + +func Object(object any) *ObjectRouter { + typ := reflect.TypeOf(object) + val := reflect.ValueOf(object) + var name string + if sn, ok := object.(serviceNamer); ok { + name = sn.Name() + } else { + name = reflect.Indirect(val).Type().Name() + } + router := NewObjectRouter(name) + + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + mtype := method.Type + mname := method.Name + // Method must be exported. + if !method.IsExported() { + continue + } + // Method needs three ins: receiver, ObjectContext, I + if mtype.NumIn() != 3 { + continue + } + + if ctxType := mtype.In(1); ctxType != typeOfObjectContext { + continue + } + + // Method needs two outs: O, and error + if mtype.NumOut() != 2 { + continue + } + + // The second return type of the method must be error. + if returnType := mtype.Out(1); returnType != typeOfError { + continue + } + + router.Handler(mname, &objectReflectHandler{ + fn: method.Func, + receiver: val, + input: mtype.In(2), + output: mtype.Out(0), + }) + } + + return router +} + +func Service(service any) *ServiceRouter { + typ := reflect.TypeOf(service) + val := reflect.ValueOf(service) + var name string + if sn, ok := service.(serviceNamer); ok { + name = sn.Name() + } else { + name = reflect.Indirect(val).Type().Name() + } + router := NewServiceRouter(name) + + for m := 0; m < typ.NumMethod(); m++ { + method := typ.Method(m) + + mtype := method.Type + mname := method.Name + // Method must be exported. + if !method.IsExported() { + continue + } + + // Method needs three ins: receiver, Context, I + if mtype.NumIn() != 3 { + continue + } + + if ctxType := mtype.In(1); ctxType != typeOfContext { + continue + } + + // Method needs two outs: O, and error + if mtype.NumOut() != 2 { + continue + } + + // The second return type of the method must be error. + if returnType := mtype.Out(1); returnType != typeOfError { + continue + } + + router.Handler(mname, &serviceReflectHandler{ + fn: method.Func, + receiver: val, + input: mtype.In(2), + output: mtype.Out(0), + }) + } + + return router +} + +type objectReflectHandler struct { + fn reflect.Value + receiver reflect.Value + input reflect.Type + output reflect.Type +} + +func (h *objectReflectHandler) Call(ctx ObjectContext, bytes []byte) ([]byte, error) { + input := reflect.New(h.input) + + if len(bytes) > 0 { + // use the zero value if there is no input data at all + if err := json.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) + } + } + + // we are sure about the fn signature so it's safe to do this + output := h.fn.Call([]reflect.Value{ + h.receiver, + reflect.ValueOf(ctx), + input.Elem(), + }) + + outI := output[0].Interface() + errI := output[1].Interface() + if errI != nil { + return nil, errI.(error) + } + + bytes, err := json.Marshal(outI) + if err != nil { + return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + } + + return bytes, nil +} + +func (h *objectReflectHandler) sealed() {} + +type serviceReflectHandler struct { + fn reflect.Value + receiver reflect.Value + input reflect.Type + output reflect.Type +} + +func (h *serviceReflectHandler) Call(ctx Context, bytes []byte) ([]byte, error) { + input := reflect.New(h.input) + + if len(bytes) > 0 { + // use the zero value if there is no input data at all + if err := json.Unmarshal(bytes, input.Interface()); err != nil { + return nil, TerminalError(fmt.Errorf("request doesn't match handler signature: %w", err)) + } + } + + // we are sure about the fn signature so it's safe to do this + output := h.fn.Call([]reflect.Value{ + h.receiver, + reflect.ValueOf(ctx), + input.Elem(), + }) + + outI := output[0].Interface() + errI := output[1].Interface() + if errI != nil { + return nil, errI.(error) + } + + bytes, err := json.Marshal(outI) + if err != nil { + return nil, TerminalError(fmt.Errorf("failed to serialize output: %w", err)) + } + + return bytes, nil +} + +func (h *serviceReflectHandler) sealed() {} diff --git a/router.go b/router.go index 8a97c4f..7bdffb4 100644 --- a/router.go +++ b/router.go @@ -72,11 +72,11 @@ type Context interface { // and delay is the duration with which to delay requests ObjectSend(object, key string, delay time.Duration) ServiceSendClient - // SideEffects runs the function (fn) until it succeeds or permanently fails. + // Run runs the function (fn) until it succeeds or permanently fails. // this stores the results of the function inside restate runtime so a replay // will produce the same value (think generating a unique id for example) - // Note: use the SideEffectAs helper function - SideEffect(fn func() ([]byte, error)) ([]byte, error) + // Note: use the RunAs helper function + Run(fn func(ctx context.Context) ([]byte, error)) ([]byte, error) Awakeable() Awakeable[[]byte] ResolveAwakeable(id string, value []byte) @@ -87,13 +87,23 @@ type Context interface { // Router interface type Router interface { + Name() string Type() internal.ServiceType // Set of handlers associated with this router Handlers() map[string]Handler } -type Handler interface { +type ObjectHandler interface { + Call(ctx ObjectContext, request []byte) (output []byte, err error) + Handler +} + +type ServiceHandler interface { Call(ctx Context, request []byte) (output []byte, err error) + Handler +} + +type Handler interface { sealed() } @@ -104,25 +114,6 @@ const ( ServiceType_SERVICE ServiceType = "SERVICE" ) -type ObjectHandlerWrapper struct { - h *ObjectHandler -} - -func (o ObjectHandlerWrapper) Call(ctx Context, request []byte) ([]byte, error) { - switch ctx := ctx.(type) { - case ObjectContext: - return o.h.Call(ctx, request) - default: - panic("Object handler called with context that doesn't implement ObjectContext") - } -} - -func (ObjectHandlerWrapper) sealed() {} - -type ServiceHandlerWrapper struct { - h ServiceHandler -} - type KeyValueStore interface { // Set sets key value to bytes array. You can // Note: Use SetAs helper function to seamlessly store @@ -156,20 +147,26 @@ type ObjectHandlerFn[I any, O any] func(ctx ObjectContext, input I) (output O, e // ServiceRouter implements Router type ServiceRouter struct { + name string handlers map[string]Handler } var _ Router = &ServiceRouter{} // NewServiceRouter creates a new ServiceRouter -func NewServiceRouter() *ServiceRouter { +func NewServiceRouter(name string) *ServiceRouter { return &ServiceRouter{ + name: name, handlers: make(map[string]Handler), } } +func (r *ServiceRouter) Name() string { + return r.name +} + // Handler registers a new handler by name -func (r *ServiceRouter) Handler(name string, handler *ServiceHandler) *ServiceRouter { +func (r *ServiceRouter) Handler(name string, handler ServiceHandler) *ServiceRouter { r.handlers[name] = handler return r } @@ -184,19 +181,25 @@ func (r *ServiceRouter) Type() internal.ServiceType { // ObjectRouter type ObjectRouter struct { + name string handlers map[string]Handler } var _ Router = &ObjectRouter{} -func NewObjectRouter() *ObjectRouter { +func NewObjectRouter(name string) *ObjectRouter { return &ObjectRouter{ + name: name, handlers: make(map[string]Handler), } } -func (r *ObjectRouter) Handler(name string, handler *ObjectHandler) *ObjectRouter { - r.handlers[name] = ObjectHandlerWrapper{h: handler} +func (r *ObjectRouter) Name() string { + return r.name +} + +func (r *ObjectRouter) Handler(name string, handler ObjectHandler) *ObjectRouter { + r.handlers[name] = handler return r } @@ -240,11 +243,11 @@ func SetAs[T any](ctx ObjectContext, key string, value T) error { return nil } -// SideEffectAs helper function runs a side effect function with specific concrete type as a result +// RunAs helper function runs a run function with specific concrete type as a result // it does encoding/decoding of bytes automatically using msgpack -func SideEffectAs[T any](ctx Context, fn func() (T, error)) (output T, err error) { - bytes, err := ctx.SideEffect(func() ([]byte, error) { - out, err := fn() +func RunAs[T any](ctx Context, fn func(context.Context) (T, error)) (output T, err error) { + bytes, err := ctx.Run(func(ctx context.Context) ([]byte, error) { + out, err := fn(ctx) if err != nil { return nil, err } diff --git a/server/restate.go b/server/restate.go index c8b0256..6d4d2ff 100644 --- a/server/restate.go +++ b/server/restate.go @@ -50,14 +50,14 @@ func NewRestate() *Restate { } } -func (r *Restate) Bind(name string, router restate.Router) *Restate { - if _, ok := r.routers[name]; ok { +func (r *Restate) Bind(router restate.Router) *Restate { + if _, ok := r.routers[router.Name()]; ok { // panic because this is a programming error // to register multiple router with the same name panic("router with the same name exists") } - r.routers[name] = router + r.routers[router.Name()] = router return r }