Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Object][FFI] Introduce runtime::String::CanConvertFrom #5718

Merged
merged 2 commits into from
Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,15 @@ class String : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }

/*!
* \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String
* \param val The value to be checked
* \return A boolean indicating if val can be converted to String
*/
static bool CanConvertFrom(const TVMArgValue& val) {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
return val.type_code() == kTVMStr || val.IsObjectRef<tvm::runtime::String>();
}

/*!
* \brief Hash the binary bytes
* \param data The data pointer
Expand Down
2 changes: 1 addition & 1 deletion src/ir/attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ void DictAttrsNode::InitByPackedArgs(const runtime::TVMArgs& args, bool allow_un
runtime::TVMArgValue val = args[i + 1];
if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kTVMStr) {
} else if (String::CanConvertFrom(val)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@junrushao1994 : Sorry to bug you on this. I was wondering whether, the first if condition is always true, when the val is String objref type. Please check once. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your question! Hmmmm but I am not sure I understand if correctly. Just to confirm, "The first if condition", are you referring to "val.type_code() == kTVMStr" in CanConvertFrom? Thank you!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python frontend is still able to pass runtime::String back to C++ via packed function. One case I found in our testcases is: https://github.com/apache/incubator-tvm/pull/5687/files#diff-5243f3909dfa484d85d3d8471f259169R422-R424

Copy link
Contributor

@ANSHUMAN87 ANSHUMAN87 Jun 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your question! Hmmmm but I am not sure I understand if correctly. Just to confirm, "The first if condition", are you referring to "val.type_code() == kTVMStr" in CanConvertFrom? Thank you!

Sorry for confusion.
The first if condition --> if (val.IsObjectRef()) might be true always for String objref, so maybe the condition inside String::CanConvertFrom(val) will not reach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Python frontend is still able to pass runtime::String back to C++ via packed function. One case I found in our testcases is: https://github.com/apache/incubator-tvm/pull/5687/files#diff-5243f3909dfa484d85d3d8471f259169R422-R424

I think that may be because of implicit conversion of String(Overloaded) to std::string!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah I got what you mean, and thank you for the clarification @ANSHUMAN87! Totally agree with your points, but we want to unify the way we check "FFI -> string" conversion, so that future users won't end up messing them up. This is the reason why I made this change: After this PR, the only places that use raw kTVMStr are RPC and C-API.

Thank you again for the comments @zhiics @ANSHUMAN87!

dict.Set(key, val.operator String());
} else {
dict.Set(key, val.operator PrimExpr());
Expand Down
52 changes: 16 additions & 36 deletions src/node/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,29 +292,16 @@ TVM_REGISTER_REFLECTION_VTABLE(MapNode, MapNodeTrait)

TVM_REGISTER_GLOBAL("node.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kTVMStr) {
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kTVMStr) << "key of str map need to be str";
CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of the map to be object";
data.emplace(
std::make_pair(String(args[i].operator std::string()), args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
} else {
// Container node.
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].IsObjectRef<ObjectRef>()) << "key of map need to be object";
CHECK(args[i + 1].IsObjectRef<ObjectRef>()) << "value of map to be object";
data.emplace(std::make_pair(args[i].operator ObjectRef(), args[i + 1].operator ObjectRef()));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
MapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) {
ObjectRef k =
String::CanConvertFrom(args[i]) ? args[i].operator String() : args[i].operator ObjectRef();
ObjectRef v = args[i + 1];
data.emplace(std::move(k), std::move(v));
}
auto node = make_object<MapNode>();
node->data = std::move(data);
*ret = Map<ObjectRef, ObjectRef>(node);
});

TVM_REGISTER_GLOBAL("node.MapSize").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand All @@ -331,27 +318,20 @@ TVM_REGISTER_GLOBAL("node.MapGetItem").set_body([](TVMArgs args, TVMRetValue* re
CHECK(ptr->IsInstance<MapNode>());

auto* n = static_cast<const MapNode*>(ptr);
if (args[1].type_code() == kTVMStr) {
auto it = n->data.find(String(args[1].operator std::string()));
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
} else {
auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
}
auto it = n->data.find(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
CHECK(it != n->data.end()) << "cannot find the corresponding key in the Map";
*ret = (*it).second;
});

TVM_REGISTER_GLOBAL("node.MapCount").set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<MapNode>());
const MapNode* n = static_cast<const MapNode*>(ptr);
if (args[1].type_code() == kTVMStr) {
*ret = static_cast<int64_t>(n->data.count(String(args[1].operator std::string())));
} else {
*ret = static_cast<int64_t>(n->data.count(args[1].operator ObjectRef()));
}
int64_t cnt = n->data.count(String::CanConvertFrom(args[1]) ? args[1].operator String()
: args[1].operator ObjectRef());
*ret = cnt;
});

TVM_REGISTER_GLOBAL("node.MapItems").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
9 changes: 9 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,15 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return val->data == rhs.operator std::string();
}
break;
case kTVMObjectHandle:
if (rhs.IsObjectRef<String>()) {
if (auto* val = lhs.as<tir::StringImmNode>()) {
return rhs.operator String() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator String() == val->data;
}
}
break;
default:
CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
}
Expand Down
3 changes: 2 additions & 1 deletion src/runtime/graph/debug/graph_runtime_debug.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
/*!
* \file graph_runtime_debug.cc
*/
#include <tvm/runtime/container.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
Expand Down Expand Up @@ -173,7 +174,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(const std::string& name,
});
} else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
if (String::CanConvertFrom(args[0])) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else {
this->DebugGetNodeOutput(args[0], args[1]);
Expand Down
15 changes: 6 additions & 9 deletions src/runtime/graph/graph_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,17 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
// Return member functions during query.
if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]);
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else {
this->SetInput(args[0], args[1]);
}
});
} else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]);
if (String::CanConvertFrom(args[0])) {
int in_idx = this->GetInputIndex(args[0].operator String());
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else {
this->SetInputZeroCopy(args[0], args[1]);
Expand All @@ -417,11 +417,8 @@ PackedFunc GraphRuntime::GetFunction(const std::string& name,
} else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0;
if (args[0].type_code() == kTVMStr) {
in_idx = this->GetInputIndex(args[0]);
} else if (args[0].IsObjectRef<runtime::String>()) {
auto str = args[0].AsObjectRef<runtime::String>();
in_idx = this->GetInputIndex(str);
if (String::CanConvertFrom(args[0])) {
in_idx = this->GetInputIndex(args[0].operator String());
} else {
in_idx = args[0];
}
Expand Down