Skip to content

Commit

Permalink
[FSTORE-8][append] Java fix for snowflake connector (#810)
Browse files Browse the repository at this point in the history
  • Loading branch information
moritzmeister authored Oct 3, 2022
1 parent a587bed commit 9b79da4
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 2 deletions.
8 changes: 8 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
<hoverfly.version>0.12.2</hoverfly.version>
<junit.version>5.9.1</junit.version>
<surefire-plugin.version>2.22.0</surefire-plugin.version>
<mockito.version>4.3.1</mockito.version>
</properties>

<dependencies>
Expand Down Expand Up @@ -272,6 +273,13 @@
<version>${spark.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ public Object read(String query, String dataFormat, Map<String, String> options,
throws FeatureStoreException, IOException {
Map<String, String> readOptions = sparkOptions();
if (!Strings.isNullOrEmpty(query)) {
// if table also specified we override to use query
readOptions.remove(Constants.SNOWFLAKE_TABLE);
readOptions.put("query", query);
}
return SparkEngine.getInstance().read(this, Constants.SNOWFLAKE_FORMAT, readOptions, null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ public static synchronized SparkEngine getInstance() {
return INSTANCE;
}

// for testing
public static void setInstance(SparkEngine sparkEngine) {
INSTANCE = sparkEngine;
}

@Getter
private SparkSession sparkSession;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.logicalclocks.hsfs;

import com.logicalclocks.hsfs.engine.SparkEngine;
import com.logicalclocks.hsfs.util.Constants;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
Expand All @@ -29,6 +31,7 @@
import java.util.Base64;
import java.util.Map;


public class TestStorageConnector {

@Test
Expand All @@ -46,6 +49,27 @@ public void testBigQueryCredentialsBase64Encoded(@TempDir Path tempDir) throws I

// Assert
Assertions.assertEquals(credentials,
new String(Base64.getDecoder().decode(sparkOptions.get(Constants.BIGQ_CREDENTIALS)), StandardCharsets.UTF_8));
new String(Base64.getDecoder().decode(sparkOptions.get(Constants.BIGQ_CREDENTIALS)), StandardCharsets.UTF_8));
}

@Test
public void testSnowflakeConnector_read() throws Exception {
// Arrange
StorageConnector.SnowflakeConnector snowflakeConnector = new StorageConnector.SnowflakeConnector();
snowflakeConnector.setTable(Constants.SNOWFLAKE_TABLE);

SparkEngine sparkEngine = Mockito.mock(SparkEngine.class);
SparkEngine.setInstance(sparkEngine);

ArgumentCaptor<Map> mapArg = ArgumentCaptor.forClass(Map.class);
String query = "select * from dbtable";

// Act
snowflakeConnector.read(query, null, null, null);
Mockito.verify(sparkEngine).read(Mockito.any(), Mockito.any(), mapArg.capture(), Mockito.any());

// Assert
Assertions.assertFalse(mapArg.getValue().containsKey(Constants.SNOWFLAKE_TABLE));
Assertions.assertEquals(query, mapArg.getValue().get("query"));
}
}

0 comments on commit 9b79da4

Please sign in to comment.