Skip to content

Commit

Permalink
Remove PrimExpr from String (apache#5311)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Apr 17, 2020
1 parent 9938a66 commit 8246972
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 32 deletions.
6 changes: 0 additions & 6 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr {
*/
TVM_DLL PrimExpr(float value); // NOLINT(*)

/*!
* \brief construct from runtime String.
* \param value The value to be constructed.
*/
TVM_DLL PrimExpr(runtime::String value); // NOLINT(*)

/*! \return the data type of this expression. */
DataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
Expand Down
3 changes: 0 additions & 3 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value)
PrimExpr::PrimExpr(float value)
: PrimExpr(FloatImm(DataType::Float(32), value)) {}

PrimExpr::PrimExpr(runtime::String value)
: PrimExpr(tir::StringImmNode::make(value)) {}

PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
using runtime::ObjectTypeChecker;
if (auto* ptr = ref.as<tir::IterVarNode>()) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name,
} else if (target_name == "hybrid") {
t->device_type = kDLCPU;
} else if (target_name == "hexagon") {
t->keys_array.push_back(runtime::String("hexagon"));
t->keys_array.push_back("hexagon");
t->device_type = kDLHexagon;
} else {
LOG(ERROR) << "Unknown target name " << target_name;
Expand Down
43 changes: 24 additions & 19 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node,
TVM_REGISTER_GLOBAL("tir.AttrStmt")
.set_body_typed(AttrStmtNode::make);


Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
CHECK(condition.defined());
CHECK(message.dtype() == DataType::Int(32) ||
Expand All @@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) {
}

TVM_REGISTER_GLOBAL("tir.AssertStmt")
.set_body_typed(AssertStmtNode::make);

.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) {
if (const auto* str = message.as<StringObj>()) {
auto msg = StringImmNode::make(str->data);
return AssertStmtNode::make(condition, msg, body);
} else {
return AssertStmtNode::make(condition, Downcast<PrimExpr>(message), body);
}
});

Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) {
CHECK(body.defined());
Expand All @@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer")


Stmt ForNode::make(Var loop_var,
PrimExpr min,
PrimExpr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body) {
PrimExpr min,
PrimExpr extent,
ForType for_type,
DeviceAPI device_api,
Stmt body) {
CHECK(min.defined());
CHECK(extent.defined());
CHECK(min.dtype().is_scalar());
Expand All @@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For")
Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body) {
return ForNode::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api),
body);
min,
extent,
static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api),
body);
});


Expand Down Expand Up @@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide")


Stmt AllocateNode::make(Var buffer_var,
DataType dtype,
Array<PrimExpr> extents,
PrimExpr condition,
Stmt body,
PrimExpr new_expr,
std::string free_function) {
DataType dtype,
Array<PrimExpr> extents,
PrimExpr condition,
Stmt body,
PrimExpr new_expr,
std::string free_function) {
for (size_t i = 0; i < extents.size(); ++i) {
CHECK(extents[i].defined());
CHECK(extents[i].dtype().is_scalar());
Expand Down
4 changes: 2 additions & 2 deletions topi/include/topi/contrib/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
runtime::String("tvm.contrib.cublas.matmul"),
StringImmNode::make("tvm.contrib.cublas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
Expand Down Expand Up @@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs,
{ { b, n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
runtime::String("tvm.contrib.cublas.batch_matmul"),
StringImmNode::make("tvm.contrib.cublas.batch_matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
Expand Down
2 changes: 1 addition & 1 deletion topi/include/topi/contrib/rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs,
{ { n, m } }, { lhs->dtype }, { lhs, rhs },
[&](Array<Buffer> ins, Array<Buffer> outs) {
return call_packed({
runtime::String("tvm.contrib.rocblas.matmul"),
StringImmNode::make("tvm.contrib.rocblas.matmul"),
pack_buffer(ins[0]),
pack_buffer(ins[1]),
pack_buffer(outs[0]),
Expand Down

0 comments on commit 8246972

Please sign in to comment.