Skip to content

Commit

Permalink
Fix BulkUpdateDetection Race Condition
Browse files Browse the repository at this point in the history
There are 2 parts to a call to BulkUpdateDetection: the original request, and an async portion that does the bulk of the work. Both end with a log statement but during testing it isn't always clear which part of the call will log first. I've added a WaitSync for the request portion so we can control the timing.

InitMock doesn't have to wait on the WaitSync if the parameters of that test case won't kick off the async portion BUT any test case that does make it to the async portion should wait inside a DoAndReturn of the first possible mocked request (BuildBulkIndexer).

In plain english, once the main request is done it marks the nonAsyncWG as Done while the async portion of the request is waiting to build the bulk indexer. This ensures the async portion begins after the original request ends.
  • Loading branch information
coreyogburn committed Jan 21, 2025
1 parent fbac142 commit 70ddd8b
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions server/detectionshandler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2178,7 +2178,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name string
NewStatus string
ReqBody []byte
InitMock func(*testing.T, *Server, *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster)
InitMock func(*testing.T, *Server, *gomock.Controller, *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster)
Code int
Response any
Logs []EntryMatcher
Expand All @@ -2188,7 +2188,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Sunny Day - IDs",
NewStatus: "enable",
ReqBody: []byte(`{"ids":["123","456","789"]}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, nonAsyncWG *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore)
mAuth := srv.Authorizer.(*rbac.FakeAuthorizer)
mHostAuth := srv.Host.Authorizer.(*rbac.FakeAuthorizer)
Expand Down Expand Up @@ -2225,7 +2225,11 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {

docIndexer := servermock.NewMockBulkIndexer(ctrl)

mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).Return(docIndexer, nil)
mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, logger *log.Entry) (esutil.BulkIndexer, error) {
nonAsyncWG.Wait()

return docIndexer, nil
})

engElastAlert.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil)
engSuricata.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil)
Expand Down Expand Up @@ -2341,7 +2345,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Sunny Day - Query",
NewStatus: "enable",
ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, nonAsyncWG *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore)
mAuth := srv.Authorizer.(*rbac.FakeAuthorizer)
mHostAuth := srv.Host.Authorizer.(*rbac.FakeAuthorizer)
Expand Down Expand Up @@ -2380,7 +2384,11 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {

docIndexer := servermock.NewMockBulkIndexer(ctrl)

mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).Return(docIndexer, nil)
mDetStore.EXPECT().BuildBulkIndexer(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, logger *log.Entry) (esutil.BulkIndexer, error) {
nonAsyncWG.Wait()

return docIndexer, nil
})

engElastAlert.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil)
engSuricata.EXPECT().ApplyFilters(gomock.Any()).Return(false, nil)
Expand Down Expand Up @@ -2496,7 +2504,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Cannot Delete Community Rules - Ids",
NewStatus: "delete",
ReqBody: []byte(`{"ids":["123","456","789"]}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore)
mAuth := srv.Authorizer.(*rbac.FakeAuthorizer)

Expand Down Expand Up @@ -2530,7 +2538,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Cannot Delete Community Rules - Query",
NewStatus: "delete",
ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore)
mAuth := srv.Authorizer.(*rbac.FakeAuthorizer)

Expand Down Expand Up @@ -2570,7 +2578,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Query Failure",
NewStatus: "enable",
ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
mDetStore := srv.Detectionstore.(*servermock.MockDetectionstore)
mAuth := srv.Authorizer.(*rbac.FakeAuthorizer)

Expand All @@ -2591,7 +2599,7 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
Name: "Unauthorized",
NewStatus: "disable",
ReqBody: []byte(`{"query":"severity: low AND ruleset: ETOPEN"}`),
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller) (*sync.WaitGroup, *MockBroadcaster) {
InitMock: func(t *testing.T, srv *Server, ctrl *gomock.Controller, _ *sync.WaitGroup) (*sync.WaitGroup, *MockBroadcaster) {
return nil, nil
},
Code: 401,
Expand All @@ -2613,7 +2621,10 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {

h := NewDetectionHandler(srv)

wg, mb := test.InitMock(t, srv, ctrl)
nonAsyncWG := &sync.WaitGroup{}
nonAsyncWG.Add(1)

asyncWG, mb := test.InitMock(t, srv, ctrl, nonAsyncWG)
if mb != nil {
defer mb.Close()
}
Expand All @@ -2632,8 +2643,10 @@ func TestHandlerBulkUpdateDetection(t *testing.T) {
r := httptest.NewRequestWithContext(ctx, "PUT", fmt.Sprintf("/detection/bulk/%s", test.NewStatus), bytes.NewReader(test.ReqBody))

h.BulkUpdateDetection(w, r)
if wg != nil {
wg.Wait()
nonAsyncWG.Done()

if asyncWG != nil {
asyncWG.Wait()
}

assert.Equal(t, test.Code, w.Code)
Expand Down

0 comments on commit 70ddd8b

Please sign in to comment.