From 8246972d4dd86de564efe6a417aa00f9b11fbf0e Mon Sep 17 00:00:00 2001 From: Zhi <5145158+zhiics@users.noreply.github.com> Date: Sun, 12 Apr 2020 09:12:23 -0700 Subject: [PATCH] Remove PrimExpr from String (#5311) --- include/tvm/ir/expr.h | 6 ---- src/ir/expr.cc | 3 -- src/target/target.cc | 2 +- src/tir/ir/stmt.cc | 43 ++++++++++++++++------------- topi/include/topi/contrib/cublas.h | 4 +-- topi/include/topi/contrib/rocblas.h | 2 +- 6 files changed, 28 insertions(+), 32 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4e0a301156a3..859a134cd5aa 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -108,12 +108,6 @@ class PrimExpr : public BaseExpr { */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! - * \brief construct from runtime String. - * \param value The value to be constructed. - */ - TVM_DLL PrimExpr(runtime::String value); // NOLINT(*) - /*! \return the data type of this expression. */ DataType dtype() const { return static_cast(get())->dtype; diff --git a/src/ir/expr.cc b/src/ir/expr.cc index e08d832cabc9..7272213ad406 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -40,9 +40,6 @@ PrimExpr::PrimExpr(int32_t value) PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr::PrimExpr(runtime::String value) - : PrimExpr(tir::StringImmNode::make(value)) {} - PrimExpr PrimExpr::FromObject_(ObjectRef ref) { using runtime::ObjectTypeChecker; if (auto* ptr = ref.as()) { diff --git a/src/target/target.cc b/src/target/target.cc index 61d5f6fe79ad..50856d62af30 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -137,7 +137,7 @@ Target CreateTarget(const std::string& target_name, } else if (target_name == "hybrid") { t->device_type = kDLCPU; } else if (target_name == "hexagon") { - t->keys_array.push_back(runtime::String("hexagon")); + t->keys_array.push_back("hexagon"); t->device_type = kDLHexagon; } else { LOG(ERROR) << "Unknown target name " << target_name; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 64e7ef572673..705fe7bdf26e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -58,7 +58,6 @@ Stmt AttrStmtNode::make(ObjectRef node, TVM_REGISTER_GLOBAL("tir.AttrStmt") .set_body_typed(AttrStmtNode::make); - Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { CHECK(condition.defined()); CHECK(message.dtype() == DataType::Int(32) || @@ -74,8 +73,14 @@ Stmt AssertStmtNode::make(PrimExpr condition, PrimExpr message, Stmt body) { } TVM_REGISTER_GLOBAL("tir.AssertStmt") -.set_body_typed(AssertStmtNode::make); - +.set_body_typed([](PrimExpr condition, ObjectRef message, Stmt body) { + if (const auto* str = message.as()) { + auto msg = StringImmNode::make(str->data); + return AssertStmtNode::make(condition, msg, body); + } else { + return AssertStmtNode::make(condition, Downcast(message), body); + } +}); Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { CHECK(body.defined()); @@ -92,11 +97,11 @@ TVM_REGISTER_GLOBAL("tir.ProducerConsumer") Stmt ForNode::make(Var loop_var, - PrimExpr min, - PrimExpr extent, - ForType for_type, - DeviceAPI device_api, - Stmt body) { + PrimExpr min, + PrimExpr extent, + ForType for_type, + DeviceAPI device_api, + Stmt body) { CHECK(min.defined()); CHECK(extent.defined()); CHECK(min.dtype().is_scalar()); @@ -119,11 +124,11 @@ TVM_REGISTER_GLOBAL("tir.For") Var loop_var, PrimExpr min, PrimExpr extent, int for_type, int device_api, Stmt body) { return ForNode::make(loop_var, - min, - extent, - static_cast(for_type), - static_cast(device_api), - body); + min, + extent, + static_cast(for_type), + static_cast(device_api), + body); }); @@ -176,12 +181,12 @@ TVM_REGISTER_GLOBAL("tir.Provide") Stmt AllocateNode::make(Var buffer_var, - DataType dtype, - Array extents, - PrimExpr condition, - Stmt body, - PrimExpr new_expr, - std::string free_function) { + DataType dtype, + Array extents, + PrimExpr condition, + Stmt body, + PrimExpr new_expr, + std::string free_function) { for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); CHECK(extents[i].dtype().is_scalar()); diff --git a/topi/include/topi/contrib/cublas.h b/topi/include/topi/contrib/cublas.h index ee18deae0781..f2ed029f5b33 100644 --- a/topi/include/topi/contrib/cublas.h +++ b/topi/include/topi/contrib/cublas.h @@ -53,7 +53,7 @@ inline Tensor cublas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - runtime::String("tvm.contrib.cublas.matmul"), + StringImmNode::make("tvm.contrib.cublas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), @@ -85,7 +85,7 @@ inline Tensor cublas_batch_matmul(const Tensor& lhs, { { b, n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - runtime::String("tvm.contrib.cublas.batch_matmul"), + StringImmNode::make("tvm.contrib.cublas.batch_matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]), diff --git a/topi/include/topi/contrib/rocblas.h b/topi/include/topi/contrib/rocblas.h index 9fe1825fe65e..f0bf92678f9a 100644 --- a/topi/include/topi/contrib/rocblas.h +++ b/topi/include/topi/contrib/rocblas.h @@ -52,7 +52,7 @@ inline Tensor rocblas_matmul(const Tensor& lhs, { { n, m } }, { lhs->dtype }, { lhs, rhs }, [&](Array ins, Array outs) { return call_packed({ - runtime::String("tvm.contrib.rocblas.matmul"), + StringImmNode::make("tvm.contrib.rocblas.matmul"), pack_buffer(ins[0]), pack_buffer(ins[1]), pack_buffer(outs[0]),