Skip to content

Commit

Permalink
Support passing headers in calls (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackkleeman authored Aug 9, 2024
1 parent adafb6e commit adae520
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 12 deletions.
3 changes: 2 additions & 1 deletion internal/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ type SetOption interface {
}

type CallOptions struct {
Codec encoding.Codec
Codec encoding.Codec
Headers map[string]string
}

type CallOption interface {
Expand Down
57 changes: 49 additions & 8 deletions internal/state/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package state

import (
"bytes"
"cmp"
"fmt"
"slices"
"time"

restate "github.com/restatedev/sdk-go"
Expand All @@ -29,7 +31,7 @@ func (c *serviceCall) RequestFuture(input any) (restate.ResponseFuture, error) {
return nil, errors.NewTerminalError(fmt.Errorf("failed to marshal RequestFuture input: %w", err))
}

entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, bytes)
entry, entryIndex := c.machine.doCall(c.service, c.key, c.method, c.options.Headers, bytes)

return decodingResponseFuture{
futures.NewResponseFuture(c.machine.suspensionCtx, entry, entryIndex, func(err error) any { return c.machine.newProtocolViolation(entry, err) }),
Expand Down Expand Up @@ -70,22 +72,26 @@ func (c *serviceCall) Send(input any, delay time.Duration) error {
if err != nil {
return errors.NewTerminalError(fmt.Errorf("failed to marshal Send input: %w", err))
}
c.machine.sendCall(c.service, c.key, c.method, bytes, delay)
c.machine.sendCall(c.service, c.key, c.method, c.options.Headers, bytes, delay)
return nil
}

func (m *Machine) doCall(service, key, method string, params []byte) (*wire.CallEntryMessage, uint32) {
func (m *Machine) doCall(service, key, method string, headersMap map[string]string, params []byte) (*wire.CallEntryMessage, uint32) {
headers := headersToProto(headersMap)

entry, entryIndex := replayOrNew(
m,
func(entry *wire.CallEntryMessage) *wire.CallEntryMessage {
if entry.ServiceName != service ||
entry.Key != key ||
entry.HandlerName != method ||
!headersEqual(entry.Headers, headers) ||
!bytes.Equal(entry.Parameter, params) {
panic(m.newEntryMismatch(&wire.CallEntryMessage{
CallEntryMessage: protocol.CallEntryMessage{
ServiceName: service,
HandlerName: method,
Headers: headers,
Parameter: params,
Key: key,
},
Expand All @@ -94,17 +100,18 @@ func (m *Machine) doCall(service, key, method string, params []byte) (*wire.Call

return entry
}, func() *wire.CallEntryMessage {
return m._doCall(service, key, method, params)
return m._doCall(service, key, method, headers, params)
})
return entry, entryIndex
}

func (m *Machine) _doCall(service, key, method string, params []byte) *wire.CallEntryMessage {
func (m *Machine) _doCall(service, key, method string, headers []*protocol.Header, params []byte) *wire.CallEntryMessage {
msg := &wire.CallEntryMessage{
CallEntryMessage: protocol.CallEntryMessage{
ServiceName: service,
HandlerName: method,
Parameter: params,
Headers: headers,
Key: key,
},
}
Expand All @@ -113,18 +120,51 @@ func (m *Machine) _doCall(service, key, method string, params []byte) *wire.Call
return msg
}

func (m *Machine) sendCall(service, key, method string, body []byte, delay time.Duration) {
func headersEqual(left, right []*protocol.Header) bool {
if len(left) != len(right) {
return false
}
for i := range left {
if left[i].Key != right[i].Key || left[i].Value != right[i].Value {
return false
}
}
return true
}

func headersToProto(headers map[string]string) []*protocol.Header {
if len(headers) == 0 {
return nil
}

h := make([]*protocol.Header, 0, len(headers))
for k, v := range headers {
h = append(h, &protocol.Header{Key: k, Value: v})
}

slices.SortFunc(h, func(a, b *protocol.Header) int {
return cmp.Compare(a.Key, b.Key)
})

return h
}

func (m *Machine) sendCall(service, key, method string, headersMap map[string]string, body []byte, delay time.Duration) {
headers := headersToProto(headersMap)

_, _ = replayOrNew(
m,
func(entry *wire.OneWayCallEntryMessage) restate.Void {
if entry.ServiceName != service ||
entry.Key != key ||
entry.HandlerName != method ||
!headersEqual(entry.Headers, headers) ||
!bytes.Equal(entry.Parameter, body) {
panic(m.newEntryMismatch(&wire.OneWayCallEntryMessage{
OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{
ServiceName: service,
HandlerName: method,
Headers: headers,
Parameter: body,
Key: key,
},
Expand All @@ -134,13 +174,13 @@ func (m *Machine) sendCall(service, key, method string, body []byte, delay time.
return restate.Void{}
},
func() restate.Void {
m._sendCall(service, key, method, body, delay)
m._sendCall(service, key, method, headers, body, delay)
return restate.Void{}
},
)
}

func (c *Machine) _sendCall(service, key, method string, params []byte, delay time.Duration) {
func (c *Machine) _sendCall(service, key, method string, headers []*protocol.Header, params []byte, delay time.Duration) {
var invokeTime uint64
if delay != 0 {
invokeTime = uint64(time.Now().Add(delay).UnixMilli())
Expand All @@ -150,6 +190,7 @@ func (c *Machine) _sendCall(service, key, method string, params []byte, delay ti
OneWayCallEntryMessage: protocol.OneWayCallEntryMessage{
ServiceName: service,
HandlerName: method,
Headers: headers,
Parameter: params,
Key: key,
InvokeTime: invokeTime,
Expand Down
4 changes: 2 additions & 2 deletions internal/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,8 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {
return err
}

if _, ok := msg.(*wire.InputEntryMessage); !ok {
inputMsg, ok := msg.(*wire.InputEntryMessage)
if !ok {
return wire.ErrUnexpectedMessage
}

Expand Down Expand Up @@ -587,7 +588,6 @@ func (m *Machine) process(ctx *Context, start *wire.StartMessage) error {

go m.handleCompletionsAcks()

inputMsg := msg.(*wire.InputEntryMessage)
m.request.Body = inputMsg.GetValue()

if len(inputMsg.GetHeaders()) > 0 {
Expand Down
17 changes: 16 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func (w withPayloadCodec) BeforeObject(opts *options.ObjectOptions) {
opts.DefaultCodec = w.codec
}

// withPayloadCodec is an option that can be provided to handler/service options
// WithPayloadCodec is an option that can be provided to handler/service options
// in order to specify a custom [encoding.PayloadCodec] with which to (de)serialise and
// set content-types instead of the default of JSON.
//
Expand All @@ -73,3 +73,18 @@ var WithBinary = WithPayloadCodec(encoding.BinaryCodec)

// WithJSON is an option to specify the use of [encoding.JsonCodec] for (de)serialisation
var WithJSON = WithPayloadCodec(encoding.JSONCodec)

type withHeaders struct {
headers map[string]string
}

var _ options.CallOption = withHeaders{}

func (w withHeaders) BeforeCall(opts *options.CallOptions) {
opts.Headers = w.headers
}

// WithHeaders is an option to specify outgoing headers when making a call
func WithHeaders(headers map[string]string) withHeaders {
return withHeaders{headers}
}

0 comments on commit adae520

Please sign in to comment.