Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[METAL] Fix issue with GPU fails #7819

Merged
merged 6 commits into from
Apr 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,49 @@ def max_thread_dimensions(self):
"""
return json.loads(self._GetDeviceAttr(self.device_type, self.device_id, 8))

def sync(self):
"""Synchronize until jobs finished at the context."""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, None))
def create_raw_stream(self):
"""Create a new runtime stream at the context.

User should free the stream after use.

Returns
-------
stream : TVMStreamHandle
The created runtime stream.
"""
stream = ctypes.c_void_p()
check_call(_LIB.TVMStreamCreate(self.device_type, self.device_id, ctypes.byref(stream)))
return stream

def free_raw_stream(self, stream):
"""Free a created stream handle.

Parameters
----------
stream : TVMStreamHandle
The stream which should to be released.
"""
check_call(_LIB.TVMStreamFree(self.device_type, self.device_id, stream))

def set_raw_stream(self, stream):
"""Set a created stream handle.

Parameters
----------
stream : TVMStreamHandle
The stream which should to be set to the device.
"""
check_call(_LIB.TVMSetStream(self.device_type, self.device_id, stream))

def sync(self, stream=None):
"""Synchronize until jobs finished at the context.

Parameters
----------
stream : TVMStreamHandle
Jobs in this stream should be finished.
"""
check_call(_LIB.TVMSynchronize(self.device_type, self.device_id, stream))

def __eq__(self, other):
return (
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,8 @@ def _timed_rpc_run(

if error_no == 0:
try:
stream = dev.create_raw_stream()
dev.set_raw_stream(stream)
random_fill = remote.get_function("tvm.contrib.random.random_fill")
assert (
random_fill
Expand Down Expand Up @@ -1108,14 +1110,21 @@ def _timed_rpc_run(
"task_inputs not fully matched, check if there's any unexpected error"
)
dev.sync()

# First run for check that the kernel is correct
func.entry_func(*args)
dev.sync()

costs = time_f(*args).results

# clean up remote files
remote.remove(build_res.filename)
remote.remove(os.path.splitext(build_res.filename)[0] + ".so")
remote.remove("")
dev.free_raw_stream(stream)
# pylint: disable=broad-except
except Exception:
dev.free_raw_stream(stream)
costs = (MAX_FLOAT,)
error_no = MeasureErrorNo.RUNTIME_DEVICE
error_msg = make_traceback_info()
Expand Down
10 changes: 2 additions & 8 deletions src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,11 @@ void DeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, s

void DeviceAPI::FreeWorkspace(Device dev, void* ptr) { FreeDataSpace(dev, ptr); }

TVMStreamHandle DeviceAPI::CreateStream(Device dev) {
LOG(FATAL) << "Device does not support stream api.";
return nullptr;
}
TVMStreamHandle DeviceAPI::CreateStream(Device dev) { return nullptr; }

void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {
LOG(FATAL) << "Device does not support stream api.";
}
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}

void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}

//--------------------------------------------------------
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/crt/common/crt_runtime_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ int TVMDeviceCopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream
return 0;
}

int TVMStreamCreate(int device_type, int device_id, TVMStreamHandle* out) {
out = NULL;
return 0;
}

int TVMStreamFree(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

int TVMSetStream(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

int TVMSynchronize(int device_type, int device_id, TVMStreamHandle stream) { return 0; }

static TVMMutableFuncRegistry global_func_registry;
Expand Down
45 changes: 36 additions & 9 deletions src/runtime/metal/metal_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,39 @@
namespace tvm {
namespace runtime {
namespace metal {
/*!
* \brief Structure for error handling in queues
*/
class Stream {
public:
explicit Stream(id<MTLDevice> device) : error_happened_(false) {
queue_ = [device newCommandQueue];
}
~Stream() { [queue_ release]; }
id<MTLCommandBuffer> GetCommandBuffer() {
id<MTLCommandBuffer> cb = [queue_ commandBuffer];
[cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
if (buffer.status == MTLCommandBufferStatusError) SetErrorStatus();
}];
return cb;
}
bool HasErrorHappened() { return error_happened_; }

private:
void SetErrorStatus() { error_happened_ = true; }
// Queue
id<MTLCommandQueue> queue_;
// Check if error happened in one previous run
bool error_happened_;
};

/*!
* \brief Process global Metal workspace.
*/
class MetalWorkspace final : public DeviceAPI {
public:
// the devices
std::vector<id<MTLDevice> > devices;
// the queues
std::vector<id<MTLCommandQueue> > queues;
// Warp size constant
std::vector<int> warp_size;
// Whether it is initialized.
Expand All @@ -62,13 +86,6 @@ class MetalWorkspace final : public DeviceAPI {
std::mutex mutex;
// Destructor
~MetalWorkspace();
// Get command queue for given device.
id<MTLCommandQueue> GetCommandQueue(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
ICHECK(dev.device_id >= 0 && static_cast<size_t>(dev.device_id) < queues.size())
<< "Invalid Metal device_id=" << dev.device_id;
return queues[dev.device_id];
}
// Get device for given device
id<MTLDevice> GetDevice(Device dev) {
ICHECK_EQ(dev.device_type, kDLMetal);
Expand All @@ -84,23 +101,33 @@ class MetalWorkspace final : public DeviceAPI {
void GetAttr(Device dev, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final;
void FreeDataSpace(Device dev, void* ptr) final;
TVMStreamHandle CreateStream(Device dev) final;
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;

// get the global workspace
static MetalWorkspace* Global();

protected:
void CopyDataFromTo(const void* from, size_t from_size, void* to, size_t to_size, size_t size,
Device dev_from, Device dev_to, DLDataType type_hint,
TVMStreamHandle stream) final;

private:
// Pointers to default allocated streams
std::vector<Stream*> default_streams_;
};

/*! \brief Thread local workspace */
class MetalThreadEntry {
public:
/*! \brief The current device */
Device device;
/*! \brief The current stream */
std::vector<Stream*> stream;
/*! \brief The shared buffer used for copy. */
std::vector<id<MTLBuffer> > temp_buffer_;
/*! \brief workspace pool */
Expand Down
50 changes: 40 additions & 10 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ int GetWarpSize(id<MTLDevice> dev) {
for (auto x : devices) {
[x release];
}
for (auto x : queues) {
[x release];
for (auto x : default_streams_) {
delete x;
}
}

Expand All @@ -136,13 +136,17 @@ int GetWarpSize(id<MTLDevice> dev) {
// on iPhone
id<MTLDevice> d = MTLCreateSystemDefaultDevice();
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
#else
NSArray<id<MTLDevice> >* devs = MTLCopyAllDevices();
for (size_t i = 0; i < devs.count; ++i) {
id<MTLDevice> d = [devs objectAtIndex:i];
devices.push_back(d);
queues.push_back([d newCommandQueue]);
Stream* stream = new Stream(d);
MetalThreadEntry::ThreadLocal()->stream.push_back(stream);
default_streams_.push_back(stream);
LOG(INFO) << "Intializing Metal device " << i << ", name=" << [d.name UTF8String];
warp_size.push_back(GetWarpSize(d));
}
Expand Down Expand Up @@ -183,16 +187,25 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

Stream* GetStream(TVMStreamHandle stream, int device_id) {
if (stream != nullptr)
return static_cast<Stream*>(stream);
else
return MetalThreadEntry::ThreadLocal()->stream[device_id];
}

void MetalWorkspace::CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, Device dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle stream) {
@autoreleasepool {
this->Init();
ICHECK(stream == nullptr);
Device dev = dev_from;
Stream* s = GetStream(stream, dev.device_id);
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned! Cannot copy data to current stream";
}
if (dev_from.device_type == kDLCPU) dev = dev_to;
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
int from_dev_type = static_cast<int>(dev_from.device_type);
int to_dev_type = static_cast<int>(dev_to.device_type);

Expand Down Expand Up @@ -249,17 +262,34 @@ int GetWarpSize(id<MTLDevice> dev) {
}
}

TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
Stream* stream = new Stream(devices[dev.device_id]);
return static_cast<TVMStreamHandle>(stream);
}

void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
ICHECK(stream != nullptr);
Stream* s = static_cast<Stream*>(stream);
delete s;
}

void MetalWorkspace::StreamSync(Device dev, TVMStreamHandle stream) {
@autoreleasepool {
ICHECK(stream == nullptr);
Stream* s = GetStream(stream, dev.device_id);
// commit an empty command buffer and wait until it completes.
id<MTLCommandQueue> queue = GetCommandQueue(dev);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = s->GetCommandBuffer();
[cb commit];
[cb waitUntilCompleted];
if (s->HasErrorHappened()) {
LOG(FATAL) << "Error! Some problems on GPU happaned!";
}
}
}

void MetalWorkspace::SetStream(Device dev, TVMStreamHandle stream) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = static_cast<Stream*>(stream);
}

void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
Expand Down
5 changes: 3 additions & 2 deletions src/runtime/metal/metal_module.mm
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,16 @@ void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) cons
@autoreleasepool {
metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
int device_id = t->device.device_id;
auto stream = static_cast<metal::Stream*>(t->stream[device_id]);
if (stream->HasErrorHappened()) return;
if (scache_[device_id] == nil) {
scache_[device_id] = m_->GetPipelineState(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
int blockSize = wl.block_dim(0) * wl.block_dim(1) * wl.block_dim(2);
auto maxTotalThreadsPerThreadgroup = scache_[device_id].maxTotalThreadsPerThreadgroup;
CHECK_LE(blockSize, maxTotalThreadsPerThreadgroup);
id<MTLCommandQueue> queue = w_->GetCommandQueue(t->device);
id<MTLCommandBuffer> cb = [queue commandBuffer];
id<MTLCommandBuffer> cb = stream->GetCommandBuffer();
id<MTLComputeCommandEncoder> encoder = [cb computeCommandEncoder];
[encoder setComputePipelineState:scache_[device_id]];
for (size_t i = 0; i < num_buffer_args_; ++i) {
Expand Down
Loading