Skip to content

Commit

Permalink
chore: enforce max limit for webhook (#4975)
Browse files Browse the repository at this point in the history
  • Loading branch information
achettyiitr authored Aug 6, 2024
1 parent 77b75fb commit 377887a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 9 deletions.
3 changes: 3 additions & 0 deletions gateway/webhook/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (

"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"

gwstats "github.com/rudderlabs/rudder-server/gateway/internal/stats"
gwtypes "github.com/rudderlabs/rudder-server/gateway/internal/types"
"github.com/rudderlabs/rudder-server/gateway/webhook/model"
Expand Down Expand Up @@ -56,6 +57,8 @@ func Setup(gwHandle Gateway, transformerFeaturesService transformer.FeaturesServ
maxTransformerProcess := config.GetIntVar(64, 1, "Gateway.webhook.maxTransformerProcess")
// Parse all query params from sources mentioned in this list
webhook.config.sourceListForParsingParams = config.GetStringSliceVar([]string{"Shopify", "adjust"}, "Gateway.webhook.sourceListForParsingParams")
// Maximum request size to gateway
webhook.config.maxReqSize = config.GetReloadableIntVar(4000, 1024, "Gateway.maxReqSizeInKB")

webhook.config.forwardGetRequestForSrcMap = lo.SliceToMap(
config.GetStringSliceVar([]string{"adjust"}, "Gateway.webhook.forwardGetRequestForSrcs"),
Expand Down
8 changes: 8 additions & 0 deletions gateway/webhook/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ type HandleT struct {
backgroundCancel context.CancelFunc

config struct {
maxReqSize config.ValueLoader[int]
webhookBatchTimeout config.ValueLoader[time.Duration]
maxWebhookBatchSize config.ValueLoader[int]
sourceListForParsingParams []string
Expand Down Expand Up @@ -334,6 +335,13 @@ func (bt *batchWebhookTransformerT) batchTransformLoop() {
req.done <- transformerResponse{Err: response.GetStatus(response.InvalidJSON)}
continue
}
if len(body) > bt.webhook.config.maxReqSize.Load() {
req.done <- transformerResponse{
StatusCode: response.GetErrorStatusCode(response.RequestBodyTooLarge),
Err: response.GetStatus(response.RequestBodyTooLarge),
}
continue
}

payload, err := sourceTransformAdapter.getTransformerEvent(req.authContext, body)
if err != nil {
Expand Down
62 changes: 53 additions & 9 deletions gateway/webhook/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -15,13 +16,16 @@ import (

"go.uber.org/mock/gomock"

"github.com/rudderlabs/rudder-go-kit/bytesize"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/rudderlabs/rudder-go-kit/config"
"github.com/rudderlabs/rudder-go-kit/logger"
"github.com/rudderlabs/rudder-go-kit/stats"
"github.com/rudderlabs/rudder-go-kit/stats/memstats"

gwStats "github.com/rudderlabs/rudder-server/gateway/internal/stats"
gwtypes "github.com/rudderlabs/rudder-server/gateway/internal/types"
mockWebhook "github.com/rudderlabs/rudder-server/gateway/mocks"
Expand Down Expand Up @@ -61,11 +65,11 @@ type mockSourceTransformAdapter struct {
url string
}

func (v0 *mockSourceTransformAdapter) getTransformerEvent(authCtx *gwtypes.AuthRequestContext, body []byte) ([]byte, error) {
func (v0 *mockSourceTransformAdapter) getTransformerEvent(_ *gwtypes.AuthRequestContext, body []byte) ([]byte, error) {
return body, nil
}

func (v0 *mockSourceTransformAdapter) getTransformerURL(sourceType string) (string, error) {
func (v0 *mockSourceTransformAdapter) getTransformerURL(string) (string, error) {
return v0.url, nil
}

Expand All @@ -77,13 +81,53 @@ func getMockSourceTransformAdapterFunc(url string) func(ctx context.Context) (so
}
}

func TestWebhookMaxRequestSize(t *testing.T) {
initWebhook()

ctrl := gomock.NewController(t)

mockGW := mockWebhook.NewMockGateway(ctrl)
mockGW.EXPECT().TrackRequestMetrics(gomock.Any()).Times(1)
mockGW.EXPECT().NewSourceStat(gomock.Any(), gomock.Any()).Return(&gwStats.SourceStat{}).Times(1)

mockTransformerFeaturesService := mock_features.NewMockFeaturesService(ctrl)

maxReqSizeInKB := 1

webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = func(ctx context.Context) (sourceTransformAdapter, error) {
return &mockSourceTransformAdapter{}, nil
}
})
webhookHandler.config.maxReqSize = config.SingleValueLoader(maxReqSizeInKB)
t.Cleanup(func() {
_ = webhookHandler.Shutdown()
})

webhookHandler.Register(sourceDefName)

payload := fmt.Sprintf(`{"hello":"world", "data": %q}`, strings.Repeat("a", 2*maxReqSizeInKB*int(bytesize.KB)))
require.Greater(t, len(payload), maxReqSizeInKB*int(bytesize.KB))

req := httptest.NewRequest(http.MethodPost, "/v1/webhook", bytes.NewBufferString(payload))
resp := httptest.NewRecorder()

reqCtx := context.WithValue(req.Context(), gwtypes.CtxParamCallType, "webhook")
reqCtx = context.WithValue(reqCtx, gwtypes.CtxParamAuthRequestContext, &gwtypes.AuthRequestContext{
SourceDefName: sourceDefName,
})

webhookHandler.RequestHandler(resp, req.WithContext(reqCtx))
require.Equal(t, http.StatusRequestEntityTooLarge, resp.Result().StatusCode)
}

func TestWebhookBlockTillFeaturesAreFetched(t *testing.T) {
initWebhook()
ctrl := gomock.NewController(t)
mockGW := mockWebhook.NewMockGateway(ctrl)
mockTransformerFeaturesService := mock_features.NewMockFeaturesService(ctrl)
mockTransformerFeaturesService.EXPECT().Wait().Return(make(chan struct{})).Times(1)
webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.Default)
webhookHandler := Setup(mockGW, mockTransformerFeaturesService, stats.NOP)

mockGW.EXPECT().TrackRequestMetrics(gomock.Any()).Times(1)
mockGW.EXPECT().NewSourceStat(gomock.Any(), gomock.Any()).Return(&gwStats.SourceStat{}).Times(1)
Expand Down Expand Up @@ -112,7 +156,7 @@ func TestWebhookRequestHandlerWithTransformerBatchGeneralError(t *testing.T) {
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, sampleError, http.StatusBadRequest)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})

Expand Down Expand Up @@ -157,7 +201,7 @@ func TestWebhookRequestHandlerWithTransformerBatchPayloadLengthMismatchError(t *
respBody, _ := json.Marshal(responses)
_, _ = w.Write(respBody)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})

Expand Down Expand Up @@ -200,7 +244,7 @@ func TestWebhookRequestHandlerWithTransformerRequestError(t *testing.T) {
respBody, _ := json.Marshal(responses)
_, _ = w.Write(respBody)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})

Expand Down Expand Up @@ -243,7 +287,7 @@ func TestWebhookRequestHandlerWithOutputToSource(t *testing.T) {
respBody, _ := json.Marshal(responses)
_, _ = w.Write(respBody)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})
mockGW.EXPECT().TrackRequestMetrics("").Times(1)
Expand Down Expand Up @@ -285,7 +329,7 @@ func TestWebhookRequestHandlerWithOutputToGateway(t *testing.T) {
respBody, _ := json.Marshal(responses)
_, _ = w.Write(respBody)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})
mockGW.EXPECT().TrackRequestMetrics("").Times(1)
Expand Down Expand Up @@ -332,7 +376,7 @@ func TestWebhookRequestHandlerWithOutputToGatewayAndSource(t *testing.T) {
respBody, _ := json.Marshal(responses)
_, _ = w.Write(respBody)
}))
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.Default, func(bt *batchWebhookTransformerT) {
webhookHandler := Setup(mockGW, transformer.NewNoOpService(), stats.NOP, func(bt *batchWebhookTransformerT) {
bt.sourceTransformAdapter = getMockSourceTransformAdapterFunc(transformerServer.URL)
})
mockGW.EXPECT().TrackRequestMetrics("").Times(1)
Expand Down

0 comments on commit 377887a

Please sign in to comment.