Skip to content

Commit

Permalink
[RUNTIME][String] Overload string operators (#5806)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored Jun 17, 2020
1 parent def496d commit 7f37eb4
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 73 deletions.
185 changes: 117 additions & 68 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,72 +1175,6 @@ class String : public ObjectRef {
*/
inline String& operator=(const char* other);

/*!
* \brief Compare is less than other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator<(const std::string& other) const { return this->compare(other) < 0; }
bool operator<(const String& other) const { return this->compare(other) < 0; }
bool operator<(const char* other) const { return this->compare(other) < 0; }

/*!
* \brief Compare is greater than other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator>(const std::string& other) const { return this->compare(other) > 0; }
bool operator>(const String& other) const { return this->compare(other) > 0; }
bool operator>(const char* other) const { return this->compare(other) > 0; }

/*!
* \brief Compare is less than or equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator<=(const std::string& other) const { return this->compare(other) <= 0; }
bool operator<=(const String& other) const { return this->compare(other) <= 0; }
bool operator<=(const char* other) const { return this->compare(other) <= 0; }

/*!
* \brief Compare is greater than or equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator>=(const std::string& other) const { return this->compare(other) >= 0; }
bool operator>=(const String& other) const { return this->compare(other) >= 0; }
bool operator>=(const char* other) const { return this->compare(other) >= 0; }

/*!
* \brief Compare is equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator==(const std::string& other) const { return this->compare(other) == 0; }
bool operator==(const String& other) const { return this->compare(other) == 0; }
bool operator==(const char* other) const { return compare(other) == 0; }

/*!
* \brief Compare is not equal to other std::string
*
* \param other The other string
*
* \return the comparison result
*/
bool operator!=(const std::string& other) const { return this->compare(other) != 0; }
bool operator!=(const String& other) const { return this->compare(other) != 0; }
bool operator!=(const char* other) const { return this->compare(other) != 0; }

/*!
* \brief Compares this String object to other
*
Expand Down Expand Up @@ -1372,6 +1306,29 @@ class String : public ObjectRef {
*/
static int memncmp(const char* lhs, const char* rhs, size_t lhs_count, size_t rhs_count);

/*!
* \brief Concatenate two char sequences
*
* \param lhs Pointers to the lhs char array
* \param lhs_size The size of the lhs char array
* \param rhs Pointers to the rhs char array
* \param rhs_size The size of the rhs char array
*
* \return The concatenated char sequence
*/
static String Concat(const char* lhs, size_t lhs_size, const char* rhs, size_t rhs_size) {
std::string ret(lhs, lhs_size);
ret.append(rhs, rhs_size);
return String(ret);
}

// Overload + operator
friend String operator+(const String& lhs, const String& rhs);
friend String operator+(const String& lhs, const std::string& rhs);
friend String operator+(const std::string& lhs, const String& rhs);
friend String operator+(const String& lhs, const char* rhs);
friend String operator+(const char* lhs, const String& rhs);

friend struct tvm::ObjectEqual;
};

Expand Down Expand Up @@ -1410,10 +1367,102 @@ inline String& String::operator=(std::string other) {

inline String& String::operator=(const char* other) { return operator=(std::string(other)); }

inline String operator+(const std::string lhs, const String& rhs) {
return lhs + rhs.operator std::string();
inline String operator+(const String& lhs, const String& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
}

inline String operator+(const String& lhs, const std::string& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
}

inline String operator+(const std::string& lhs, const String& rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = rhs.size();
return String::Concat(lhs.data(), lhs_size, rhs.data(), rhs_size);
}

inline String operator+(const char* lhs, const String& rhs) {
size_t lhs_size = std::strlen(lhs);
size_t rhs_size = rhs.size();
return String::Concat(lhs, lhs_size, rhs.data(), rhs_size);
}

inline String operator+(const String& lhs, const char* rhs) {
size_t lhs_size = lhs.size();
size_t rhs_size = std::strlen(rhs);
return String::Concat(lhs.data(), lhs_size, rhs, rhs_size);
}

// Overload < operator
inline bool operator<(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) < 0; }

inline bool operator<(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) > 0; }

inline bool operator<(const String& lhs, const String& rhs) { return lhs.compare(rhs) < 0; }

inline bool operator<(const String& lhs, const char* rhs) { return lhs.compare(rhs) < 0; }

inline bool operator<(const char* lhs, const String& rhs) { return rhs.compare(lhs) > 0; }

// Overload > operator
inline bool operator>(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) > 0; }

inline bool operator>(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) < 0; }

inline bool operator>(const String& lhs, const String& rhs) { return lhs.compare(rhs) > 0; }

inline bool operator>(const String& lhs, const char* rhs) { return lhs.compare(rhs) > 0; }

inline bool operator>(const char* lhs, const String& rhs) { return rhs.compare(lhs) < 0; }

// Overload <= operator
inline bool operator<=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }

inline bool operator<=(const String& lhs, const String& rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const String& lhs, const char* rhs) { return lhs.compare(rhs) <= 0; }

inline bool operator<=(const char* lhs, const String& rhs) { return rhs.compare(lhs) >= 0; }

// Overload >= operator
inline bool operator>=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) <= 0; }

inline bool operator>=(const String& lhs, const String& rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const String& lhs, const char* rhs) { return lhs.compare(rhs) >= 0; }

inline bool operator>=(const char* lhs, const String& rhs) { return rhs.compare(rhs) <= 0; }

