Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Improve string interpretation in Target creation #12152

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 188 additions & 70 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/tag.h>
#include <tvm/target/target.h>
Expand All @@ -30,8 +31,13 @@

#include <algorithm>
#include <cctype>
#include <cstring>
#include <ios>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../runtime/object_internal.h"

Expand Down Expand Up @@ -62,6 +68,17 @@ class TargetInternal {

private:
static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target);
static bool IsQuoted(const std::string& str);
static std::string Quote(const std::string& str);
static std::string JoinString(const std::vector<std::string>& array, char separator);
static std::vector<std::string> SplitString(const std::string& str, char separator);
static std::string Interpret(const std::string& str);
static std::string Uninterpret(const std::string& str);
static std::string StringifyAtomicType(const ObjectRef& obj);
static std::string StringifyArray(const ArrayNode& array);

static constexpr char quote = '\'';
static constexpr char escape = '\\';
};

/********** Helper functions **********/
Expand Down Expand Up @@ -135,48 +152,50 @@ static std::string RemovePrefixDashes(const std::string& s) {
return s.substr(n_dashes);
}

static int FindFirstSubstr(const std::string& str, const std::string& substr) {
size_t pos = str.find_first_of(substr);
return pos == std::string::npos ? -1 : pos;
}

static Optional<String> JoinString(const std::vector<String>& array, char separator) {
char escape = '\\';
char quote = '\'';

if (array.empty()) {
return NullOpt;
bool TargetInternal::IsQuoted(const std::string& str) {
std::string::size_type start = 0, end = str.size();
if (end < 2 || str[start] != quote || str[end - 1] != quote) {
return false;
}

std::ostringstream os;

for (size_t i = 0; i < array.size(); ++i) {
if (i > 0) {
os << separator;
bool escaping = false;
for (auto i = start + 1, e = end - 1; i < e; ++i) {
if (escaping) {
escaping = false;
} else if (str[i] == escape) {
escaping = true;
} else if (str[i] == quote) {
return false;
}
}
// If the reduced string ends with \, then the terminating quote is escaped.
return !escaping;
}

std::string str = array[i];
std::string TargetInternal::Quote(const std::string& str) {
std::string result(1, quote);
result.append(str);
result.push_back(quote);
return result;
}

if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) {
os << str;
} else {
os << quote;
for (char c : str) {
if (c == quote) {
os << escape;
}
os << c;
}
os << quote;
std::string TargetInternal::JoinString(const std::vector<std::string>& array, char separator) {
std::string result;
ICHECK(separator != quote && separator != escape)
<< "string join separator cannot be " << quote << " or " << escape;

bool is_first = true;
for (const auto& s : array) {
if (!is_first) {
result.push_back(separator);
}
result.append(s);
is_first = false;
}
return String(os.str());
}

static std::vector<std::string> SplitString(const std::string& str, char separator) {
char escape = '\\';
char quote = '\'';
return result;
}

std::vector<std::string> TargetInternal::SplitString(const std::string& str, char separator) {
std::vector<std::string> output;

const char* start = str.data();
Expand All @@ -199,10 +218,12 @@ static std::vector<std::string> SplitString(const std::string& str, char separat
if ((*pos == separator) && !pos_quoted) {
finish_word();
pos++;
} else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
current_word << quote;
} else if (*pos == escape && pos + 1 < end) {
current_word << escape;
current_word << pos[1];
pos += 2;
} else if (*pos == quote) {
current_word << quote;
pos_quoted = !pos_quoted;
pos++;
} else {
Expand All @@ -218,12 +239,91 @@ static std::vector<std::string> SplitString(const std::string& str, char separat
return output;
}

std::string TargetInternal::Interpret(const std::string& str) {
// String interpretation deals with quotes (') and escapes(\).
// - An escape character must be followed by another character forming an
// "escape sequence". (Trailing escape is not allowed.) An escape prevents
// interpretation of the character that follows. This happens regardless of
// whether the escape sequence appears within quoted substring or not.
// - A quote character, when interpreted, marks the beginning or the end of a
// quoted substring. (A quoted substring cannot contain unescaped quotes.)
// - Any other character, when interpreted, represents itself.
//
// Interpretation happens in two steps:
// 1. If the entire string is quoted, the quotes are removed first, and the
// resulting string is treated as unquoted.
// 2. Each character or escape sequence is interpreted, and the result is copied
// to the result. When not inside a quoted substring, the interpretation of an
// escape sequence is the escaped character, otherwise it is the entire escape
// sequence.
//
// Examples:
// blah -> blah Nothing happened
// 'blah' -> blah Enclosing quotes removed
// 'bl'ah -> 'bl'ah Non-enclosing quotes remain
// '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes
// interpreted.
// '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above.
//
// Note that
// '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah'

std::string result;
if (str.empty()) {
return result;
}

// Check if the entire string is enclosed in quotes ''. If so, strip the quotes
// and treat the string as unquoted (so that escapes are interpreted). Doing that
// will allow '\'foo\'' to become 'foo', instead of \'foo\'.
std::string::size_type start = 0, end = str.size();
if (IsQuoted(str)) {
start++;
end--;
}

bool inside_quote = false;
bool escaping = false;

for (auto i = start, e = end; i < e; ++i) {
std::string::value_type c = str[i];
if (escaping) {
escaping = false;
} else if (c == escape) {
escaping = true;
if (!inside_quote) {
continue;
}
} else if (c == quote) {
inside_quote = !inside_quote;
}
result.push_back(c);
}

return result;
}

std::string TargetInternal::Uninterpret(const std::string& str) {
// Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str.
std::string result;

for (std::string::size_type i = 0, e = str.size(); i < e; ++i) {
std::string::value_type c = str[i];
if (c == escape || c == quote) {
result.push_back(escape);
}
result.push_back(c);
}

return result;
}

static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key,
std::string* value) {
int pos;
std::string::size_type pos;
std::string& result_k = *key;
std::string& result_v = *value;
if ((pos = FindFirstSubstr(s, "=")) != -1) {
if ((pos = s.find_first_of('=')) != std::string::npos) {
// case 1. --key=value
result_k = s.substr(0, pos);
result_v = s.substr(pos + 1);
Expand Down Expand Up @@ -267,37 +367,42 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi

ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info) {
std::string interp_str = Interpret(str);
if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
std::istringstream is(str);
std::istringstream is(interp_str);
int v;
if (!(is >> v)) {
std::string lower(str.size(), '\x0');
std::transform(str.begin(), str.end(), lower.begin(),
std::string lower(interp_str.size(), '\x0');
std::transform(interp_str.begin(), interp_str.end(), lower.begin(),
[](unsigned char c) { return std::tolower(c); });
// Bool is a subclass of IntImm, so allow textual boolean values.
if (lower == "true") {
v = 1;
} else if (lower == "false") {
v = 0;
} else {
throw Error(": Cannot parse into type \"Integer\" from string: " + str);
throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str);
}
}
return Integer(v);
} else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing string, strip leading/trailing spaces
auto start = str.find_first_not_of(' ');
auto end = str.find_last_not_of(' ');
return String(str.substr(start, (end - start + 1)));
// Parsing string, strip leading/trailing spaces, and enclosing quotes if any
auto start = interp_str.find_first_not_of(' ');
auto end = interp_str.find_last_not_of(' ');
if (start == std::string::npos || end == std::string::npos) {
// The whole string is made of spaces.
return String();
}
return String(interp_str.substr(start, (end - start + 1)));

} else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing target
return Target(TargetInternal::FromString(str));
return Target(TargetInternal::FromString(interp_str));
} else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
// Parsing array
std::vector<ObjectRef> result;
for (const std::string& substr : SplitString(str, ',')) {
for (const std::string& substr : SplitString(interp_str, ',')) {
try {
ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
result.push_back(parsed);
Expand All @@ -308,7 +413,8 @@ ObjectRef TargetInternal::ParseType(const std::string& str,
}
return Array<ObjectRef>(result);
}
throw Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str);
throw Error(": Unsupported type \"" + info.type_key +
"\" for parsing from string: " + interp_str);
}

ObjectRef TargetInternal::ParseType(const ObjectRef& obj,
Expand Down Expand Up @@ -385,14 +491,35 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj,

/********** Stringifying **********/

static inline Optional<String> StringifyAtomicType(const ObjectRef& obj) {
std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) {
if (const auto* p = obj.as<IntImmNode>()) {
return String(std::to_string(p->value));
return std::to_string(p->value);
}
if (const auto* p = obj.as<StringObj>()) {
return GetRef<String>(p);
auto s = static_cast<std::string>(GetRef<String>(p));
auto u = Uninterpret(s);
if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) {
u = Quote(u);
}
return u;
}
return NullOpt;
LOG(FATAL) << "Cannot stringify this object";
return ""; // unreachable
}

std::string TargetInternal::StringifyArray(const ArrayNode& array) {
std::vector<std::string> elements;

for (const ObjectRef& item : array) {
std::string s = StringifyAtomicType(item);
std::string u = Uninterpret(s);
if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) {
u = Quote(u);
}
elements.push_back(u);
}

return JoinString(elements, ',');
}

Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) {
Expand All @@ -402,30 +529,21 @@ Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef
keys.push_back(kv.first);
}
std::sort(keys.begin(), keys.end());
std::vector<String> result;
std::vector<std::string> result;

for (const auto& key : keys) {
const ObjectRef& obj = attrs[key];
Optional<String> value = NullOpt;
std::string value;
if (const auto* array = obj.as<ArrayNode>()) {
std::vector<String> items;
for (const ObjectRef& item : *array) {
Optional<String> str = StringifyAtomicType(item);
if (str.defined()) {
items.push_back(str.value());
} else {
items.clear();
break;
}
}
value = JoinString(items, ',');
value = String(StringifyArray(*array));
} else {
value = StringifyAtomicType(obj);
}
if (value.defined()) {
result.push_back("-" + key + "=" + value.value());
if (!value.empty()) {
result.push_back("-" + key + "=" + value);
}
}
return JoinString(result, ' ');
return String(JoinString(result, ' '));
}

const std::string& TargetNode::str() const {
Expand Down
Loading