From adae520683bb9ba166e7535bf18ad5c0df8074d9 Mon Sep 17 00:00:00 2001 From: Jack Kleeman Date: Fri, 9 Aug 2024 13:38:17 +0100 Subject: [PATCH] Support passing headers in calls (#17) --- internal/options/options.go | 3 +- internal/state/call.go | 57 +++++++++++++++++++++++++++++++------ internal/state/state.go | 4 +-- options.go | 17 ++++++++++- 4 files changed, 69 insertions(+), 12 deletions(-) diff --git a/internal/options/options.go b/internal/options/options.go index 0aeaa02..f278ae3 100644 --- a/internal/options/options.go +++ b/internal/options/options.go @@ -35,7 +35,8 @@ type SetOption interface { } type CallOptions struct { - Codec encoding.Codec + Codec encoding.Codec + Headers map[string]string } type CallOption interface { diff --git a/internal/state/call.go b/internal/state/call.go index 7a2564e..16e67a9 100644 --- a/internal/state/call.go +++ b/internal/state/call.go @@ -2,7 +2,9 @@ package state import ( "bytes" + "cmp" "fmt" + "slices" "time" restate "github.com/restatedev/sdk-go" @@ -29,7 +31,7 @@ func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) { return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err)) } - entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, bytes) + entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, c.options.Headers, bytes) return decodingResponseFuture{ futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }), @@ -70,22 +72,26 @@ func (c *serviceCall) Send(input any, delay time.Duration) error { if err != nil { return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err)) } - c.machine.sendCall(c.service, c.key, c.method, bytes, delay) + c.machine.sendCall(c.service, c.key, c.method, c.options.Headers, bytes, delay) return nil } -func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) { +func (m *Machine) doCall(service, key, method string, headersMap map[string]string, params []byte) (*wire.CallEntryMessage, uint32) { + headers := headersToProto(headersMap) + entry, entryIndex := replayOrNew( m, func(entry *wire.CallEntryMessage) *wire.CallEntryMessage { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || + !headersEqual(entry.Headers, headers) || !bytes.Equal(entry.Parameter, params) { panic(m.newEntryMismatch(&wire.CallEntryMessage{ CallEntryMessage: protocol.CallEntryMessage{ ServiceName: service, HandlerName: method, + Headers: headers, Parameter: params, Key: key, }, @@ -94,17 +100,18 @@ func (m *Machine) doCall(service, key, method string, params []byte) (*wire.Call return entry }, func() *wire.CallEntryMessage { - return m._doCall(service, key, method, params) + return m._doCall(service, key, method, headers, params) }) return entry, entryIndex } -func (m *Machine) _doCall(service, key, method string, params []byte) *wire.CallEntryMessage { +func (m *Machine) _doCall(service, key, method string, headers []*protocol.Header, params []byte) *wire.CallEntryMessage { msg := &wire.CallEntryMessage{ CallEntryMessage: protocol.CallEntryMessage{ ServiceName: service, HandlerName: method, Parameter: params, + Headers: headers, Key: key, }, } @@ -113,18 +120,51 @@ func (m *Machine) _doCall(service, key, method string, params []byte) *wire.Call return msg } -func (m *Machine) sendCall(service, key, method string, body []byte, delay time.Duration) { +func headersEqual(left, right []*protocol.Header) bool { + if len(left) != len(right) { + return false + } + for i := range left { + if left[i].Key != right[i].Key || left[i].Value != right[i].Value { + return false + } + } + return true +} + +func headersToProto(headers map[string]string) []*protocol.Header { + if len(headers) == 0 { + return nil + } + + h := make([]*protocol.Header, 0, len(headers)) + for k, v := range headers { + h = append(h, &protocol.Header{Key: k, Value: v}) + } + + slices.SortFunc(h, func(a, b *protocol.Header) int { + return cmp.Compare(a.Key, b.Key) + }) + + return h +} + +func (m *Machine) sendCall(service, key, method string, headersMap map[string]string, body []byte, delay time.Duration) { + headers := headersToProto(headersMap) + _, _ = replayOrNew( m, func(entry *wire.OneWayCallEntryMessage) restate.Void { if entry.ServiceName != service || entry.Key != key || entry.HandlerName != method || + !headersEqual(entry.Headers, headers) || !bytes.Equal(entry.Parameter, body) { panic(m.newEntryMismatch(&wire.OneWayCallEntryMessage{ OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ ServiceName: service, HandlerName: method, + Headers: headers, Parameter: body, Key: key, }, @@ -134,13 +174,13 @@ func (m *Machine) sendCall(service, key, method string, body []byte, delay time. return restate.Void{} }, func() restate.Void { - m._sendCall(service, key, method, body, delay) + m._sendCall(service, key, method, headers, body, delay) return restate.Void{} }, ) } -func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) { +func (c *Machine) _sendCall(service, key, method string, headers []*protocol.Header, params []byte, delay time.Duration) { var invokeTime uint64 if delay != 0 { invokeTime = uint64(time.Now().Add(delay).UnixMilli()) @@ -150,6 +190,7 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{ ServiceName: service, HandlerName: method, + Headers: headers, Parameter: params, Key: key, InvokeTime: invokeTime, diff --git a/internal/state/state.go b/internal/state/state.go index da84468..3f26737 100644 --- a/internal/state/state.go +++ b/internal/state/state.go @@ -553,7 +553,8 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { return err } - if _, ok := msg.(*wire.InputEntryMessage); !ok { + inputMsg, ok := msg.(*wire.InputEntryMessage) + if !ok { return wire.ErrUnexpectedMessage } @@ -587,7 +588,6 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error { go m.handleCompletionsAcks() - inputMsg := msg.(*wire.InputEntryMessage) m.request.Body = inputMsg.GetValue() if len(inputMsg.GetHeaders()) > 0 { diff --git a/options.go b/options.go index 0ba9e17..52a0bf9 100644 --- a/options.go +++ b/options.go @@ -56,7 +56,7 @@ func (w withPayloadCodec) BeforeObject(opts *options.ObjectOptions) { opts.DefaultCodec = w.codec } -// withPayloadCodec is an option that can be provided to handler/service options +// WithPayloadCodec is an option that can be provided to handler/service options // in order to specify a custom [encoding.PayloadCodec] with which to (de)serialise and // set content-types instead of the default of JSON. // @@ -73,3 +73,18 @@ var WithBinary = WithPayloadCodec(encoding.BinaryCodec) // WithJSON is an option to specify the use of [encoding.JsonCodec] for (de)serialisation var WithJSON = WithPayloadCodec(encoding.JSONCodec) + +type withHeaders struct { + headers map[string]string +} + +var _ options.CallOption = withHeaders{} + +func (w withHeaders) BeforeCall(opts *options.CallOptions) { + opts.Headers = w.headers +} + +// WithHeaders is an option to specify outgoing headers when making a call +func WithHeaders(headers map[string]string) withHeaders { + return withHeaders{headers} +}