Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API Handshake for Connect #83

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ func (h *handler) Connect(ctx context.Context) error {
APIBaseUrl: h.GetAPIBaseURL(),
IsDev: h.isDev(),
DevServerUrl: DevServerURL(),
ConnectUrls: h.ConnectURLs,
InstanceId: h.InstanceId,
BuildId: h.BuildId,
Platform: Ptr(platform()),
Expand Down
51 changes: 29 additions & 22 deletions connect/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ type connectReport struct {
type connectOpt func(opts *connectOpts)
type connectOpts struct {
notifyConnectedChan chan struct{}
excludeGateways []string
}

func withNotifyConnectedChan(ch chan struct{}) connectOpt {
Expand All @@ -32,14 +33,20 @@ func withNotifyConnectedChan(ch chan struct{}) connectOpt {
}
}

func withExcludeGateways(exclude ...string) connectOpt {
return func(opts *connectOpts) {
opts.excludeGateways = exclude
}
}

func (h *connectHandler) connect(ctx context.Context, data connectionEstablishData, opts ...connectOpt) {
o := connectOpts{}
for _, opt := range opts {
opt(&o)
}

// Set up connection (including connect handshake protocol)
preparedConn, err := h.prepareConnection(ctx, data)
preparedConn, err := h.prepareConnection(ctx, data, o.excludeGateways)
if err != nil {
h.logger.Error("could not establish connection", "err", err)

Expand All @@ -60,7 +67,7 @@ func (h *connectHandler) connect(ctx context.Context, data connectionEstablishDa
}

// Set up connection lifecycle logic (receiving messages, handling requests, etc.)
err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayHost)
err = h.handleConnection(ctx, data, preparedConn.ws, preparedConn.gatewayGroupName)
if err != nil {
h.logger.Error("could not handle connection", "err", err)

Expand Down Expand Up @@ -89,31 +96,33 @@ type connectionEstablishData struct {
}

type connection struct {
ws *websocket.Conn
gatewayHost string
connectionId string
ws *websocket.Conn
gatewayGroupName string
connectionId string
}

func (h *connectHandler) prepareConnection(ctx context.Context, data connectionEstablishData) (connection, error) {
func (h *connectHandler) prepareConnection(ctx context.Context, data connectionEstablishData, excludeGateways []string) (connection, error) {
connectTimeout, cancelConnectTimeout := context.WithTimeout(ctx, 10*time.Second)
defer cancelConnectTimeout()

gatewayHost := h.hostsManager.pickAvailableGateway()
if gatewayHost == "" {
// All gateways have been tried, reset the internal state to retry
h.hostsManager.resetGateways()

return connection{}, reconnectError{fmt.Errorf("no available gateway hosts")}
startRes, err := h.apiClient.start(ctx, data.hashedSigningKey, &connectproto.StartRequest{
ExcludeGateways: excludeGateways,
})
if err != nil {
return connection{}, reconnectError{fmt.Errorf("could not start connection: %w", err)}
}

h.logger.Debug("handshake successful", "gateway_endpoint", startRes.GetGatewayEndpoint(), "gateway_group", startRes.GetGatewayGroup())

gatewayHost := startRes.GetGatewayEndpoint()

// Establish WebSocket connection to one of the gateways
ws, _, err := websocket.Dial(connectTimeout, gatewayHost, &websocket.DialOptions{
Subprotocols: []string{
types.GatewaySubProtocol,
},
})
if err != nil {
h.hostsManager.markUnreachableGateway(gatewayHost)
return connection{}, reconnectError{fmt.Errorf("could not connect to gateway: %w", err)}
}

Expand All @@ -122,19 +131,19 @@ func (h *connectHandler) prepareConnection(ctx context.Context, data connectionE

h.logger.Debug("websocket connection established", "gateway_host", gatewayHost)

err = h.performConnectHandshake(ctx, connectionId.String(), ws, gatewayHost, data)
err = h.performConnectHandshake(ctx, connectionId.String(), ws, startRes, data)
if err != nil {
return connection{}, reconnectError{fmt.Errorf("could not perform connect handshake: %w", err)}
}

return connection{
ws: ws,
gatewayHost: gatewayHost,
connectionId: connectionId.String(),
ws: ws,
gatewayGroupName: startRes.GetGatewayGroup(),
connectionId: connectionId.String(),
}, nil
}

func (h *connectHandler) handleConnection(ctx context.Context, data connectionEstablishData, ws *websocket.Conn, gatewayHost string) error {
func (h *connectHandler) handleConnection(ctx context.Context, data connectionEstablishData, ws *websocket.Conn, gatewayGroupName string) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -217,8 +226,6 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
// - Worker shutdown, parent context got cancelled
if err := eg.Wait(); err != nil && ctx.Err() == nil {
if errors.Is(err, errGatewayDraining) {
h.hostsManager.markDrainingGateway(gatewayHost)

// Gateway is draining and will not accept new connections.
// We must reconnect to a different gateway, only then can we close the old connection.
waitUntilConnected, doneWaiting := context.WithTimeout(context.Background(), 10*time.Second)
Expand All @@ -232,7 +239,7 @@ func (h *connectHandler) handleConnection(ctx context.Context, data connectionEs
}()

// Establish new connection, notify the routine above when the new connection is established
go h.connect(context.Background(), data, withNotifyConnectedChan(notifyConnectedChan))
go h.connect(context.Background(), data, withNotifyConnectedChan(notifyConnectedChan), withExcludeGateways(gatewayGroupName))

// Wait until the new connection is established before closing the old one
<-waitUntilConnected.Done()
Expand Down Expand Up @@ -303,7 +310,7 @@ func (h *connectHandler) withTemporaryConnection(data connectionEstablishData, h
return fmt.Errorf("could not establish connection after %d attempts", maxAttempts)
}

ws, err := h.prepareConnection(context.Background(), data)
ws, err := h.prepareConnection(context.Background(), data, nil)
if err != nil {
attempts++
continue
Expand Down
59 changes: 0 additions & 59 deletions connect/gateway_hosts.go

This file was deleted.

26 changes: 3 additions & 23 deletions connect/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"log/slog"
"os"
"runtime"
"strings"
"sync"
"time"
)
Expand All @@ -35,6 +34,7 @@ func Connect(ctx context.Context, opts Opts, invoker FunctionInvoker, logger *sl
notifyConnectDoneChan: make(chan connectReport),
notifyConnectedChan: make(chan struct{}),
initiateConnectionChan: make(chan struct{}),
apiClient: newWorkerApiClient(opts.APIBaseUrl, opts.Env),
}

wp := NewWorkerPool(ctx, opts.WorkerConcurrency, ch.processExecutorRequest)
Expand Down Expand Up @@ -71,7 +71,6 @@ type Opts struct {
APIBaseUrl string
IsDev bool
DevServerUrl string
ConnectUrls []string

InstanceId *string
BuildId *string
Expand All @@ -91,8 +90,6 @@ type connectHandler struct {
messageBuffer []*connectproto.ConnectMessage
messageBufferLock sync.Mutex

hostsManager *hostsManager

workerPool *workerPool

// Notify when connect finishes (either with an error or because the context got canceled)
Expand All @@ -103,6 +100,8 @@ type connectHandler struct {

// Channel to imperatively initiate a connection
initiateConnectionChan chan struct{}

apiClient *workerApiClient
}

// authContext is wrapper for information related to authentication
Expand Down Expand Up @@ -132,13 +131,6 @@ func (h *connectHandler) Connect(ctx context.Context) error {
return fmt.Errorf("failed to serialize connect config: %w", err)
}

hosts := h.connectURLs()
if len(hosts) == 0 {
return fmt.Errorf("no connect URLs provided")
}

h.hostsManager = newHostsManager(hosts)

var attempts int

// We construct a connection loop, which will attempt to reconnect on failure
Expand Down Expand Up @@ -305,18 +297,6 @@ func (h *connectHandler) processExecutorRequest(msg workerPoolMsg) {
}
}

func (h *connectHandler) connectURLs() []string {
if len(h.opts.ConnectUrls) > 0 {
return h.opts.ConnectUrls
}

if h.opts.IsDev {
return []string{fmt.Sprintf("%s/connect", strings.Replace(h.opts.DevServerUrl, "http", "ws", 1))}
}

return nil
}

func (h *connectHandler) instanceId() string {
if h.opts.InstanceId != nil {
return *h.opts.InstanceId
Expand Down
14 changes: 3 additions & 11 deletions connect/handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,18 @@ func shouldReconnect(err error) bool {
return errors.Is(err, reconnectError{})
}

func (h *connectHandler) performConnectHandshake(ctx context.Context, connectionId string, ws *websocket.Conn, gatewayHost string, data connectionEstablishData) error {
func (h *connectHandler) performConnectHandshake(ctx context.Context, connectionId string, ws *websocket.Conn, startResponse *connectproto.StartResponse, data connectionEstablishData) error {
// Wait for gateway hello message
{
initialMessageTimeout, cancelInitialTimeout := context.WithTimeout(ctx, 5*time.Second)
defer cancelInitialTimeout()
var helloMessage connectproto.ConnectMessage
err := wsproto.Read(initialMessageTimeout, ws, &helloMessage)
if err != nil {
h.hostsManager.markUnreachableGateway(gatewayHost)
return reconnectError{fmt.Errorf("did not receive gateway hello message: %w", err)}
}

if helloMessage.Kind != connectproto.GatewayMessageType_GATEWAY_HELLO {
h.hostsManager.markUnreachableGateway(gatewayHost)
return reconnectError{fmt.Errorf("expected gateway hello message, got %s", helloMessage.Kind)}
}

Expand All @@ -50,26 +48,20 @@ func (h *connectHandler) performConnectHandshake(ctx context.Context, connection

// Send connect message
{

apiOrigin := h.opts.APIBaseUrl
if h.opts.IsDev {
apiOrigin = h.opts.DevServerUrl
}

data, err := proto.Marshal(&connectproto.WorkerConnectRequestData{
SessionId: &connectproto.SessionIdentifier{
BuildId: h.opts.BuildId,
InstanceId: h.instanceId(),
ConnectionId: connectionId,
},
AuthData: &connectproto.AuthData{
HashedSigningKey: data.hashedSigningKey,
SessionToken: startResponse.GetSessionToken(),
SyncToken: startResponse.GetSyncToken(),
},
AppName: h.opts.AppName,
Config: &connectproto.ConfigDetails{
Capabilities: data.marshaledCapabilities,
Functions: data.marshaledFns,
ApiOrigin: apiOrigin,
},
SystemAttributes: &connectproto.SystemAttributes{
CpuCores: data.numCpuCores,
Expand Down
67 changes: 67 additions & 0 deletions connect/workerapi.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package connect

import (
"bytes"
"context"
"fmt"
"github.com/inngest/inngest/proto/gen/connect/v1"
"google.golang.org/protobuf/proto"
"io"
"net/http"
)

type workerApiClient struct {
client http.Client
env *string
apiBaseUrl string
}

func newWorkerApiClient(apiBaseUrl string, env *string) *workerApiClient {
return &workerApiClient{
apiBaseUrl: apiBaseUrl,
env: env,
}
}

func (a *workerApiClient) start(ctx context.Context, hashedSigningKey []byte, req *connect.StartRequest) (*connect.StartResponse, error) {
reqBody, err := proto.Marshal(req)
if err != nil {
return nil, fmt.Errorf("could not marshal start request: %w", err)
}

httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("%s/v0/connect/start", a.apiBaseUrl), bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("could not create start request: %w", err)
}

httpReq.Header.Set("Content-Type", "application/protobuf")
httpReq.Header.Set("Authorization", fmt.Sprintf("Bearer %s", string(hashedSigningKey)))

if a.env != nil {
httpReq.Header.Add("X-Inngest-Env", *a.env)
}

httpRes, err := a.client.Do(httpReq)
if err != nil {
return nil, fmt.Errorf("could not send start request: %w", err)
}

defer httpRes.Body.Close()

if httpRes.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected status code: %d", httpRes.StatusCode)
}

byt, err := io.ReadAll(httpRes.Body)
if err != nil {
return nil, fmt.Errorf("could not read start response: %w", err)
}

res := &connect.StartResponse{}
err = proto.Unmarshal(byt, res)
if err != nil {
return nil, fmt.Errorf("could not unmarshal start response: %w", err)
}

return res, nil
}
Loading
Loading