diff --git a/sdks/java/io/sparkreceiver/build.gradle b/sdks/java/io/sparkreceiver/build.gradle index f226435631e4c..8d4b96f298cdb 100644 --- a/sdks/java/io/sparkreceiver/build.gradle +++ b/sdks/java/io/sparkreceiver/build.gradle @@ -32,8 +32,16 @@ description = "Apache Beam :: SDKs :: Java :: IO :: Spark Receiver" ext.summary = """Apache Beam SDK provides a simple, Java-based interface for streaming integration with CDAP plugins.""" +configurations.all { + exclude group: 'org.slf4j', module: 'slf4j-log4j12' + exclude group: 'org.slf4j', module: 'slf4j-jdk14' + exclude group: 'org.slf4j', module: 'slf4j-simple' +} + dependencies { implementation library.java.commons_lang3 + implementation library.java.joda_time + implementation library.java.slf4j_api implementation library.java.spark_streaming implementation library.java.spark_core implementation library.java.vendored_guava_26_0_jre diff --git a/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java new file mode 100644 index 0000000000000..c51a5168ce398 --- /dev/null +++ b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/ReadFromSparkReceiverWithOffsetDoFn.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.sparkreceiver; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.DoFn.UnboundedPerElement; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.splittabledofn.ManualWatermarkEstimator; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; +import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker; +import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator; +import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimators; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.receiver.Receiver; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SplittableDoFn which reads from {@link Receiver} that implements {@link HasOffset}. By default, + * a {@link WatermarkEstimators.Manual} watermark estimator is used to track watermark. + * + *

Initial range The initial range is {@code [0, Long.MAX_VALUE)} + * + *

Resume Processing Every time the sparkConsumer.hasRecords() returns false, {@link + * ReadFromSparkReceiverWithOffsetDoFn} will move to process the next element. + */ +@UnboundedPerElement +class ReadFromSparkReceiverWithOffsetDoFn extends DoFn { + + private static final Logger LOG = + LoggerFactory.getLogger(ReadFromSparkReceiverWithOffsetDoFn.class); + + /** Constant waiting time after the {@link Receiver} starts. Required to prepare for polling */ + private static final int START_POLL_TIMEOUT_MS = 1000; + + private final SerializableFunction> + createWatermarkEstimatorFn; + private final SerializableFunction getOffsetFn; + private final SerializableFunction getTimestampFn; + private final ReceiverBuilder> sparkReceiverBuilder; + + ReadFromSparkReceiverWithOffsetDoFn(SparkReceiverIO.Read transform) { + createWatermarkEstimatorFn = WatermarkEstimators.Manual::new; + + ReceiverBuilder> sparkReceiverBuilder = + transform.getSparkReceiverBuilder(); + checkStateNotNull(sparkReceiverBuilder, "Spark Receiver Builder can't be null!"); + this.sparkReceiverBuilder = sparkReceiverBuilder; + + SerializableFunction getOffsetFn = transform.getGetOffsetFn(); + checkStateNotNull(getOffsetFn, "Get offset fn can't be null!"); + this.getOffsetFn = getOffsetFn; + + SerializableFunction getTimestampFn = transform.getTimestampFn(); + if (getTimestampFn == null) { + getTimestampFn = input -> Instant.now(); + } + this.getTimestampFn = getTimestampFn; + } + + @GetInitialRestriction + public OffsetRange initialRestriction(@Element byte[] element) { + return new OffsetRange(0, Long.MAX_VALUE); + } + + @GetInitialWatermarkEstimatorState + public Instant getInitialWatermarkEstimatorState(@Timestamp Instant currentElementTimestamp) { + return currentElementTimestamp; + } + + @NewWatermarkEstimator + public WatermarkEstimator newWatermarkEstimator( + @WatermarkEstimatorState Instant watermarkEstimatorState) { + return createWatermarkEstimatorFn.apply(ensureTimestampWithinBounds(watermarkEstimatorState)); + } + + @GetSize + public double getSize(@Element byte[] element, @Restriction OffsetRange offsetRange) { + return restrictionTracker(element, offsetRange).getProgress().getWorkRemaining(); + } + + @NewTracker + public OffsetRangeTracker restrictionTracker( + @Element byte[] element, @Restriction OffsetRange restriction) { + return new OffsetRangeTracker(restriction); + } + + @GetRestrictionCoder + public Coder restrictionCoder() { + return new OffsetRange.Coder(); + } + + // Need to do an unchecked cast from Object + // because org.apache.spark.streaming.receiver.ReceiverSupervisor accepts Object in push methods + @SuppressWarnings("unchecked") + private static class SparkConsumerWithOffset implements SparkConsumer { + private final Queue recordsQueue; + private @Nullable Receiver sparkReceiver; + private final Long startOffset; + + SparkConsumerWithOffset(Long startOffset) { + this.startOffset = startOffset; + this.recordsQueue = new ConcurrentLinkedQueue<>(); + } + + @Override + public boolean hasRecords() { + return !recordsQueue.isEmpty(); + } + + @Override + public @Nullable V poll() { + return recordsQueue.poll(); + } + + @Override + public void start(Receiver sparkReceiver) { + this.sparkReceiver = sparkReceiver; + try { + new WrappedSupervisor( + sparkReceiver, + new SparkConf(), + objects -> { + V record = (V) objects[0]; + recordsQueue.offer(record); + return null; + }); + } catch (Exception e) { + LOG.error("Can not init Spark Receiver!", e); + throw new IllegalStateException("Spark Receiver was not initialized"); + } + ((HasOffset) sparkReceiver).setStartOffset(startOffset); + sparkReceiver.supervisor().startReceiver(); + try { + TimeUnit.MILLISECONDS.sleep(START_POLL_TIMEOUT_MS); + } catch (InterruptedException e) { + LOG.error("SparkReceiver was interrupted before polling started", e); + throw new IllegalStateException("Spark Receiver was interrupted before polling started"); + } + } + + @Override + public void stop() { + if (sparkReceiver != null) { + sparkReceiver.stop("SparkReceiver is stopped."); + } + recordsQueue.clear(); + } + } + + @ProcessElement + public ProcessContinuation processElement( + @Element byte[] element, + RestrictionTracker tracker, + WatermarkEstimator watermarkEstimator, + OutputReceiver receiver) { + + SparkConsumer sparkConsumer; + Receiver sparkReceiver; + try { + sparkReceiver = sparkReceiverBuilder.build(); + } catch (Exception e) { + LOG.error("Can not build Spark Receiver", e); + throw new IllegalStateException("Spark Receiver was not built!"); + } + sparkConsumer = new SparkConsumerWithOffset<>(tracker.currentRestriction().getFrom()); + sparkConsumer.start(sparkReceiver); + + while (sparkConsumer.hasRecords()) { + V record = sparkConsumer.poll(); + if (record != null) { + Long offset = getOffsetFn.apply(record); + if (!tracker.tryClaim(offset)) { + sparkConsumer.stop(); + LOG.debug("Stop for restriction: {}", tracker.currentRestriction().toString()); + return ProcessContinuation.stop(); + } + Instant currentTimeStamp = getTimestampFn.apply(record); + ((ManualWatermarkEstimator) watermarkEstimator).setWatermark(currentTimeStamp); + receiver.outputWithTimestamp(record, currentTimeStamp); + } + } + sparkConsumer.stop(); + LOG.debug("Resume for restriction: {}", tracker.currentRestriction().toString()); + return ProcessContinuation.resume(); + } + + private static Instant ensureTimestampWithinBounds(Instant timestamp) { + if (timestamp.isBefore(BoundedWindow.TIMESTAMP_MIN_VALUE)) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + LOG.debug("Timestamp was before MIN_VALUE({})", BoundedWindow.TIMESTAMP_MIN_VALUE); + } else if (timestamp.isAfter(BoundedWindow.TIMESTAMP_MAX_VALUE)) { + timestamp = BoundedWindow.TIMESTAMP_MAX_VALUE; + LOG.debug("Timestamp was after MAX_VALUE({})", BoundedWindow.TIMESTAMP_MAX_VALUE); + } + return timestamp; + } +} diff --git a/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkConsumer.java b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkConsumer.java new file mode 100644 index 0000000000000..6d54968e4900f --- /dev/null +++ b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkConsumer.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.sparkreceiver; + +import java.io.Serializable; +import org.apache.spark.streaming.receiver.Receiver; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * Interface for start/stop reading from some Spark {@link Receiver} into some place and poll from + * it. + */ +interface SparkConsumer extends Serializable { + + boolean hasRecords(); + + @Nullable + V poll(); + + void start(Receiver sparkReceiver); + + void stop(); +} diff --git a/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java new file mode 100644 index 0000000000000..954ce2b836b3f --- /dev/null +++ b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIO.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.sparkreceiver; + +import static org.apache.beam.sdk.util.Preconditions.checkStateNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + +import com.google.auto.value.AutoValue; +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; +import org.apache.beam.sdk.transforms.Impulse; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.spark.streaming.receiver.Receiver; +import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Streaming sources for Spark {@link Receiver}. + * + *

Reading using {@link SparkReceiverIO}

+ * + *

You will need to pass a {@link ReceiverBuilder} which is responsible for instantiating new + * {@link Receiver} objects. + * + *

{@link Receiver} that will be used should implement {@link HasOffset} interface. You will need + * to pass {@code getOffsetFn} which is a {@link SerializableFunction} that defines how to get + * {@code Long offset} from {@code V record}. + * + *

Optionally you can pass {@code timestampFn} which is a {@link SerializableFunction} that + * defines how to get {@code Instant timestamp} from {@code V record}. + * + *

Example of {@link SparkReceiverIO#read()} usage: + * + *

{@code
+ * Pipeline p = ...; // Create pipeline.
+ *
+ * // Create ReceiverBuilder for CustomReceiver
+ * ReceiverBuilder receiverBuilder =
+ *         new ReceiverBuilder<>(CustomReceiver.class).withConstructorArgs();
+ *
+ * //Read from CustomReceiver
+ * p.apply("Spark Receiver Read",
+ *  SparkReceiverIO.Read reader =
+ *    SparkReceiverIO.read()
+ *      .withGetOffsetFn(Long::valueOf)
+ *      .withTimestampFn(Instant::parse)
+ *      .withSparkReceiverBuilder(receiverBuilder);
+ * }
+ */ +@Experimental(Kind.SOURCE_SINK) +public class SparkReceiverIO { + + private static final Logger LOG = LoggerFactory.getLogger(SparkReceiverIO.class); + + public static Read read() { + return new AutoValue_SparkReceiverIO_Read.Builder().build(); + } + + /** A {@link PTransform} to read from Spark {@link Receiver}. */ + @AutoValue + @AutoValue.CopyAnnotations + public abstract static class Read extends PTransform> { + + abstract @Nullable ReceiverBuilder> getSparkReceiverBuilder(); + + abstract @Nullable SerializableFunction getGetOffsetFn(); + + abstract @Nullable SerializableFunction getTimestampFn(); + + abstract Builder toBuilder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setSparkReceiverBuilder( + ReceiverBuilder> sparkReceiverBuilder); + + abstract Builder setGetOffsetFn(SerializableFunction getOffsetFn); + + abstract Builder setTimestampFn(SerializableFunction timestampFn); + + abstract Read build(); + } + + /** Sets {@link ReceiverBuilder} with value and custom Spark {@link Receiver} class. */ + public Read withSparkReceiverBuilder( + ReceiverBuilder> sparkReceiverBuilder) { + checkArgument(sparkReceiverBuilder != null, "Spark receiver builder can not be null"); + return toBuilder().setSparkReceiverBuilder(sparkReceiverBuilder).build(); + } + + /** A function to get offset in order to start {@link Receiver} from it. */ + public Read withGetOffsetFn(SerializableFunction getOffsetFn) { + checkArgument(getOffsetFn != null, "Get offset function can not be null"); + return toBuilder().setGetOffsetFn(getOffsetFn).build(); + } + + /** A function to calculate timestamp for a record. */ + public Read withTimestampFn(SerializableFunction timestampFn) { + checkArgument(timestampFn != null, "Timestamp function can not be null"); + return toBuilder().setTimestampFn(timestampFn).build(); + } + + @Override + public PCollection expand(PBegin input) { + validateTransform(); + return input.apply(new ReadFromSparkReceiverViaSdf<>(this)); + } + + public void validateTransform() { + ReceiverBuilder> sparkReceiverBuilder = getSparkReceiverBuilder(); + checkStateNotNull(sparkReceiverBuilder, "withSparkReceiverBuilder() is required"); + checkStateNotNull(getGetOffsetFn(), "withGetOffsetFn() is required"); + } + } + + static class ReadFromSparkReceiverViaSdf extends PTransform> { + + private final Read sparkReceiverRead; + + ReadFromSparkReceiverViaSdf(Read sparkReceiverRead) { + this.sparkReceiverRead = sparkReceiverRead; + } + + @Override + public PCollection expand(PBegin input) { + final ReceiverBuilder> sparkReceiverBuilder = + sparkReceiverRead.getSparkReceiverBuilder(); + checkStateNotNull(sparkReceiverBuilder, "withSparkReceiverBuilder() is required"); + if (!HasOffset.class.isAssignableFrom(sparkReceiverBuilder.getSparkReceiverClass())) { + throw new UnsupportedOperationException( + String.format( + "Given Spark Receiver class %s doesn't implement HasOffset interface," + + " therefore it is not supported!", + sparkReceiverBuilder.getSparkReceiverClass().getName())); + } else { + LOG.info("{} started reading", ReadFromSparkReceiverWithOffsetDoFn.class.getSimpleName()); + return input + .apply(Impulse.create()) + .apply(ParDo.of(new ReadFromSparkReceiverWithOffsetDoFn<>(sparkReceiverRead))); + // TODO: Split data from SparkReceiver into multiple workers + } + } + } +} diff --git a/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/WrappedSupervisor.java b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/WrappedSupervisor.java index e0a4e64c47b82..7611b42dff00c 100644 --- a/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/WrappedSupervisor.java +++ b/sdks/java/io/sparkreceiver/src/main/java/org/apache/beam/sdk/io/sparkreceiver/WrappedSupervisor.java @@ -25,19 +25,21 @@ import org.apache.spark.streaming.receiver.BlockGeneratorListener; import org.apache.spark.streaming.receiver.Receiver; import org.apache.spark.streaming.receiver.ReceiverSupervisor; +import scala.Function0; import scala.Option; import scala.collection.Iterator; import scala.collection.mutable.ArrayBuffer; /** Wrapper class for {@link ReceiverSupervisor} that doesn't use Spark Environment. */ -@SuppressWarnings("return.type.incompatible") public class WrappedSupervisor extends ReceiverSupervisor { + private final SparkConf sparkConf; private final SerializableFunction storeFn; public WrappedSupervisor( Receiver receiver, SparkConf conf, SerializableFunction storeFn) { super(receiver, conf); + this.sparkConf = conf; this.storeFn = storeFn; } @@ -66,7 +68,11 @@ public void pushArrayBuffer( @Override public BlockGenerator createBlockGenerator(BlockGeneratorListener blockGeneratorListener) { - return null; + return new BlockGenerator( + blockGeneratorListener, + this.streamId(), + this.sparkConf, + BlockGenerator.$lessinit$greater$default$4()); } @Override @@ -86,4 +92,9 @@ public long getCurrentRateLimit() { public boolean isReceiverStopped() { return super.isReceiverStopped(); } + + @Override + public void logInfo(Function0 msg) { + // Do not log with Spark logging + } } diff --git a/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java b/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java new file mode 100644 index 0000000000000..6bba7bee9af3b --- /dev/null +++ b/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/CustomReceiverWithOffset.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.sparkreceiver; + +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.receiver.Receiver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Imitation of Spark {@link Receiver} that implements {@link HasOffset} interface. Used to test + * {@link SparkReceiverIO#read()}. + */ +public class CustomReceiverWithOffset extends Receiver implements HasOffset { + + private static final Logger LOG = LoggerFactory.getLogger(CustomReceiverWithOffset.class); + private static final int TIMEOUT_MS = 500; + public static final int RECORDS_COUNT = 20; + + /* + Used in test for imitation of reading with exception + */ + public static boolean shouldFailInTheMiddle = false; + + private Long startOffset; + + CustomReceiverWithOffset() { + super(StorageLevel.MEMORY_AND_DISK_2()); + } + + @Override + public void setStartOffset(Long startOffset) { + if (startOffset != null) { + this.startOffset = startOffset; + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public void onStart() { + Executors.newSingleThreadExecutor(new ThreadFactoryBuilder().build()).submit(this::receive); + } + + @Override + public void onStop() {} + + @Override + public Long getEndOffset() { + return Long.MAX_VALUE; + } + + private void receive() { + Long currentOffset = startOffset; + while (!isStopped()) { + if (currentOffset < RECORDS_COUNT) { + if (shouldFailInTheMiddle && currentOffset == RECORDS_COUNT / 2) { + shouldFailInTheMiddle = false; + LOG.debug("Expected fail in the middle of reading"); + throw new IllegalStateException("Expected exception"); + } + store(String.valueOf(currentOffset)); + currentOffset++; + } else { + break; + } + try { + TimeUnit.MILLISECONDS.sleep(TIMEOUT_MS); + } catch (InterruptedException e) { + LOG.error("Interrupted", e); + } + } + } +} diff --git a/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java b/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java new file mode 100644 index 0000000000000..e81dca5150e5d --- /dev/null +++ b/sdks/java/io/sparkreceiver/src/test/java/org/apache/beam/sdk/io/sparkreceiver/SparkReceiverIOTest.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.sparkreceiver; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.util.HashSet; +import java.util.Set; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestPipelineOptions; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Test class for {@link SparkReceiverIO}. */ +@RunWith(JUnit4.class) +public class SparkReceiverIOTest { + + public static final TestPipelineOptions OPTIONS = + TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class); + + static { + OPTIONS.setBlockOnRun(false); + } + + @Rule public final transient TestPipeline pipeline = TestPipeline.fromOptions(OPTIONS); + + @Test + public void testReadBuildsCorrectly() { + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SerializableFunction offsetFn = Long::valueOf; + SerializableFunction timestampFn = Instant::parse; + + SparkReceiverIO.Read read = + SparkReceiverIO.read() + .withGetOffsetFn(offsetFn) + .withTimestampFn(timestampFn) + .withSparkReceiverBuilder(receiverBuilder); + + assertEquals(offsetFn, read.getGetOffsetFn()); + assertEquals(receiverBuilder, read.getSparkReceiverBuilder()); + } + + @Test + public void testReadObjectCreationFailsIfReceiverBuilderIsNull() { + assertThrows( + IllegalArgumentException.class, + () -> SparkReceiverIO.read().withSparkReceiverBuilder(null)); + } + + @Test + public void testReadObjectCreationFailsIfGetOffsetFnIsNull() { + assertThrows( + IllegalArgumentException.class, () -> SparkReceiverIO.read().withGetOffsetFn(null)); + } + + @Test + public void testReadObjectCreationFailsIfTimestampFnIsNull() { + assertThrows( + IllegalArgumentException.class, () -> SparkReceiverIO.read().withTimestampFn(null)); + } + + @Test + public void testReadValidationFailsMissingReceiverBuilder() { + SparkReceiverIO.Read read = SparkReceiverIO.read(); + assertThrows(IllegalStateException.class, read::validateTransform); + } + + @Test + public void testReadValidationFailsMissingSparkConsumer() { + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SparkReceiverIO.Read read = + SparkReceiverIO.read().withSparkReceiverBuilder(receiverBuilder); + assertThrows(IllegalStateException.class, read::validateTransform); + } + + @Test + public void testReadFromCustomReceiverWithOffset() { + CustomReceiverWithOffset.shouldFailInTheMiddle = false; + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SparkReceiverIO.Read reader = + SparkReceiverIO.read() + .withGetOffsetFn(Long::valueOf) + .withTimestampFn(Instant::parse) + .withSparkReceiverBuilder(receiverBuilder); + + for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { + TestOutputDoFn.EXPECTED_RECORDS.add(String.valueOf(i)); + } + pipeline.apply(reader).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new TestOutputDoFn())); + + pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + } + + @Test + public void testReadFromCustomReceiverWithOffsetFailsAndReread() { + CustomReceiverWithOffset.shouldFailInTheMiddle = true; + ReceiverBuilder receiverBuilder = + new ReceiverBuilder<>(CustomReceiverWithOffset.class).withConstructorArgs(); + SparkReceiverIO.Read reader = + SparkReceiverIO.read() + .withGetOffsetFn(Long::valueOf) + .withTimestampFn(Instant::parse) + .withSparkReceiverBuilder(receiverBuilder); + + for (int i = 0; i < CustomReceiverWithOffset.RECORDS_COUNT; i++) { + TestOutputDoFn.EXPECTED_RECORDS.add(String.valueOf(i)); + } + pipeline.apply(reader).setCoder(StringUtf8Coder.of()).apply(ParDo.of(new TestOutputDoFn())); + + pipeline.run().waitUntilFinish(Duration.standardSeconds(15)); + + assertEquals(0, TestOutputDoFn.EXPECTED_RECORDS.size()); + } + + /** {@link DoFn} that throws {@code RuntimeException} if receives unexpected element. */ + private static class TestOutputDoFn extends DoFn { + private static final Set EXPECTED_RECORDS = new HashSet<>(); + + @ProcessElement + public void processElement(@Element String element, OutputReceiver outputReceiver) { + if (!EXPECTED_RECORDS.contains(element)) { + throw new RuntimeException("Received unexpected element: " + element); + } else { + EXPECTED_RECORDS.remove(element); + outputReceiver.output(element); + } + } + } +}