diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMatrixMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMatrixMemoryManager.java index 47a8391c346..f69d3405720 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMatrixMemoryManager.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMatrixMemoryManager.java @@ -18,7 +18,9 @@ */ package org.apache.sysml.runtime.instructions.gpu.context; +import java.util.Collections; import java.util.HashSet; +import java.util.List; import java.util.Set; import java.util.stream.Collectors; @@ -43,8 +45,7 @@ public GPUMatrixMemoryManager(GPUMemoryManager gpuManager) { void addGPUObject(GPUObject gpuObj) { gpuObjects.add(gpuObj); } - - + /** * Get list of all Pointers in a GPUObject * @param gObj gpu object @@ -81,6 +82,20 @@ else if(sparsePtr.val != null) * so that an extraneous host to dev transfer can be avoided */ HashSet gpuObjects = new HashSet<>(); + + /** + * Return a set of GPU Objects associated with a list of pointers + * @param pointers A list of pointers + * @return A set of GPU objects corresponding to any of these pointers + */ + Set getGpuObjects(Set pointers) { + Set gObjs = new HashSet<>(); + for (GPUObject g : gpuObjects) { + if (!Collections.disjoint(getPointers(g), pointers)) + gObjs.add(g); + } + return gObjs; + } /** * Return all pointers in the first section @@ -94,10 +109,14 @@ Set getPointers() { * Get pointers from the first memory sections "Matrix Memory" * @param locked return locked pointers if true * @param dirty return dirty pointers if true + * @param isCleanupEnabled return pointers marked for cleanup if true * @return set of pointers */ - Set getPointers(boolean locked, boolean dirty) { - return gpuObjects.stream().filter(gObj -> gObj.isLocked() == locked && gObj.isDirty() == dirty).flatMap(gObj -> getPointers(gObj).stream()).collect(Collectors.toSet()); + Set getPointers(boolean locked, boolean dirty, boolean isCleanupEnabled) { + return gpuObjects.stream().filter( + gObj -> (gObj.isLocked() == locked && gObj.isDirty() == dirty) || + (gObj.mat.isCleanupEnabled() == isCleanupEnabled)).flatMap( + gObj -> getPointers(gObj).stream()).collect(Collectors.toSet()); } /** diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java index 6a04d97c6a6..d403ca78515 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUMemoryManager.java @@ -429,7 +429,6 @@ void guardedCudaFree(Pointer toFree) { else { throw new RuntimeException("ERROR : Internal state corrupted, attempting to free an unaccounted pointer:" + toFree); } - } /** @@ -521,11 +520,19 @@ private Set nonIn(Set superset, Set subset) { */ public void clearTemporaryMemory() { // To record the cuda block sizes needed by allocatedGPUObjects, others are cleared up. - Set unlockedDirtyPointers = matrixMemoryManager.getPointers(false, true); + Set unlockedDirtyPointers = matrixMemoryManager.getPointers(false, true, false); Set temporaryPointers = nonIn(allPointers.keySet(), unlockedDirtyPointers); - for(Pointer tmpPtr : temporaryPointers) { + for (Pointer tmpPtr : temporaryPointers) { guardedCudaFree(tmpPtr); } + + // Also set the pointer(s) to null in the corresponding GPU objects to avoid double freeing pointers + Set gObjs = matrixMemoryManager.getGpuObjects(temporaryPointers); + for (GPUObject g : gObjs) { + g.jcudaDenseMatrixPtr = null; + g.jcudaSparseMatrixPtr = null; + removeGPUObject(g); + } } /** diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java index b95d47144dc..edcd683a4a4 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/context/GPUObject.java @@ -68,7 +68,7 @@ public class GPUObject { /** * Pointer to the underlying sparse matrix block on GPU */ - private CSRPointer jcudaSparseMatrixPtr = null; + CSRPointer jcudaSparseMatrixPtr = null; /** * whether the block attached to this {@link GPUContext} is dirty on the device and needs to be copied back to host