Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Refactor][std::string --> String] Relay updated with String #5578

Merged
merged 7 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class IdNode : public Object {
* this only acts as a hint to the user,
* and is not used for equality.
*/
std::string name_hint;
String name_hint;
ANSHUMAN87 marked this conversation as resolved.
Show resolved Hide resolved

void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); }

Expand All @@ -107,7 +107,7 @@ class Id : public ObjectRef {
* \brief The constructor
* \param name_hint The name of the variable.
*/
TVM_DLL explicit Id(std::string name_hint);
TVM_DLL explicit Id(String name_hint);

TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class VarNode : public ExprNode {
Type type_annotation;

/*! \return The name hint of the variable */
const std::string& name_hint() const { return vid->name_hint; }
const String& name_hint() const { return vid->name_hint; }

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("vid", &vid);
Expand All @@ -188,7 +188,7 @@ class VarNode : public ExprNode {
hash_reduce.FreeVarHashImpl(this);
}

TVM_DLL static Var make(std::string name_hint, Type type_annotation);
TVM_DLL static Var make(String name_hint, Type type_annotation);

TVM_DLL static Var make(Id vid, Type type_annotation);

Expand All @@ -203,7 +203,7 @@ class Var : public Expr {
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
*/
TVM_DLL Var(std::string name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {}
TVM_DLL Var(String name_hint, Type type_annotation) : Var(Id(name_hint), type_annotation) {}

/*!
* \brief The constructor
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relay/op_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class OpSpecialization : public ObjectRef {
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
int plevel);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
Expand Down Expand Up @@ -150,7 +150,7 @@ class OpStrategy : public ObjectRef {
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
int plevel);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ using Sequential = tvm::transform::Sequential;
*/
TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, const std::string& name, const tvm::Array<runtime::String>& required);
int opt_level, const String& name, const tvm::Array<runtime::String>& required);

/*! \brief Remove expressions which does not effect the program result.
*
Expand Down Expand Up @@ -298,7 +298,7 @@ TVM_DLL Pass ConvertLayout(const Map<std::string, Array<String>>& desired_layout
*
* \return The pass.
*/
TVM_DLL Pass Legalize(const std::string& legalize_map_attr_name = "FTVMLegalize");
TVM_DLL Pass Legalize(const String& legalize_map_attr_name = "FTVMLegalize");

/*!
* \brief Canonicalize cast expressions to make operator fusion more efficient.
Expand Down Expand Up @@ -387,7 +387,7 @@ TVM_DLL Function InferType(const Function& f, const IRModule& mod, const GlobalV
* an Expr consumed by multiple callers.
* \return The rewritten expression.
*/
TVM_DLL Expr ForwardRewrite(const Expr& expr, const std::string& rewrite_map_attr_name,
TVM_DLL Expr ForwardRewrite(const Expr& expr, const String& rewrite_map_attr_name,
std::function<ObjectRef(const Call&)> fcontext = nullptr,
std::function<Expr(const Expr&)> fmulti_ref_trigger = nullptr);

Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _convert(item, nodes):
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Id": [_update_from_std_str("name_hint")],
"relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Type": _rename("Type"),
"relay.TupleType": _rename("TupleType"),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, name_hint, type_annotation=None):
@property
def name_hint(self):
"""Get name hint of the current var."""
name = self.vid.name_hint
name = str(self.vid.name_hint)
return name


Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ class CompileEngineImpl : public CompileEngineNode {
auto symbol_name = src_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
CHECK(symbol_name.defined()) << "No external symbol is set for:\n"
<< AsText(src_func, false);
auto gv = GlobalVar(std::string(symbol_name.value()));
auto gv = GlobalVar(symbol_name.value());
// No need to keep compiler attribute at this point, functions have been
// extracted for specific codegen.
src_func = WithAttr(std::move(src_func), attr::kCompiler, NullValue<ObjectRef>());
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using namespace tvm::runtime;

TVM_REGISTER_NODE_TYPE(IdNode);

Id::Id(std::string name_hint) {
Id::Id(String name_hint) {
ObjectPtr<IdNode> n = make_object<IdNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ Var::Var(Id vid, Type type_annotation) {

TVM_REGISTER_NODE_TYPE(VarNode);

TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](std::string str, Type type_annotation) {
TVM_REGISTER_GLOBAL("relay.ir.Var").set_body_typed([](String str, Type type_annotation) {
return Var(str, type_annotation);
});

Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/op_strategy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ te::Schedule OpImplementation::Schedule(const Attrs& attrs, const Array<te::Tens
}

void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
tvm::relay::FTVMSchedule fschedule, std::string name,
tvm::relay::FTVMSchedule fschedule, String name,
int plevel) {
auto n = make_object<OpImplementationNode>();
n->fcompute = fcompute;
Expand All @@ -52,7 +52,7 @@ void OpSpecialization::AddImplementation(tvm::relay::FTVMCompute fcompute,
(*this)->implementations.push_back(OpImplementation(n));
}

void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, std::string name,
void OpStrategy::AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule, String name,
int plevel) {
auto curr_cond = te::SpecializedCondition::Current();
auto self = this->operator->();
Expand Down
2 changes: 1 addition & 1 deletion src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ bool FunctionPassNode::SkipFunction(const Function& func) const {

Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, const std::string& name, const tvm::Array<runtime::String>& required) {
int opt_level, const String& name, const tvm::Array<runtime::String>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPass(pass_func, pass_info);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/algorithm/topk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ bool TopKRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}

Expr MakeTopK(Expr data, int k, int axis, std::string ret_type, bool is_ascend, DataType dtype) {
Expr MakeTopK(Expr data, int k, int axis, String ret_type, bool is_ascend, DataType dtype) {
auto attrs = make_object<TopKAttrs>();
attrs->k = k;
attrs->axis = axis;
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ Beginning of a region that is handled by a given compiler.
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
.set_body_typed([](Expr expr, std::string compiler) {
.set_body_typed([](Expr expr, String compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
Expand All @@ -207,7 +207,7 @@ End of a region that is handled by a given compiler.
});

TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
.set_body_typed([](Expr expr, std::string compiler) {
.set_body_typed([](Expr expr, String compiler) {
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ RELAY_REGISTER_OP("debug")
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FTVMCompute>("FTVMCompute", DebugCompute);

Expr MakeDebug(Expr expr, std::string name) {
Expr MakeDebug(Expr expr, String name) {
auto dattrs = make_object<DebugAttrs>();
if (name.size() > 0) {
dattrs->debug_func = EnvFunc::Get(name);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/image/dilation2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Array<Array<Layout> > Dilation2DInferCorrectLayout(const Attrs& attrs,
// Positional relay function to create dilation2d operator
// used by frontend FFI.
Expr MakeDilation2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilations, std::string data_layout, std::string kernel_layout,
Array<IndexExpr> dilations, String data_layout, String kernel_layout,
DataType out_dtype) {
auto attrs = make_object<Dilation2DAttrs>();
attrs->strides = std::move(strides);
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

// Positional relay function to create image operator
// used by frontend FFI.
Expr MakeResize(Expr data, Array<IndexExpr> size, std::string layout, std::string method,
std::string coordinate_transformation_mode, DataType out_dtype) {
Expr MakeResize(Expr data, Array<IndexExpr> size, String layout, String method,
String coordinate_transformation_mode, DataType out_dtype) {
auto attrs = make_object<ResizeAttrs>();
attrs->size = std::move(size);
attrs->layout = std::move(layout);
Expand Down Expand Up @@ -133,7 +133,7 @@ bool CropAndResizeRel(const Array<Type>& types, int num_inputs, const Attrs& att
}

Expr MakeCropAndResize(Expr data, Expr boxes, Expr box_indices, Array<IndexExpr> crop_size,
std::string layout, std::string method, double extrapolation_value,
String layout, String method, double extrapolation_value,
DataType out_dtype) {
auto attrs = make_object<CropAndResizeAttrs>();
attrs->crop_size = std::move(crop_size);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack_type,
std::string name) {
String name) {
auto attrs = make_object<BitPackAttrs>();
attrs->bits = bits;
attrs->pack_axis = pack_axis;
Expand Down Expand Up @@ -150,7 +150,7 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
// used by frontend FFI.
Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
IndexExpr channels, Array<IndexExpr> kernel_size, int activation_bits,
int weight_bits, std::string data_layout, std::string kernel_layout,
int weight_bits, String data_layout, String kernel_layout,
DataType pack_dtype, DataType out_dtype, bool unipolar) {
auto attrs = make_object<BinaryConv2DAttrs>();
attrs->strides = std::move(strides);
Expand Down
Loading