diff --git a/docker-compose-ci.yml b/docker-compose-ci.yml index ae2bfef1..9fbda290 100644 --- a/docker-compose-ci.yml +++ b/docker-compose-ci.yml @@ -17,6 +17,20 @@ services: ports: - "64790:3306" + postgres-multi-query: + image: postgres:latest + environment: + POSTGRES_DB: kestra + POSTGRES_USER: postgres + POSTGRES_PASSWORD: pg_passwd + healthcheck: + test: ["CMD-SHELL", "pg_isready -d $${POSTGRES_DB} -U $${POSTGRES_USER}"] + interval: 30s + timeout: 10s + retries: 10 + ports: + - "56983:5432" + postgres: image: bitnami/postgresql:latest environment: diff --git a/plugin-jdbc-mysql/src/test/resources/scripts/mysql_queries.sql b/plugin-jdbc-mysql/src/test/resources/scripts/mysql_queries.sql index 83abbd3d..d286d180 100644 --- a/plugin-jdbc-mysql/src/test/resources/scripts/mysql_queries.sql +++ b/plugin-jdbc-mysql/src/test/resources/scripts/mysql_queries.sql @@ -40,6 +40,6 @@ DROP TABLE IF EXISTS test_transaction; CREATE TABLE test_transaction ( id MEDIUMINT NOT NULL AUTO_INCREMENT, - name CHAR(30) NOT NULL, + name VARCHAR(30) NOT NULL, PRIMARY KEY (id) ); \ No newline at end of file diff --git a/plugin-jdbc-postgres/src/main/java/io/kestra/plugin/jdbc/postgresql/Queries.java b/plugin-jdbc-postgres/src/main/java/io/kestra/plugin/jdbc/postgresql/Queries.java new file mode 100644 index 00000000..349088d7 --- /dev/null +++ b/plugin-jdbc-postgres/src/main/java/io/kestra/plugin/jdbc/postgresql/Queries.java @@ -0,0 +1,77 @@ +package io.kestra.plugin.jdbc.postgresql; + +import io.kestra.core.models.annotations.Example; +import io.kestra.core.models.annotations.Plugin; +import io.kestra.core.models.tasks.RunnableTask; +import io.kestra.core.runners.RunContext; +import io.kestra.plugin.jdbc.AbstractCellConverter; +import io.kestra.plugin.jdbc.AbstractJdbcQueries; +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.*; +import lombok.experimental.SuperBuilder; + +import java.sql.DriverManager; +import java.sql.SQLException; +import java.time.ZoneId; +import java.util.Properties; + + +@SuperBuilder +@ToString +@EqualsAndHashCode +@Getter +@NoArgsConstructor +@Schema( + title = "Preform multiple queries on a PostgreSQL server." +) +@Plugin( + examples = { + @Example( + full = true, + title = "Execute a query and fetch results in a task.", + code = """ + id: postgres_query + namespace: company.team + + tasks: + - id: fetch + type: io.kestra.plugin.jdbc.postgresql.Queries + url: jdbc:postgresql://127.0.0.1:56982/ + username: pg_user + password: pg_password + sql: | + SELECT firstName, lastName FROM employee; + SELECT brand FROM laptop; + fetchType: FETCH + """ + ) + } +) +public class Queries extends AbstractJdbcQueries implements RunnableTask, PostgresConnectionInterface { + @Builder.Default + protected Boolean ssl = false; + protected SslMode sslMode; + protected String sslRootCert; + protected String sslCert; + protected String sslKey; + protected String sslKeyPassword; + + @Override + public Properties connectionProperties(RunContext runContext) throws Exception { + Properties properties = super.connectionProperties(runContext); + PostgresService.handleSsl(properties, runContext, this); + + return properties; + } + + @Override + protected AbstractCellConverter getCellConverter(ZoneId zoneId) { + return new PostgresCellConverter(zoneId); + } + + @Override + public void registerDriver() throws SQLException { + DriverManager.registerDriver(new org.postgresql.Driver()); + } + +} diff --git a/plugin-jdbc-postgres/src/test/java/io/kestra/plugin/jdbc/postgresql/QueriesPostgresTest.java b/plugin-jdbc-postgres/src/test/java/io/kestra/plugin/jdbc/postgresql/QueriesPostgresTest.java new file mode 100644 index 00000000..571daf06 --- /dev/null +++ b/plugin-jdbc-postgres/src/test/java/io/kestra/plugin/jdbc/postgresql/QueriesPostgresTest.java @@ -0,0 +1,314 @@ +package io.kestra.plugin.jdbc.postgresql; + +import io.kestra.core.junit.annotations.KestraTest; +import io.kestra.core.models.property.Property; +import io.kestra.core.runners.RunContext; +import io.kestra.plugin.jdbc.AbstractJdbcQueries; +import io.kestra.plugin.jdbc.AbstractRdbmsTest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Properties; + +import static io.kestra.core.models.tasks.common.FetchType.FETCH; +import static io.kestra.core.models.tasks.common.FetchType.FETCH_ONE; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@KestraTest +public class QueriesPostgresTest extends AbstractRdbmsTest { + + @Test + void testMultiSelect() throws Exception { + RunContext runContext = runContextFactory.of(Collections.emptyMap()); + + Queries taskGet = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH) + .timeZoneId("Europe/Paris") + .sql(""" + SELECT first_name, last_name FROM employee; + SELECT brand FROM laptop; + """) + .build(); + + AbstractJdbcQueries.MultiQueryOutput runOutput = taskGet.run(runContext); + assertThat(runOutput.getOutputs().size(), is(2)); + assertThat(runOutput.getOutputs().get(0), notNullValue()); + assertThat(runOutput.getOutputs().get(1), notNullValue()); + } + + @Test + void testMultiSelectWithParameters() throws Exception { + RunContext runContext = runContextFactory.of(Collections.emptyMap()); + + Map parameters = Map.of( + "age", 40, + "brand", "Apple", + "cpu_frequency", 1.5 + ); + + Queries taskGet = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH) + .timeZoneId("Europe/Paris") + .sql(""" + SELECT first_name, last_name, age FROM employee where age > :age and age < :age + 10; + SELECT brand, model FROM laptop where brand = :brand and cpu_frequency > :cpu_frequency; + """) + .parameters(Property.of(parameters)) + .build(); + + AbstractJdbcQueries.MultiQueryOutput runOutput = taskGet.run(runContext); + assertThat(runOutput.getOutputs().size(), is(2)); + + List> employees = runOutput.getOutputs().getFirst().getRows(); + assertThat("employees", employees, notNullValue()); + assertThat("employees", employees.size(), is(1)); + assertThat("employee selected", employees.getFirst().get("age"), is(45)); + assertThat("employee selected", employees.getFirst().get("first_name"), is("John")); + assertThat("employee selected", employees.getFirst().get("last_name"), is("Doe")); + + List>laptops = runOutput.getOutputs().getLast().getRows(); + assertThat("laptops", laptops, notNullValue()); + assertThat("laptops", laptops.size(), is(1)); + assertThat("selected laptop", laptops.getFirst().get("brand"), is("Apple")); + } + + @Test + void testMultiQueriesOnlySelectOutputs() throws Exception { + RunContext runContext = runContextFactory.of(Collections.emptyMap()); + + Queries taskGet = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH_ONE) + .timeZoneId("Europe/Paris") + .sql(""" + DROP TABLE IF EXISTS animals; + CREATE TABLE animals ( + id SERIAL, + name VARCHAR(30) NOT NULL, + PRIMARY KEY (id) + ); + INSERT INTO animals (name) VALUES ('cat'),('dog'); + SELECT COUNT(id) as animals_count FROM animals; + INSERT INTO animals (name) VALUES ('ostrich'),('snake'),('whale'); + SELECT COUNT(id) as animals_count FROM animals; + """) + .build(); + + AbstractJdbcQueries.MultiQueryOutput runOutput = taskGet.run(runContext); + assertThat(runOutput.getOutputs().size(), is(2)); + assertThat(runOutput.getOutputs().getFirst().getRow().get("animals_count"), is(2L)); + assertThat(runOutput.getOutputs().getLast().getRow().get("animals_count"), is(5L)); + } + + @Test + void testMultiQueriesTransactionalShouldRollback() throws Exception { + long expectedUpdateNumber = 1L; + RunContext runContext = runContextFactory.of(Collections.emptyMap()); + + //Queries should pass in a transaction + Queries queriesPass = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH_ONE) + .timeZoneId("Europe/Paris") + .sql(""" + INSERT INTO test_transaction (name) VALUES ('test_1'); + SELECT COUNT(id) as transaction_count FROM test_transaction; + """) + .build(); + + AbstractJdbcQueries.MultiQueryOutput runOutput = queriesPass.run(runContext); + assertThat(runOutput.getOutputs().size(), is(1)); + assertThat(runOutput.getOutputs().getFirst().getRow().get("transaction_count"), is(expectedUpdateNumber)); + + //Queries should fail due to bad sql + Queries queriesFail = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH_ONE) + .timeZoneId("Europe/Paris") + .sql(""" + INSERT INTO test_transaction (name) VALUES ('test_2'); + INSERT INTO test_transaction (name) VALUES (1000f); + """) //Try inserting before failing + .build(); + + assertThrows(Exception.class, () -> queriesFail.run(runContext)); + + //Final query to verify the amount of updated rows + Queries verifyQuery = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH_ONE) + .timeZoneId("Europe/Paris") + .sql(""" + SELECT COUNT(id) as transaction_count FROM test_transaction; + """) //Try inserting before failing + .build(); + + AbstractJdbcQueries.MultiQueryOutput verifyOutput = verifyQuery.run(runContext); + assertThat(verifyOutput.getOutputs().size(), is(1)); + assertThat(verifyOutput.getOutputs().getFirst().getRow().get("transaction_count"), is(expectedUpdateNumber)); + } + + @Test + void testMultiQueriesNonTransactionalShouldNotRollback() throws Exception { + RunContext runContext = runContextFactory.of(Collections.emptyMap()); + + //Queries should pass in a transaction + Queries queriesFail = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH_ONE) + .timeZoneId("Europe/Paris") + .transaction(Property.of(false)) //No rollback on failure + .sql(""" + INSERT INTO test_transaction (name) VALUES ('test_no_rollback_success_1'); + INSERT INTO test_transaction (name) VALUES ('test_no_rollback_success_2'); + INSERT INTO test_transaction (name) VALUES (10f); + INSERT INTO test_transaction (name) VALUES ('test_no_rollback_fail_1'); + INSERT INTO test_transaction (name) VALUES ('test_no_rollback_fail_2'); + """) //Expect failure with 2 inserts + .build(); + + assertThrows(Exception.class, () -> queriesFail.run(runContext)); + + //Final query to verify the amount of updated rows + Queries verifyQuery = Queries.builder() + .url(getUrl()) + .username(getUsername()) + .password(getPassword()) + .fetchType(FETCH) + .timeZoneId("Europe/Paris") + .sql(""" + SELECT name FROM test_transaction; + """) + .build(); + + AbstractJdbcQueries.MultiQueryOutput verifyOutput = verifyQuery.run(runContext); + List names = verifyOutput.getOutputs().getFirst().getRows() + .stream().map(m -> (String) m.get("name")) + .filter(name -> name.startsWith("test_no_rollback")) + .toList(); + + assertThat(names.size(), is(2)); + assertThat(names, containsInAnyOrder("test_no_rollback_success_1", "test_no_rollback_success_2")); + } + + @Override + protected String getUrl() { + return "jdbc:postgresql://127.0.0.1:56983/"; + } + + @Override + protected String getUsername() { + return TestUtils.username(); + } + + @Override + protected String getPassword() { + return TestUtils.password(); + } + + @Override + protected Connection getConnection() throws SQLException { + Properties props = new Properties(); + props.put("jdbc.url", getUrl()); + props.put("user", getUsername()); + props.put("password", getPassword()); + + try { + PostgresService.handleSsl(props, runContextFactory.of(), new PostgresConnection()); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return DriverManager.getConnection(props.getProperty("jdbc.url"), props); + } + + public static class PostgresConnection implements PostgresConnectionInterface { + @Override + public String getUrl() { + return "jdbc:postgresql://127.0.0.1:56983/"; + } + + @Override + public String getUsername() { + return TestUtils.username(); + } + + @Override + public String getPassword() { + return TestUtils.password(); + } + + @Override + public Boolean getSsl() { + return null; + } + + @Override + public SslMode getSslMode() { + return null; + } + + @Override + public String getSslRootCert() { + return null; + } + + @Override + public String getSslCert() { + return null; + } + + @Override + public String getSslKey() { + return null; + } + + @Override + public String getSslKeyPassword() { + return null; + } + + @Override + public void registerDriver() throws SQLException { + + } + } + + @Override + protected void initDatabase() throws SQLException, FileNotFoundException, URISyntaxException { + executeSqlScript("scripts/postgres_queries.sql"); + } + + @Override + @BeforeEach + public void init() throws IOException, URISyntaxException, SQLException { + initDatabase(); + } +} diff --git a/plugin-jdbc-postgres/src/test/resources/scripts/postgres_queries.sql b/plugin-jdbc-postgres/src/test/resources/scripts/postgres_queries.sql new file mode 100644 index 00000000..c4c96d43 --- /dev/null +++ b/plugin-jdbc-postgres/src/test/resources/scripts/postgres_queries.sql @@ -0,0 +1,43 @@ +DROP TABLE IF EXISTS employee; + +CREATE TABLE employee ( + employee_id SERIAL, + first_name VARCHAR(30), + last_name VARCHAR(30), + age SMALLINT, + PRIMARY KEY (employee_id) +); + +INSERT INTO employee (first_name, last_name, age) +VALUES + ( 'John', 'Doe', 45), + ( 'Bryan', 'Grant', 33), + ( 'Jude', 'Philips', 25), + ( 'Michael', 'Page', 62); + +DROP TABLE IF EXISTS laptop; + +CREATE TABLE laptop +( + laptop_id SERIAL, + brand VARCHAR(30), + model VARCHAR(30), + cpu_frequency REAL, + PRIMARY KEY (laptop_id) +); +INSERT INTO laptop (brand, model, cpu_frequency) +VALUES + ('Apple', 'MacBookPro M1 13', 2.2), + ('Apple', 'MacBookPro M3 16', 1.5), + ('LG', 'Gram', 1.95), + ('Lenovo', 'ThinkPad', 1.05); + + +/* Table for testing transactionnal queries */ +DROP TABLE IF EXISTS test_transaction; +CREATE TABLE test_transaction +( + id SERIAL, + name VARCHAR(30) NOT NULL, + PRIMARY KEY (id) +); \ No newline at end of file diff --git a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java index 5e645ed1..d7c7fb84 100644 --- a/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java +++ b/plugin-jdbc/src/main/java/io/kestra/plugin/jdbc/AbstractJdbcQueries.java @@ -9,6 +9,7 @@ import org.slf4j.Logger; import java.sql.*; +import java.util.Collections; import java.util.LinkedList; import java.util.List; import java.util.Map; @@ -35,47 +36,46 @@ public AbstractJdbcQueries.MultiQueryOutput run(RunContext runContext) throws Ex Connection conn = null; PreparedStatement stmt = null; Savepoint savepoint = null; + long totalSize = 0L; + List outputList = new LinkedList<>(); + try { //Create connection in not autocommit mode to enable rollback on error conn = this.connection(runContext); - if(isTransactional) { - conn.setAutoCommit(false); - savepoint = conn.setSavepoint(); - } + conn.setAutoCommit(false); + savepoint = conn.setSavepoint(); String sqlRendered = runContext.render(this.sql, this.additionalVars); - - stmt = createPreparedStatementAndPopulateParameters(runContext, conn, sqlRendered); - - stmt.setFetchSize(this.getFetchSize()); - - logger.debug("Starting query: {}", sqlRendered); - - boolean hasMoreResult = stmt.execute(); - if(isTransactional) { + String[] queries = isTransactional ? new String[]{sqlRendered} : sqlRendered.split("(?<='\\);)"); + + for(String query : queries) { + //Create statement, execute + stmt = createPreparedStatementAndPopulateParameters(runContext, conn, query); + stmt.setFetchSize(this.getFetchSize()); + logger.debug("Starting query: {}", query); + boolean hasMoreResult = stmt.execute(); conn.commit(); - } - //Create Outputs - List outputList = new LinkedList<>(); - long totalSize = 0L; - while (hasMoreResult || stmt.getUpdateCount() != -1) { - try(ResultSet rs = stmt.getResultSet()) { - //When sql is not a select statement skip output creation - if(rs != null) { - AbstractJdbcQuery.Output.OutputBuilder output = AbstractJdbcQuery.Output.builder(); - totalSize += populateOutputFromResultSet(runContext, stmt, rs, output, cellConverter, conn); - outputList.add(output.build()); + //Create Outputs + while (hasMoreResult || stmt.getUpdateCount() != -1) { + try(ResultSet rs = stmt.getResultSet()) { + //When sql is not a select statement skip output creation + if(rs != null) { + AbstractJdbcQuery.Output.OutputBuilder output = AbstractJdbcQuery.Output.builder(); + totalSize += populateOutputFromResultSet(runContext, stmt, rs, output, cellConverter, conn); + outputList.add(output.build()); + } } + hasMoreResult = stmt.getMoreResults(); } - hasMoreResult = stmt.getMoreResults(); } + conn.commit(); runContext.metric(Counter.of("fetch.size", totalSize, this.tags())); return MultiQueryOutput.builder().outputs(outputList).build(); } catch (Exception e) { - if(conn != null && savepoint != null) { + if(isTransactional && conn != null && savepoint != null) { conn.rollback(savepoint); } throw new RuntimeException(e);