From 7ad79c4c273521046e3e5b736bb7d72c218497d9 Mon Sep 17 00:00:00 2001 From: nitish Date: Wed, 11 Dec 2024 10:26:45 +0530 Subject: [PATCH] added check before setting checkpoint directory --- .../java/zingg/spark/client/SparkClient.java | 13 ++++++++--- .../executor/TestSparkExecutorsCompound.java | 22 +++++-------------- .../core/session/SparkSessionProvider.java | 11 ++++++++-- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/spark/client/src/main/java/zingg/spark/client/SparkClient.java b/spark/client/src/main/java/zingg/spark/client/SparkClient.java index f2ec6e01..14f65969 100644 --- a/spark/client/src/main/java/zingg/spark/client/SparkClient.java +++ b/spark/client/src/main/java/zingg/spark/client/SparkClient.java @@ -1,5 +1,6 @@ package zingg.spark.client; +import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; @@ -79,12 +80,18 @@ public SparkSession getSession() { SparkSession s = SparkSession .builder() .appName("Zingg") - .getOrCreate(); - JavaSparkContext ctx = JavaSparkContext.fromSparkContext(s.sparkContext()); + .getOrCreate(); + SparkContext sparkContext = s.sparkContext(); + if (sparkContext.getCheckpointDir().isEmpty()) { + sparkContext.setCheckpointDir("/tmp/checkpoint"); + } + JavaSparkContext ctx = JavaSparkContext.fromSparkContext(sparkContext); JavaSparkContext.jarOfClass(IZingg.class); LOG.debug("Context " + ctx.toString()); //initHashFns(); - ctx.setCheckpointDir("/tmp/checkpoint"); + if (!ctx.getCheckpointDir().isPresent()) { + ctx.setCheckpointDir(String.valueOf(sparkContext.getCheckpointDir())); + } setSession(s); return s; } diff --git a/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutorsCompound.java b/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutorsCompound.java index aefb5260..4b101989 100644 --- a/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutorsCompound.java +++ b/spark/core/src/test/java/zingg/spark/core/executor/TestSparkExecutorsCompound.java @@ -4,14 +4,13 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataType; -import zingg.common.client.IZingg; +import org.junit.jupiter.api.extension.ExtendWith; import zingg.common.client.ZinggClientException; import zingg.common.client.util.DFObjectUtil; import zingg.common.client.util.IWithSession; @@ -19,10 +18,12 @@ import zingg.common.core.executor.TestExecutorsCompound; import zingg.common.core.executor.TrainMatcher; import zingg.spark.client.util.SparkDFObjectUtil; +import zingg.spark.core.TestSparkBase; import zingg.spark.core.context.ZinggSparkContext; import zingg.spark.core.executor.labeller.ProgrammaticSparkLabeller; import zingg.spark.core.executor.validate.SparkTrainMatchValidator; +@ExtendWith(TestSparkBase.class) public class TestSparkExecutorsCompound extends TestExecutorsCompound,Row,Column,DataType> { protected static final String CONFIG_FILE = "zingg/spark/core/executor/configSparkIntTest.json"; protected static final String TEST_DATA_FILE = "zingg/spark/core/executor/test.csv"; @@ -31,22 +32,11 @@ public class TestSparkExecutorsCompound extends TestExecutorsCompound