From 4030175d0d0fbbba2ce6d2d83da0b4479a3000e9 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 29 Jun 2023 16:05:32 +0800 Subject: [PATCH] [SPARK-44241][Core] Mistakenly set io.connectionTimeout/connectionCreationTimeout to zero or negative will cause incessant executor cons/destructions --- .../client/TransportClientFactory.java | 16 +++++++-- .../spark/network/util/TransportConf.java | 4 +-- .../client/TransportClientFactorySuite.java | 33 ++++++++++++++++--- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 6fb9923cd3d7..3df72e65c2ab 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -245,12 +245,13 @@ TransportClient createClient(InetSocketAddress address) logger.debug("Creating new connection to {}", address); Bootstrap bootstrap = new Bootstrap(); + int connCreateTimeout = conf.connectionCreationTimeoutMs(); bootstrap.group(workerGroup) .channel(socketChannelClass) // Disable Nagle's Algorithm since we don't want packets to wait .option(ChannelOption.TCP_NODELAY, true) .option(ChannelOption.SO_KEEPALIVE, true) - .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs()) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connCreateTimeout) .option(ChannelOption.ALLOCATOR, pooledAllocator); if (conf.receiveBuf() > 0) { @@ -276,10 +277,19 @@ public void initChannel(SocketChannel ch) { // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); - if (!cf.await(conf.connectionCreationTimeoutMs())) { + + if (connCreateTimeout <= 0) { + cf.awaitUninterruptibly(); + assert cf.isDone(); + if (cf.isCancelled()) { + throw new IOException(String.format("Connecting to %s cancelled", address)); + } else if (!cf.isSuccess()) { + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); + } + } else if (!cf.await(connCreateTimeout)) { throw new IOException( String.format("Connecting to %s timed out (%s ms)", - address, conf.connectionCreationTimeoutMs())); + address, connCreateTimeout)); } else if (cf.cause() != null) { throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index bbfb99168da2..deac78ffedde 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -103,7 +103,7 @@ public int connectionTimeoutMs() { conf.get("spark.network.timeout", "120s")); long defaultTimeoutMs = JavaUtils.timeStringAsSec( conf.get(SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY, defaultNetworkTimeoutS + "s")) * 1000; - return (int) defaultTimeoutMs; + return defaultTimeoutMs < 0 ? 0 : (int) defaultTimeoutMs; } /** Connect creation timeout in milliseconds. Default 30 secs. */ @@ -111,7 +111,7 @@ public int connectionCreationTimeoutMs() { long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs()); long defaultTimeoutMs = JavaUtils.timeStringAsSec( conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY, connectionTimeoutS + "s")) * 1000; - return (int) defaultTimeoutMs; + return defaultTimeoutMs < 0 ? 0 : (int) defaultTimeoutMs; } /** Number of concurrent connections between two nodes for fetching data. */ diff --git a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java index 4ee9a6ed10bf..47b571af83d7 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/client/TransportClientFactorySuite.java @@ -31,10 +31,6 @@ import org.junit.Before; import org.junit.Test; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotSame; -import static org.junit.Assert.assertTrue; - import org.apache.spark.network.TestUtils; import org.apache.spark.network.TransportContext; import org.apache.spark.network.server.NoOpRpcHandler; @@ -45,6 +41,8 @@ import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportConf; +import static org.junit.Assert.*; + public class TransportClientFactorySuite { private TransportConf conf; private TransportContext context; @@ -237,4 +235,31 @@ public void fastFailConnectionInTimeWindow() { Assert.assertThrows("fail this connection directly", IOException.class, () -> factory.createClient(TestUtils.getLocalHost(), unreachablePort, true)); } + + @Test + public void unlimitedConnectionAndCreationTimeouts() throws IOException, InterruptedException { + Map configMap = new HashMap<>(); + configMap.put("spark.shuffle.io.connectionTimeout", "-1"); + configMap.put("spark.shuffle.io.connectionCreationTimeout", "-1"); + TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(configMap)); + RpcHandler rpcHandler = new NoOpRpcHandler(); + try (TransportContext ctx = new TransportContext(conf, rpcHandler, true); + TransportClientFactory factory = ctx.createClientFactory()){ + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 5000; + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertTrue(c1.isActive()); + // When connectionCreationTimeout is unlimited, the connection shall be able to + // fail when the server is not reachable. + TransportServer server = ctx.createServer(); + int unreachablePort = server.getPort(); + JavaUtils.closeQuietly(server); + IOException exception = Assert.assertThrows(IOException.class, + () -> factory.createClient(TestUtils.getLocalHost(), unreachablePort, true)); + assertNotEquals(exception.getCause(), null); + } + } }