From c3312026c75d0a57e4fef8e27e8ecaef68ddb886 Mon Sep 17 00:00:00 2001 From: Neil Xie Date: Tue, 2 Apr 2024 09:06:12 -0700 Subject: [PATCH] Add more tests for OS client --- .../elasticsearch/client/os2/client_test.go | 320 ++++++++++++++++++ 1 file changed, 320 insertions(+) diff --git a/common/elasticsearch/client/os2/client_test.go b/common/elasticsearch/client/os2/client_test.go index a4e643495ce..e579bf36602 100644 --- a/common/elasticsearch/client/os2/client_test.go +++ b/common/elasticsearch/client/os2/client_test.go @@ -26,6 +26,8 @@ import ( "bytes" "context" "crypto/tls" + "encoding/json" + "fmt" "io" "net/http" "net/http/httptest" @@ -40,6 +42,13 @@ import ( "github.com/uber/cadence/common/log/testlogger" ) +type MockTransport struct{} + +func (m *MockTransport) Perform(req *http.Request) (*http.Response, error) { + // Simulate a network or connection error + return nil, fmt.Errorf("forced connection error") +} + func TestNewClient(t *testing.T) { logger := testlogger.New(t) testServer := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -288,3 +297,314 @@ func TestCloseBody(t *testing.T) { _, err = osResponse.Body.Read(make([]byte, 1)) assert.Error(t, err, "Expected response body to be closed after calling closeBody") } + +func TestPutMapping(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + index string + body string + expectedErr bool + }{ + { + name: "Successful PutMapping", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }, + index: "testIndex", + body: `{"properties": {"field": {"type": "text"}}}`, + expectedErr: false, + }, + { + name: "Failed PutMapping", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }, + index: "nonExistentIndex", + body: `{"properties": {"field": {"type": "text"}}}`, + expectedErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + err := os2Client.PutMapping(context.Background(), tc.index, tc.body) + + if tc.expectedErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestPutMappingError(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + } + + os2Client, testServer := getSecureMockOS2Client(t, http.HandlerFunc(handler), true) + defer testServer.Close() + os2Client.client.Transport = &MockTransport{} + err := os2Client.PutMapping(context.Background(), "testIndex", `{"properties": {"field": {"type": "text"}}}`) + assert.Error(t, err) +} + +func TestIsNotFoundError(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + expected bool + }{ + { + name: "NotFound error", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": map[string]interface{}{}, + "status": 404, + }) + }), + expected: true, + }, + { + name: "Other error", + handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Bad Request", http.StatusBadRequest) + }), + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + err := os2Client.CreateIndex(context.Background(), "testIndex") + res := os2Client.IsNotFoundError(err) + assert.Equal(t, tc.expected, res) + }) + } +} + +func TestCount(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + index string + query string + expectedCount int64 + expectError bool + }{ + { + name: "Successful Count", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{"count": 42}`) + }, + index: "testIndex", + query: "{}", + expectedCount: 42, + expectError: false, + }, + { + name: "OpenSearch Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": "Internal Server Error"}`) + }, + index: "testIndex", + query: "{}", + expectError: true, + }, + { + name: "Decoding Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{"count": "should be an int64"}`) + }, + index: "testIndex", + query: "{}", + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + count, err := os2Client.Count(context.Background(), tc.index, tc.query) + + if tc.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.expectedCount, count) + } + }) + } +} + +func TestScroll(t *testing.T) { + testCases := []struct { + name string + scrollID string + handler http.HandlerFunc + expectError bool + expectedScrollID string // Add more fields as needed for assertions + }{ + { + name: "Initial Search Request", + scrollID: "", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollID123", "took": 10, "hits": {"total": {"value": 2}, "hits": [{"_source": {"field1": "value1"}}]}}`) + }, + expectError: false, + expectedScrollID: "scrollID123", + }, + { + name: "Subsequent Scroll Request", + scrollID: "existingScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollID456", "took": 5, "hits": {"total": {"value": 1}, "hits": [{"_source": {"field2": "value2"}}]}}`) + }, + expectError: false, + expectedScrollID: "scrollID456", + }, + { + name: "Error Response", + scrollID: "", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": "Internal Server Error"}`) + }, + expectError: true, + }, + { + name: "No More Hits", + scrollID: "someScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"_scroll_id": "scrollIDNoHits", "took": 5, "hits": {"hits": []}}`) + }, + expectError: false, + expectedScrollID: "scrollIDNoHits", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + resp, err := os2Client.Scroll(context.Background(), "testIndex", "{}", tc.scrollID) + + if tc.expectError { + assert.Error(t, err) + } else if tc.name == "No More Hits" { + assert.Equal(t, io.EOF, err, "Expected io.EOF error for no more hits") + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, tc.expectedScrollID, resp.ScrollID) + } + }) + } +} + +func TestClearScroll(t *testing.T) { + testCases := []struct { + name string + scrollID string + handler http.HandlerFunc + expectedError bool + }{ + { + name: "Successful Scroll Clear", + scrollID: "testScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, `{}`) + }, + expectedError: false, + }, + { + name: "OpenSearch Server Error", + scrollID: "testScrollID", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, `{"error": {"root_cause": [{"type": "internal_server_error","reason": "Internal server error"}],"type": "internal_server_error","reason": "Internal server error"}}`) + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + err := os2Client.ClearScroll(context.Background(), tc.scrollID) + + if tc.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestSearch(t *testing.T) { + testCases := []struct { + name string + index string + body string + handler http.HandlerFunc + expectedError bool + expectedHits int + }{ + { + name: "Successful Search", + index: "testIndex", + body: `{"query": {"match_all": {}}}`, + handler: func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"took": 10, "hits": {"total": {"value": 2}, "hits": [{"_source": {"field": "value"}}, {"_source": {"field": "another value"}}]}}`) + }, + expectedError: false, + expectedHits: 2, + }, + { + name: "OpenSearch Error", + index: "testIndex", + body: `{"query": {"match_all": {}}}`, + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, `{"error": "Bad request"}`) + }, + expectedError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + os2Client, testServer := getSecureMockOS2Client(t, tc.handler, true) + defer testServer.Close() + + resp, err := os2Client.Search(context.Background(), tc.index, tc.body) + + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, resp) + } else { + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Len(t, resp.Hits.Hits, tc.expectedHits) + } + }) + } +}