Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[METAL] Fix the rest memory leaks in Metal runtime #8175

Merged
merged 3 commits into from
Jun 4, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,34 @@

#include "../workspace_pool.h"

#define AUTORELEASEPOOL tvm::runtime::metal::AutoReleasePoolWrapper::GetInstance() << [&]()
echuraev marked this conversation as resolved.
Show resolved Hide resolved

namespace tvm {
namespace runtime {
namespace metal {
class AutoReleasePoolWrapper {
public:
static AutoReleasePoolWrapper& GetInstance();
template <typename T>
void operator<<(const T& f) {
std::exception_ptr eptr;
@autoreleasepool {
try {
f();
} catch (...) {
eptr = std::current_exception();
}
}
if (eptr) std::rethrow_exception(eptr);
}

private:
AutoReleasePoolWrapper() = default;
~AutoReleasePoolWrapper() = default;
AutoReleasePoolWrapper(const AutoReleasePoolWrapper&) = delete;
AutoReleasePoolWrapper& operator=(const AutoReleasePoolWrapper&) = delete;
};

/*!
* \brief Structure for error handling in queues
*/
Expand Down
40 changes: 22 additions & 18 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,20 @@
namespace runtime {
namespace metal {

AutoReleasePoolWrapper& AutoReleasePoolWrapper::GetInstance() {
static AutoReleasePoolWrapper instance;
return instance;
}

MetalWorkspace* MetalWorkspace::Global() {
@autoreleasepool {
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}
// NOTE: explicitly use new to avoid exit-time destruction of global state
// Global state will be recycled by OS as the process exits.
static MetalWorkspace* inst = new MetalWorkspace();
return inst;
}

void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) {
@autoreleasepool {
AUTORELEASEPOOL {
this->Init();
size_t index = static_cast<size_t>(dev.device_id);
if (kind == kExist) {
Expand Down Expand Up @@ -80,7 +83,7 @@
case kDriverVersion:
return;
}
}
};
}

static const char* kDummyKernel = R"A0B0(
Expand Down Expand Up @@ -161,7 +164,8 @@ int GetWarpSize(id<MTLDevice> dev) {

void* MetalWorkspace::AllocDataSpace(Device device, size_t nbytes, size_t alignment,
DLDataType type_hint) {
@autoreleasepool {
id<MTLBuffer> buf;
AUTORELEASEPOOL {
this->Init();
id<MTLDevice> dev = GetDevice(device);
// GPU memory only
Expand All @@ -173,20 +177,20 @@ int GetWarpSize(id<MTLDevice> dev) {
storage_mode = MTLResourceStorageModeManaged;
#endif
*/
id<MTLBuffer> buf = [dev newBufferWithLength:nbytes options:storage_mode];
buf = [dev newBufferWithLength:nbytes options:storage_mode];
ICHECK(buf != nil);
return (void*)(buf);
}
};
return (void*)(buf);
}

void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
@autoreleasepool {
AUTORELEASEPOOL {
// MTLBuffer PurgeableState should be set to empty before manual
// release in order to prevent memory leak
[(id<MTLBuffer>)ptr setPurgeableState:MTLPurgeableStateEmpty];
// release the ptr.
CFRelease(ptr);
}
};
}

Stream* GetStream(TVMStreamHandle stream, int device_id) {
Expand All @@ -199,7 +203,7 @@ int GetWarpSize(id<MTLDevice> dev) {
void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
@autoreleasepool {
AUTORELEASEPOOL {
this->Init();
Device dev = dev_from;
Stream* s = GetStream(stream, dev.device_id);
Expand Down Expand Up @@ -261,7 +265,7 @@ int GetWarpSize(id<MTLDevice> dev) {
LOG(FATAL) << "Expect copy from/to Metal or between Metal"
<< ", from=" << from_dev_type << ", to=" << to_dev_type;
}
}
};
}

TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
Expand All @@ -276,7 +280,7 @@ int GetWarpSize(id<MTLDevice> dev) {
}

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
@autoreleasepool {
AUTORELEASEPOOL {
Stream* s = GetStream(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
Expand All @@ -285,7 +289,7 @@ int GetWarpSize(id<MTLDevice> dev) {
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned!";
}
}
};
}

void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
Expand Down
26 changes: 16 additions & 10 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ void Init(MetalModuleNode* m, ObjectPtr<Object> sptr, const std::string& func_na
}
// invoke the function with void arguments
void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const {
@autoreleasepool {
AUTORELEASEPOOL {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
Expand Down Expand Up @@ -223,7 +223,7 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
[encoder dispatchThreadgroups:dimGrid threadsPerThreadgroup:dimBlock];
[encoder endEncoding];
[cb commit];
}
};
}

private:
Expand All @@ -248,27 +248,33 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons

PackedFunc MetalModuleNode::GetFunction(const std::string& name,
const ObjectPtr<Object>& sptr_to_self) {
@autoreleasepool {
PackedFunc pf;
AUTORELEASEPOOL {
ICHECK_EQ(sptr_to_self.get(), this);
ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main";
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
if (it == fmap_.end()) {
pf = PackedFunc();
return;
}
const FunctionInfo& info = it->second;
MetalWrappedFunc f;
size_t num_buffer_args = NumBufferArgs(info.arg_types);
f.Init(this, sptr_to_self, name, num_buffer_args, info.arg_types.size() - num_buffer_args,
info.thread_axis_tags);
return PackFuncNonBufferArg(f, info.arg_types);
}
pf = PackFuncNonBufferArg(f, info.arg_types);
};
return pf;
}

Module MetalModuleCreate(std::string data, std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap, std::string source) {
@autoreleasepool {
ObjectPtr<Object> n;
AUTORELEASEPOOL {
metal::MetalWorkspace::Global()->Init();
auto n = make_object<MetalModuleNode>(data, fmt, fmap, source);
return Module(n);
}
n = make_object<MetalModuleNode>(data, fmt, fmap, source);
};
return Module(n);
}

// Load module from module.
Expand Down