Skip to content

Commit

Permalink
[FLINK-35351][checkpoint] Fix fail during restore from unaligned chec…
Browse files Browse the repository at this point in the history
…kpoint with custom partitioner

Co-authored-by:  Andrey Gaskov <31715230+empathy87@users.noreply.github.com>
  • Loading branch information
2 people authored and pnowojski committed May 31, 2024
1 parent e095657 commit d6d4090
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
Expand Down Expand Up @@ -421,7 +423,31 @@ public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment)
stateAssignment.oldState.get(stateAssignment.inputOperatorID);
final List<List<InputChannelStateHandle>> inputOperatorState =
splitBySubtasks(inputState, OperatorSubtaskState::getInputChannelState);
if (inputState.getParallelism() == executionJobVertex.getParallelism()) {

boolean hasAnyFullMapper =
executionJobVertex.getJobVertex().getInputs().stream()
.map(JobEdge::getDownstreamSubtaskStateMapper)
.anyMatch(m -> m.equals(SubtaskStateMapper.FULL));
boolean hasAnyPreviousOperatorChanged =
executionJobVertex.getInputs().stream()
.map(IntermediateResult::getProducer)
.map(vertexAssignments::get)
.anyMatch(
taskStateAssignment -> {
final int oldParallelism =
stateAssignment
.oldState
.get(stateAssignment.inputOperatorID)
.getParallelism();
return oldParallelism
!= taskStateAssignment.executionJobVertex
.getParallelism();
});

// need rescale if any input operator parallelism was changed and have any input with FULL
// subtask state mapper
if (inputState.getParallelism() == executionJobVertex.getParallelism()
&& !(hasAnyFullMapper && hasAnyPreviousOperatorChanged)) {
stateAssignment.inputChannelStates.putAll(
toInstanceMap(stateAssignment.inputOperatorID, inputOperatorState));
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewOperatorStateHandle;
import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewResultSubpartitionStateHandle;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ARBITRARY;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.FULL;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.RANGE;
import static org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper.ROUND_ROBIN;
import static org.apache.flink.util.Preconditions.checkArgument;
Expand Down Expand Up @@ -561,6 +562,44 @@ private InflightDataGateOrPartitionRescalingDescriptor gate(
oldIndices, rescaleMapping, ambiguousSubtaskIndexes, mappingType);
}

@Test
public void testChannelStateAssignmentTwoGatesPartiallyDownscaling()
throws JobException, JobExecutionException {
JobVertex upstream1 = createJobVertex(new OperatorID(), 2);
JobVertex upstream2 = createJobVertex(new OperatorID(), 2);
JobVertex downstream = createJobVertex(new OperatorID(), 3);
List<OperatorID> operatorIds =
Stream.of(upstream1, upstream2, downstream)
.map(v -> v.getOperatorIDs().get(0).getGeneratedOperatorID())
.collect(Collectors.toList());
Map<OperatorID, OperatorState> states = buildOperatorStates(operatorIds, 3);

connectVertices(upstream1, downstream, ARBITRARY, FULL);
connectVertices(upstream2, downstream, ROUND_ROBIN, ROUND_ROBIN);

Map<OperatorID, ExecutionJobVertex> vertices =
toExecutionVertices(upstream1, upstream2, downstream);

new StateAssignmentOperation(0, new HashSet<>(vertices.values()), states, false)
.assignStates();

assertThat(
getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 0)
.getInputChannelState()
.size())
.isEqualTo(6);
assertThat(
getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 1)
.getInputChannelState()
.size())
.isEqualTo(6);
assertThat(
getAssignedState(vertices.get(operatorIds.get(2)), operatorIds.get(2), 2)
.getInputChannelState()
.size())
.isEqualTo(6);
}

