Skip to content
10 changes: 2 additions & 8 deletions pkg/async-gateway/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,14 @@ func (e *Endpoint) CreateWorkload(w http.ResponseWriter, r *http.Request) {
return
}

contentType := r.Header.Get("Content-Type")
if contentType == "" {
respondPlainText(w, http.StatusBadRequest, "error: missing Content-Type key in request header")
return
}

body := r.Body
defer func() {
_ = r.Body.Close()
}()

log := e.logger.With(zap.String("id", requestID), zap.String("contentType", contentType))
log := e.logger.With(zap.String("id", requestID))

id, err := e.service.CreateWorkload(requestID, body, contentType)
id, err := e.service.CreateWorkload(requestID, body, r.Header)
if err != nil {
respondPlainText(w, http.StatusInternalServerError, fmt.Sprintf("error: %v", err))
logErrorWithTelemetry(log, errors.Wrap(err, "failed to create workload"))
Expand Down
29 changes: 22 additions & 7 deletions pkg/async-gateway/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@ limitations under the License.
package gateway

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"

"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/types/async"
"go.uber.org/zap"
)

// Service provides an interface to the async-gateway business logic
type Service interface {
CreateWorkload(id string, payload io.Reader, contentType string) (string, error)
CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error)
GetWorkload(id string) (GetWorkloadResponse, error)
}

Expand All @@ -52,25 +55,37 @@ func NewService(clusterUID, apiName string, queue Queue, storage Storage, logger
}

