diff --git a/common/src/main/java/io/netty/util/CompositeFileRegion.java b/common/src/main/java/io/netty/util/CompositeFileRegion.java new file mode 100644 index 0000000000..4549ca0f2d --- /dev/null +++ b/common/src/main/java/io/netty/util/CompositeFileRegion.java @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.netty.util; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import io.netty.channel.FileRegion; + +import org.apache.uniffle.common.netty.protocol.AbstractFileRegion; + +public class CompositeFileRegion extends AbstractFileRegion { + private final FileRegion[] regions; + private long totalSize = 0; + private long bytesTransferred = 0; + + public CompositeFileRegion(FileRegion... regions) { + this.regions = regions; + for (FileRegion region : regions) { + totalSize += region.count(); + } + } + + @Override + public long position() { + return bytesTransferred; + } + + @Override + public long count() { + return totalSize; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + long totalBytesTransferred = 0; + + for (FileRegion region : regions) { + if (position >= region.count()) { + position -= region.count(); + } else { + long currentBytesTransferred = region.transferTo(target, position); + totalBytesTransferred += currentBytesTransferred; + bytesTransferred += currentBytesTransferred; + + if (currentBytesTransferred < region.count() - position) { + break; + } + position = 0; + } + } + + return totalBytesTransferred; + } + + @Override + public long transferred() { + return bytesTransferred; + } + + @Override + public AbstractFileRegion retain() { + super.retain(); + for (FileRegion region : regions) { + region.retain(); + } + return this; + } + + @Override + public AbstractFileRegion retain(int increment) { + super.retain(increment); + for (FileRegion region : regions) { + region.retain(increment); + } + return this; + } + + @Override + public boolean release() { + boolean released = super.release(); + for (FileRegion region : regions) { + if (!region.release()) { + released = false; + } + } + return released; + } + + @Override + public boolean release(int decrement) { + boolean released = super.release(decrement); + for (FileRegion region : regions) { + if (!region.release(decrement)) { + released = false; + } + } + return released; + } + + @Override + protected void deallocate() { + for (FileRegion region : regions) { + if (region instanceof AbstractReferenceCounted) { + ((AbstractReferenceCounted) region).deallocate(); + } + } + } + + @Override + public AbstractFileRegion touch() { + super.touch(); + for (FileRegion region : regions) { + region.touch(); + } + return this; + } + + @Override + public AbstractFileRegion touch(Object hint) { + super.touch(hint); + for (FileRegion region : regions) { + region.touch(hint); + } + return this; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleDataSegment.java b/common/src/main/java/org/apache/uniffle/common/ShuffleDataSegment.java index af7299087c..532921b14c 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleDataSegment.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleDataSegment.java @@ -27,11 +27,15 @@ public class ShuffleDataSegment { private final long offset; private final int length; + + private final int storageId; private final List bufferSegments; - public ShuffleDataSegment(long offset, int length, List bufferSegments) { + public ShuffleDataSegment( + long offset, int length, int storageId, List bufferSegments) { this.offset = offset; this.length = length; + this.storageId = storageId; this.bufferSegments = bufferSegments; } @@ -46,4 +50,8 @@ public int getLength() { public List getBufferSegments() { return bufferSegments; } + + public int getStorageId() { + return storageId; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java index d4f863f890..2d99345464 100644 --- a/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java +++ b/common/src/main/java/org/apache/uniffle/common/ShuffleIndexResult.java @@ -30,8 +30,10 @@ public class ShuffleIndexResult { private static final Logger LOG = LoggerFactory.getLogger(ShuffleIndexResult.class); + private static final int[] DEFAULT_STORAGE_IDS = new int[] {0}; private final ManagedBuffer buffer; + private final int[] storageIds; private long dataFileLen; private String dataFileName; @@ -44,15 +46,28 @@ public ShuffleIndexResult(byte[] data, long dataFileLen) { } public ShuffleIndexResult(ByteBuffer data, long dataFileLen) { - this.buffer = - new NettyManagedBuffer(data != null ? Unpooled.wrappedBuffer(data) : Unpooled.EMPTY_BUFFER); - this.dataFileLen = dataFileLen; + this( + new NettyManagedBuffer(data != null ? Unpooled.wrappedBuffer(data) : Unpooled.EMPTY_BUFFER), + dataFileLen, + null, + DEFAULT_STORAGE_IDS); } public ShuffleIndexResult(ManagedBuffer buffer, long dataFileLen, String dataFileName) { + this(buffer, dataFileLen, dataFileName, DEFAULT_STORAGE_IDS); + } + + public ShuffleIndexResult( + ManagedBuffer buffer, long dataFileLen, String dataFileName, int storageId) { + this(buffer, dataFileLen, dataFileName, new int[] {storageId}); + } + + public ShuffleIndexResult( + ManagedBuffer buffer, long dataFileLen, String dataFileName, int[] storageIds) { this.buffer = buffer; this.dataFileLen = dataFileLen; this.dataFileName = dataFileName; + this.storageIds = storageIds; } public byte[] getData() { @@ -99,4 +114,8 @@ public ManagedBuffer getManagedBuffer() { public String getDataFileName() { return dataFileName; } + + public int[] getStorageIds() { + return storageIds; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java index cd40024826..26c782cf9f 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/MessageEncoder.java @@ -86,7 +86,8 @@ public void encode(ChannelHandlerContext ctx, Message in, List out) thro header.writeInt(bodyLength); in.encode(header); if (header.writableBytes() != 0) { - throw new RssException("header's writable bytes should be 0"); + throw new RssException( + "header's writable bytes should be 0, but it is " + header.writableBytes()); } if (body != null) { diff --git a/common/src/main/java/org/apache/uniffle/common/netty/buffer/MultiFileSegmentManagedBuffer.java b/common/src/main/java/org/apache/uniffle/common/netty/buffer/MultiFileSegmentManagedBuffer.java new file mode 100644 index 0000000000..319f2ede8b --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/netty/buffer/MultiFileSegmentManagedBuffer.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.buffer; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.util.CompositeFileRegion; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** A wrapper of multiple {@link FileSegmentManagedBuffer}, used for combine shuffle index files. */ +public class MultiFileSegmentManagedBuffer extends ManagedBuffer { + + private static final Logger LOG = LoggerFactory.getLogger(MultiFileSegmentManagedBuffer.class); + private final List managedBuffers; + + public MultiFileSegmentManagedBuffer(List managedBuffers) { + this.managedBuffers = managedBuffers; + } + + @Override + public int size() { + return managedBuffers.stream().mapToInt(ManagedBuffer::size).sum(); + } + + @Override + public ByteBuf byteBuf() { + return Unpooled.wrappedBuffer(this.nioByteBuffer()); + } + + @Override + public ByteBuffer nioByteBuffer() { + ByteBuffer merged = ByteBuffer.allocate(size()); + for (ManagedBuffer managedBuffer : managedBuffers) { + ByteBuffer buffer = managedBuffer.nioByteBuffer(); + merged.put(buffer.slice()); + } + merged.flip(); + return merged; + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() { + List fileRegions = new ArrayList<>(managedBuffers.size()); + for (ManagedBuffer managedBuffer : managedBuffers) { + Object object = managedBuffer.convertToNetty(); + if (object instanceof FileRegion) { + fileRegions.add((FileRegion) object); + } + } + return new CompositeFileRegion(fileRegions.toArray(new FileRegion[0])); + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java index b96c028fbd..e8fae16414 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataRequest.java @@ -30,6 +30,7 @@ public class GetLocalShuffleDataRequest extends RequestMessage { private long offset; private int length; private long timestamp; + private int storageId; public GetLocalShuffleDataRequest( long requestId, @@ -41,6 +42,30 @@ public GetLocalShuffleDataRequest( long offset, int length, long timestamp) { + this( + requestId, + appId, + shuffleId, + partitionId, + partitionNumPerRange, + partitionNum, + offset, + length, + -1, + timestamp); + } + + protected GetLocalShuffleDataRequest( + long requestId, + String appId, + int shuffleId, + int partitionId, + int partitionNumPerRange, + int partitionNum, + long offset, + int length, + int storageId, + long timestamp) { super(requestId); this.appId = appId; this.shuffleId = shuffleId; @@ -49,6 +74,7 @@ public GetLocalShuffleDataRequest( this.partitionNum = partitionNum; this.offset = offset; this.length = length; + this.storageId = storageId; this.timestamp = timestamp; } @@ -132,6 +158,10 @@ public long getTimestamp() { return timestamp; } + public int getStorageId() { + return storageId; + } + @Override public String getOperationType() { return "getLocalShuffleData"; diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV2Request.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV2Request.java new file mode 100644 index 0000000000..8dea653ed5 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleDataV2Request.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.protocol; + +import io.netty.buffer.ByteBuf; + +import org.apache.uniffle.common.util.ByteBufUtils; + +public class GetLocalShuffleDataV2Request extends GetLocalShuffleDataRequest { + + public GetLocalShuffleDataV2Request( + long requestId, + String appId, + int shuffleId, + int partitionId, + int partitionNumPerRange, + int partitionNum, + long offset, + int length, + int storageId, + long timestamp) { + super( + requestId, + appId, + shuffleId, + partitionId, + partitionNumPerRange, + partitionNum, + offset, + length, + storageId, + timestamp); + } + + @Override + public Type type() { + return Type.GET_LOCAL_SHUFFLE_DATA_V2_REQUEST; + } + + @Override + public int encodedLength() { + // add int type storageId to encoded length + return super.encodedLength() + Integer.BYTES; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + buf.writeInt(getStorageId()); + } + + public static GetLocalShuffleDataV2Request decode(ByteBuf byteBuf) { + long requestId = byteBuf.readLong(); + String appId = ByteBufUtils.readLengthAndString(byteBuf); + int shuffleId = byteBuf.readInt(); + int partitionId = byteBuf.readInt(); + int partitionNumPerRange = byteBuf.readInt(); + int partitionNum = byteBuf.readInt(); + long offset = byteBuf.readLong(); + int length = byteBuf.readInt(); + long timestamp = byteBuf.readLong(); + int storageId = byteBuf.readInt(); + return new GetLocalShuffleDataV2Request( + requestId, + appId, + shuffleId, + partitionId, + partitionNumPerRange, + partitionNum, + offset, + length, + storageId, + timestamp); + } + + @Override + public String getOperationType() { + return "getLocalShuffleDataV2"; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java index f97373805b..455dda4b15 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexResponse.java @@ -24,6 +24,7 @@ import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; import org.apache.uniffle.common.util.ByteBufUtils; +import org.apache.uniffle.common.util.Constants; public class GetLocalShuffleIndexResponse extends RpcResponse { @@ -92,4 +93,8 @@ public Type type() { public long getFileLength() { return fileLength; } + + public int[] getStorageIds() { + return Constants.EMPTY_INT_ARRAY; + } } diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexV2Response.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexV2Response.java new file mode 100644 index 0000000000..b08a5ca8e2 --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/GetLocalShuffleIndexV2Response.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.netty.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; +import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.ByteBufUtils; +import org.apache.uniffle.common.util.Constants; + +public class GetLocalShuffleIndexV2Response extends GetLocalShuffleIndexResponse { + + private final int[] storageIds; + + public GetLocalShuffleIndexV2Response( + long requestId, StatusCode statusCode, String retMessage, byte[] indexData, long fileLength) { + this( + requestId, + statusCode, + retMessage, + indexData == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(indexData), + fileLength); + } + + public GetLocalShuffleIndexV2Response( + long requestId, + StatusCode statusCode, + String retMessage, + ByteBuf indexData, + long fileLength) { + this( + requestId, + statusCode, + retMessage, + new NettyManagedBuffer(indexData), + fileLength, + Constants.EMPTY_INT_ARRAY); + } + + public GetLocalShuffleIndexV2Response( + long requestId, + StatusCode statusCode, + String retMessage, + ManagedBuffer managedBuffer, + long fileLength, + int[] storageIds) { + super(requestId, statusCode, retMessage, managedBuffer, fileLength); + this.storageIds = storageIds; + } + + @Override + public int encodedLength() { + // super encodedLength + 4(storageIds.length) + 4 * storageIds.length + return super.encodedLength() + Integer.BYTES + Integer.BYTES * storageIds.length; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + buf.writeInt(storageIds.length); + for (int storageId : storageIds) { + buf.writeInt(storageId); + } + } + + public static GetLocalShuffleIndexV2Response decode(ByteBuf byteBuf, boolean decodeBody) { + long requestId = byteBuf.readLong(); + StatusCode statusCode = StatusCode.fromCode(byteBuf.readInt()); + String retMessage = ByteBufUtils.readLengthAndString(byteBuf); + long fileLength = byteBuf.readLong(); + int[] storageIds = new int[byteBuf.readInt()]; + for (int i = 0; i < storageIds.length; i++) { + storageIds[i] = byteBuf.readInt(); + } + if (decodeBody) { + NettyManagedBuffer nettyManagedBuffer = new NettyManagedBuffer(byteBuf); + return new GetLocalShuffleIndexV2Response( + requestId, statusCode, retMessage, nettyManagedBuffer, fileLength, storageIds); + } else { + return new GetLocalShuffleIndexV2Response( + requestId, + statusCode, + retMessage, + NettyManagedBuffer.EMPTY_BUFFER, + fileLength, + storageIds); + } + } + + @Override + public Type type() { + return Type.GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE; + } + + @Override + public int[] getStorageIds() { + return storageIds; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java index e7f18ab2f1..2ad0b0d776 100644 --- a/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java +++ b/common/src/main/java/org/apache/uniffle/common/netty/protocol/Message.java @@ -64,7 +64,10 @@ public enum Type implements Encodable { GET_SHUFFLE_RESULT_FOR_MULTI_PART_RESPONSE(19), REQUIRE_BUFFER_RESPONSE(20), GET_SORTED_SHUFFLE_DATA_REQUEST(21), - GET_SORTED_SHUFFLE_DATA_RESPONSE(22); + GET_SORTED_SHUFFLE_DATA_RESPONSE(22), + GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE(23), + GET_LOCAL_SHUFFLE_DATA_V2_REQUEST(24), + ; private final byte id; @@ -138,6 +141,10 @@ public static Type decode(ByteBuf buf) { return GET_SORTED_SHUFFLE_DATA_REQUEST; case 22: return GET_SORTED_SHUFFLE_DATA_RESPONSE; + case 23: + return GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE; + case 24: + return GET_LOCAL_SHUFFLE_DATA_V2_REQUEST; case -1: throw new IllegalArgumentException("User type messages cannot be decoded."); default: @@ -154,12 +161,16 @@ public static Message decode(Type msgType, ByteBuf in) { return SendShuffleDataRequestV1.decode(in); case GET_LOCAL_SHUFFLE_DATA_REQUEST: return GetLocalShuffleDataRequest.decode(in); + case GET_LOCAL_SHUFFLE_DATA_V2_REQUEST: + return GetLocalShuffleDataV2Request.decode(in); case GET_LOCAL_SHUFFLE_DATA_RESPONSE: return GetLocalShuffleDataResponse.decode(in, true); case GET_LOCAL_SHUFFLE_INDEX_REQUEST: return GetLocalShuffleIndexRequest.decode(in); case GET_LOCAL_SHUFFLE_INDEX_RESPONSE: return GetLocalShuffleIndexResponse.decode(in, true); + case GET_LOCAL_SHUFFLE_INDEX_V2_RESPONSE: + return GetLocalShuffleIndexV2Response.decode(in, true); case GET_MEMORY_SHUFFLE_DATA_REQUEST: return GetMemoryShuffleDataRequest.decode(in); case GET_MEMORY_SHUFFLE_DATA_RESPONSE: diff --git a/common/src/main/java/org/apache/uniffle/common/segment/AbstractSegmentSplitter.java b/common/src/main/java/org/apache/uniffle/common/segment/AbstractSegmentSplitter.java new file mode 100644 index 0000000000..48f64fad4f --- /dev/null +++ b/common/src/main/java/org/apache/uniffle/common/segment/AbstractSegmentSplitter.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.common.segment; + +import java.nio.BufferUnderflowException; +import java.nio.ByteBuffer; +import java.util.List; + +import com.google.common.base.Predicate; +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.BufferSegment; +import org.apache.uniffle.common.ShuffleDataSegment; +import org.apache.uniffle.common.ShuffleIndexResult; +import org.apache.uniffle.common.exception.RssException; + +public abstract class AbstractSegmentSplitter implements SegmentSplitter { + protected static final Logger LOGGER = LoggerFactory.getLogger(AbstractSegmentSplitter.class); + + protected int readBufferSize; + + public AbstractSegmentSplitter(int readBufferSize) { + this.readBufferSize = readBufferSize; + } + + protected List splitCommon( + ShuffleIndexResult shuffleIndexResult, Predicate taskFilter) { + if (shuffleIndexResult == null || shuffleIndexResult.isEmpty()) { + return Lists.newArrayList(); + } + + ByteBuffer indexData = shuffleIndexResult.getIndexData(); + long dataFileLen = shuffleIndexResult.getDataFileLen(); + int[] storageIds = shuffleIndexResult.getStorageIds(); + + List bufferSegments = Lists.newArrayList(); + List dataFileSegments = Lists.newArrayList(); + int bufferOffset = 0; + long fileOffset = -1; + long totalLength = 0; + + int storageIndex = 0; + long preOffset = -1; + int preStorageId = -1; + int currentStorageId = 0; + + while (indexData.hasRemaining()) { + try { + final long offset = indexData.getLong(); + final int length = indexData.getInt(); + final int uncompressLength = indexData.getInt(); + final long crc = indexData.getLong(); + final long blockId = indexData.getLong(); + final long taskAttemptId = indexData.getLong(); + + if (storageIds.length == 0) { + currentStorageId = -1; + } else if (preOffset > offset) { + storageIndex++; + if (storageIndex >= storageIds.length) { + LOGGER.warn("storageIds length {} is not enough.", storageIds.length); + } + currentStorageId = storageIds[storageIndex]; + } else { + currentStorageId = storageIds[storageIndex]; + } + preOffset = offset; + + totalLength += length; + + if (dataFileLen != -1 && totalLength > dataFileLen) { + LOGGER.info( + "Abort inconsistent data, the data length: {}(bytes) recorded in index file is greater than " + + "the real data file length: {}(bytes). Block id: {}" + + "This may happen when the data is flushing, please ignore.", + totalLength, + dataFileLen, + blockId); + break; + } + + boolean storageChanged = preStorageId != -1 && currentStorageId != preStorageId; + + if (bufferOffset >= readBufferSize + || storageChanged + || (taskFilter != null && !taskFilter.test(taskAttemptId))) { + if (bufferOffset > 0) { + ShuffleDataSegment sds = + new ShuffleDataSegment(fileOffset, bufferOffset, preStorageId, bufferSegments); + dataFileSegments.add(sds); + bufferSegments = Lists.newArrayList(); + bufferOffset = 0; + fileOffset = -1; + } + } + + if (taskFilter == null || taskFilter.test(taskAttemptId)) { + if (fileOffset == -1) { + fileOffset = offset; + } + bufferSegments.add( + new BufferSegment( + blockId, bufferOffset, length, uncompressLength, crc, taskAttemptId)); + preStorageId = currentStorageId; + bufferOffset += length; + } + } catch (BufferUnderflowException ue) { + throw new RssException("Read index data under flow", ue); + } + } + + if (bufferOffset > 0) { + ShuffleDataSegment sds = + new ShuffleDataSegment(fileOffset, bufferOffset, currentStorageId, bufferSegments); + dataFileSegments.add(sds); + } + + return dataFileSegments; + } +} diff --git a/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java b/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java index 04763d8428..dd548ebba7 100644 --- a/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java +++ b/common/src/main/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitter.java @@ -17,101 +17,19 @@ package org.apache.uniffle.common.segment; -import java.nio.BufferUnderflowException; -import java.nio.ByteBuffer; import java.util.List; -import com.google.common.collect.Lists; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleDataSegment; import org.apache.uniffle.common.ShuffleIndexResult; -import org.apache.uniffle.common.exception.RssException; - -public class FixedSizeSegmentSplitter implements SegmentSplitter { - private static final Logger LOGGER = LoggerFactory.getLogger(FixedSizeSegmentSplitter.class); - - private int readBufferSize; +public class FixedSizeSegmentSplitter extends AbstractSegmentSplitter { public FixedSizeSegmentSplitter(int readBufferSize) { - this.readBufferSize = readBufferSize; + super(readBufferSize); } @Override public List split(ShuffleIndexResult shuffleIndexResult) { - if (shuffleIndexResult == null || shuffleIndexResult.isEmpty()) { - return Lists.newArrayList(); - } - - ByteBuffer indexData = shuffleIndexResult.getIndexData(); - long dataFileLen = shuffleIndexResult.getDataFileLen(); - return transIndexDataToSegments(indexData, readBufferSize, dataFileLen); - } - - private static List transIndexDataToSegments( - ByteBuffer indexData, int readBufferSize, long dataFileLen) { - List bufferSegments = Lists.newArrayList(); - List dataFileSegments = Lists.newArrayList(); - int bufferOffset = 0; - long fileOffset = -1; - long totalLength = 0; - - while (indexData.hasRemaining()) { - try { - final long offset = indexData.getLong(); - final int length = indexData.getInt(); - final int uncompressLength = indexData.getInt(); - final long crc = indexData.getLong(); - final long blockId = indexData.getLong(); - final long taskAttemptId = indexData.getLong(); - - // The index file is written, read and parsed sequentially, so these parsed index segments - // index a continuous shuffle data in the corresponding data file and the first segment's - // offset field is the offset of these shuffle data in the data file. - if (fileOffset == -1) { - fileOffset = offset; - } - - totalLength += length; - - // If ShuffleServer is flushing the file at this time, the length in the index file record - // may be greater - // than the length in the actual data file, and it needs to be returned at this time to - // avoid EOFException - if (dataFileLen != -1 && totalLength > dataFileLen) { - LOGGER.info( - "Abort inconsistent data, the data length: {}(bytes) recorded in index file is greater than " - + "the real data file length: {}(bytes). Block id: {}" - + "This may happen when the data is flushing, please ignore.", - totalLength, - dataFileLen, - blockId); - break; - } - - bufferSegments.add( - new BufferSegment(blockId, bufferOffset, length, uncompressLength, crc, taskAttemptId)); - bufferOffset += length; - - if (bufferOffset >= readBufferSize) { - ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); - dataFileSegments.add(sds); - bufferSegments = Lists.newArrayList(); - bufferOffset = 0; - fileOffset = -1; - } - } catch (BufferUnderflowException ue) { - throw new RssException("Read index data under flow", ue); - } - } - - if (bufferOffset > 0) { - ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); - dataFileSegments.add(sds); - } - - return dataFileSegments; + // For FixedSizeSegmentSplitter, we do not filter by taskAttemptId, so pass null for the filter. + return splitCommon(shuffleIndexResult, null); } } diff --git a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java index 366968c342..c371b0ad2c 100644 --- a/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java +++ b/common/src/main/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitter.java @@ -17,20 +17,12 @@ package org.apache.uniffle.common.segment; -import java.nio.BufferUnderflowException; -import java.nio.ByteBuffer; -import java.util.ArrayList; import java.util.List; -import com.google.common.collect.Lists; import org.roaringbitmap.longlong.Roaring64NavigableMap; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleDataSegment; import org.apache.uniffle.common.ShuffleIndexResult; -import org.apache.uniffle.common.exception.RssException; /** * {@class LocalOrderSegmentSplitter} will be initialized only when the {@class @@ -46,122 +38,16 @@ *

Last but not least, this split strategy depends on LOCAL_ORDER of index file, which must be * guaranteed by the shuffle server. */ -public class LocalOrderSegmentSplitter implements SegmentSplitter { - private static final Logger LOGGER = LoggerFactory.getLogger(LocalOrderSegmentSplitter.class); - +public class LocalOrderSegmentSplitter extends AbstractSegmentSplitter { private Roaring64NavigableMap expectTaskIds; - private int readBufferSize; public LocalOrderSegmentSplitter(Roaring64NavigableMap expectTaskIds, int readBufferSize) { + super(readBufferSize); this.expectTaskIds = expectTaskIds; - this.readBufferSize = readBufferSize; } @Override public List split(ShuffleIndexResult shuffleIndexResult) { - if (shuffleIndexResult == null || shuffleIndexResult.isEmpty()) { - return Lists.newArrayList(); - } - - ByteBuffer indexData = shuffleIndexResult.getIndexData(); - long dataFileLen = shuffleIndexResult.getDataFileLen(); - - List bufferSegments = Lists.newArrayList(); - - List dataFileSegments = Lists.newArrayList(); - int bufferOffset = 0; - long fileOffset = -1; - long totalLen = 0; - - long lastExpectedBlockIndex = -1; - - List indexTaskIds = new ArrayList<>(); - - /** - * One ShuffleDataSegment should meet following requirements: - * - *

1. taskId in [startMapIndex, endMapIndex) taskIds bitmap. Attention: the index in the - * range is not the map task id, which means the required task ids are not continuous. 2. - * ShuffleDataSegment size should < readBufferSize 3. Single shuffleDataSegment's blocks should - * be continuous - */ - int index = 0; - while (indexData.hasRemaining()) { - try { - long offset = indexData.getLong(); - int length = indexData.getInt(); - int uncompressLength = indexData.getInt(); - long crc = indexData.getLong(); - long blockId = indexData.getLong(); - long taskAttemptId = indexData.getLong(); - - totalLen += length; - indexTaskIds.add(taskAttemptId); - - // If ShuffleServer is flushing the file at this time, the length in the index file record - // may be greater - // than the length in the actual data file, and it needs to be returned at this time to - // avoid EOFException - if (dataFileLen != -1 && totalLen > dataFileLen) { - LOGGER.info( - "Abort inconsistent data, the data length: {}(bytes) recorded in index file is greater than " - + "the real data file length: {}(bytes). Block id: {}. This should not happen. " - + "This may happen when the data is flushing, please ignore.", - totalLen, - dataFileLen, - blockId); - break; - } - - boolean conditionOfDiscontinuousBlocks = - lastExpectedBlockIndex != -1 - && bufferSegments.size() > 0 - && expectTaskIds.contains(taskAttemptId) - && index - lastExpectedBlockIndex != 1; - - boolean conditionOfLimitedBufferSize = bufferOffset >= readBufferSize; - - if (conditionOfDiscontinuousBlocks || conditionOfLimitedBufferSize) { - ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); - dataFileSegments.add(sds); - bufferSegments = Lists.newArrayList(); - bufferOffset = 0; - fileOffset = -1; - } - - if (expectTaskIds.contains(taskAttemptId)) { - if (fileOffset == -1) { - fileOffset = offset; - } - bufferSegments.add( - new BufferSegment( - blockId, bufferOffset, length, uncompressLength, crc, taskAttemptId)); - bufferOffset += length; - lastExpectedBlockIndex = index; - } - index++; - } catch (BufferUnderflowException ue) { - throw new RssException("Read index data under flow", ue); - } - } - - if (bufferOffset > 0) { - ShuffleDataSegment sds = new ShuffleDataSegment(fileOffset, bufferOffset, bufferSegments); - dataFileSegments.add(sds); - } - - if (LOGGER.isDebugEnabled()) { - LOGGER.debug( - "Index file task-ids sequence: {}, expected task-ids: {}", - indexTaskIds, - getExpectedTaskIds(expectTaskIds)); - } - return dataFileSegments; - } - - private List getExpectedTaskIds(Roaring64NavigableMap expectTaskIds) { - List taskIds = new ArrayList<>(); - expectTaskIds.forEach(value -> taskIds.add(value)); - return taskIds; + return splitCommon(shuffleIndexResult, taskAttemptId -> expectTaskIds.contains(taskAttemptId)); } } diff --git a/common/src/main/java/org/apache/uniffle/common/util/Constants.java b/common/src/main/java/org/apache/uniffle/common/util/Constants.java index 79ceb2f10f..c96eedc1d6 100644 --- a/common/src/main/java/org/apache/uniffle/common/util/Constants.java +++ b/common/src/main/java/org/apache/uniffle/common/util/Constants.java @@ -93,4 +93,5 @@ private Constants() {} public static final String DATE_PATTERN = "yyyy-MM-dd HH:mm:ss"; public static final String SPARK_RSS_CONFIG_PREFIX = "spark."; + public static final int[] EMPTY_INT_ARRAY = new int[0]; } diff --git a/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java b/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java index b681f63ca4..6834f8b154 100644 --- a/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java +++ b/common/src/test/java/org/apache/uniffle/common/segment/FixedSizeSegmentSplitterTest.java @@ -20,13 +20,16 @@ import java.nio.ByteBuffer; import java.util.List; +import io.netty.buffer.Unpooled; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; import org.apache.uniffle.common.ShuffleDataSegment; import org.apache.uniffle.common.ShuffleIndexResult; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import static org.apache.uniffle.common.segment.LocalOrderSegmentSplitterTest.generateData; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -90,4 +93,52 @@ public void testSplit() { assertTrue(e.getMessage().contains("Read index data under flow")); } } + + @Test + @DisplayName("Test splitting with storage ID changes") + void testSplitContainsStorageId() { + SegmentSplitter splitter = new FixedSizeSegmentSplitter(50); + int[] storageIds = new int[] {1, 2, 3}; + byte[] data0 = + generateData(Pair.of(32, 0), Pair.of(16, 0), Pair.of(10, 0), Pair.of(32, 6), Pair.of(6, 0)); + byte[] data1 = + generateData(Pair.of(32, 1), Pair.of(16, 0), Pair.of(10, 0), Pair.of(32, 6), Pair.of(6, 0)); + byte[] data2 = + generateData(Pair.of(32, 1), Pair.of(16, 0), Pair.of(10, 0), Pair.of(32, 6), Pair.of(6, 0)); + + ByteBuffer dataCombined = + ByteBuffer.allocate(data0.length + data1.length + data2.length) + .put(data0) + .put(data1) + .put(data2); + dataCombined.flip(); + List shuffleDataSegments = + splitter.split( + new ShuffleIndexResult( + new NettyManagedBuffer(Unpooled.wrappedBuffer(dataCombined)), -1L, "", storageIds)); + assertEquals(6, shuffleDataSegments.size(), "Expected 6 segments"); + assertSegment(shuffleDataSegments.get(0), 0, 58, 3, storageIds[0]); + // split while previous data segments over read buffer size + assertSegment(shuffleDataSegments.get(1), 58, 38, 2, storageIds[0]); + // split while storage id changed, which offset less than previous offset + assertSegment(shuffleDataSegments.get(2), 0, 58, 3, storageIds[1]); + // split while previous data segments over read buffer size + assertSegment(shuffleDataSegments.get(3), 58, 38, 2, storageIds[1]); + // split while storage id changed, which offset less than previous offset + assertSegment(shuffleDataSegments.get(4), 0, 58, 3, storageIds[2]); + // split while previous data segments over read buffer size + assertSegment(shuffleDataSegments.get(5), 58, 38, 2, storageIds[2]); + } + + private void assertSegment( + ShuffleDataSegment segment, + int expectedOffset, + int expectedLength, + int expectedSize, + int expectedStorageId) { + assertEquals(expectedOffset, segment.getOffset(), "Incorrect offset"); + assertEquals(expectedLength, segment.getLength(), "Incorrect length"); + assertEquals(expectedSize, segment.getBufferSegments().size(), "Incorrect buffer segment size"); + assertEquals(expectedStorageId, segment.getStorageId(), "Incorrect storage ID"); + } } diff --git a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java index 9e0ff1e5c7..3ee5815f4f 100644 --- a/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java +++ b/common/src/test/java/org/apache/uniffle/common/segment/LocalOrderSegmentSplitterTest.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; import java.util.List; +import io.netty.buffer.Unpooled; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -29,6 +31,7 @@ import org.apache.uniffle.common.BufferSegment; import org.apache.uniffle.common.ShuffleDataSegment; import org.apache.uniffle.common.ShuffleIndexResult; +import org.apache.uniffle.common.netty.buffer.NettyManagedBuffer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -375,4 +378,58 @@ public static byte[] generateData(Pair... configEntries) { } return byteBuffer.array(); } + + @Test + @DisplayName("Test splitting with storage ID changes") + void testSplitContainsStorageId() { + Roaring64NavigableMap taskIds = Roaring64NavigableMap.bitmapOf(1); + SegmentSplitter splitter = new LocalOrderSegmentSplitter(taskIds, 50); + // generate 3 batch segments with storage ID 1, 2, 3 to simulate 3 flush event write 3 index + // file + int[] storageIds = new int[] {1, 2, 3}; + byte[] data0 = + generateData(Pair.of(32, 0), Pair.of(16, 0), Pair.of(10, 0), Pair.of(32, 6), Pair.of(6, 0)); + byte[] data1 = + generateData(Pair.of(32, 1), Pair.of(26, 1), Pair.of(10, 1), Pair.of(32, 0), Pair.of(6, 1)); + byte[] data2 = + generateData(Pair.of(32, 1), Pair.of(16, 0), Pair.of(10, 0), Pair.of(32, 6), Pair.of(6, 0)); + ByteBuffer dataCombined = + ByteBuffer.allocate(data0.length + data1.length + data2.length) + .put(data0) + .put(data1) + .put(data2); + dataCombined.flip(); + List shuffleDataSegments = + splitter.split( + new ShuffleIndexResult( + new NettyManagedBuffer(Unpooled.wrappedBuffer(dataCombined)), -1L, "", storageIds)); + assertEquals(4, shuffleDataSegments.size(), "Expected 3 segments"); + assertEquals( + 5, + shuffleDataSegments.stream().mapToInt(s -> s.getBufferSegments().size()).sum(), + "Incorrect total size of buffer segments"); + // First data segment come from the 0,1 part of data1 since first data array contains none with + // taskId 1. + assertSegment(shuffleDataSegments.get(0), 0, 58, 2, storageIds[1]); + // This data segment come from the No.3 part of data1 since No.4 part is not belong to taskId 1, + // close this data segment cause discontinuous block. + assertSegment(shuffleDataSegments.get(1), 58, 10, 1, storageIds[1]); + // This data segment come from the No.5 part of data1 since storage id changed. + assertSegment(shuffleDataSegments.get(2), 100, 6, 1, storageIds[1]); + // This data segment come from the No.1 part of data2 since other parts are not belong to taskId + // 1. + assertSegment(shuffleDataSegments.get(3), 0, 32, 1, storageIds[2]); + } + + private void assertSegment( + ShuffleDataSegment segment, + int expectedOffset, + int expectedLength, + int expectedSize, + int expectedStorageId) { + assertEquals(expectedOffset, segment.getOffset(), "Incorrect offset"); + assertEquals(expectedLength, segment.getLength(), "Incorrect length"); + assertEquals(expectedSize, segment.getBufferSegments().size(), "Incorrect buffer segment size"); + assertEquals(expectedStorageId, segment.getStorageId(), "Incorrect storage ID"); + } } diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java index 85499f3b0a..2ab9f7b8fd 100644 --- a/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java +++ b/integration-test/common/src/test/java/org/apache/uniffle/test/RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed.java @@ -37,6 +37,7 @@ import io.netty.buffer.ByteBuf; import org.apache.hadoop.io.IntWritable; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; @@ -68,9 +69,11 @@ import org.apache.uniffle.server.ShuffleServer; import org.apache.uniffle.server.ShuffleServerConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; +import org.apache.uniffle.server.storage.MultiPartLocalStorageManager; import org.apache.uniffle.storage.util.StorageType; import static org.apache.uniffle.coordinator.CoordinatorConf.COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED; +import static org.apache.uniffle.server.ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MEMORY_SHUFFLE_HIGHWATERMARK_PERCENTAGE; import static org.apache.uniffle.server.ShuffleServerConf.SERVER_MEMORY_SHUFFLE_LOWWATERMARK_PERCENTAGE; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -88,8 +91,13 @@ public class RemoteMergeShuffleWithRssClientTestWhenShuffleFlushed extends Shuff public static void setupServers(@TempDir File tmpDir) throws Exception { CoordinatorConf coordinatorConf = getCoordinatorConf(); coordinatorConf.setBoolean(COORDINATOR_DYNAMIC_CLIENT_CONF_ENABLED, false); - createCoordinatorServer(coordinatorConf); ShuffleServerConf shuffleServerConf = getShuffleServerConf(ServerType.GRPC_NETTY); + Assumptions.assumeTrue( + !shuffleServerConf + .get(SERVER_LOCAL_STORAGE_MANAGER_CLASS) + .equals(MultiPartLocalStorageManager.class.getName()), + MultiPartLocalStorageManager.class.getName() + " is not working with remote merge feature"); + createCoordinatorServer(coordinatorConf); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_ENABLE, true); shuffleServerConf.set(ShuffleServerConf.SERVER_MERGE_DEFAULT_MERGED_BLOCK_SIZE, "1k"); shuffleServerConf.set( diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalForMultiPartLocalStorageManagerTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalForMultiPartLocalStorageManagerTest.java new file mode 100644 index 0000000000..bebee47eaf --- /dev/null +++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkClientWithLocalForMultiPartLocalStorageManagerTest.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.test; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.roaringbitmap.longlong.Roaring64NavigableMap; + +import org.apache.uniffle.client.factory.ShuffleClientFactory; +import org.apache.uniffle.client.impl.ShuffleReadClientImpl; +import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcClient; +import org.apache.uniffle.client.impl.grpc.ShuffleServerGrpcNettyClient; +import org.apache.uniffle.client.request.RssFinishShuffleRequest; +import org.apache.uniffle.client.request.RssRegisterShuffleRequest; +import org.apache.uniffle.client.request.RssSendCommitRequest; +import org.apache.uniffle.client.request.RssSendShuffleDataRequest; +import org.apache.uniffle.common.ClientType; +import org.apache.uniffle.common.PartitionRange; +import org.apache.uniffle.common.ShuffleBlockInfo; +import org.apache.uniffle.common.ShuffleServerInfo; +import org.apache.uniffle.common.rpc.ServerType; +import org.apache.uniffle.coordinator.CoordinatorConf; +import org.apache.uniffle.server.ShuffleDataReadEvent; +import org.apache.uniffle.server.ShuffleServer; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.server.storage.MultiPartLocalStorageManager; +import org.apache.uniffle.storage.common.LocalStorage; +import org.apache.uniffle.storage.util.StorageType; + +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class SparkClientWithLocalForMultiPartLocalStorageManagerTest extends ShuffleReadWriteBase { + + private static File GRPC_DATA_DIR1; + private static File GRPC_DATA_DIR2; + private static File NETTY_DATA_DIR1; + private static File NETTY_DATA_DIR2; + private ShuffleServerGrpcClient grpcShuffleServerClient; + private ShuffleServerGrpcNettyClient nettyShuffleServerClient; + private static ShuffleServerConf grpcShuffleServerConfig; + private static ShuffleServerConf nettyShuffleServerConfig; + + @BeforeAll + public static void setupServers(@TempDir File tmpDir) throws Exception { + CoordinatorConf coordinatorConf = getCoordinatorConf(); + createCoordinatorServer(coordinatorConf); + + GRPC_DATA_DIR1 = new File(tmpDir, "data1"); + GRPC_DATA_DIR2 = new File(tmpDir, "data2"); + String grpcBasePath = GRPC_DATA_DIR1.getAbsolutePath() + "," + GRPC_DATA_DIR2.getAbsolutePath(); + ShuffleServerConf grpcShuffleServerConf = buildShuffleServerConf(grpcBasePath, ServerType.GRPC); + grpcShuffleServerConf.set( + ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS, + MultiPartLocalStorageManager.class.getName()); + createShuffleServer(grpcShuffleServerConf); + + NETTY_DATA_DIR1 = new File(tmpDir, "netty_data1"); + NETTY_DATA_DIR2 = new File(tmpDir, "netty_data2"); + String nettyBasePath = + NETTY_DATA_DIR1.getAbsolutePath() + "," + NETTY_DATA_DIR2.getAbsolutePath(); + ShuffleServerConf nettyShuffleServerConf = + buildShuffleServerConf(nettyBasePath, ServerType.GRPC_NETTY); + nettyShuffleServerConf.set( + ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS, + MultiPartLocalStorageManager.class.getName()); + createShuffleServer(nettyShuffleServerConf); + + startServers(); + + grpcShuffleServerConfig = grpcShuffleServerConf; + nettyShuffleServerConfig = nettyShuffleServerConf; + } + + private static ShuffleServerConf buildShuffleServerConf(String basePath, ServerType serverType) + throws Exception { + ShuffleServerConf shuffleServerConf = getShuffleServerConf(serverType); + shuffleServerConf.setString("rss.storage.type", StorageType.LOCALFILE.name()); + shuffleServerConf.setString("rss.storage.basePath", basePath); + return shuffleServerConf; + } + + @BeforeEach + public void createClient() throws Exception { + grpcShuffleServerClient = + new ShuffleServerGrpcClient( + LOCALHOST, grpcShuffleServerConfig.getInteger(ShuffleServerConf.RPC_SERVER_PORT)); + nettyShuffleServerClient = + new ShuffleServerGrpcNettyClient( + LOCALHOST, + nettyShuffleServerConfig.getInteger(ShuffleServerConf.RPC_SERVER_PORT), + nettyShuffleServerConfig.getInteger(ShuffleServerConf.NETTY_SERVER_PORT)); + } + + @AfterEach + public void closeClient() { + grpcShuffleServerClient.close(); + nettyShuffleServerClient.close(); + } + + private ShuffleClientFactory.ReadClientBuilder baseReadBuilder(boolean isNettyMode) { + List shuffleServerInfo = + isNettyMode + ? Lists.newArrayList( + new ShuffleServerInfo( + LOCALHOST, + nettyShuffleServerConfig.getInteger(ShuffleServerConf.RPC_SERVER_PORT), + nettyShuffleServerConfig.getInteger(ShuffleServerConf.NETTY_SERVER_PORT))) + : Lists.newArrayList( + new ShuffleServerInfo( + LOCALHOST, + grpcShuffleServerConfig.getInteger(ShuffleServerConf.RPC_SERVER_PORT))); + return ShuffleClientFactory.newReadBuilder() + .clientType(isNettyMode ? ClientType.GRPC_NETTY : ClientType.GRPC) + .storageType(StorageType.LOCALFILE.name()) + .shuffleId(0) + .partitionId(0) + .indexReadLimit(100) + .partitionNumPerRange(1) + .partitionNum(10) + .readBufferSize(1000) + .shuffleServerInfoList(shuffleServerInfo); + } + + private static Stream isNettyModeProvider() { + return Stream.of(Arguments.of(true), Arguments.of(false)); + } + + @ParameterizedTest + @MethodSource("isNettyModeProvider") + public void testClientRemoteReadFromMultipleDisk(boolean isNettyMode) { + String testAppId = "testClientRemoteReadFromMultipleDisk_appId"; + registerApp(testAppId, Lists.newArrayList(new PartitionRange(0, 0)), isNettyMode); + + // Send shuffle data + Map expectedData = Maps.newHashMap(); + Roaring64NavigableMap blockIdBitmap = Roaring64NavigableMap.bitmapOf(); + + List blocks = + createShuffleBlockList(0, 0, 0, 5, 30, blockIdBitmap, expectedData, mockSSI); + sendTestData(testAppId, blocks, isNettyMode); + + List shuffleServers = isNettyMode ? nettyShuffleServers : grpcShuffleServers; + // Mark one storage reaching high watermark, it should switch another storage for next writing + ShuffleServer shuffleServer = shuffleServers.get(0); + ShuffleDataReadEvent readEvent = new ShuffleDataReadEvent(testAppId, 0, 0, 0, 0); + LocalStorage storage1 = + (LocalStorage) shuffleServer.getStorageManager().selectStorage(readEvent); + storage1.getMetaData().setSize(20 * 1024 * 1024); + + blocks = createShuffleBlockList(0, 0, 0, 3, 25, blockIdBitmap, expectedData, mockSSI); + sendTestData(testAppId, blocks, isNettyMode); + + readEvent = new ShuffleDataReadEvent(testAppId, 0, 0, 0, 1); + LocalStorage storage2 = + (LocalStorage) shuffleServer.getStorageManager().selectStorage(readEvent); + assertNotEquals(storage1, storage2); + + Roaring64NavigableMap taskIdBitmap = Roaring64NavigableMap.bitmapOf(0); + // unexpected taskAttemptId should be filtered + ShuffleReadClientImpl readClient = + baseReadBuilder(isNettyMode) + .appId(testAppId) + .blockIdBitmap(blockIdBitmap) + .taskIdBitmap(taskIdBitmap) + .build(); + + validateResult(readClient, expectedData); + readClient.checkProcessedBlockIds(); + readClient.close(); + } + + protected void registerApp( + String testAppId, List partitionRanges, boolean isNettyMode) { + ShuffleServerGrpcClient shuffleServerClient = + isNettyMode ? nettyShuffleServerClient : grpcShuffleServerClient; + RssRegisterShuffleRequest rrsr = + new RssRegisterShuffleRequest(testAppId, 0, partitionRanges, ""); + shuffleServerClient.registerShuffle(rrsr); + } + + protected void sendTestData( + String testAppId, List blocks, boolean isNettyMode) { + ShuffleServerGrpcClient shuffleServerClient = + isNettyMode ? nettyShuffleServerClient : grpcShuffleServerClient; + Map> partitionToBlocks = Maps.newHashMap(); + partitionToBlocks.put(0, blocks); + + Map>> shuffleToBlocks = Maps.newHashMap(); + shuffleToBlocks.put(0, partitionToBlocks); + + RssSendShuffleDataRequest rssdr = + new RssSendShuffleDataRequest(testAppId, 3, 1000, shuffleToBlocks); + shuffleServerClient.sendShuffleData(rssdr); + RssSendCommitRequest rscr = new RssSendCommitRequest(testAppId, 0); + shuffleServerClient.sendCommit(rscr); + RssFinishShuffleRequest rfsr = new RssFinishShuffleRequest(testAppId, 0); + shuffleServerClient.finishShuffle(rfsr); + } +} diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java index c7471991a1..56f02721da 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcClient.java @@ -923,6 +923,7 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request .setOffset(request.getOffset()) .setLength(request.getLength()) .setTimestamp(start) + .setStorageId(request.getStorageId()) .build(); String requestInfo = "appId[" @@ -931,6 +932,8 @@ public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request + request.getShuffleId() + "], partitionId[" + request.getPartitionId() + + "], storageId[" + + request.getStorageId() + "]"; int retry = 0; GetLocalShuffleDataResponse rpcResponse; @@ -1016,7 +1019,8 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ StatusCode.SUCCESS, new NettyManagedBuffer( Unpooled.wrappedBuffer(rpcResponse.getIndexData().toByteArray())), - rpcResponse.getDataFileLen()); + rpcResponse.getDataFileLen(), + rpcResponse.getStorageIdsList().stream().mapToInt(Integer::intValue).toArray()); break; default: diff --git a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java index a89ea9c4c6..fbc4e363bc 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/impl/grpc/ShuffleServerGrpcNettyClient.java @@ -52,6 +52,7 @@ import org.apache.uniffle.common.netty.client.TransportContext; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse; +import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataV2Request; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest; @@ -350,7 +351,8 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ return new RssGetShuffleIndexResponse( StatusCode.SUCCESS, getLocalShuffleIndexResponse.body(), - getLocalShuffleIndexResponse.getFileLength()); + getLocalShuffleIndexResponse.getFileLength(), + getLocalShuffleIndexResponse.getStorageIds()); default: String msg = "Can't get shuffle index from " @@ -369,17 +371,30 @@ public RssGetShuffleIndexResponse getShuffleIndex(RssGetShuffleIndexRequest requ @Override public RssGetShuffleDataResponse getShuffleData(RssGetShuffleDataRequest request) { TransportClient transportClient = getTransportClient(); + // Construct old version or v2 get shuffle data request to compatible with old server GetLocalShuffleDataRequest getLocalShuffleIndexRequest = - new GetLocalShuffleDataRequest( - requestId(), - request.getAppId(), - request.getShuffleId(), - request.getPartitionId(), - request.getPartitionNumPerRange(), - request.getPartitionNum(), - request.getOffset(), - request.getLength(), - System.currentTimeMillis()); + request.storageIdSpecified() + ? new GetLocalShuffleDataV2Request( + requestId(), + request.getAppId(), + request.getShuffleId(), + request.getPartitionId(), + request.getPartitionNumPerRange(), + request.getPartitionNum(), + request.getOffset(), + request.getLength(), + request.getStorageId(), + System.currentTimeMillis()) + : new GetLocalShuffleDataRequest( + requestId(), + request.getAppId(), + request.getShuffleId(), + request.getPartitionId(), + request.getPartitionNumPerRange(), + request.getPartitionNum(), + request.getOffset(), + request.getLength(), + System.currentTimeMillis()); String requestInfo = "appId[" + request.getAppId() diff --git a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java index 5801922171..c245e48b7c 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/request/RssGetShuffleDataRequest.java @@ -28,6 +28,7 @@ public class RssGetShuffleDataRequest extends RetryableRequest { private final int partitionNum; private final long offset; private final int length; + private final int storageId; public RssGetShuffleDataRequest( String appId, @@ -37,6 +38,7 @@ public RssGetShuffleDataRequest( int partitionNum, long offset, int length, + int storageId, int retryMax, long retryIntervalMax) { this.appId = appId; @@ -46,6 +48,7 @@ public RssGetShuffleDataRequest( this.partitionNum = partitionNum; this.offset = offset; this.length = length; + this.storageId = storageId; this.retryMax = retryMax; this.retryIntervalMax = retryIntervalMax; } @@ -59,7 +62,17 @@ public RssGetShuffleDataRequest( int partitionNum, long offset, int length) { - this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, offset, length, 1, 0); + this( + appId, + shuffleId, + partitionId, + partitionNumPerRange, + partitionNum, + offset, + length, + -1, + 1, + 0); } public String getAppId() { @@ -90,6 +103,14 @@ public int getLength() { return length; } + public int getStorageId() { + return storageId; + } + + public boolean storageIdSpecified() { + return storageId != -1; + } + @Override public String operationType() { return "GetShuffleData"; diff --git a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java index 4d3667ab17..7c42ac8137 100644 --- a/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java +++ b/internal-client/src/main/java/org/apache/uniffle/client/response/RssGetShuffleIndexResponse.java @@ -20,13 +20,19 @@ import org.apache.uniffle.common.ShuffleIndexResult; import org.apache.uniffle.common.netty.buffer.ManagedBuffer; import org.apache.uniffle.common.rpc.StatusCode; +import org.apache.uniffle.common.util.Constants; public class RssGetShuffleIndexResponse extends ClientResponse { private final ShuffleIndexResult shuffleIndexResult; - public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data, long dataFileLen) { + public RssGetShuffleIndexResponse( + StatusCode statusCode, ManagedBuffer data, long dataFileLen, int[] storageIds) { super(statusCode); - this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen, null); + this.shuffleIndexResult = new ShuffleIndexResult(data, dataFileLen, null, storageIds); + } + + public RssGetShuffleIndexResponse(StatusCode statusCode, ManagedBuffer data, long dataFileLen) { + this(statusCode, data, dataFileLen, Constants.EMPTY_INT_ARRAY); } public ShuffleIndexResult getShuffleIndexResult() { diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto index e883eddc43..7209d85bb4 100644 --- a/proto/src/main/proto/Rss.proto +++ b/proto/src/main/proto/Rss.proto @@ -86,6 +86,7 @@ message GetLocalShuffleDataRequest { int64 offset = 6; int32 length = 7; int64 timestamp = 8; + int32 storageId = 9; } message GetLocalShuffleDataResponse { @@ -124,6 +125,7 @@ message GetLocalShuffleIndexResponse { StatusCode status = 2; string retMsg = 3; int64 dataFileLen = 4; + repeated int32 storageIds = 5; } message ReportShuffleResultRequest { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java b/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java index 11a8aeff71..cc121c47d5 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java @@ -23,6 +23,7 @@ public class ShuffleDataReadEvent { private int shuffleId; private int partitionId; private int startPartition; + private int storageId; public ShuffleDataReadEvent( String appId, int shuffleId, int partitionId, int startPartitionOfRange) { @@ -32,6 +33,12 @@ public ShuffleDataReadEvent( this.startPartition = startPartitionOfRange; } + public ShuffleDataReadEvent( + String appId, int shuffleId, int partitionId, int startPartitionOfRange, int storageId) { + this(appId, shuffleId, partitionId, startPartitionOfRange); + this.storageId = storageId; + } + public String getAppId() { return appId; } @@ -47,4 +54,8 @@ public int getPartitionId() { public int getStartPartition() { return startPartition; } + + public int getStorageId() { + return storageId; + } } diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java index fd2f5802c2..2ec535e74c 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerConf.java @@ -26,6 +26,7 @@ import org.apache.uniffle.common.config.ConfigUtils; import org.apache.uniffle.common.config.RssBaseConf; import org.apache.uniffle.server.buffer.ShuffleBufferType; +import org.apache.uniffle.server.storage.LocalStorageManager; public class ShuffleServerConf extends RssBaseConf { @@ -744,6 +745,12 @@ public class ShuffleServerConf extends RssBaseConf { .defaultValue(false) .withDescription("Whether to enable app detail log"); + public static final ConfigOption SERVER_LOCAL_STORAGE_MANAGER_CLASS = + ConfigOptions.key("rss.server.localStorageManagerClass") + .stringType() + .defaultValue(LocalStorageManager.class.getName()) + .withDescription("The class of local storage manager implementation"); + public ShuffleServerConf() {} public ShuffleServerConf(String fileName) { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java index 0c38960d00..79206492aa 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -1083,6 +1084,7 @@ public void getLocalShuffleData( int partitionNum = request.getPartitionNum(); long offset = request.getOffset(); int length = request.getLength(); + int storageId = request.getStorageId(); auditContext.withAppId(appId).withShuffleId(shuffleId); auditContext.withArgs( @@ -1095,7 +1097,9 @@ public void getLocalShuffleData( + ", offset=" + offset + ", length=" - + length); + + length + + ", storageId=" + + storageId); StatusCode status = verifyRequest(appId); if (status != StatusCode.SUCCESS) { @@ -1144,7 +1148,8 @@ public void getLocalShuffleData( Storage storage = shuffleServer .getStorageManager() - .selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0])); + .selectStorage( + new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0], storageId)); if (storage != null) { storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId)); } @@ -1163,7 +1168,8 @@ public void getLocalShuffleData( partitionNum, storageType, offset, - length); + length, + storageId); long readTime = System.currentTimeMillis() - start; ShuffleServerMetrics.counterTotalReadTime.inc(readTime); ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getDataLength()); @@ -1300,6 +1306,10 @@ public void getLocalShuffleIndex( builder.setIndexData(UnsafeByteOperations.unsafeWrap(data)); builder.setDataFileLen(shuffleIndexResult.getDataFileLen()); + builder.addAllStorageIds( + Arrays.stream(shuffleIndexResult.getStorageIds()) + .boxed() + .collect(Collectors.toList())); auditContext.withReturnValue("len=" + shuffleIndexResult.getDataFileLen()); reply = builder.build(); } catch (FileNotFoundException indexFileNotFoundException) { diff --git a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java index 6211e66735..d9a06da1e2 100644 --- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java +++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java @@ -789,7 +789,8 @@ public ShuffleDataResult getShuffleData( int partitionNum, String storageType, long offset, - int length) { + int length, + int storageId) { refreshAppId(appId); CreateShuffleReadHandlerRequest request = new CreateShuffleReadHandlerRequest(); @@ -804,12 +805,24 @@ public ShuffleDataResult getShuffleData( ShuffleStorageUtils.getPartitionRange(partitionId, partitionNumPerRange, partitionNum); Storage storage = storageManager.selectStorage( - new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0])); + new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0], storageId)); if (storage == null) { throw new FileNotFoundException("No such data stored in current storage manager."); } - return storage.getOrCreateReadHandler(request).getShuffleData(offset, length); + // only one partition part in one storage + try { + return storage.getOrCreateReadHandler(request).getShuffleData(offset, length); + } catch (FileNotFoundException e) { + LOG.warn( + "shuffle file not found {}-{}-{} in {}", + appId, + shuffleId, + partitionId, + storage.getStoragePath(), + e); + throw e; + } } public ShuffleIndexResult getShuffleIndex( @@ -831,12 +844,16 @@ public ShuffleIndexResult getShuffleIndex( int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, partitionNumPerRange, partitionNum); Storage storage = - storageManager.selectStorage( + storageManager.selectStorageById( new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0])); if (storage == null) { throw new FileNotFoundException("No such data in current storage manager."); } - return storage.getOrCreateReadHandler(request).getShuffleIndex(); + ShuffleIndexResult result = storage.getOrCreateReadHandler(request).getShuffleIndex(); + if (result == null) { + throw new FileNotFoundException("No such data in current storage manager."); + } + return result; } public void checkResourceStatus() { diff --git a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java index b5137d818c..ab6ce2834e 100644 --- a/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java +++ b/server/src/main/java/org/apache/uniffle/server/netty/ShuffleServerNettyHandler.java @@ -53,6 +53,7 @@ import org.apache.uniffle.common.netty.protocol.GetLocalShuffleDataResponse; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexRequest; import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexResponse; +import org.apache.uniffle.common.netty.protocol.GetLocalShuffleIndexV2Response; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataRequest; import org.apache.uniffle.common.netty.protocol.GetMemoryShuffleDataResponse; import org.apache.uniffle.common.netty.protocol.GetSortedShuffleDataRequest; @@ -552,7 +553,7 @@ public void handleGetLocalShuffleIndexRequest( } String msg = "OK"; - GetLocalShuffleIndexResponse response; + GetLocalShuffleIndexV2Response response; int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, partitionNumPerRange, partitionNum); Storage storage = @@ -587,8 +588,13 @@ public void handleGetLocalShuffleIndexRequest( auditContext.withStatusCode(status); auditContext.withReturnValue("len=" + data.size()); response = - new GetLocalShuffleIndexResponse( - req.getRequestId(), status, msg, data, shuffleIndexResult.getDataFileLen()); + new GetLocalShuffleIndexV2Response( + req.getRequestId(), + status, + msg, + data, + shuffleIndexResult.getDataFileLen(), + shuffleIndexResult.getStorageIds()); ReleaseMemoryAndRecordReadTimeListener listener = new ReleaseMemoryAndRecordReadTimeListener( start, assumedFileSize, data.size(), requestInfo, req, response, client); @@ -605,7 +611,7 @@ public void handleGetLocalShuffleIndexRequest( requestInfo, indexFileNotFoundException); response = - new GetLocalShuffleIndexResponse( + new GetLocalShuffleIndexV2Response( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); } catch (Exception e) { shuffleServer.getShuffleBufferManager().releaseReadMemory(assumedFileSize); @@ -616,7 +622,7 @@ public void handleGetLocalShuffleIndexRequest( msg = "Error happened when get shuffle index for " + requestInfo + ", " + e.getMessage(); LOG.error(msg, e); response = - new GetLocalShuffleIndexResponse( + new GetLocalShuffleIndexV2Response( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); } } else { @@ -624,7 +630,7 @@ public void handleGetLocalShuffleIndexRequest( msg = "Can't require memory to get shuffle index"; LOG.warn("{} for {}", msg, requestInfo); response = - new GetLocalShuffleIndexResponse( + new GetLocalShuffleIndexV2Response( req.getRequestId(), status, msg, Unpooled.EMPTY_BUFFER, 0L); } auditContext.withStatusCode(response.getStatusCode()); @@ -642,6 +648,7 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat int partitionNum = req.getPartitionNum(); long offset = req.getOffset(); int length = req.getLength(); + int storageId = req.getStorageId(); auditContext.withAppId(appId); auditContext.withShuffleId(shuffleId); auditContext.withArgs( @@ -656,7 +663,9 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat + ", offset=" + offset + ", length=" - + length); + + length + + ", storageId=" + + storageId); StatusCode status = verifyRequest(appId); if (status != StatusCode.SUCCESS) { auditContext.withStatusCode(status); @@ -699,7 +708,9 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat Storage storage = shuffleServer .getStorageManager() - .selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0])); + .selectStorage( + new ShuffleDataReadEvent( + appId, shuffleId, partitionId, range[0], req.getStorageId())); if (storage != null) { storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId)); } @@ -719,7 +730,8 @@ public void handleGetLocalShuffleData(TransportClient client, GetLocalShuffleDat partitionNum, storageType, offset, - length); + length, + storageId); ShuffleServerMetrics.counterTotalReadDataSize.inc(sdr.getDataLength()); ShuffleServerMetrics.counterTotalReadLocalDataFileSize.inc(sdr.getDataLength()); ShuffleServerMetrics.gaugeReadLocalDataFileThreadNum.inc(); diff --git a/server/src/main/java/org/apache/uniffle/server/storage/HybridStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/HybridStorageManager.java index 05b83e558f..29693ad18c 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/HybridStorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/HybridStorageManager.java @@ -46,7 +46,7 @@ public class HybridStorageManager implements StorageManager { private final StorageManagerSelector storageManagerSelector; HybridStorageManager(ShuffleServerConf conf) { - warmStorageManager = new LocalStorageManager(conf); + warmStorageManager = LocalStorageManagerFactory.get(conf); coldStorageManager = new HadoopStorageManager(conf); try { @@ -115,6 +115,11 @@ public Storage selectStorage(ShuffleDataReadEvent event) { return warmStorageManager.selectStorage(event); } + @Override + public Storage selectStorageById(ShuffleDataReadEvent event) { + return warmStorageManager.selectStorageById(event); + } + @Override public void updateWriteMetrics(ShuffleDataFlushEvent event, long writeTime) { throw new UnsupportedOperationException(); diff --git a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java index c33c17f3bf..ed1d30f9e5 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java @@ -94,7 +94,7 @@ public class LocalStorageManager extends SingleStorageManager { private boolean isStorageAuditLogEnabled; @VisibleForTesting - LocalStorageManager(ShuffleServerConf conf) { + public LocalStorageManager(ShuffleServerConf conf) { super(conf); storageBasePaths = RssUtils.getConfiguredLocalDirs(conf); if (CollectionUtils.isEmpty(storageBasePaths)) { @@ -136,6 +136,7 @@ public class LocalStorageManager extends SingleStorageManager { .ratio(ratio) .lowWaterMarkOfWrite(lowWaterMarkOfWrite) .highWaterMarkOfWrite(highWaterMarkOfWrite) + .setId(idx) .localStorageMedia(storageType); if (isDiskCapacityWatermarkCheckEnabled) { builder.enableDiskCapacityWatermarkCheck(); diff --git a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManagerFactory.java b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManagerFactory.java new file mode 100644 index 0000000000..acdcd4377d --- /dev/null +++ b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManagerFactory.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.server.storage; + +import org.apache.commons.lang3.StringUtils; + +import org.apache.uniffle.common.util.RssUtils; +import org.apache.uniffle.server.ShuffleServerConf; + +public class LocalStorageManagerFactory { + public static LocalStorageManager get(ShuffleServerConf conf) { + String className = conf.get(ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS); + if (StringUtils.isEmpty(className)) { + throw new IllegalStateException( + "Configuration error: " + + ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS.toString() + + " should not set to empty"); + } + + try { + return (LocalStorageManager) + RssUtils.getConstructor(className, ShuffleServerConf.class).newInstance(conf); + } catch (Exception e) { + throw new IllegalStateException( + "Configuration error: " + + ShuffleServerConf.SERVER_LOCAL_STORAGE_MANAGER_CLASS.toString() + + " is failed to create instance of " + + className, + e); + } + } +} diff --git a/server/src/main/java/org/apache/uniffle/server/storage/MultiPartLocalStorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/MultiPartLocalStorageManager.java new file mode 100644 index 0000000000..5bca9a9e1f --- /dev/null +++ b/server/src/main/java/org/apache/uniffle/server/storage/MultiPartLocalStorageManager.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.server.storage; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.stream.Collectors; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.server.ShuffleDataFlushEvent; +import org.apache.uniffle.server.ShuffleDataReadEvent; +import org.apache.uniffle.server.ShuffleServerConf; +import org.apache.uniffle.storage.common.CompositeReadingViewStorage; +import org.apache.uniffle.storage.common.LocalStorage; +import org.apache.uniffle.storage.common.Storage; +import org.apache.uniffle.storage.util.ShuffleStorageUtils; + +public class MultiPartLocalStorageManager extends LocalStorageManager { + private static final Logger LOG = LoggerFactory.getLogger(MultiPartLocalStorageManager.class); + // id -> storage + private final Map idToStorages; + + private final CompositeReadingViewStorage compositeStorage; + + public MultiPartLocalStorageManager(ShuffleServerConf conf) { + super(conf); + idToStorages = new ConcurrentSkipListMap<>(); + for (LocalStorage storage : getStorages()) { + idToStorages.put(storage.getId(), storage); + } + + compositeStorage = new CompositeReadingViewStorage(getStorages()); + } + + @Override + public Storage selectStorage(ShuffleDataFlushEvent event) { + if (getStorages().size() == 1) { + if (event.getUnderStorage() == null) { + event.setUnderStorage(getStorages().get(0)); + } + return getStorages().get(0); + } + String appId = event.getAppId(); + int shuffleId = event.getShuffleId(); + int partitionId = event.getStartPartition(); + + // TODO(baoloongmao): extend to support select storage by free space + // eventId is a non-negative long. + LocalStorage storage = getStorages().get((int) (event.getEventId() % getStorages().size())); + if (storage != null) { + if (storage.isCorrupted()) { + if (storage.containsWriteHandler(appId, shuffleId, partitionId)) { + LOG.error( + "LocalStorage: {} is corrupted. Switching another storage for event: {}, some data will be lost", + storage.getBasePath(), + event); + } + } else { + if (event.getUnderStorage() == null) { + event.setUnderStorage(storage); + } + return storage; + } + } + + // TODO(baoloongmao): update health storages and store it as member of this class. + List candidates = + getStorages().stream() + .filter(x -> x.canWrite() && !x.isCorrupted()) + .collect(Collectors.toList()); + + if (candidates.size() == 0) { + return null; + } + final LocalStorage selectedStorage = + candidates.get( + ShuffleStorageUtils.getStorageIndex(candidates.size(), appId, shuffleId, partitionId)); + if (storage == null || storage.isCorrupted() || event.getUnderStorage() == null) { + event.setUnderStorage(selectedStorage); + return selectedStorage; + } + return storage; + } + + @Override + public Storage selectStorage(ShuffleDataReadEvent event) { + if (getStorages().size() == 1) { + return getStorages().get(0); + } + + int storageId = event.getStorageId(); + // TODO(baoloongmao): check AOOB exception + return idToStorages.get(storageId); + } + + @Override + public Storage selectStorageById(ShuffleDataReadEvent event) { + return compositeStorage; + } +} diff --git a/server/src/main/java/org/apache/uniffle/server/storage/StorageManager.java b/server/src/main/java/org/apache/uniffle/server/storage/StorageManager.java index 70425a22d8..4d8ca68226 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/StorageManager.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/StorageManager.java @@ -36,6 +36,10 @@ public interface StorageManager { Storage selectStorage(ShuffleDataReadEvent event); + default Storage selectStorageById(ShuffleDataReadEvent event) { + return selectStorage(event); + } + boolean write(Storage storage, ShuffleWriteHandler handler, ShuffleDataFlushEvent event); void updateWriteMetrics(ShuffleDataFlushEvent event, long writeTime); diff --git a/server/src/main/java/org/apache/uniffle/server/storage/StorageManagerFactory.java b/server/src/main/java/org/apache/uniffle/server/storage/StorageManagerFactory.java index f21d6f6fe2..e99434b0db 100644 --- a/server/src/main/java/org/apache/uniffle/server/storage/StorageManagerFactory.java +++ b/server/src/main/java/org/apache/uniffle/server/storage/StorageManagerFactory.java @@ -33,7 +33,7 @@ public static StorageManagerFactory getInstance() { public StorageManager createStorageManager(ShuffleServerConf conf) { StorageType type = StorageType.valueOf(conf.get(ShuffleServerConf.RSS_STORAGE_TYPE).name()); if (StorageType.LOCALFILE.equals(type) || StorageType.MEMORY_LOCALFILE.equals(type)) { - return new LocalStorageManager(conf); + return LocalStorageManagerFactory.get(conf); } else if (StorageType.HDFS.equals(type) || StorageType.MEMORY_HDFS.equals(type)) { return new HadoopStorageManager(conf); } else if (StorageType.LOCALFILE_HDFS.equals(type) diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/CompositeReadingViewStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/CompositeReadingViewStorage.java new file mode 100644 index 0000000000..3d32e7922c --- /dev/null +++ b/storage/src/main/java/org/apache/uniffle/storage/common/CompositeReadingViewStorage.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.storage.common; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.exception.FileNotFoundException; +import org.apache.uniffle.storage.handler.api.ServerReadHandler; +import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler; +import org.apache.uniffle.storage.handler.impl.CompositeLocalFileServerReadHandler; +import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest; +import org.apache.uniffle.storage.request.CreateShuffleWriteHandlerRequest; + +public class CompositeReadingViewStorage extends AbstractStorage { + + private static final Logger LOG = LoggerFactory.getLogger(CompositeReadingViewStorage.class); + private final List localStorages; + + public CompositeReadingViewStorage(List localStorages) { + super(); + this.localStorages = localStorages; + } + + @Override + ShuffleWriteHandler newWriteHandler(CreateShuffleWriteHandlerRequest request) { + return null; + } + + public ServerReadHandler getOrCreateReadHandler(CreateShuffleReadHandlerRequest request) { + // Do not cache it since this class is just a wrapper + return newReadHandler(request); + } + + @Override + protected ServerReadHandler newReadHandler(CreateShuffleReadHandlerRequest request) { + List handlers = new ArrayList<>(); + for (LocalStorage storage : localStorages) { + try { + handlers.add(storage.getOrCreateReadHandler(request)); + } catch (FileNotFoundException e) { + // ignore it + } catch (Exception e) { + LOG.error("Failed to create read handler for storage: " + storage, e); + } + } + return new CompositeLocalFileServerReadHandler(handlers); + } + + @Override + public boolean canWrite() { + return false; + } + + @Override + public void updateWriteMetrics(StorageWriteMetrics metrics) {} + + @Override + public void updateReadMetrics(StorageReadMetrics metrics) {} + + @Override + public void createMetadataIfNotExist(String shuffleKey) {} + + @Override + public String getStoragePath() { + return null; + } + + @Override + public String getStorageHost() { + return null; + } +} diff --git a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java index e748608ded..a87fd1d87f 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java +++ b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java @@ -46,6 +46,7 @@ public class LocalStorage extends AbstractStorage { public static final String STORAGE_HOST = "local"; private final long diskCapacity; + private final int id; private volatile long diskAvailableBytes; private volatile long serviceUsedBytes; // for test cases @@ -68,6 +69,7 @@ private LocalStorage(Builder builder) { this.capacity = builder.capacity; this.media = builder.media; this.enableDiskCapacityCheck = builder.enableDiskCapacityWatermarkCheck; + this.id = builder.id; File baseFolder = new File(basePath); try { @@ -149,7 +151,8 @@ protected ServerReadHandler newReadHandler(CreateShuffleReadHandlerRequest reque request.getPartitionId(), request.getPartitionNumPerRange(), request.getPartitionNum(), - basePath); + basePath, + id); } // only for tests. @@ -282,6 +285,10 @@ public void markSpaceFull() { isSpaceEnough = false; } + public int getId() { + return id; + } + public static class Builder { private long capacity; private double ratio; @@ -290,6 +297,7 @@ public static class Builder { private String basePath; private StorageMedia media; private boolean enableDiskCapacityWatermarkCheck; + private int id; private Builder() {} @@ -328,6 +336,11 @@ public Builder enableDiskCapacityWatermarkCheck() { return this; } + public Builder setId(int id) { + this.id = id; + return this; + } + public LocalStorage build() { return new LocalStorage(this); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/api/ServerReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/api/ServerReadHandler.java index b16a4d327e..267cd30305 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/api/ServerReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/api/ServerReadHandler.java @@ -25,4 +25,6 @@ public interface ServerReadHandler { ShuffleDataResult getShuffleData(long offset, int length); ShuffleIndexResult getShuffleIndex(); + + int getStorageId(); } diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/CompositeLocalFileServerReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/CompositeLocalFileServerReadHandler.java new file mode 100644 index 0000000000..deb9953b10 --- /dev/null +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/CompositeLocalFileServerReadHandler.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.uniffle.storage.handler.impl; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.uniffle.common.ShuffleDataResult; +import org.apache.uniffle.common.ShuffleIndexResult; +import org.apache.uniffle.common.netty.buffer.ManagedBuffer; +import org.apache.uniffle.common.netty.buffer.MultiFileSegmentManagedBuffer; +import org.apache.uniffle.storage.handler.api.ServerReadHandler; + +public class CompositeLocalFileServerReadHandler implements ServerReadHandler { + + private static final Logger LOG = + LoggerFactory.getLogger(CompositeLocalFileServerReadHandler.class); + private final List handlers; + + public CompositeLocalFileServerReadHandler(List handlers) { + this.handlers = handlers; + } + + @Override + public ShuffleDataResult getShuffleData(long offset, int length) { + return null; + } + + @Override + public ShuffleIndexResult getShuffleIndex() { + if (handlers.size() == 0) { + // caller should handle the null return + return null; + } + int[] storageIds = new int[handlers.size()]; + List managedBuffers = new ArrayList<>(handlers.size()); + String dataFileName = ""; + long length = 0; + for (int i = 0; i < handlers.size(); i++) { + ServerReadHandler handler = handlers.get(i); + storageIds[i] = handler.getStorageId(); + ShuffleIndexResult result = handler.getShuffleIndex(); + length += result.getDataFileLen(); + managedBuffers.add(result.getManagedBuffer()); + if (i == 0) { + // Use the first data file name as the data file name of the combined result. + // TODO: This cannot work for remote merge feature. + dataFileName = result.getDataFileName(); + } + } + MultiFileSegmentManagedBuffer mergedManagedBuffer = + new MultiFileSegmentManagedBuffer(managedBuffers); + return new ShuffleIndexResult(mergedManagedBuffer, length, dataFileName, storageIds); + } + + @Override + public int getStorageId() { + return 0; + } +} diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java index 56c68d0623..4772dcf21c 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java @@ -157,6 +157,7 @@ public ShuffleDataResult readShuffleData(ShuffleDataSegment shuffleDataSegment) partitionNum, shuffleDataSegment.getOffset(), expectedLength, + shuffleDataSegment.getStorageId(), retryMax, retryIntervalMax); try { diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java index 4eb7e2d523..f9da2a6e89 100644 --- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java +++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileServerReadHandler.java @@ -36,6 +36,7 @@ public class LocalFileServerReadHandler implements ServerReadHandler { private static final Logger LOG = LoggerFactory.getLogger(LocalFileServerReadHandler.class); + private final int storageId; private String indexFileName = ""; private String dataFileName = ""; private String appId; @@ -48,13 +49,25 @@ public LocalFileServerReadHandler( int partitionId, int partitionNumPerRange, int partitionNum, - String path) { + String path, + int storageId) { this.appId = appId; this.shuffleId = shuffleId; this.partitionId = partitionId; + this.storageId = storageId; init(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, path); } + public LocalFileServerReadHandler( + String appId, + int shuffleId, + int partitionId, + int partitionNumPerRange, + int partitionNum, + String path) { + this(appId, shuffleId, partitionId, partitionNumPerRange, partitionNum, path, 0); + } + private void init( String appId, int shuffleId, @@ -150,7 +163,7 @@ public ShuffleIndexResult getShuffleIndex() { // get dataFileSize for read segment generation in DataSkippableReadHandler#readShuffleData long dataFileSize = new File(dataFileName).length(); return new ShuffleIndexResult( - new FileSegmentManagedBuffer(indexFile, 0, len), dataFileSize, dataFileName); + new FileSegmentManagedBuffer(indexFile, 0, len), dataFileSize, dataFileName, storageId); } public String getDataFileName() { @@ -160,4 +173,9 @@ public String getDataFileName() { public String getIndexFileName() { return indexFileName; } + + @Override + public int getStorageId() { + return storageId; + } }