@Test
void testChannelStateAssignmentDownscaling() throws JobException, JobExecutionException {
List<OperatorID> operatorIds = buildOperatorIds(2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private static ByteStreamStateHandle cloneByteStreamStateHandle(
public static InputChannelStateHandle createNewInputChannelStateHandle(
int numNamedStates, Random random) {
return new InputChannelStateHandle(
new InputChannelInfo(random.nextInt(), random.nextInt()),
new InputChannelInfo(0, random.nextInt()),
createStreamStateHandle(numNamedStates, random),
genOffsets(numNamedStates, random));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
Expand All @@ -44,6 +45,7 @@
import org.apache.flink.streaming.api.functions.co.CoMapFunction;
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
import org.apache.flink.streaming.api.functions.co.KeyedCoProcessFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.util.Collector;

import org.apache.commons.lang3.ArrayUtils;
Expand All @@ -52,8 +54,10 @@
import org.junit.runners.Parameterized;

import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.Collections;

import static org.apache.flink.api.common.eventtime.WatermarkStrategy.noWatermarks;
Expand Down Expand Up @@ -329,6 +333,44 @@ public void create(

addFailingSink(joined, minCheckpoints, slotSharing);
}
},
CUSTOM_PARTITIONER {
final int sinkParallelism = 3;
final int numberElements = 1000;

@Override
public void create(
StreamExecutionEnvironment environment,
int minCheckpoints,
boolean slotSharing,
int expectedFailuresUntilSourceFinishes,
long sourceSleepMs) {
int parallelism = environment.getParallelism();
environment
.fromData(generateStrings(numberElements / parallelism, sinkParallelism))
.name("source")
.setParallelism(parallelism)
.partitionCustom(new StringPartitioner(), str -> str.split(" ")[0])
.addSink(new StringSink(numberElements / sinkParallelism))
.name("sink")
.setParallelism(sinkParallelism);
}

private Collection<String> generateStrings(
int producePerPartition, int partitionCount) {
Collection<String> list = new ArrayList<>();
for (int i = 0; i < producePerPartition; i++) {
for (int partition = 0; partition < partitionCount; partition++) {
list.add(buildString(partition, i));
}
}
return list;
}

private String buildString(int partition, int index) {
String longStr = new String(new char[3713]).replace('\0', '\uFFFF');
return partition + " " + index + " " + longStr;
}
};

void addFailingSink(
Expand Down Expand Up @@ -485,6 +527,7 @@ public static Object[][] getScaleFactors() {
// captured in-flight records, see FLINK-31963.
Object[][] parameters =
new Object[][] {
new Object[] {"downscale", Topology.CUSTOM_PARTITIONER, 3, 2, 0L},
new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 12, 7, 0L},
new Object[] {"upscale", Topology.KEYED_DIFFERENT_PARALLELISM, 7, 12, 0L},
new Object[] {"downscale", Topology.KEYED_DIFFERENT_PARALLELISM, 5, 3, 5L},
Expand Down Expand Up @@ -561,6 +604,7 @@ public UnalignedCheckpointRescaleITCase(

@Test
public void shouldRescaleUnalignedCheckpoint() throws Exception {
StringSink.failed = false;
final UnalignedSettings prescaleSettings =
new UnalignedSettings(topology)
.setParallelism(oldParallelism)
Expand All @@ -585,8 +629,12 @@ protected void checkCounters(JobExecutionResult result) {
"NUM_OUTPUTS = NUM_INPUTS",
result.<Long>getAccumulatorResult(NUM_OUTPUTS),
equalTo(result.getAccumulatorResult(NUM_INPUTS)));
collector.checkThat(
"NUM_DUPLICATES", result.<Long>getAccumulatorResult(NUM_DUPLICATES), equalTo(0L));
if (!topology.equals(Topology.CUSTOM_PARTITIONER)) {
collector.checkThat(
"NUM_DUPLICATES",
result.<Long>getAccumulatorResult(NUM_DUPLICATES),
equalTo(0L));
}
}

/**
Expand Down Expand Up @@ -705,4 +753,51 @@ public Long map2(Long value) throws Exception {
return checkHeader(value);
}
}

private static class StringPartitioner implements Partitioner<String> {
@Override
public int partition(String key, int numPartitions) {
return Integer.parseInt(key) % numPartitions;
}
}

private static class StringSink implements SinkFunction<String>, CheckpointedFunction {

static volatile boolean failed = false;

int checkpointConsumed = 0;

int recordsConsumed = 0;

final int numberElements;

public StringSink(int numberElements) {
this.numberElements = numberElements;
}

@Override
public void invoke(String value, Context ctx) throws Exception {
if (!failed && checkpointConsumed > 1) {
failed = true;
throw new TestException("FAIL");
}
recordsConsumed++;
if (!failed && recordsConsumed == (numberElements / 3)) {
Thread.sleep(1000);
}
if (recordsConsumed == (numberElements - 1)) {
Thread.sleep(1000);
}
}

@Override
public void snapshotState(FunctionSnapshotContext context) {
checkpointConsumed++;
}

@Override
public void initializeState(FunctionInitializationContext context) {
// do nothing
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1136,7 +1136,7 @@ protected static long checkHeader(long value) {
return value;
}

private static class TestException extends Exception {
static class TestException extends Exception {
public TestException(String s) {
super(s);
}
Expand Down

0 comments on commit d6d4090

Please sign in to comment.