// Overload == operator
inline bool operator==(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) == 0; }

inline bool operator==(const String& lhs, const String& rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const String& lhs, const char* rhs) { return lhs.compare(rhs) == 0; }

inline bool operator==(const char* lhs, const String& rhs) { return rhs.compare(lhs) == 0; }

// Overload != operator
inline bool operator!=(const String& lhs, const std::string& rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const std::string& lhs, const String& rhs) { return rhs.compare(lhs) != 0; }

inline bool operator!=(const String& lhs, const String& rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const String& lhs, const char* rhs) { return lhs.compare(rhs) != 0; }

inline bool operator!=(const char* lhs, const String& rhs) { return rhs.compare(lhs) != 0; }

inline std::ostream& operator<<(std::ostream& out, const String& input) {
out.write(input.data(), input.size());
return out;
Expand Down
2 changes: 1 addition & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ void IRModuleNode::ImportFromStd(const String& path) {
auto* f = tvm::runtime::Registry::Get("tvm.relay.std_path");
CHECK(f != nullptr) << "The Relay std_path is not set, please register tvm.relay.std_path.";
std::string std_path = (*f)();
this->Import(std_path + "/" + path.operator std::string());
this->Import(std_path + "/" + path);
}

std::unordered_set<String> IRModuleNode::Imports() const { return this->import_set_; }
Expand Down
4 changes: 1 addition & 3 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ Doc RelayTextPrinter::VisitExpr_(const FunctionNode* op) {
return PrintFunc(Doc::Text("fn "), GetRef<Function>(op));
}

Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) {
return Doc::Text('@' + op->name_hint.operator std::string());
}
Doc RelayTextPrinter::VisitExpr_(const GlobalVarNode* op) { return Doc::Text("@" + op->name_hint); }

Doc RelayTextPrinter::VisitExpr_(const OpNode* op) { return Doc::Text(op->name); }

Expand Down
2 changes: 1 addition & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ Var Var::copy_with_suffix(const String& suffix) const {
} else {
new_ptr = make_object<VarNode>(*node);
}
new_ptr->name_hint = new_ptr->name_hint.operator std::string() + suffix.operator std::string();
new_ptr->name_hint = new_ptr->name_hint + suffix;
return Var(new_ptr);
}

Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -458,23 +458,54 @@ TEST(String, compare) {

// compare with string
CHECK_EQ(str_source.compare(source), 0);
CHECK(str_source == source);
CHECK(source == str_source);
CHECK(str_source <= source);
CHECK(source <= str_source);
CHECK(str_source >= source);
CHECK(source >= str_source);
CHECK_LT(str_source.compare(mismatch1), 0);
CHECK(str_source < mismatch1);
CHECK(mismatch1 != str_source);
CHECK_GT(str_source.compare(mismatch2), 0);
CHECK(str_source > mismatch2);
CHECK(mismatch2 < str_source);
CHECK_GT(str_source.compare(mismatch3), 0);
CHECK(str_source > mismatch3);
CHECK_LT(str_source.compare(mismatch4), 0);
CHECK(str_source < mismatch4);
CHECK(mismatch4 > str_source);

// compare with char*
CHECK_EQ(str_source.compare(source.data()), 0);
CHECK(str_source == source.data());
CHECK(source.data() == str_source);
CHECK(str_source <= source.data());
CHECK(source <= str_source.data());
CHECK(str_source >= source.data());
CHECK(source >= str_source.data());
CHECK_LT(str_source.compare(mismatch1.data()), 0);
CHECK(str_source < mismatch1.data());
CHECK(str_source != mismatch1.data());
CHECK(mismatch1.data() != str_source);
CHECK_GT(str_source.compare(mismatch2.data()), 0);
CHECK(str_source > mismatch2.data());
CHECK(mismatch2.data() < str_source);
CHECK_GT(str_source.compare(mismatch3.data()), 0);
CHECK(str_source > mismatch3.data());
CHECK_LT(str_source.compare(mismatch4.data()), 0);
CHECK(str_source < mismatch4.data());
CHECK(mismatch4.data() > str_source);

// compare with String
CHECK_LT(str_source.compare(str_mismatch1), 0);
CHECK(str_source < str_mismatch1);
CHECK_GT(str_source.compare(str_mismatch2), 0);
CHECK(str_source > str_mismatch2);
CHECK_GT(str_source.compare(str_mismatch3), 0);
CHECK(str_source > str_mismatch3);
CHECK_LT(str_source.compare(str_mismatch4), 0);
CHECK(str_source < str_mismatch4);
}

TEST(String, c_str) {
Expand Down Expand Up @@ -513,6 +544,23 @@ TEST(String, Cast) {
String s2 = Downcast<String>(r);
}

TEST(String, Concat) {
String s1("hello");
String s2("world");
std::string s3("world");
String res1 = s1 + s2;
String res2 = s1 + s3;
String res3 = s3 + s1;
String res4 = s1 + "world";
String res5 = "world" + s1;

CHECK_EQ(res1.compare("helloworld"), 0);
CHECK_EQ(res2.compare("helloworld"), 0);
CHECK_EQ(res3.compare("worldhello"), 0);
CHECK_EQ(res4.compare("helloworld"), 0);
CHECK_EQ(res5.compare("worldhello"), 0);
}

TEST(Optional, Composition) {
Optional<String> opt0(nullptr);
Optional<String> opt1 = String("xyz");
Expand Down

0 comments on commit 7f37eb4

Please sign in to comment.