Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODEGEN] Fix code generation bugs for C/CUDA & Improve VerifyGPUCode pass #6041

Merged
merged 3 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -629,12 +629,12 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
this->PrintExpr(op->args[0], os);
os << " == NULL)";
} else if (op->op.same_as(builtin::reinterpret())) {
// generate (*( TYPE *)(&(ARG)))
int ssa_scope = BeginScope();
std::string rhs = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "(*(";
this->PrintType(op->dtype, os);
os << " *)(&(";
this->PrintExpr(op->args[0], os);
os << ")))";
os << " *)(&(" << rhs << ")))";
EndScope(ssa_scope);
} else if (op->op.same_as(builtin::isnan())) {
os << "(";
this->PrintExpr(op->args[0], os);
Expand Down Expand Up @@ -720,14 +720,15 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
} else {
CHECK(is_one(op->predicate)) << "Predicated store is not supported";
arith::PVar<PrimExpr> base;

// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!
Do we need to add some UTs to cover these cases?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test case added

int vec_scope = BeginScope();

if (arith::ramp(base, 1, t.lanes()).Match(op->index)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value cannot
// be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();

// store elements seperately
std::string index = SSAGetID(PrintExpr(op->index), op->index.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
Expand All @@ -754,8 +755,8 @@ void CodeGenC::VisitStmt_(const StoreNode* op) {
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
EndScope(vec_scope);
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,18 @@ class GPUCodeVerifier : public StmtExprVisitor {
ExprVisitor::VisitExpr_(op);
}

void VisitStmt_(const StoreNode* op) {
// Currently not able to check out: If the index expression failed
// to be simplified to a RampNode
if (op->index->IsInstance<RampNode>()) {
if (op->index->dtype.lanes() > 1) {
valid_ &= static_cast<size_t>(op->index->dtype.lanes() * op->index->dtype.bytes()) <=
max_vector_bytes_;
}
}
StmtVisitor::VisitStmt_(op);
}

private:
int nest_level_{0};

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_target_codegen_c_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_reinterpret():
nn = 1024
n = tvm.runtime.convert(nn)
A = te.placeholder((n,), name='A', dtype="int32")
B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", A(*i)), name='B')
B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin("float32", "tir.reinterpret", 2 + A(*i)), name='B')
s = te.create_schedule(B.op)

def check_c():
Expand All @@ -114,7 +114,7 @@ def check_c():
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fadd(a, b)
tvm.testing.assert_allclose(
b.asnumpy(), a.asnumpy().view('float32'))
b.asnumpy(), (2 + a.asnumpy()).view('float32'))
check_c()


Expand Down