Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support init script at JdbcDatabaseContainer level #575

Merged
merged 5 commits into from
Nov 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@

package org.testcontainers.ext;

import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testcontainers.delegate.DatabaseDelegate;

import javax.script.ScriptException;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList;
import java.util.List;

Expand Down Expand Up @@ -210,6 +214,30 @@ public static boolean containsSqlScriptDelimiters(String script, String delim) {
return false;
}

/**
* Load script from classpath and apply it to the given database
*
* @param databaseDelegate database delegate for script execution
* @param initScriptPath the resource to load the init script from
*/
public static void runInitScript(DatabaseDelegate databaseDelegate, String initScriptPath) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method can also be used in ContainerDatabaseDriver.runInitScriptIfRequired(...) to replace almost all of the method's code, but I didn't do it to have diff smaller.

try {
URL resource = ScriptUtils.class.getClassLoader().getResource(initScriptPath);
if (resource == null) {
LOGGER.warn("Could not load classpath init script: {}", initScriptPath);
throw new ScriptLoadException("Could not load classpath init script: " + initScriptPath + ". Resource not found.");
}
String scripts = IOUtils.toString(resource, StandardCharsets.UTF_8);
executeDatabaseScript(databaseDelegate, initScriptPath, scripts);
} catch (IOException e) {
LOGGER.warn("Could not load classpath init script: {}", initScriptPath);
throw new ScriptLoadException("Could not load classpath init script: " + initScriptPath, e);
} catch (ScriptException e) {
LOGGER.error("Error while executing init script: {}", initScriptPath, e);
throw new UncategorizedScriptException("Error while executing init script: " + initScriptPath, e);
}
}

public static void executeDatabaseScript(DatabaseDelegate databaseDelegate, String scriptPath, String script) throws ScriptException {
executeDatabaseScript(databaseDelegate, scriptPath, script, false, false, DEFAULT_COMMENT_PREFIX, DEFAULT_STATEMENT_SEPARATOR, DEFAULT_BLOCK_COMMENT_START_DELIMITER, DEFAULT_BLOCK_COMMENT_END_DELIMITER);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,20 @@ public void testMySQL8() throws SQLException {
}
}

@Test
public void testExplicitInitScript() throws SQLException {
try (MySQLContainer container = (MySQLContainer) new MySQLContainer()
.withInitScript("somepath/init_mysql.sql")
.withLogConsumer(new Slf4jLogConsumer(logger))) {
container.start();

ResultSet resultSet = performQuery(container, "SELECT foo FROM bar");
String firstColumnValue = resultSet.getString(1);

assertEquals("Value from init script should equal real value", "hello world", firstColumnValue);
}
}

@Test
public void testEmptyPasswordWithNonRootUser() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import com.zaxxer.hikari.HikariConfig;
import com.zaxxer.hikari.HikariDataSource;
import org.junit.Rule;
import org.junit.Test;
import org.testcontainers.containers.JdbcDatabaseContainer;
import org.testcontainers.containers.PostgreSQLContainer;

import java.sql.ResultSet;
Expand All @@ -17,23 +17,42 @@
*/
public class SimplePostgreSQLTest {

@Rule
public PostgreSQLContainer postgres = new PostgreSQLContainer();

@Test
public void testSimple() throws SQLException {
try (PostgreSQLContainer postgres = new PostgreSQLContainer<>()) {
postgres.start();

ResultSet resultSet = performQuery(postgres, "SELECT 1");

int resultSetInt = resultSet.getInt(1);
assertEquals("A basic SELECT query succeeds", 1, resultSetInt);
}
}

@Test
public void testExplicitInitScript() throws SQLException {
try (PostgreSQLContainer postgres = new PostgreSQLContainer<>()
.withInitScript("somepath/init_postgresql.sql")) {
postgres.start();

ResultSet resultSet = performQuery(postgres, "SELECT foo FROM bar");

String firstColumnValue = resultSet.getString(1);
assertEquals("Value from init script should equal real value", "hello world", firstColumnValue);
}
}

private ResultSet performQuery(JdbcDatabaseContainer container, String sql) throws SQLException {
HikariConfig hikariConfig = new HikariConfig();
hikariConfig.setJdbcUrl(postgres.getJdbcUrl());
hikariConfig.setUsername(postgres.getUsername());
hikariConfig.setPassword(postgres.getPassword());
hikariConfig.setJdbcUrl(container.getJdbcUrl());
hikariConfig.setUsername(container.getUsername());
hikariConfig.setPassword(container.getPassword());

HikariDataSource ds = new HikariDataSource(hikariConfig);
Statement statement = ds.getConnection().createStatement();
statement.execute("SELECT 1");
statement.execute(sql);
ResultSet resultSet = statement.getResultSet();

resultSet.next();
int resultSetInt = resultSet.getInt(1);
assertEquals("A basic SELECT query succeeds", 1, resultSetInt);
return resultSet;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
CREATE TABLE bar (
foo VARCHAR(255)
);

INSERT INTO bar (foo) VALUES ('hello world');
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
package org.testcontainers.containers;

import lombok.NonNull;
import com.github.dockerjava.api.command.InspectContainerResponse;
import org.jetbrains.annotations.NotNull;
import org.rnorth.ducttape.ratelimits.RateLimiter;
import org.rnorth.ducttape.ratelimits.RateLimiterBuilder;
import org.rnorth.ducttape.unreliables.Unreliables;
import org.testcontainers.containers.traits.LinkableContainer;
import org.testcontainers.delegate.DatabaseDelegate;
import org.testcontainers.ext.ScriptUtils;
import org.testcontainers.jdbc.JdbcDatabaseDelegate;
import org.testcontainers.utility.MountableFile;

import java.sql.Connection;
Expand All @@ -26,6 +30,7 @@ public abstract class JdbcDatabaseContainer<SELF extends JdbcDatabaseContainer<S

private static final Object DRIVER_LOAD_MUTEX = new Object();
private Driver driver;
private String initScriptPath;
protected Map<String, String> parameters = new HashMap<>();

private static final RateLimiter DB_CONNECT_RATE_LIMIT = RateLimiterBuilder.newBuilder()
Expand Down Expand Up @@ -111,6 +116,11 @@ public SELF withConnectTimeoutSeconds(int connectTimeoutSeconds) {
return self();
}

public SELF withInitScript(String initScriptPath) {
this.initScriptPath = initScriptPath;
return self();
}

@Override
protected void waitUntilContainerStarted() {
// Repeatedly try and open a connection to the DB and execute a test query
Expand All @@ -135,6 +145,11 @@ protected void waitUntilContainerStarted() {
});
}

@Override
protected void containerIsStarted(InspectContainerResponse containerInfo) {
runInitScriptIfRequired();
}

/**
* Obtain an instance of the correct JDBC driver for this particular database container type
*
Expand Down Expand Up @@ -202,6 +217,15 @@ protected void optionallyMapResourceParameterAsVolume(@NotNull String paramName,
}
}

/**
* Load init script content and apply it to the database if initScriptPath is set
*/
protected void runInitScriptIfRequired() {
if (initScriptPath != null) {
ScriptUtils.runInitScript(getDatabaseDelegate(), initScriptPath);
}
}

public void setParameters(Map<String, String> parameters) {
this.parameters = parameters;
}
Expand All @@ -228,4 +252,8 @@ protected int getStartupTimeoutSeconds() {
protected int getConnectTimeoutSeconds() {
return connectTimeoutSeconds;
}

protected DatabaseDelegate getDatabaseDelegate() {
return new JdbcDatabaseDelegate(this, "");
}
}