1717
1818package org .apache .spark .shuffle .sort
1919
20- import java .io .{ DataInputStream , File , FileInputStream , FileOutputStream }
20+ import java .io ._
2121
2222import org .mockito .{Mock , MockitoAnnotations }
2323import org .mockito .Answers .RETURNS_SMART_NULLS
@@ -26,7 +26,7 @@ import org.mockito.Mockito._
2626import org .mockito .invocation .InvocationOnMock
2727import org .scalatest .BeforeAndAfterEach
2828
29- import org .apache .spark .{SparkConf , SparkFunSuite }
29+ import org .apache .spark .{SparkConf , SparkFunSuite , TaskContext }
3030import org .apache .spark .shuffle .IndexShuffleBlockResolver
3131import org .apache .spark .storage ._
3232import org .apache .spark .util .Utils
@@ -36,6 +36,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
3636
3737 @ Mock (answer = RETURNS_SMART_NULLS ) private var blockManager : BlockManager = _
3838 @ Mock (answer = RETURNS_SMART_NULLS ) private var diskBlockManager : DiskBlockManager = _
39+ @ Mock (answer = RETURNS_SMART_NULLS ) private var taskContext : TaskContext = _
3940
4041 private var tempDir : File = _
4142 private val conf : SparkConf = new SparkConf (loadDefaults = false )
@@ -48,6 +49,8 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
4849 when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
4950 when(diskBlockManager.getFile(any[BlockId ])).thenAnswer(
5051 (invocation : InvocationOnMock ) => new File (tempDir, invocation.getArguments.head.toString))
52+
53+ TaskContext .setTaskContext(taskContext)
5154 }
5255
5356 override def afterEach (): Unit = {
@@ -155,4 +158,65 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa
155158 indexIn2.close()
156159 }
157160 }
161+
162+ test(" get data file should in different task attempts" ) {
163+ val resolver = new IndexShuffleBlockResolver (conf, blockManager)
164+ val shuffleId = 1
165+ val mapId = 2
166+ when(taskContext.attemptNumber()).thenReturn(0 , Seq (1 , 2 , 3 ): _* )
167+ assert(resolver.getDataFile(shuffleId, mapId).getName.endsWith(" 0.data" ))
168+ assert(resolver.getDataFile(shuffleId, mapId).getName.endsWith(" 1.data" ))
169+ assert(resolver.getDataFile(shuffleId, mapId).getName.endsWith(" 2.data" ))
170+ assert(resolver.getDataFile(shuffleId, mapId).getName.endsWith(" 3.data" ))
171+ }
172+
173+ test(" different task attempts should be able to choose different local dirs" ) {
174+ val localDirSuffixes = 1 to 4
175+ val dirs = localDirSuffixes.map(x => tempDir + " /test_local" + x).mkString(" ," )
176+ val confClone = conf.clone.set(" spark.local.dir" , dirs)
177+ val resolver = new IndexShuffleBlockResolver (confClone, blockManager)
178+ val dbm = new DiskBlockManager (confClone, true )
179+ when(blockManager.diskBlockManager).thenReturn(dbm)
180+ when(taskContext.attemptNumber()).thenReturn(0 , Seq (1 , 2 , 3 ): _* )
181+ val dataFiles = localDirSuffixes.map(_ => resolver.getDataFile(1 , 2 ))
182+ val usedLocalDirSuffixed =
183+ dataFiles.map(_.getAbsolutePath.split(" test_local" )(1 ).substring(0 , 1 ).toInt)
184+ assert(usedLocalDirSuffixed.diff(localDirSuffixes).isEmpty)
185+ }
186+
187+ test(" new task attempt should be able to success in another available local dir" ) {
188+ val localDirSuffixes = 1 to 2
189+ val dirs = localDirSuffixes.map { x => tempDir + " /test_local" + x }.mkString(" ," )
190+ val confClone = conf.clone.set(" spark.local.dir" , dirs)
191+ val resolver = new IndexShuffleBlockResolver (confClone, blockManager)
192+ val dbm = new DiskBlockManager (confClone, true )
193+ when(blockManager.diskBlockManager).thenReturn(dbm, dbm)
194+ val shuffleId = 1
195+ val mapId = 2
196+ val lengths = Array [Long ](10 , 0 , 20 )
197+ val dataTmp = File .createTempFile(" shuffle" , null , tempDir)
198+ val out = new FileOutputStream (dataTmp)
199+ Utils .tryWithSafeFinally {
200+ out.write(new Array [Byte ](30 ))
201+ } {
202+ out.close()
203+ }
204+ val idxName = s " shuffle_ ${shuffleId}_ ${mapId}_0.index "
205+ val localDirIdx = Utils .nonNegativeHash(idxName) % localDirSuffixes.length
206+
207+ val badDisk = dbm.localDirs(localDirIdx)
208+ badDisk.setWritable(false ) // just like a disk error occurs
209+
210+ // 1. index -> fail
211+ // 2. index -> data -> verify data
212+ when(taskContext.attemptNumber()).thenReturn(0 , Seq (1 , 1 , 1 ): _* )
213+ val e =
214+ intercept[IOException ](resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp))
215+ assert(e.getMessage.contains(badDisk.getAbsolutePath))
216+
217+ resolver.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
218+
219+ val dataFile = resolver.getDataFile(shuffleId, mapId)
220+ assert(dataFile.exists())
221+ }
158222}
0 commit comments