// CreateWorkload enqueues an async workload request and uploads the request payload to S3
func (s *service) CreateWorkload(id string, payload io.Reader, contentType string) (string, error) {
func (s *service) CreateWorkload(id string, payload io.Reader, headers http.Header) (string, error) {
prefix := async.StoragePath(s.clusterUID, s.apiName)
log := s.logger.With(zap.String("id", id), zap.String("contentType", contentType))
log := s.logger.With(zap.String("id", id))

buf := &bytes.Buffer{}
if err := json.NewEncoder(buf).Encode(headers); err != nil {
return "", errors.Wrap(err, "failed to dump headers")
}

headersPath := async.HeadersPath(prefix, id)
log.Debugw("uploading headers", zap.String("path", headersPath))
if err := s.storage.Upload(headersPath, buf, "application/json"); err != nil {
return "", errors.Wrap(err, "failed to upload headers")
}

contentType := headers.Get("Content-Type")
payloadPath := async.PayloadPath(prefix, id)
log.Debug("uploading payload", zap.String("path", payloadPath))
log.Debugw("uploading payload", zap.String("path", payloadPath))
if err := s.storage.Upload(payloadPath, payload, contentType); err != nil {
return "", err
return "", errors.Wrap(err, "failed to upload payload")
}

log.Debug("sending message to queue")
if err := s.queue.SendMessage(id, id); err != nil {
return "", err
return "", errors.Wrap(err, "failed to send message to queue")
}

statusPath := fmt.Sprintf("%s/%s/status/%s", prefix, id, async.StatusInQueue)
log.Debug(fmt.Sprintf("setting status to %s", async.StatusInQueue))
if err := s.storage.Upload(statusPath, strings.NewReader(""), "text/plain"); err != nil {
return "", err
return "", errors.Wrap(err, "failed to upload workload status")
}

return id, nil
Expand Down
51 changes: 30 additions & 21 deletions pkg/dequeuer/async_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ type AsyncMessageHandlerConfig struct {
TargetURL string
}

type userPayload struct {
Body io.ReadCloser
ContentType string
}

func NewAsyncMessageHandler(config AsyncMessageHandlerConfig, awsClient *awslib.Client, eventHandler RequestEventHandler, logger *zap.SugaredLogger) *AsyncMessageHandler {
return &AsyncMessageHandler{
config: config,
Expand Down Expand Up @@ -104,9 +99,21 @@ func (h *AsyncMessageHandler) handleMessage(requestID string) error {
}
return errors.Wrap(err, "failed to get payload")
}
defer h.deletePayload(requestID)
defer func() {
h.deletePayload(requestID)
_ = payload.Close()
}()

result, err := h.submitRequest(payload, requestID)
headers, err := h.getHeaders(requestID)
if err != nil {
updateStatusErr := h.updateStatus(requestID, async.StatusFailed)
if updateStatusErr != nil {
h.log.Errorw("failed to update status after failure to get headers", "id", requestID, "error", updateStatusErr)
}
return errors.Wrap(err, "failed to get payload")
}

result, err := h.submitRequest(payload, headers, requestID)
if err != nil {
h.log.Errorw("failed to submit request to user container", "id", requestID, "error", err)
updateStatusErr := h.updateStatus(requestID, async.StatusFailed)
Expand Down Expand Up @@ -138,7 +145,7 @@ func (h *AsyncMessageHandler) updateStatus(requestID string, status async.Status
return h.aws.UploadStringToS3("", h.config.Bucket, key)
}

func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error) {
func (h *AsyncMessageHandler) getPayload(requestID string) (io.ReadCloser, error) {
key := async.PayloadPath(h.storagePath, requestID)
output, err := h.aws.S3().GetObject(
&s3.GetObjectInput{
Expand All @@ -149,16 +156,7 @@ func (h *AsyncMessageHandler) getPayload(requestID string) (*userPayload, error)
if err != nil {
return nil, errors.WithStack(err)
}

contentType := "application/octet-stream"
if output.ContentType != nil {
contentType = *output.ContentType
}

return &userPayload{
Body: output.Body,
ContentType: contentType,
}, nil
return output.Body, nil
}

func (h *AsyncMessageHandler) deletePayload(requestID string) {
Expand All @@ -170,13 +168,13 @@ func (h *AsyncMessageHandler) deletePayload(requestID string) {
}
}

func (h *AsyncMessageHandler) submitRequest(payload *userPayload, requestID string) (interface{}, error) {
req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload.Body)
func (h *AsyncMessageHandler) submitRequest(payload io.Reader, headers http.Header, requestID string) (interface{}, error) {
req, err := http.NewRequest(http.MethodPost, h.config.TargetURL, payload)
if err != nil {
return nil, errors.WithStack(err)
}

req.Header.Set("Content-Type", payload.ContentType)
req.Header = headers
req.Header.Set(CortexRequestIDHeader, requestID)

startTime := time.Now()
Expand Down Expand Up @@ -216,3 +214,14 @@ func (h *AsyncMessageHandler) uploadResult(requestID string, result interface{})
key := async.ResultPath(h.storagePath, requestID)
return h.aws.UploadJSONToS3(result, h.config.Bucket, key)
}

func (h *AsyncMessageHandler) getHeaders(requestID string) (http.Header, error) {
key := async.HeadersPath(h.storagePath, requestID)

var headers http.Header
if err := h.aws.ReadJSONFromS3(&headers, h.config.Bucket, key); err != nil {
return nil, err
}

return headers, nil
}
5 changes: 4 additions & 1 deletion pkg/dequeuer/async_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ func TestAsyncMessageHandler_Handle(t *testing.T) {
})
require.NoError(t, err)

err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, fmt.Sprintf("%s/%s/payload", asyncHandler.storagePath, requestID))
err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.PayloadPath(asyncHandler.storagePath, requestID))
require.NoError(t, err)

err = awsClient.UploadStringToS3("{}", asyncHandler.config.Bucket, async.HeadersPath(asyncHandler.storagePath, requestID))
require.NoError(t, err)

err = asyncHandler.Handle(&sqs.Message{
Expand Down
4 changes: 4 additions & 0 deletions pkg/types/async/s3_paths.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ func PayloadPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/payload", storagePath, requestID)
}

func HeadersPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/headers.json", storagePath, requestID)
}

func ResultPath(storagePath string, requestID string) string {
return fmt.Sprintf("%s/%s/result.json", storagePath, requestID)
}
Expand Down
2 changes: 1 addition & 1 deletion test/e2e/tests/aws/test_autoscaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


@pytest.mark.usefixtures("client")
@pytest.mark.parametrize("apis", TEST_APIS)
@pytest.mark.parametrize("apis", TEST_APIS, ids=[api["primary"] for api in TEST_APIS])
def test_autoscaling(printer: Callable, config: Dict, client: cx.Client, apis: Dict[str, Any]):
skip_autoscaling_test = config["global"].get("skip_autoscaling", False)
if skip_autoscaling_test:
Expand Down
4 changes: 1 addition & 3 deletions test/e2e/tests/aws/test_realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,8 @@ def test_realtime_api(printer: Callable, config: Dict, client: cx.Client, api: D


@pytest.mark.usefixtures("client")
@pytest.mark.parametrize("api", TEST_APIS_ARM)
@pytest.mark.parametrize("api", TEST_APIS_ARM, ids=[api["name"] for api in TEST_APIS_ARM])
def test_realtime_api_arm(printer: Callable, config: Dict, client: cx.Client, api: Dict[str, str]):

printer(f"testing {api['name']}")
e2e.tests.test_realtime_api(
printer=printer,
client=client,
Expand Down