diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 2eed979387f3..c55cb01b4688 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -909,6 +909,9 @@ class NDArray { // aux_handles always reflect the correct number of aux data for (size_t i = 0; i < aux_shapes.size(); i++) { CheckAndAllocAuxData(i, aux_shapes[i]); + // this line is needed in case when aux_shapes[i].Size() = 0 + // aux_handles[i] will not be updated and take only default value. + aux_handles[i].ctx = ctx; } if (!delay_alloc) { CheckAndAllocData(storage_shape, dtype); @@ -983,8 +986,8 @@ class NDArray { #endif delay_alloc = false; } else if (shandle.size < dbytes) { - // free storage - Storage::Get()->Free(shandle); + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, shandle.ctx); #if MXNET_USE_MKLDNN == 1 @@ -1049,14 +1052,12 @@ class NDArray { << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; if (aux_handles.size() <= i) { aux_handles.resize(i + 1); - // set context for the newly created aux handle - aux_handles[i].ctx = ctx; } size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); if (aux_handles[i].size < aux_bytes) { - // free storage - Storage::Get()->Free(aux_handles[i]); - // init storage + // free storage if necessary and alloc again + if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); + // init aux storage aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); } // init shape diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 377bef072b03..367712755483 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -121,9 +121,9 @@ NDArray::Chunk::~Chunk() { CHECK_EQ(mem.mem->GetDataHandle(), mem.h.dptr); } #endif - Storage::Get()->Free(mem.h); + if (mem.h.size > 0) Storage::Get()->Free(mem.h); for (const auto& aux : mem.aux_h) { - Storage::Get()->Free(aux); + if (aux.size > 0) Storage::Get()->Free(aux); } } }, shandle.ctx, var); @@ -134,8 +134,8 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) { << "data is expected to be allocated after aux_data"; auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); if (shandle.size < dbytes) { - // free storage - Storage::Get()->Free(shandle); + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); // init storage shandle = Storage::Get()->Alloc(dbytes, ctx); #if MXNET_USE_MKLDNN == 1 diff --git a/src/storage/pooled_storage_manager.h b/src/storage/pooled_storage_manager.h index 7c4b070afdd2..c407a9f00cb6 100644 --- a/src/storage/pooled_storage_manager.h +++ b/src/storage/pooled_storage_manager.h @@ -155,10 +155,6 @@ void GPUPooledStorageManager::Alloc(Storage::Handle* handle) { } void GPUPooledStorageManager::Free(Storage::Handle handle) { - // Do nothing if dptr is nullptr. Otherwise, nullptr may be reused - // which can cause illegal memory access error. - if (handle.dptr == nullptr) return; - std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); size_t size = RoundAllocSize(handle.size); auto&& reuse_pool = memory_pool_[size]; @@ -316,10 +312,6 @@ void GPUPooledRoundedStorageManager::Alloc(Storage::Handle* handle) { } void GPUPooledRoundedStorageManager::Free(Storage::Handle handle) { - // Do nothing if dptr is nullptr. Otherwise, nullptr may be reused - // which can cause illegal memory access error. - if (handle.dptr == nullptr) return; - std::lock_guard lock(Storage::Get()->GetMutex(Context::kGPU)); int bucket = get_bucket(handle.size); auto&& reuse_pool = memory_pool_[bucket];