Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][String] Overload string operators #5806

Merged
merged 5 commits into from
Jun 17, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 124 additions & 68 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,72 +1175,6 @@ class String : public ObjectRef {
*/
inline String& operator=(const char* other);

/*!
* \brief Compare is less than other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator<(const std::string& other) const { return this->compare(other) < 0; }
bool operator<(const String& other) const { return this->compare(other) < 0; }
bool operator<(const char* other) const { return this->compare(other) < 0; }

/*!
* \brief Compare is greater than other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator>(const std::string& other) const { return this->compare(other) > 0; }
bool operator>(const String& other) const { return this->compare(other) > 0; }
bool operator>(const char* other) const { return this->compare(other) > 0; }

/*!
* \brief Compare is less than or equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator<=(const std::string& other) const { return this->compare(other) <= 0; }
bool operator<=(const String& other) const { return this->compare(other) <= 0; }
bool operator<=(const char* other) const { return this->compare(other) <= 0; }

/*!
* \brief Compare is greater than or equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator>=(const std::string& other) const { return this->compare(other) >= 0; }
bool operator>=(const String& other) const { return this->compare(other) >= 0; }
bool operator>=(const char* other) const { return this->compare(other) >= 0; }

/*!
* \brief Compare is equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator==(const std::string& other) const { return this->compare(other) == 0; }
bool operator==(const String& other) const { return this->compare(other) == 0; }
bool operator==(const char* other) const { return compare(other) == 0; }

/*!
* \brief Compare is not equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator!=(const std::string& other) const { return this->compare(other) != 0; }
bool operator!=(const String& other) const { return this->compare(other) != 0; }
bool operator!=(const char* other) const { return this->compare(other) != 0; }

/*!
* \brief Compares this String object to other
*
Expand Down Expand Up @@ -1372,6 +1306,31 @@ class String : public ObjectRef {
*/
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);

/*!
* \brief Concatenate two char sequences
*
* \param lhs Pointers to the lhs char array
* \param lhs_size The size of the lhs char array
* \param rhs Pointers to the rhs char array
* \param rhs_size The size of the rhs char array
*
* \return The concatenated char sequence
*/
static char* Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) {
char* concat = new char[lhs_size + rhs_size + 1];
tqchen marked this conversation as resolved.
Show resolved Hide resolved
std::memcpy(concat, lhs, lhs_size);
std::memcpy(concat + lhs_size, rhs, rhs_size);
concat[lhs_size + rhs_size] = '\0';
return concat;
}

// Overload + operator
friend String operator+(const String& lhs, const String& rhs);
friend String operator+(const String& lhs, const std::string& rhs);
friend String operator+(const std::string& lhs, const String& rhs);
friend String operator+(const String& lhs, const char* rhs);
friend String operator+(const char* lhs, const String& rhs);

friend struct tvm::ObjectEqual;
};

Expand Down Expand Up @@ -1410,10 +1369,107 @@ inline String& String::operator=(std::string other) {

inline String& String::operator=(const char* other) { return operator=(std::string(other)); }

inline String operator+(const std::string lhs, const String& rhs) {
return lhs + rhs.operator std::string();
inline String operator+(const String& lhs, const String& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
char* concat = String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
return String(concat);
}

inline String operator+(const String& lhs, const std::string& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
char* concat = String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
return String(concat);
}

inline String operator+(const std::string& lhs, const String& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
char* concat = String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
return String(concat);
}

inline String operator+(const char* lhs, const String& rhs) {
size_t lhs_size = std::strlen(lhs);
size_t rhs_size = rhs.size();
char* concat = String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
return String(concat);
}

inline String operator+(const String& lhs, const char* rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = std::strlen(rhs);
char* concat = String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
return String(concat);
}

// Overload < operator
inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; }
tqchen marked this conversation as resolved.
Show resolved Hide resolved

inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; }

inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; }

inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; }

inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; }

// Overload > operator
inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; }
zhiics marked this conversation as resolved.
Show resolved Hide resolved

inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; }

inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; }

inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; }

inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; }

// Overload <= operator
inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }

inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }

// Overload >= operator
inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }

inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; }

// Overload == operator
inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }

inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; }

// Overload != operator
inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; }

inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; }

inline std::ostream& operator<<(std::ostream& out, const String& input) {
out.write(input.data(), input.size());
return out;
Expand Down
2 changes: 1 addition & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
this->Import(std_path + "/" + path.operator std::string());
this->Import(std_path + "/" + path);
}

std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
Expand Down
4 changes: 1 addition & 3 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
return Doc::Text('@' + op->name_hint.operator std::string());
}
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Var Var::copy_with_suffix(const String& suffix) const {
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string();
new_ptr->name_hint = new_ptr->name_hint + suffix;
return Var(new_ptr);
}

Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,23 +458,54 @@ TEST(String, compare) {

// compare with string
CHECK_EQ(str_source.compare(source), 0);
CHECK(str_source == source);
CHECK(source == str_source);
CHECK(str_source <= source);
CHECK(source <= str_source);
CHECK(str_source >= source);
CHECK(source >= str_source);
CHECK_LT(str_source.compare(mismatch1), 0);
CHECK(str_source < mismatch1);
CHECK(mismatch1 != str_source);
CHECK_GT(str_source.compare(mismatch2), 0);
CHECK(str_source > mismatch2);
CHECK(mismatch2 < str_source);
CHECK_GT(str_source.compare(mismatch3), 0);
CHECK(str_source > mismatch3);
CHECK_LT(str_source.compare(mismatch4), 0);
CHECK(str_source < mismatch4);
CHECK(mismatch4 > str_source);

// compare with char*
CHECK_EQ(str_source.compare(source.data()), 0);
CHECK(str_source == source.data());
CHECK(source.data() == str_source);
CHECK(str_source <= source.data());
CHECK(source <= str_source.data());
CHECK(str_source >= source.data());
CHECK(source >= str_source.data());
CHECK_LT(str_source.compare(mismatch1.data()), 0);
CHECK(str_source < mismatch1.data());
CHECK(str_source != mismatch1.data());
CHECK(mismatch1.data() != str_source);
CHECK_GT(str_source.compare(mismatch2.data()), 0);
CHECK(str_source > mismatch2.data());
CHECK(mismatch2.data() < str_source);
CHECK_GT(str_source.compare(mismatch3.data()), 0);
CHECK(str_source > mismatch3.data());
CHECK_LT(str_source.compare(mismatch4.data()), 0);
CHECK(str_source < mismatch4.data());
CHECK(mismatch4.data() > str_source);

// compare with String
CHECK_LT(str_source.compare(str_mismatch1), 0);
CHECK(str_source < str_mismatch1);
CHECK_GT(str_source.compare(str_mismatch2), 0);
CHECK(str_source > str_mismatch2);
CHECK_GT(str_source.compare(str_mismatch3), 0);
CHECK(str_source > str_mismatch3);
CHECK_LT(str_source.compare(str_mismatch4), 0);
CHECK(str_source < str_mismatch4);
}

TEST(String, c_str) {
Expand Down Expand Up @@ -513,6 +544,23 @@ TEST(String, Cast) {
String s2 = Downcast<String>(r);
}

TEST(String, Concat) {
String s1("hello");
String s2("world");
std::string s3("world");
String res1 = s1 + s2;
String res2 = s1 + s3;
String res3 = s3 + s1;
String res4 = s1 + "world";
String res5 = "world" + s1;

CHECK_EQ(res1.compare("helloworld"), 0);
CHECK_EQ(res2.compare("helloworld"), 0);
CHECK_EQ(res3.compare("worldhello"), 0);
CHECK_EQ(res4.compare("helloworld"), 0);
CHECK_EQ(res5.compare("worldhello"), 0);
}

TEST(Optional, Composition) {
Optional<String> opt0(nullptr);
Optional<String> opt1 = String("xyz");
Expand Down