Skip to content

Commit

Permalink
proper device query through rocm api
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Nov 12, 2019
1 parent 03a29da commit fcb8bb5
Showing 1 changed file with 59 additions and 55 deletions.
114 changes: 59 additions & 55 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -22,23 +22,21 @@
* \file rocm_device_api.cc
* \brief GPU specific API
*/
#include <tvm/runtime/device_api.h>

#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <hip/hip_runtime_api.h>
#include <hsa/hsa.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include "rocm_common.h"

namespace tvm {
namespace runtime {

class ROCMDeviceAPI final : public DeviceAPI {
public:
void SetDevice(TVMContext ctx) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
}
void SetDevice(TVMContext ctx) final { ROCM_CALL(hipSetDevice(ctx.device_id)); }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final {
int value = 0;
switch (kind) {
Expand All @@ -54,35 +52,59 @@ class ROCMDeviceAPI final : public DeviceAPI {
break;
}
case kMaxThreadsPerBlock: {
value = 1024;
ROCM_CALL(
hipDeviceGetAttribute(&value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id));
break;
}
case kWarpSize: {
value = 64;
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: {
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeMaxSharedMemoryPerBlock,
ctx.device_id));
break;
}
case kMaxSharedMemoryPerBlock: return;
case kComputeVersion: {
hipDeviceProp_t prop;
ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id));
*rv = prop.gcnArch;
std::ostringstream os;
ROCM_CALL(
hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id));
os << value << ".";
ROCM_CALL(
hipDeviceGetAttribute(&value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id));
os << value;
*rv = os.str();
return;
}
case kDeviceName:
return;
case kMaxClockRate: {
ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, ctx.device_id));
break;
}
case kMultiProcessorCount: {
ROCM_CALL(
hipDeviceGetAttribute(&value, hipDeviceAttributeMultiprocessorCount, ctx.device_id));
break;
}
case kMaxThreadDimensions: {
int dims[3];
ROCM_CALL(hipDeviceGetAttribute(&dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(&dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id));
ROCM_CALL(hipDeviceGetAttribute(&dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id));

std::stringstream ss;
ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]";
*rv = ss.str();
return;
}
case kDeviceName: return;
case kMaxClockRate: return;
case kMultiProcessorCount: return;
case kMaxThreadDimensions: return;
}
*rv = value;
}
void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) final {
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final {
ROCM_CALL(hipSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
<< "ROCM space is aligned at 256 bytes";
void *ret;
CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
void* ret;
ROCM_CALL(hipMalloc(&ret, nbytes));
return ret;
}
Expand All @@ -92,14 +114,8 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCM_CALL(hipFree(ptr));
}

void CopyDataFromTo(const void* from,
size_t from_offset,
void* to,
size_t to_offset,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint,
TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
Expand All @@ -109,9 +125,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
} else {
hipMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, hip_stream);
hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, hip_stream);
}
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
Expand All @@ -130,8 +144,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
}

void SetStream(TVMContext ctx, TVMStreamHandle stream) final {
ROCMThreadEntry::ThreadLocal()
->stream = static_cast<hipStream_t>(stream);
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
}

void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final {
Expand All @@ -143,16 +156,12 @@ class ROCMDeviceAPI final : public DeviceAPI {
}

static const std::shared_ptr<ROCMDeviceAPI>& Global() {
static std::shared_ptr<ROCMDeviceAPI> inst =
std::make_shared<ROCMDeviceAPI>();
static std::shared_ptr<ROCMDeviceAPI> inst = std::make_shared<ROCMDeviceAPI>();
return inst;
}

private:
static void GPUCopy(const void* from,
void* to,
size_t size,
hipMemcpyKind kind,
static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind,
hipStream_t stream) {
if (stream != 0) {
ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
Expand All @@ -164,19 +173,14 @@ class ROCMDeviceAPI final : public DeviceAPI {

typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;

ROCMThreadEntry::ROCMThreadEntry()
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {}

ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
return ROCMThreadStore::Get();
}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); }

TVM_REGISTER_GLOBAL("device_api.rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});
TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = ROCMDeviceAPI::Global().get();
*rv = static_cast<void*>(ptr);
});

} // namespace runtime
} // namespace tvm

0 comments on commit fcb8bb5

Please sign in to comment.