diff --git a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java index b476859d8..b15053fae 100644 --- a/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java +++ b/src/main/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientInternal.java @@ -226,7 +226,9 @@ public class SnowflakeStreamingIngestClientInternal implements SnowflakeStrea String.format("%s_%s", this.name, System.currentTimeMillis())); logger.logInfo("Using {} for authorization", this.requestBuilder.getAuthType()); + } + if (this.requestBuilder != null) { // Setup client telemetries if needed this.setupMetricsForClient(); } diff --git a/src/test/java/net/snowflake/ingest/TestUtils.java b/src/test/java/net/snowflake/ingest/TestUtils.java index ba14ff610..f512dca86 100644 --- a/src/test/java/net/snowflake/ingest/TestUtils.java +++ b/src/test/java/net/snowflake/ingest/TestUtils.java @@ -25,7 +25,9 @@ import java.security.KeyFactory; import java.security.KeyPair; import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; +import java.security.spec.InvalidKeySpecException; import java.security.spec.PKCS8EncodedKeySpec; import java.sql.Connection; import java.sql.DriverManager; @@ -102,7 +104,7 @@ public class TestUtils { * * @throws IOException if can't read profile */ - private static void init() throws Exception { + private static void init() throws NoSuchAlgorithmException, InvalidKeySpecException, IOException { String testProfilePath = getTestProfilePath(); Path path = Paths.get(testProfilePath); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java index 72e4f8c2f..ef3684774 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/ChannelCacheTest.java @@ -6,9 +6,12 @@ import static java.time.ZoneOffset.UTC; +import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.streaming.OpenChannelRequest; import org.junit.Assert; import org.junit.Before; @@ -39,7 +42,12 @@ public static Object[] isIcebergMode() { @Before public void setup() { cache = new ChannelCache<>(); - client = new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(); + RequestBuilder requestBuilder = MockSnowflakeServiceClient.createRequestBuilder(httpClient); + client = + new SnowflakeStreamingIngestClientInternal<>( + "client", null, null, httpClient, isIcebergMode, true, requestBuilder, new HashMap<>()); + channel1 = new SnowflakeStreamingIngestChannelInternal<>( "channel1", diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java index b46f8c800..b5ed0ba96 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/FlushServiceTest.java @@ -97,6 +97,7 @@ private abstract static class TestContext implements AutoCloseable { FlushService flushService; IStorageManager storageManager; InternalStage storage; + ExternalVolume extVolume; ParameterProvider parameterProvider; RegisterService registerService; @@ -104,6 +105,7 @@ private abstract static class TestContext implements AutoCloseable { TestContext() { storage = Mockito.mock(InternalStage.class); + extVolume = Mockito.mock(ExternalVolume.class); parameterProvider = new ParameterProvider(isIcebergMode); InternalParameterProvider internalParameterProvider = new InternalParameterProvider(isIcebergMode); @@ -113,9 +115,12 @@ private abstract static class TestContext implements AutoCloseable { storageManager = Mockito.spy( isIcebergMode - ? new ExternalVolumeManager(true, "role", "client", null) + ? new ExternalVolumeManager( + true, "role", "client", MockSnowflakeServiceClient.create()) : new InternalStageManager(true, "role", "client", null)); - Mockito.doReturn(storage).when(storageManager).getStorage(ArgumentMatchers.any()); + Mockito.doReturn(isIcebergMode ? extVolume : storage) + .when(storageManager) + .getStorage(ArgumentMatchers.any()); Mockito.when(storageManager.getClientPrefix()).thenReturn("client_prefix"); Mockito.when(client.getParameterProvider()) .thenAnswer((Answer) (i) -> parameterProvider); @@ -425,6 +430,7 @@ private static ColumnMetadata createLargeTestTextColumn(String name) { @Test public void testGetFilePath() { + // SNOW-1490151 Iceberg testing gaps if (isIcebergMode) { // TODO: SNOW-1502887 Blob path generation for iceberg table return; @@ -623,6 +629,7 @@ public void testBlobCreation() throws Exception { FlushService flushService = testContext.flushService; // Force = true flushes + // SNOW-1490151 Iceberg testing gaps if (!isIcebergMode) { flushService.flush(true).get(); Mockito.verify(flushService, Mockito.atLeast(2)) @@ -674,6 +681,7 @@ public void testBlobSplitDueToDifferentSchema() throws Exception { FlushService flushService = testContext.flushService; + // SNOW-1490151 Iceberg testing gaps if (!isIcebergMode) { // Force = true flushes flushService.flush(true).get(); @@ -711,6 +719,7 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { FlushService flushService = testContext.flushService; + // SNOW-1490151 Iceberg testing gaps if (!isIcebergMode) { // Force = true flushes flushService.flush(true).get(); @@ -721,6 +730,7 @@ public void testBlobSplitDueToChunkSizeLimit() throws Exception { @Test public void testBlobSplitDueToNumberOfChunks() throws Exception { + // SNOW-1490151 Iceberg testing gaps if (isIcebergMode) { return; } @@ -799,6 +809,7 @@ public void testBlobSplitDueToNumberOfChunksWithLeftoverChannels() throws Except channel3.setupSchema(Collections.singletonList(createLargeTestTextColumn("C1"))); channel3.insertRow(Collections.singletonMap("C1", 0), ""); + // SNOW-1490151 Iceberg testing gaps if (isIcebergMode) { return; } diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java index 37e41df5a..ebe6742a2 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/InsertRowsBenchmarkTest.java @@ -10,6 +10,8 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.TimeUnit; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.streaming.InsertValidationResponse; import net.snowflake.ingest.streaming.OpenChannelRequest; import net.snowflake.ingest.utils.Utils; @@ -50,9 +52,19 @@ public static Object[] isIcebergMode() { @Setup(Level.Trial) public void setUpBeforeAll() { // SNOW-1490151: Testing gaps + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(); + RequestBuilder requestBuilder = MockSnowflakeServiceClient.createRequestBuilder(httpClient); client = - new SnowflakeStreamingIngestClientInternal( - "client_PARQUET", isIcebergMode); + new SnowflakeStreamingIngestClientInternal<>( + "client_PARQUET", + null, + null, + httpClient, + isIcebergMode, + true, + requestBuilder, + new HashMap<>()); + channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java b/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java index 26a1921ab..5f8243299 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/MockSnowflakeServiceClient.java @@ -9,149 +9,231 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import net.snowflake.client.jdbc.internal.apache.commons.io.IOUtils; import net.snowflake.client.jdbc.internal.apache.http.HttpEntity; import net.snowflake.client.jdbc.internal.apache.http.HttpStatus; import net.snowflake.client.jdbc.internal.apache.http.HttpVersion; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; +import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost; import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpUriRequest; import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; import net.snowflake.client.jdbc.internal.apache.http.message.BasicStatusLine; +import net.snowflake.client.log.SFLogger; +import net.snowflake.client.log.SFLoggerFactory; import net.snowflake.ingest.TestUtils; import net.snowflake.ingest.connection.RequestBuilder; +import org.apache.commons.lang3.tuple.Pair; import org.mockito.Mockito; import org.mockito.stubbing.Answer; public class MockSnowflakeServiceClient { private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final SFLogger LOGGER = + SFLoggerFactory.getLogger(MockSnowflakeServiceClient.class); + + public static class ApiOverride { + private final Map>>> + apiOverrides = new HashMap<>(); + + public void addMapOverride( + String uriPath, Function>> override) { + apiOverrides.put(uriPath.toUpperCase(), override); + } + + public void addSerializedJsonOverride( + String uriPath, Function> override) { + addMapOverride( + uriPath, + request -> { + Pair pair = override.apply(request); + try { + Map map = objectMapper.readValue(pair.getRight(), HashMap.class); + return Pair.of(pair.getLeft(), map); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } public static SnowflakeServiceClient create() { - RequestBuilder requestBuilder = null; try { - requestBuilder = new RequestBuilder("test_host", "test_name", TestUtils.getKeyPair()); - return new SnowflakeServiceClient(createHttpClient(), requestBuilder); + CloseableHttpClient httpClient = createHttpClient(new ApiOverride()); + RequestBuilder requestBuilder = createRequestBuilder(httpClient); + return new SnowflakeServiceClient(httpClient, requestBuilder); } catch (Exception e) { throw new RuntimeException(e); } } - private static CloseableHttpClient createHttpClient() throws IOException { - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - - Mockito.doAnswer( - (Answer) - invocation -> { - HttpUriRequest request = invocation.getArgument(0); - switch (request.getURI().getPath()) { - case CLIENT_CONFIGURE_ENDPOINT: - Map clientConfigresponseMap = new HashMap<>(); - clientConfigresponseMap.put("prefix", "test_prefix"); - clientConfigresponseMap.put("status_code", 0L); - clientConfigresponseMap.put("message", "OK"); - clientConfigresponseMap.put("stage_location", getStageLocationMap()); - clientConfigresponseMap.put("deployment_id", 123L); - return buildStreamingIngestResponse(clientConfigresponseMap); - case CHANNEL_CONFIGURE_ENDPOINT: - Map channelConfigResponseMap = new HashMap<>(); - channelConfigResponseMap.put("status_code", 0L); - channelConfigResponseMap.put("message", "OK"); - channelConfigResponseMap.put("stage_location", getStageLocationMap()); - return buildStreamingIngestResponse(channelConfigResponseMap); - case OPEN_CHANNEL_ENDPOINT: - List> tableColumnsLists = new ArrayList<>(); - Map tableColumnMap = new HashMap<>(); - tableColumnMap.put("byteLength", 123L); - tableColumnMap.put("length", 0L); - tableColumnMap.put("logicalType", "test_logical_type"); - tableColumnMap.put("name", "test_column"); - tableColumnMap.put("nullable", true); - tableColumnMap.put("precision", 0L); - tableColumnMap.put("scale", 0L); - tableColumnMap.put("type", "test_type"); - tableColumnMap.put("ordinal", 0L); - tableColumnsLists.add(tableColumnMap); - Map openChannelResponseMap = new HashMap<>(); - openChannelResponseMap.put("status_code", 0L); - openChannelResponseMap.put("message", "OK"); - openChannelResponseMap.put("database", "test_db"); - openChannelResponseMap.put("schema", "test_schema"); - openChannelResponseMap.put("table", "test_table"); - openChannelResponseMap.put("channel", "test_channel"); - openChannelResponseMap.put("client_sequencer", 123L); - openChannelResponseMap.put("row_sequencer", 123L); - openChannelResponseMap.put("offset_token", "test_offset_token"); - openChannelResponseMap.put("table_columns", tableColumnsLists); - openChannelResponseMap.put("encryption_key", "test_encryption_key"); - openChannelResponseMap.put("encryption_key_id", 123L); - openChannelResponseMap.put("iceberg_location", getStageLocationMap()); - return buildStreamingIngestResponse(openChannelResponseMap); - case DROP_CHANNEL_ENDPOINT: - Map dropChannelResponseMap = new HashMap<>(); - dropChannelResponseMap.put("status_code", 0L); - dropChannelResponseMap.put("message", "OK"); - dropChannelResponseMap.put("database", "test_db"); - dropChannelResponseMap.put("schema", "test_schema"); - dropChannelResponseMap.put("table", "test_table"); - dropChannelResponseMap.put("channel", "test_channel"); - return buildStreamingIngestResponse(dropChannelResponseMap); - case CHANNEL_STATUS_ENDPOINT: - List> channelStatusList = new ArrayList<>(); - Map channelStatusMap = new HashMap<>(); - channelStatusMap.put("status_code", 0L); - channelStatusMap.put("persisted_row_sequencer", 123L); - channelStatusMap.put("persisted_client_sequencer", 123L); - channelStatusMap.put("persisted_offset_token", "test_offset_token"); - Map channelStatusResponseMap = new HashMap<>(); - channelStatusResponseMap.put("status_code", 0L); - channelStatusResponseMap.put("message", "OK"); - channelStatusResponseMap.put("channels", channelStatusList); - return buildStreamingIngestResponse(channelStatusResponseMap); - case REGISTER_BLOB_ENDPOINT: - List> channelList = new ArrayList<>(); - Map channelMap = new HashMap<>(); - channelMap.put("status_code", 0L); - channelMap.put("message", "OK"); - channelMap.put("channel", "test_channel"); - channelMap.put("client_sequencer", 123L); - channelList.add(channelMap); - List> chunkList = new ArrayList<>(); - Map chunkMap = new HashMap<>(); - chunkMap.put("channels", channelList); - chunkMap.put("database", "test_db"); - chunkMap.put("schema", "test_schema"); - chunkMap.put("table", "test_table"); - chunkList.add(chunkMap); - List> blobsList = new ArrayList<>(); - Map blobMap = new HashMap<>(); - blobMap.put("chunks", chunkList); - blobsList.add(blobMap); - Map registerBlobResponseMap = new HashMap<>(); - registerBlobResponseMap.put("status_code", 0L); - registerBlobResponseMap.put("message", "OK"); - registerBlobResponseMap.put("blobs", blobsList); - return buildStreamingIngestResponse(registerBlobResponseMap); - default: - assert false; - } - return null; - }) - .when(httpClient) - .execute(Mockito.any()); - - return httpClient; + public static RequestBuilder createRequestBuilder(CloseableHttpClient httpClient) { + try { + return new RequestBuilder( + "test_host", + "test_name", + TestUtils.getKeyPair(), + "https", + "snowflakecomputing.com", + 443, + null, + null, + httpClient, + "mock_client"); + } catch (Exception e) { + throw new RuntimeException(e); + } } - private static CloseableHttpResponse buildStreamingIngestResponse(Map payload) - throws IOException { - CloseableHttpResponse response = Mockito.mock(CloseableHttpResponse.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); + public static CloseableHttpClient createHttpClient() { + return createHttpClient(new ApiOverride()); + } + + public static CloseableHttpClient createHttpClient(ApiOverride apiOverride) { + try { + CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); + + Mockito.doAnswer( + (Answer) + invocation -> { + HttpUriRequest request = invocation.getArgument(0); + if (request.getMethod().equals(HttpPost.METHOD_NAME)) { + LOGGER.debug( + request.toString() + + IOUtils.toString( + ((HttpPost) request).getEntity().getContent(), + StandardCharsets.UTF_8)); + } + + String path = request.getURI().getPath(); + if (apiOverride.apiOverrides.containsKey(path.toUpperCase())) { + Pair> responsePayload = + apiOverride.apiOverrides.get(path.toUpperCase()).apply(request); + return buildStreamingIngestResponse( + responsePayload.getLeft(), responsePayload.getRight()); + } + switch (path) { + case "/telemetry/send/sessionless": + return buildStreamingIngestResponse(HttpStatus.SC_OK, new HashMap<>()); + case CLIENT_CONFIGURE_ENDPOINT: + Map clientConfigresponseMap = new HashMap<>(); + clientConfigresponseMap.put("prefix", "test_prefix"); + clientConfigresponseMap.put("status_code", 0L); + clientConfigresponseMap.put("message", "OK"); + clientConfigresponseMap.put("stage_location", getStageLocationMap()); + clientConfigresponseMap.put("deployment_id", 123L); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, clientConfigresponseMap); + case CHANNEL_CONFIGURE_ENDPOINT: + Map channelConfigResponseMap = new HashMap<>(); + channelConfigResponseMap.put("status_code", 0L); + channelConfigResponseMap.put("message", "OK"); + channelConfigResponseMap.put("stage_location", getStageLocationMap()); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, channelConfigResponseMap); + case OPEN_CHANNEL_ENDPOINT: + List> tableColumnsLists = new ArrayList<>(); + Map tableColumnMap = new HashMap<>(); + tableColumnMap.put("byteLength", 123L); + tableColumnMap.put("length", 0L); + tableColumnMap.put("logicalType", "test_logical_type"); + tableColumnMap.put("name", "test_column"); + tableColumnMap.put("nullable", true); + tableColumnMap.put("precision", 0L); + tableColumnMap.put("scale", 0L); + tableColumnMap.put("type", "test_type"); + tableColumnMap.put("ordinal", 0L); + tableColumnsLists.add(tableColumnMap); + Map openChannelResponseMap = new HashMap<>(); + openChannelResponseMap.put("status_code", 0L); + openChannelResponseMap.put("message", "OK"); + openChannelResponseMap.put("database", "test_db"); + openChannelResponseMap.put("schema", "test_schema"); + openChannelResponseMap.put("table", "test_table"); + openChannelResponseMap.put("channel", "test_channel"); + openChannelResponseMap.put("client_sequencer", 123L); + openChannelResponseMap.put("row_sequencer", 123L); + openChannelResponseMap.put("offset_token", "test_offset_token"); + openChannelResponseMap.put("table_columns", tableColumnsLists); + openChannelResponseMap.put("encryption_key", "test_encryption_key"); + openChannelResponseMap.put("encryption_key_id", 123L); + openChannelResponseMap.put("iceberg_location", getStageLocationMap()); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, openChannelResponseMap); + case DROP_CHANNEL_ENDPOINT: + Map dropChannelResponseMap = new HashMap<>(); + dropChannelResponseMap.put("status_code", 0L); + dropChannelResponseMap.put("message", "OK"); + dropChannelResponseMap.put("database", "test_db"); + dropChannelResponseMap.put("schema", "test_schema"); + dropChannelResponseMap.put("table", "test_table"); + dropChannelResponseMap.put("channel", "test_channel"); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, dropChannelResponseMap); + case CHANNEL_STATUS_ENDPOINT: + List> channelStatusList = new ArrayList<>(); + Map channelStatusMap = new HashMap<>(); + channelStatusMap.put("status_code", 0L); + channelStatusMap.put("persisted_row_sequencer", 123L); + channelStatusMap.put("persisted_client_sequencer", 123L); + channelStatusMap.put("persisted_offset_token", "test_offset_token"); + Map channelStatusResponseMap = new HashMap<>(); + channelStatusResponseMap.put("status_code", 0L); + channelStatusResponseMap.put("message", "OK"); + channelStatusResponseMap.put("channels", channelStatusList); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, channelStatusResponseMap); + case REGISTER_BLOB_ENDPOINT: + List> channelList = new ArrayList<>(); + Map channelMap = new HashMap<>(); + channelMap.put("status_code", 0L); + channelMap.put("message", "OK"); + channelMap.put("channel", "test_channel"); + channelMap.put("client_sequencer", 123L); + channelList.add(channelMap); + List> chunkList = new ArrayList<>(); + Map chunkMap = new HashMap<>(); + chunkMap.put("channels", channelList); + chunkMap.put("database", "test_db"); + chunkMap.put("schema", "test_schema"); + chunkMap.put("table", "test_table"); + chunkList.add(chunkMap); + List> blobsList = new ArrayList<>(); + Map blobMap = new HashMap<>(); + blobMap.put("chunks", chunkList); + blobsList.add(blobMap); + Map registerBlobResponseMap = new HashMap<>(); + registerBlobResponseMap.put("status_code", 0L); + registerBlobResponseMap.put("message", "OK"); + registerBlobResponseMap.put("blobs", blobsList); + return buildStreamingIngestResponse( + HttpStatus.SC_OK, registerBlobResponseMap); + default: + assert false; + } + return null; + }) + .when(httpClient) + .execute(Mockito.any()); + + return httpClient; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private static CloseableHttpResponse buildStreamingIngestResponse( + int statusCode, Map payload) throws IOException { + CloseableHttpResponse response = Mockito.mock(CloseableHttpResponse.class); Mockito.when(response.getStatusLine()) - .thenReturn(new BasicStatusLine(HttpVersion.HTTP_1_1, HttpStatus.SC_OK, "OK")); + .thenReturn( + new BasicStatusLine(HttpVersion.HTTP_1_1, statusCode, String.valueOf(statusCode))); + HttpEntity httpEntity = Mockito.mock(HttpEntity.class); Mockito.when(response.getEntity()).thenReturn(httpEntity); Mockito.when(httpEntity.getContent()) .thenReturn(IOUtils.toInputStream(objectMapper.writeValueAsString(payload))); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/RegisterServiceTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/RegisterServiceTest.java index 9d7b95008..f4c152baa 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/RegisterServiceTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/RegisterServiceTest.java @@ -9,13 +9,18 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import net.snowflake.client.jdbc.internal.apache.http.impl.client.CloseableHttpClient; +import net.snowflake.ingest.connection.RequestBuilder; import net.snowflake.ingest.utils.Pair; +import org.junit.After; import org.junit.Assert; +import org.junit.Before; import org.junit.Ignore; import org.junit.Test; import org.junit.runner.RunWith; @@ -30,6 +35,22 @@ public static Object[] isIcebergMode() { @Parameterized.Parameter public boolean isIcebergMode; + private SnowflakeStreamingIngestClientInternal client; + + @Before + public void setup() { + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(); + RequestBuilder requestBuilder = MockSnowflakeServiceClient.createRequestBuilder(httpClient); + client = + new SnowflakeStreamingIngestClientInternal<>( + "client", null, null, httpClient, isIcebergMode, true, requestBuilder, new HashMap<>()); + } + + @After + public void teardown() throws Exception { + client.close(); + } + @Test public void testRegisterService() throws ExecutionException, InterruptedException { RegisterService rs = new RegisterService<>(null, true); @@ -57,8 +78,6 @@ public void testRegisterService() throws ExecutionException, InterruptedExceptio */ @Test public void testRegisterServiceTimeoutException() throws Exception { - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); RegisterService rs = new RegisterService<>(client, true); Pair, CompletableFuture> blobFuture1 = @@ -85,8 +104,6 @@ public void testRegisterServiceTimeoutException() throws Exception { @Ignore @Test public void testRegisterServiceTimeoutException_testRetries() throws Exception { - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); RegisterService rs = new RegisterService<>(client, true); Pair, CompletableFuture> blobFuture1 = @@ -119,8 +136,6 @@ public void testRegisterServiceTimeoutException_testRetries() throws Exception { @Test public void testRegisterServiceNonTimeoutException() { - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); RegisterService rs = new RegisterService<>(client, true); CompletableFuture future = new CompletableFuture<>(); diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java index d576c64fd..b3e2d23d6 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestChannelTest.java @@ -31,6 +31,7 @@ import net.snowflake.client.jdbc.internal.apache.commons.io.IOUtils; import net.snowflake.client.jdbc.internal.apache.http.HttpEntity; import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders; +import net.snowflake.client.jdbc.internal.apache.http.HttpStatus; import net.snowflake.client.jdbc.internal.apache.http.StatusLine; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost; @@ -46,7 +47,10 @@ import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.SnowflakeURL; import net.snowflake.ingest.utils.Utils; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.After; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -81,6 +85,27 @@ public static Object[] isIcebergMode() { @Parameterized.Parameter public boolean isIcebergMode; + private SnowflakeStreamingIngestClientInternal client; + private MockSnowflakeServiceClient.ApiOverride apiOverride; + + @Before + public void setup() { + apiOverride = new MockSnowflakeServiceClient.ApiOverride(); + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(apiOverride); + RequestBuilder requestBuilder = MockSnowflakeServiceClient.createRequestBuilder(httpClient); + client = + new SnowflakeStreamingIngestClientInternal<>( + "client", null, null, httpClient, isIcebergMode, true, requestBuilder, new HashMap<>()); + + // some tests assume client is a mock object, just do it for everyone. + client = Mockito.spy(client); + } + + @After + public void teardown() throws Exception { + client.close(); + } + @Test public void testChannelFactoryNullFields() { String name = "CHANNEL"; @@ -89,8 +114,6 @@ public void testChannelFactoryNullFields() { String tableName = "TABLE"; long channelSequencer = 0L; long rowSequencer = 0L; - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); Object[] fields = new Object[] { @@ -131,9 +154,6 @@ public void testChannelFactorySuccess() { Long channelSequencer = 0L; long rowSequencer = 0L; - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); - SnowflakeStreamingIngestChannelInternal channel = SnowflakeStreamingIngestChannelFactory.builder(name) .setDBName(dbName) @@ -166,8 +186,6 @@ public void testChannelFactorySuccess() { @Test public void testChannelValid() { - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -216,8 +234,6 @@ public void testChannelValid() { @Test public void testChannelClose() { - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -342,29 +358,9 @@ public void testOpenChannelPostRequest() throws Exception { @Test public void testOpenChannelErrorResponse() throws Exception { - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(500); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - String responseString = "testOpenChannelErrorResponse"; - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addMapOverride( + OPEN_CHANNEL_ENDPOINT, + req -> Pair.of(HttpStatus.SC_INTERNAL_SERVER_ERROR, new HashMap<>())); OpenChannelRequest request = OpenChannelRequest.builder("CHANNEL") @@ -413,29 +409,8 @@ public void testOpenChannelSnowflakeInternalErrorResponse() throws Exception { + " \"nullable\" : true\n" + " } ]\n" + "}"; - - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - Mockito.when(statusLine.getStatusCode()).thenReturn(200); - Mockito.when(httpResponse.getStatusLine()).thenReturn(statusLine); - Mockito.when(httpResponse.getEntity()).thenReturn(httpEntity); - Mockito.when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - Mockito.when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addSerializedJsonOverride( + OPEN_CHANNEL_ENDPOINT, req -> Pair.of(HttpStatus.SC_OK, response)); OpenChannelRequest request = OpenChannelRequest.builder("CHANNEL") @@ -557,10 +532,6 @@ public void testOpenChannelSuccessResponse() throws Exception { @Test public void testInsertRow() { - SnowflakeStreamingIngestClientInternal client; - client = - new SnowflakeStreamingIngestClientInternal( - "client_PARQUET", isIcebergMode); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -646,10 +617,6 @@ public void testInsertTooLargeRow() { Map row = new HashMap<>(); schema.forEach(x -> row.put(x.getName(), byteArrayOneMb)); - SnowflakeStreamingIngestClientInternal client; - client = - new SnowflakeStreamingIngestClientInternal("test_client", isIcebergMode); - // Test channel with on error CONTINUE SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( @@ -731,8 +698,6 @@ public void testInsertRowThrottling() { final MockedMemoryInfoProvider memoryInfoProvider = new MockedMemoryInfoProvider(); memoryInfoProvider.maxMemory = maxMemory; - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -777,8 +742,6 @@ public void testInsertRowThrottling() { @Test public void testFlush() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -813,8 +776,6 @@ public void testFlush() throws Exception { @Test public void testClose() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannel channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -847,8 +808,6 @@ public void testClose() throws Exception { @Test public void testDropOnClose() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -884,8 +843,6 @@ public void testDropOnClose() throws Exception { @Test public void testDropOnCloseInvalidChannel() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", @@ -917,8 +874,6 @@ public void testDropOnCloseInvalidChannel() throws Exception { @Test public void testGetLatestCommittedOffsetToken() { String offsetToken = "10"; - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannel channel = new SnowflakeStreamingIngestChannelInternal<>( "channel", diff --git a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java index 1ce310706..a24ceb2a9 100644 --- a/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java +++ b/src/test/java/net/snowflake/ingest/streaming/internal/SnowflakeStreamingIngestClientTest.java @@ -24,10 +24,10 @@ import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; import java.io.StringWriter; -import java.nio.charset.Charset; import java.security.KeyPair; import java.security.PrivateKey; import java.time.ZoneOffset; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -42,6 +42,7 @@ import net.snowflake.client.jdbc.internal.apache.commons.io.IOUtils; import net.snowflake.client.jdbc.internal.apache.http.HttpEntity; import net.snowflake.client.jdbc.internal.apache.http.HttpHeaders; +import net.snowflake.client.jdbc.internal.apache.http.HttpStatus; import net.snowflake.client.jdbc.internal.apache.http.StatusLine; import net.snowflake.client.jdbc.internal.apache.http.client.methods.CloseableHttpResponse; import net.snowflake.client.jdbc.internal.apache.http.client.methods.HttpPost; @@ -55,11 +56,11 @@ import net.snowflake.ingest.streaming.SnowflakeStreamingIngestClientFactory; import net.snowflake.ingest.utils.Constants; import net.snowflake.ingest.utils.ErrorCode; -import net.snowflake.ingest.utils.Pair; import net.snowflake.ingest.utils.ParameterProvider; import net.snowflake.ingest.utils.SFException; import net.snowflake.ingest.utils.SnowflakeURL; import net.snowflake.ingest.utils.Utils; +import org.apache.commons.lang3.tuple.Pair; import org.bouncycastle.asn1.nist.NISTObjectIdentifiers; import org.bouncycastle.openssl.jcajce.JcaPEMWriter; import org.bouncycastle.operator.OperatorCreationException; @@ -93,6 +94,10 @@ public static Object[] isIcebergMode() { @Parameterized.Parameter public boolean isIcebergMode; + SnowflakeStreamingIngestClientInternal client; + private MockSnowflakeServiceClient.ApiOverride apiOverride; + RequestBuilder requestBuilder; + @Before public void setup() throws Exception { objectMapper.setVisibility(PropertyAccessor.GETTER, JsonAutoDetect.Visibility.ANY); @@ -103,19 +108,13 @@ public void setup() throws Exception { prop.put(PRIVATE_KEY, TestUtils.getPrivateKey()); prop.put(ROLE, TestUtils.getRole()); - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = + apiOverride = new MockSnowflakeServiceClient.ApiOverride(); + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(apiOverride); + requestBuilder = Mockito.spy(MockSnowflakeServiceClient.createRequestBuilder(httpClient)); + client = new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + "client", null, null, httpClient, isIcebergMode, true, requestBuilder, new HashMap<>()); + channel1 = new SnowflakeStreamingIngestChannelInternal<>( "channel1", @@ -358,29 +357,8 @@ public void testGetChannelsStatusWithRequest() throws Exception { response.setChannels(Collections.singletonList(channelStatus)); String responseString = objectMapper.writeValueAsString(response); - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addSerializedJsonOverride( + CHANNEL_STATUS_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, responseString)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( @@ -416,31 +394,8 @@ public void testDropChannel() throws Exception { response.setStatusCode(RESPONSE_SUCCESS); response.setMessage("dropped"); String responseString = objectMapper.writeValueAsString(response); - - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()) - .thenReturn(IOUtils.toInputStream(responseString, Charset.defaultCharset())); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addSerializedJsonOverride( + DROP_CHANNEL_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, responseString)); DropChannelRequest request = DropChannelRequest.builder("channel") @@ -463,30 +418,9 @@ public void testGetChannelsStatusWithRequestError() throws Exception { response.setMessage("honk"); response.setChannels(new ArrayList<>()); String responseString = objectMapper.writeValueAsString(response); - - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(500); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addSerializedJsonOverride( + CHANNEL_STATUS_ENDPOINT, + request -> Pair.of(HttpStatus.SC_INTERNAL_SERVER_ERROR, responseString)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( @@ -527,10 +461,10 @@ public void testRegisterBlobRequestCreationSuccess() throws Exception { KeyPair keyPair = Utils.createKeyPairFromPrivateKey( (PrivateKey) prop.get(SFSessionProperty.PRIVATE_KEY.getPropertyKey())); + CloseableHttpClient httpClient = MockSnowflakeServiceClient.createHttpClient(); RequestBuilder requestBuilder = - new RequestBuilder(url, prop.get(USER).toString(), keyPair, null, null); + new RequestBuilder(url, prop.get(USER).toString(), keyPair, httpClient, null); - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); SnowflakeStreamingIngestClientInternal client = new SnowflakeStreamingIngestClientInternal<>( "client", @@ -714,28 +648,14 @@ private Pair, Set> getRetryBlobMetadata( badChunkRegisterStatus.setTableName(chunkMetadata1.getTableName()); badChunkRegisterStatus.setChannelsStatus(channelRegisterStatuses); badChunks.add(badChunkRegisterStatus); - return new Pair<>(blobs, badChunks); + return Pair.of(blobs, badChunks); } @Test public void testGetRetryBlobs() throws Exception { - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); Pair, Set> testData = getRetryBlobMetadata(); - List blobs = testData.getFirst(); - Set badChunks = testData.getSecond(); + List blobs = testData.getLeft(); + Set badChunks = testData.getRight(); List result = client.getRetryBlobs(badChunks, blobs); Assert.assertEquals(1, result.size()); Assert.assertEquals("path1", result.get(0).getPath()); @@ -752,29 +672,9 @@ public void testGetRetryBlobs() throws Exception { @Test public void testRegisterBlobErrorResponse() throws Exception { - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(500); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - String response = "testRegisterBlobErrorResponse"; - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addMapOverride( + REGISTER_BLOB_ENDPOINT, + request -> Pair.of(HttpStatus.SC_INTERNAL_SERVER_ERROR, new HashMap<>())); try { List blobs = @@ -801,29 +701,8 @@ public void testRegisterBlobSnowflakeInternalErrorResponse() throws Exception { + " } ]\n" + "}"; - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); - + apiOverride.addSerializedJsonOverride( + REGISTER_BLOB_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, response)); try { List blobs = Collections.singletonList(new BlobMetadata("path", "md5", new ArrayList<>(), null)); @@ -858,29 +737,8 @@ public void testRegisterBlobSuccessResponse() throws Exception { + " } ]\n" + "}"; - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); - + apiOverride.addSerializedJsonOverride( + REGISTER_BLOB_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, response)); List blobs = Collections.singletonList(new BlobMetadata("path", "md5", new ArrayList<>(), null)); client.registerBlobs(blobs); @@ -889,8 +747,8 @@ public void testRegisterBlobSuccessResponse() throws Exception { @Test public void testRegisterBlobsRetries() throws Exception { Pair, Set> testData = getRetryBlobMetadata(); - List blobs = testData.getFirst(); - Set badChunks = testData.getSecond(); + List blobs = testData.getLeft(); + Set badChunks = testData.getRight(); ChunkRegisterStatus goodChunkRegisterStatus = new ChunkRegisterStatus(); goodChunkRegisterStatus.setDBName(blobs.get(0).getChunks().get(0).getDBName()); @@ -936,35 +794,13 @@ public void testRegisterBlobsRetries() throws Exception { String responseString = objectMapper.writeValueAsString(initialResponse); String retryResponseString = objectMapper.writeValueAsString(retryResponse); - - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()) - .thenReturn( - IOUtils.toInputStream(responseString), - IOUtils.toInputStream(retryResponseString), - IOUtils.toInputStream(retryResponseString), - IOUtils.toInputStream(retryResponseString)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + ArrayDeque responses = new ArrayDeque<>(); + responses.offer(responseString); + responses.offer(retryResponseString); + responses.offer(retryResponseString); + responses.offer(retryResponseString); + apiOverride.addSerializedJsonOverride( + REGISTER_BLOB_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, responses.poll())); client.getChannelCache().addChannel(channel1); client.getChannelCache().addChannel(channel2); @@ -979,23 +815,6 @@ public void testRegisterBlobsRetries() throws Exception { @Test public void testRegisterBlobChunkLimit() throws Exception { - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - - SnowflakeStreamingIngestClientInternal client = - Mockito.spy( - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null)); - assertEquals(0, client.partitionBlobListForRegistrationRequest(new ArrayList<>()).size()); assertEquals( 1, client.partitionBlobListForRegistrationRequest(createTestBlobMetadata(1)).size()); @@ -1067,8 +886,8 @@ private List createTestBlobMetadata(int... numbersOfChunks) { @Test public void testRegisterBlobsRetriesSucceeds() throws Exception { Pair, Set> testData = getRetryBlobMetadata(); - List blobs = testData.getFirst(); - Set badChunks = testData.getSecond(); + List blobs = testData.getLeft(); + Set badChunks = testData.getRight(); ChunkRegisterStatus goodChunkRegisterStatus = new ChunkRegisterStatus(); goodChunkRegisterStatus.setDBName(blobs.get(0).getChunks().get(0).getDBName()); @@ -1221,29 +1040,8 @@ public void testRegisterBlobResponseWithInvalidChannel() throws Exception { channel2Name, channel2Sequencer); - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(response)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair()); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); - + apiOverride.addSerializedJsonOverride( + REGISTER_BLOB_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, response)); SnowflakeStreamingIngestChannelInternal channel1 = new SnowflakeStreamingIngestChannelInternal<>( channel1Name, @@ -1289,15 +1087,6 @@ public void testRegisterBlobResponseWithInvalidChannel() throws Exception { @Test public void testFlush() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); - ChannelsStatusResponse response = new ChannelsStatusResponse(); - response.setStatusCode(0L); - response.setMessage("Success"); - response.setChannels(new ArrayList<>()); - - Mockito.doReturn(response).when(client).getChannelsStatus(Mockito.any()); - client.flush(false).get(); // Calling flush on closed client should fail @@ -1311,15 +1100,6 @@ public void testFlush() throws Exception { @Test public void testClose() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); - ChannelsStatusResponse response = new ChannelsStatusResponse(); - response.setStatusCode(0L); - response.setMessage("Success"); - response.setChannels(new ArrayList<>()); - - Mockito.doReturn(response).when(client).getChannelsStatus(Mockito.any()); - Assert.assertFalse(client.isClosed()); client.close(); Assert.assertTrue(client.isClosed()); @@ -1345,8 +1125,7 @@ public void testClose() throws Exception { @Test public void testCloseWithError() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); + SnowflakeStreamingIngestClientInternal client = Mockito.spy(this.client); CompletableFuture future = new CompletableFuture<>(); future.completeExceptionally(new Exception("Simulating Error")); @@ -1383,8 +1162,6 @@ public void testCloseWithError() throws Exception { @Test public void testVerifyChannelsAreFullyCommittedSuccess() throws Exception { - SnowflakeStreamingIngestClientInternal client = - Mockito.spy(new SnowflakeStreamingIngestClientInternal<>("client", isIcebergMode)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>( "channel1", @@ -1409,8 +1186,10 @@ public void testVerifyChannelsAreFullyCommittedSuccess() throws Exception { channelStatus.setStatusCode(26L); channelStatus.setPersistedOffsetToken("0"); response.setChannels(Collections.singletonList(channelStatus)); + String responseString = objectMapper.writeValueAsString(response); - Mockito.doReturn(response).when(client).getChannelsStatus(Mockito.any()); + apiOverride.addSerializedJsonOverride( + CHANNEL_STATUS_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, responseString)); client.close(); } @@ -1450,29 +1229,8 @@ public void testGetLatestCommittedOffsetTokens() throws Exception { response.setChannels(Collections.singletonList(channelStatus)); String responseString = objectMapper.writeValueAsString(response); - CloseableHttpClient httpClient = Mockito.mock(CloseableHttpClient.class); - CloseableHttpResponse httpResponse = Mockito.mock(CloseableHttpResponse.class); - StatusLine statusLine = Mockito.mock(StatusLine.class); - HttpEntity httpEntity = Mockito.mock(HttpEntity.class); - when(statusLine.getStatusCode()).thenReturn(200); - when(httpResponse.getStatusLine()).thenReturn(statusLine); - when(httpResponse.getEntity()).thenReturn(httpEntity); - when(httpEntity.getContent()).thenReturn(IOUtils.toInputStream(responseString)); - when(httpClient.execute(Mockito.any())).thenReturn(httpResponse); - - RequestBuilder requestBuilder = - Mockito.spy( - new RequestBuilder(TestUtils.getHost(), TestUtils.getUser(), TestUtils.getKeyPair())); - SnowflakeStreamingIngestClientInternal client = - new SnowflakeStreamingIngestClientInternal<>( - "client", - new SnowflakeURL("snowflake.dev.local:8082"), - null, - httpClient, - isIcebergMode, - true, - requestBuilder, - null); + apiOverride.addSerializedJsonOverride( + CHANNEL_STATUS_ENDPOINT, request -> Pair.of(HttpStatus.SC_OK, responseString)); SnowflakeStreamingIngestChannelInternal channel = new SnowflakeStreamingIngestChannelInternal<>(