diff --git a/src/main/java/org/tarantool/jdbc/SQLConnection.java b/src/main/java/org/tarantool/jdbc/SQLConnection.java index 15f3258a..b61fd74b 100644 --- a/src/main/java/org/tarantool/jdbc/SQLConnection.java +++ b/src/main/java/org/tarantool/jdbc/SQLConnection.java @@ -27,6 +27,7 @@ import java.sql.SQLFeatureNotSupportedException; import java.sql.SQLNonTransientConnectionException; import java.sql.SQLNonTransientException; +import java.sql.SQLPermission; import java.sql.SQLWarning; import java.sql.SQLXML; import java.sql.Savepoint; @@ -42,6 +43,7 @@ import java.util.concurrent.Executor; import java.util.concurrent.Future; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; /** @@ -51,6 +53,9 @@ */ public class SQLConnection implements TarantoolConnection { + private static final SQLPermission CALL_ABORT_PERMISSION = new SQLPermission("callAbort"); + private static final SQLPermission SET_NETWORK_TIMEOUT_PERMISSION = new SQLPermission("setNetworkTimeout"); + private static final int UNSET_HOLDABILITY = 0; private static final String PING_QUERY = "SELECT 1"; @@ -60,6 +65,8 @@ public class SQLConnection implements TarantoolConnection { private DatabaseMetaData cachedMetadata; private int resultSetHoldability = UNSET_HOLDABILITY; + private final AtomicBoolean isClosed = new AtomicBoolean(false); + public SQLConnection(String url, Properties properties) throws SQLException { this.url = url; this.properties = properties; @@ -205,6 +212,12 @@ public boolean getAutoCommit() throws SQLException { @Override public void close() throws SQLException { + if (isClosed.compareAndSet(false, true)) { + closeInternal(); + } + } + + private void closeInternal() { client.close(); } @@ -234,7 +247,7 @@ public void rollback(Savepoint savepoint) throws SQLException { @Override public boolean isClosed() throws SQLException { - return client.isClosed(); + return isClosed.get() || client.isClosed(); } @Override @@ -417,6 +430,7 @@ public void setNetworkTimeout(Executor executor, int milliseconds) throws SQLExc if (milliseconds < 0) { throw new SQLException("Network timeout cannot be negative."); } + SET_NETWORK_TIMEOUT_PERMISSION.checkGuard(this); client.setOperationTimeout(milliseconds); } @@ -515,7 +529,16 @@ public void abort(Executor executor) throws SQLException { if (isClosed()) { return; } - throw new SQLFeatureNotSupportedException(); + if (executor == null) { + throw new SQLNonTransientException( + "Executor cannot be null", + SQLStates.INVALID_PARAMETER_VALUE.getSqlState() + ); + } + CALL_ABORT_PERMISSION.checkGuard(this); + if (isClosed.compareAndSet(false, true)) { + executor.execute(this::closeInternal); + } } @Override diff --git a/src/test/java/org/tarantool/jdbc/JdbcConnectionIT.java b/src/test/java/org/tarantool/jdbc/JdbcConnectionIT.java index afda6205..558610d5 100644 --- a/src/test/java/org/tarantool/jdbc/JdbcConnectionIT.java +++ b/src/test/java/org/tarantool/jdbc/JdbcConnectionIT.java @@ -3,12 +3,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; import static org.tarantool.TestAssumptions.assumeMinimalServerVersion; import org.tarantool.ServerVersion; import org.tarantool.TarantoolTestHelper; +import org.tarantool.util.SQLStates; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.AfterEach; @@ -26,8 +29,14 @@ import java.sql.SQLClientInfoException; import java.sql.SQLException; import java.sql.SQLFeatureNotSupportedException; +import java.sql.SQLNonTransientException; import java.sql.Statement; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; public class JdbcConnectionIT { @@ -456,5 +465,81 @@ void testSetClientInfoProperties() { assertEquals(ClientInfoStatus.REASON_UNKNOWN_PROPERTY, failedProperties.get(targetProperty)); } + @Test + void testConnectionAbort() throws SQLException { + assertFalse(conn.isClosed()); + try (Statement statement = conn.createStatement()) { + conn.abort(Executors.newSingleThreadExecutor()); + assertTrue(conn.isClosed()); + SQLNonTransientException exception = assertThrows( + SQLNonTransientException.class, + () -> statement.executeQuery("SELECT 1") + ); + assertEquals(exception.getMessage(), "Statement is closed."); + } + } + + @Test + void testOperationInProgressAbort() throws SQLException, ExecutionException, InterruptedException { + testHelper.executeLua("box.internal.sql_create_function('TNT_SLEEP', 'INT'," + + " function(s) require('fiber').sleep(s); return s; end)"); + final ExecutorService executor = Executors.newFixedThreadPool(2); + final int sleepSeconds = 10; + + long startTime = System.currentTimeMillis(); + + Future workerFuture = executor.submit(() -> { + try { + Statement statement = conn.createStatement(); + statement.execute("SELECT tnt_sleep(" + sleepSeconds + ")"); + } catch (SQLException cause) { + return cause; + } + return null; + }); + + Future abortFuture = executor.submit(() -> { + ExecutorService abortExecutor = Executors.newSingleThreadExecutor(); + try { + conn.abort(abortExecutor); + } catch (SQLException cause) { + return cause; + } + abortExecutor.shutdown(); + try { + abortExecutor.awaitTermination(sleepSeconds, TimeUnit.SECONDS); + } catch (InterruptedException ignored) { + } + return null; + }); + + SQLException workerException = workerFuture.get(); + long endTime = System.currentTimeMillis(); + assertNotNull(workerException, "Statement execution should have been aborted, thus throwing an exception"); + + SQLException abortException = abortFuture.get(); + assertNull(abortException, () -> abortException.getMessage()); + + // It is expected to abort the statement as soon as possible. + // If the execution takes time more than 95% of the estimation the aborting fails. + assertTrue((endTime - startTime) < (sleepSeconds * 95 * 10)); + assertTrue(conn.isClosed()); + } + + @Test + void testAlreadyClosedConnectionAbort() throws SQLException { + conn.close(); + try { + conn.abort(Executors.newSingleThreadExecutor()); + } catch (SQLException cause) { + fail("Unexpected error", cause); + } + } + + @Test + void testNullParameterConnectionAbort() { + SQLException exception = assertThrows(SQLException.class, () -> conn.abort(null)); + assertEquals(SQLStates.INVALID_PARAMETER_VALUE.getSqlState(), exception.getSQLState()); + } } diff --git a/src/test/java/org/tarantool/jdbc/JdbcSecurityIT.java b/src/test/java/org/tarantool/jdbc/JdbcSecurityIT.java new file mode 100644 index 00000000..da76840e --- /dev/null +++ b/src/test/java/org/tarantool/jdbc/JdbcSecurityIT.java @@ -0,0 +1,150 @@ +package org.tarantool.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.tarantool.TestAssumptions.assumeMinimalServerVersion; + +import org.tarantool.ServerVersion; +import org.tarantool.TarantoolTestHelper; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.security.Permission; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.EnumSet; +import java.util.concurrent.Executors; + +public class JdbcSecurityIT { + + private static TarantoolTestHelper testHelper; + + private Connection connection; + private SecurityManager originalSecurityManager; + + @BeforeAll + public static void setupEnv() { + testHelper = new TarantoolTestHelper("jdbc-security-it"); + testHelper.createInstance(); + testHelper.startInstance(); + } + + @AfterAll + public static void teardownEnv() { + testHelper.stopInstance(); + } + + @BeforeEach + public void setUpTest() throws SQLException { + assumeMinimalServerVersion(testHelper.getInstanceVersion(), ServerVersion.V_2_1); + connection = DriverManager.getConnection(SqlTestUtils.makeDefaultJdbcUrl()); + originalSecurityManager = System.getSecurityManager(); + } + + @AfterEach + public void tearDownTest() throws SQLException { + assumeMinimalServerVersion(testHelper.getInstanceVersion(), ServerVersion.V_2_1); + if (connection != null) { + connection.close(); + } + System.setSecurityManager(originalSecurityManager); + } + + @Test + void testDeniedConnectionAbort() { + EnumSet exclusions = EnumSet.of(JdbcPermission.CALL_ABORT); + System.setSecurityManager(new JdbcSecurityManager(true, exclusions)); + + SecurityException securityException = assertThrows( + SecurityException.class, + () -> connection.abort(Executors.newSingleThreadExecutor()) + ); + assertEquals(securityException.getMessage(), "Permission callAbort is not allowed"); + } + + @Test + void testDeniedSetConnectionTimeout() { + EnumSet exclusions = EnumSet.of(JdbcPermission.SET_NETWORK_TIMEOUT); + System.setSecurityManager(new JdbcSecurityManager(true, exclusions)); + + SecurityException securityException = assertThrows( + SecurityException.class, + () -> connection.setNetworkTimeout(Executors.newSingleThreadExecutor(), 1000) + ); + assertEquals(securityException.getMessage(), "Permission setNetworkTimeout is not allowed"); + } + + /** + * Lists permissions supported by JDBC API. + * + *
    + *
  • setLog
  • + *
  • callAbort
  • + *
  • setSyncFactory<
  • + *
  • setNetworkTimeout
  • + *
  • deregisterDriver
  • + *
+ * + * @see java.sql.SQLPermission + */ + private enum JdbcPermission { + SET_LOG("setLog"), + CALL_ABORT("callAbort"), + SET_SYNC_FACTORY("setSyncFactory"), + SET_NETWORK_TIMEOUT("setNetworkTimeout"), + DEREGISTER_DRIVER("deregisterDriver"); + + private final String permissionName; + + JdbcPermission(String permissionName) { + this.permissionName = permissionName; + } + + public String getPermissionName() { + return permissionName; + } + + public static JdbcPermission fromName(String name) { + for (JdbcPermission values : JdbcPermission.values()) { + if (values.permissionName.equals(name)) { + return values; + } + } + return null; + } + } + + private static class JdbcSecurityManager extends SecurityManager { + private final boolean allowAll; + private final EnumSet exclusions; + + /** + * Configures a new {@link SecurityManager} that follows the custom rules. + * + * @param allowAll whether permissions are allowed by default or not + * @param exclusions optional set of exclusions + */ + private JdbcSecurityManager(boolean allowAll, EnumSet exclusions) { + this.exclusions = exclusions; + this.allowAll = allowAll; + } + + @Override + public void checkPermission(Permission permission) { + JdbcPermission jdbcPermission = JdbcPermission.fromName(permission.getName()); + if (jdbcPermission == null) { + return; + } + boolean allowed = allowAll ^ exclusions.contains(jdbcPermission); + if (!allowed) { + throw new SecurityException("Permission " + jdbcPermission.getPermissionName() + " is not allowed"); + } + } + } +} +