From aed7ea70dbc8ac6b1474bbae15bb5123c177fd9e Mon Sep 17 00:00:00 2001 From: Malachi Date: Wed, 25 Aug 2021 10:35:12 -0500 Subject: [PATCH] Add device_t::mallocHost (#90) --- src/core/platform.cpp | 8 ++++++++ src/core/platform.hpp | 3 +++ src/elliptic/linearSolver/PGMRES.cpp | 4 +--- src/linAlg/linAlg.cpp | 4 +--- 4 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/core/platform.cpp b/src/core/platform.cpp index 45ea16582..14a58c64e 100644 --- a/src/core/platform.cpp +++ b/src/core/platform.cpp @@ -181,6 +181,14 @@ device_t::buildKernel(const std::string &filename, return _kernel; } occa::memory +device_t::mallocHost(const dlong Nbytes) +{ + occa::properties props; + props["host"] = true; + occa::memory h_scratch = occa::device::malloc(Nbytes, props); + return h_scratch; +} +occa::memory device_t::malloc(const dlong Nbytes, const occa::properties& properties) { return occa::device::malloc(Nbytes, nullptr, properties); diff --git a/src/core/platform.hpp b/src/core/platform.hpp index 87f25184c..491fc53ea 100644 --- a/src/core/platform.hpp +++ b/src/core/platform.hpp @@ -58,6 +58,9 @@ class device_t : public occa::device{ occa::memory malloc(const dlong Nbytes, const occa::properties& properties); occa::memory malloc(const dlong Nwords, const dlong wordSize, occa::memory src); occa::memory malloc(const dlong Nwords, const dlong wordSize); + + occa::memory mallocHost(const dlong Nbytes); + int id() const { return _device_id; } private: dlong bufferSize; diff --git a/src/elliptic/linearSolver/PGMRES.cpp b/src/elliptic/linearSolver/PGMRES.cpp index bbfd868b3..cb27b9cb7 100644 --- a/src/elliptic/linearSolver/PGMRES.cpp +++ b/src/elliptic/linearSolver/PGMRES.cpp @@ -56,9 +56,7 @@ GmresData::GmresData(elliptic_t* elliptic) const dlong Nbytes = restart * Nblock * sizeof(dfloat); //pinned scratch buffer { - occa::properties props = platform->kernelInfo; - props["host"] = true; - h_scratch = platform->device.malloc(Nbytes, props); + h_scratch = platform->device.mallocHost(Nbytes); scratch = (dfloat*) h_scratch.ptr(); } o_scratch = platform->device.malloc(Nbytes); diff --git a/src/linAlg/linAlg.cpp b/src/linAlg/linAlg.cpp index f88c5381f..a8761ef8b 100644 --- a/src/linAlg/linAlg.cpp +++ b/src/linAlg/linAlg.cpp @@ -49,9 +49,7 @@ void linAlg_t::reallocScratch(const dlong Nbytes) if(o_scratch.size()) o_scratch.free(); //pinned scratch buffer { - occa::properties props = kernelInfo; - props["host"] = true; - h_scratch = device.malloc(Nbytes, props); + h_scratch = device.mallocHost(Nbytes); scratch = (dfloat*) h_scratch.ptr(); } o_scratch = device.malloc(Nbytes);