Skip to content

Commit

Permalink
[Engine] [Checkpoint] fix Checkpoint can't deserialize with protostuff (
Browse files Browse the repository at this point in the history
  • Loading branch information
Hisoka-X authored Oct 19, 2022
1 parent 6506e30 commit 4cecedb
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.seatunnel.engine.server.checkpoint;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

public class ActionState implements Serializable {

Expand All @@ -31,7 +33,7 @@ public class ActionState implements Serializable {
/**
* The handles to states created by the parallel actions: action index -> action state.
*/
private final ActionSubtaskState[] subtaskStates;
private final List<ActionSubtaskState> subtaskStates;

private ActionSubtaskState coordinatorState;

Expand All @@ -42,15 +44,15 @@ public class ActionState implements Serializable {

public ActionState(String actionId, int parallelism) {
this.actionId = actionId;
this.subtaskStates = new ActionSubtaskState[parallelism];
this.subtaskStates = Arrays.asList(new ActionSubtaskState[parallelism]);
this.parallelism = parallelism;
}

public String getActionId() {
return actionId;
}

public ActionSubtaskState[] getSubtaskStates() {
public List<ActionSubtaskState> getSubtaskStates() {
return subtaskStates;
}

Expand All @@ -67,6 +69,6 @@ public void reportState(int index, ActionSubtaskState state) {
coordinatorState = state;
return;
}
subtaskStates[index] = state;
subtaskStates.set(index, state);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ private void restoreTaskState(TaskLocation taskLocation) {
return;
}
for (int i = tuple.f1(); i < actionState.getParallelism(); i += currentParallelism) {
states.add(actionState.getSubtaskStates()[i]);
states.add(actionState.getSubtaskStates().get(i));
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
import static com.google.common.base.Preconditions.checkNotNull;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

public class TaskStatistics implements Serializable {
/**
* ID of the task the statistics belong to.
*/
private final Long jobVertexId;

private final SubtaskStatistics[] subtaskStats;
private final List<SubtaskStatistics> subtaskStats;

/**
* Marks whether a subtask is complete;
Expand All @@ -42,20 +44,20 @@ public class TaskStatistics implements Serializable {
TaskStatistics(Long jobVertexId, int parallelism) {
this.jobVertexId = checkNotNull(jobVertexId, "JobVertexID");
checkArgument(parallelism > 0, "the parallelism of task <= 0");
this.subtaskStats = new SubtaskStatistics[parallelism];
this.subtaskStats = Arrays.asList(new SubtaskStatistics[parallelism]);
this.subtaskCompleted = new boolean[parallelism];
}

boolean reportSubtaskStatistics(SubtaskStatistics subtask) {
checkNotNull(subtask, "Subtask stats");
int subtaskIndex = subtask.getSubtaskIndex();

if (subtaskIndex < 0 || subtaskIndex >= subtaskStats.length) {
if (subtaskIndex < 0 || subtaskIndex >= subtaskStats.size()) {
return false;
}

if (subtaskStats[subtaskIndex] == null) {
subtaskStats[subtaskIndex] = subtask;
if (subtaskStats.get(subtaskIndex) == null) {
subtaskStats.set(subtaskIndex, subtask);
numAcknowledgedSubtasks++;
return true;
} else {
Expand Down Expand Up @@ -85,7 +87,7 @@ public Long getJobVertexId() {
return jobVertexId;
}

public SubtaskStatistics[] getSubtaskStats() {
public List<SubtaskStatistics> getSubtaskStats() {
return subtaskStats;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.engine.server.checkpoint;

import org.apache.seatunnel.engine.checkpoint.storage.PipelineState;
import org.apache.seatunnel.engine.checkpoint.storage.common.ProtoStuffSerializer;
import org.apache.seatunnel.engine.core.checkpoint.CheckpointType;

import org.apache.commons.io.FileUtils;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.io.File;
import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.Map;

public class StorageTest {

@Test
public void localFileTest() throws IOException {

Map<Long, TaskStatistics> taskStatisticsMap = new HashMap<>();
taskStatisticsMap.put(1L, new TaskStatistics(1L, 32));
Map<Long, ActionState> actionStateMap = new HashMap<>();
actionStateMap.put(2L, new ActionState("test", 13));
CompletedCheckpoint completedCheckpoint = new CompletedCheckpoint(1, 2, 4324,
Instant.now().toEpochMilli(),
CheckpointType.COMPLETED_POINT_TYPE,
Instant.now().toEpochMilli(),
actionStateMap,
taskStatisticsMap);

ProtoStuffSerializer protoStuffSerializer = new ProtoStuffSerializer();
byte[] data = protoStuffSerializer.serialize(completedCheckpoint);
PipelineState pipelineState = PipelineState.builder()
.checkpointId(1)
.jobId(String.valueOf(1))
.pipelineId(1)
.states(data)
.build();

byte[] pipeData = protoStuffSerializer.serialize(pipelineState);

File file = new File("/tmp/seatunnel/test.data");

FileUtils.writeByteArrayToFile(file, pipeData);

byte[] fileData = FileUtils.readFileToByteArray(file);

PipelineState state = protoStuffSerializer.deserialize(fileData, PipelineState.class);

CompletedCheckpoint checkpoint = new ProtoStuffSerializer().deserialize(state.getStates(), CompletedCheckpoint.class);
Assertions.assertNotNull(checkpoint);
}

}

0 comments on commit 4cecedb

Please sign in to comment.