Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error messages for memory verifier and gpu memory verifier #6281

Merged
merged 2 commits into from
Aug 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 85 additions & 27 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ namespace tir {

class GPUCodeVerifier : public StmtExprVisitor {
public:
bool Verify(Stmt stmt, int64_t max_local_memory_per_block, int64_t max_shared_memory_per_block,
int64_t max_threads_per_block, int64_t max_thread_x, int64_t max_thread_y,
int64_t max_thread_z, int64_t max_vthread, int64_t max_vector_bytes) {
std::vector<String> Verify(Stmt stmt, int64_t max_local_memory_per_block,
int64_t max_shared_memory_per_block, int64_t max_threads_per_block,
int64_t max_thread_x, int64_t max_thread_y, int64_t max_thread_z,
int64_t max_vthread, int64_t max_vector_bytes) {
max_local_memory_per_block_ = static_cast<size_t>(max_local_memory_per_block);
max_shared_memory_per_block_ = static_cast<size_t>(max_shared_memory_per_block);
max_threads_per_block_ = static_cast<size_t>(max_threads_per_block);
Expand All @@ -52,7 +53,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
// TODO(jcf94): Add support of detecting CUDA Misaligned Address error
this->VisitStmt(stmt);

return valid_;
return errors_;
}

void VisitStmt_(const AllocateNode* op) final {
Expand All @@ -66,7 +67,13 @@ class GPUCodeVerifier : public StmtExprVisitor {
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
}

Expand Down Expand Up @@ -98,27 +105,39 @@ class GPUCodeVerifier : public StmtExprVisitor {
visited_threads_.insert(name);
thread_per_block_ *= length;

auto err = [this](std::string id, size_t ext, size_t m) {
if (ext > m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") is greater than maximum allowed (" << m
<< ");";
errors_.push_back(s.str());
}
};

if (name == "threadIdx.x") {
valid_ &= length <= max_thread_x_;
err("threadIdx.x", length, max_thread_x_);
thread_x_extent_ = length;
} else if (name == "threadIdx.y") {
valid_ &= length <= max_thread_y_;
err("threadIdx.y", length, max_thread_y_);
thread_y_extent_ = length;
} else if (name == "threadIdx.z") {
valid_ &= length <= max_thread_z_;
err("threadIdx.z", length, max_thread_z_);
thread_z_extent_ = length;
} else if (name == "vthread") {
valid_ &= length <= max_vthread_;
err("vthread", length, max_vthread_);
}
} else {
// the thread should be bound to axes with the same length
if (name == "threadIdx.x") {
valid_ &= length == thread_x_extent_;
} else if (name == "threadIdx.y") {
valid_ &= length == thread_y_extent_;
} else if (name == "threadIdx.z") {
valid_ &= length == thread_z_extent_;
}
auto err = [this, name](std::string id, size_t ext, size_t m) {
if (name == id && ext != m) {
std::stringstream s;
s << "Extent of " << id << " (" << ext << ") does not match the bound " << m;
errors_.push_back(s.str());
}
};
err("threadIdx.x", length, thread_x_extent_);
err("threadIdx.y", length, thread_y_extent_);
err("threadIdx.z", length, thread_z_extent_);
}
}

Expand All @@ -128,10 +147,17 @@ class GPUCodeVerifier : public StmtExprVisitor {

if (nest_level_ == 0) {
// exit a kernel, check the validity
valid_ &= thread_per_block_ <= max_threads_per_block_;

valid_ &= local_memory_per_block_ <= max_local_memory_per_block_;
valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_;
auto err = [this](std::string id, size_t num, size_t m) {
if (num > m) {
std::stringstream s;
s << "Used " << id << " (" << num << ") is greater than the allowed maximum (" << m
<< ")";
errors_.push_back(s.str());
}
};
err("threads per block", thread_per_block_, max_threads_per_block_);
err("local memory per block", local_memory_per_block_, max_local_memory_per_block_);
err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_);
}
} else {
StmtVisitor::VisitStmt_(op);
Expand All @@ -143,23 +169,41 @@ class GPUCodeVerifier : public StmtExprVisitor {
const auto* extent = op->extent.as<IntImmNode>();
CHECK(extent);

valid_ &= static_cast<size_t>(extent->value) <= max_vthread_;
size_t num_vthread = static_cast<size_t>(extent->value);
if (num_vthread > max_vthread_) {
std::stringstream s;
s << "Number of vthreads (" << num_vthread << ") is greater than the allowed maximum ("
<< max_vthread_ << ")";
errors_.push_back(s.str());
}
}

StmtVisitor::VisitStmt_(op);
}

void VisitExpr_(const LoadNode* op) {
if (op->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) <= max_vector_bytes_;
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
<< op->dtype.bytes() << ") for dtype " << op->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
max_vector_bytes_;
if (static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->index->dtype.lanes() << ") times number of bytes ("
<< op->index->dtype.bytes() << ") for dtype " << op->index->dtype
<< " is greater than the maximum number of vector bytes (" << max_vector_bytes_ << ")";
errors_.push_back(s.str());
}
}
StmtVisitor::VisitStmt_(op);
}
Expand All @@ -183,7 +227,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t max_thread_x_, max_thread_y_, max_thread_z_, max_vthread_;
size_t max_vector_bytes_;

bool valid_{true};
std::vector<String> errors_;

void Reset_() {
visited_local_buffers_.clear();
Expand All @@ -196,7 +240,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}
};

bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
std::vector<String> VerifyGPUCode_(const PrimFunc& func, Map<String, PrimExpr> constraints) {
GPUCodeVerifier verifier;

int64_t max_local_memory_per_block = INT64_MAX;
Expand Down Expand Up @@ -236,6 +280,11 @@ bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
max_vthread, max_vector_bytes);
}

bool VerifyGPUCode(const PrimFunc& func, Map<String, PrimExpr> constraints) {
auto errs = VerifyGPUCode_(func, constraints);
return errs.size() == 0;
}

TVM_REGISTER_GLOBAL("tir.analysis.verify_gpu_code").set_body_typed(VerifyGPUCode);

namespace transform {
Expand All @@ -245,7 +294,16 @@ Pass VerifyGPUCode(Map<String, PrimExpr> constraints) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
CHECK(VerifyGPUCode(func, constraints)) << "RuntimeError: GPU constraint violated" << func;
auto errs = VerifyGPUCode_(func, constraints);
if (errs.size() != 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << std::endl;
}
LOG(FATAL) << "RuntimeError: GPU constraint(s) violated:\n"
<< s.str() << " In function\n"
<< func;
}
}
}
return mod;
Expand Down
56 changes: 30 additions & 26 deletions src/tir/analysis/verify_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,14 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
}

/// Verification result
bool Failed() const { return failure_; }
std::vector<String> Errors() const { return errs_; }

protected:
/// Visitor implementation
//@{
void VisitExpr(const PrimExpr& n) final {
if (Failed()) return;
StmtExprVisitor::VisitExpr(n);
}
void VisitExpr(const PrimExpr& n) final { StmtExprVisitor::VisitExpr(n); }

void VisitStmt(const Stmt& n) final {
if (Failed()) return;
StmtExprVisitor::VisitStmt(n);
}
void VisitStmt(const Stmt& n) final { StmtExprVisitor::VisitStmt(n); }

void VisitStmt_(const LetStmtNode* op) final {
// Book keep definitions
Expand Down Expand Up @@ -139,15 +133,18 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
if (!IsFromFunctionArgs(var.get())) return;

// The verification fails in this case.
SetFailure();
std::stringstream s;
s << "Variable `" << var
<< "` is directly accessed by host memory (it is not contained in a thread environment or in "
"the function arguments.";
errs_.push_back(s.str());
}

/// Status getter/setter
//@{
bool InThreadEnv() const { return in_thread_env_; }
void EnterThreadEnv() { in_thread_env_ = true; }
void ExitThreadEnv() { in_thread_env_ = false; }
void SetFailure() { failure_ = true; }
//@}

/// Check if a given DLDeviceType/TVMDeviceExtType value denotes GPU device.
Expand All @@ -162,7 +159,7 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
/// Status of visitor
//@{
bool in_thread_env_{false};
bool failure_{false}; ///< If the verification fails (i.e. has illegal access)
std::vector<String> errs_;
//@}
tir::PrimFunc func_{nullptr}; ///< Function to be verified.
int dev_type_{kDLCPU}; ///< Device type
Expand All @@ -171,38 +168,45 @@ class MemoryAccessVerifier final : protected StmtExprVisitor {
} // namespace

/// Interface of VerifyMemory pass
bool VerifyMemory(const PrimFunc& func) {
std::vector<String> VerifyMemory_(const PrimFunc& func) {
auto target = func->GetAttr<Target>(tvm::attr::kTarget);
CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute";

if (func->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->kind->device_type);
v.Run();
return !v.Failed();
return v.Errors();
} else {
return true;
return {};
}
}

bool VerifyMemory(const PrimFunc& func) { return VerifyMemory_(func).size() == 0; }

TVM_REGISTER_GLOBAL("tir.analysis.verify_memory").set_body_typed(VerifyMemory);

namespace transform {

Pass VerifyMemory() {
auto pass_func =
[=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
CHECK(VerifyMemory(func))
<< "RuntimeError: Direct host side access to device memory is detected."
<< " Did you forget to bind?\n"
<< func;
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
auto errs = VerifyMemory_(func);
if (errs.size() > 0) {
std::stringstream s;
for (auto& err : errs) {
s << " " << err << "\n";
}
LOG(FATAL) << "RuntimeError: Memory verification failed with the following errors:\n"
<< s.str() << " Did you forget to bind?\n"
<< func;
}
return mod;
};
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.VerifyMemory", {});
}

Expand Down