diff --git a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java index 26191be26d5..a712965a440 100644 --- a/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java +++ b/common/src/main/java/org/apache/celeborn/common/network/client/TransportClientFactory.java @@ -251,7 +251,15 @@ public void initChannel(SocketChannel ch) { // Connect to the remote server long preConnect = System.nanoTime(); ChannelFuture cf = bootstrap.connect(address); - if (!cf.await(connectTimeoutMs)) { + if (connectTimeoutMs <= 0) { + cf.await(); + assert cf.isDone(); + if (cf.isCancelled()) { + throw new IOException(String.format("Connecting to %s cancelled", address)); + } else if (!cf.isSuccess()) { + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); + } + } else if (!cf.await(connectTimeoutMs)) { throw new CelebornIOException( String.format("Connecting to %s timed out (%s ms)", address, connectTimeoutMs)); } else if (cf.cause() != null) { diff --git a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java index 26a9b4885fe..b77a9c7d0c7 100644 --- a/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java +++ b/common/src/test/java/org/apache/celeborn/common/network/TransportClientFactorySuiteJ.java @@ -211,4 +211,31 @@ public void closeFactoryBeforeCreateClient() throws IOException, InterruptedExce factory.close(); factory.createClient(getLocalHost(), server1.getPort()); } + + @Test + public void unlimitedConnectionAndCreationTimeouts() throws IOException, InterruptedException { + CelebornConf _conf = new CelebornConf(); + _conf.set("celeborn.shuffle.io.connectTimeout", "-1"); + _conf.set("celeborn.shuffle.io.connectionTimeout", "-1"); + TransportConf conf = new TransportConf(TEST_MODULE, _conf); + try (TransportContext ctx = new TransportContext(conf, new BaseMessageHandler(), true); + TransportClientFactory factory = ctx.createClientFactory()) { + TransportClient c1 = factory.createClient(getLocalHost(), server1.getPort()); + assertTrue(c1.isActive()); + long expiredTime = System.currentTimeMillis() + 5000; + while (c1.isActive() && System.currentTimeMillis() < expiredTime) { + Thread.sleep(10); + } + assertTrue(c1.isActive()); + // When connectionTimeout is unlimited, the connection shall be able to fail when the server + // is not reachable. + TransportServer server = ctx.createServer(); + int unreachablePort = server.getPort(); + JavaUtils.closeQuietly(server); + IOException exception = + assertThrows( + IOException.class, () -> factory.createClient(getLocalHost(), unreachablePort)); + assertNotEquals(exception.getCause(), null); + } + } }