-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Change target string to Target object in the TE compiler and interpreter #8835
Changes from 13 commits
c6447b6
3f19fca
8aaee0c
eb288ac
76257bb
d67e885
ee7881e
e3ca300
ee60645
8da2c54
1ebe623
4a65400
4205389
29f802c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
#include <tvm/target/target_kind.h> | ||
|
||
#include <string> | ||
#include <unordered_map> | ||
#include <unordered_set> | ||
#include <vector> | ||
|
||
|
@@ -203,5 +204,59 @@ void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host); | |
* \param host The Target typed object for target host to be updated | ||
*/ | ||
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host); | ||
|
||
// TODO(@electriclilies): Move to somewhere in backend and add note about appropriate use | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey what about moving these methods temporarily to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. moved them! |
||
|
||
/*! \brief Target hash function */ | ||
struct TargetStrHash { | ||
/*! | ||
* \brief Calculate the hash code of a Target based on the string value of the Target | ||
This will be removed when maps from Targets to IRModules are removed from the codebase. | ||
* \param target The Target to hash | ||
* \return String hash of the target | ||
*/ | ||
size_t operator()(const Target& target) const { | ||
return String::HashBytes(target->str().c_str(), target->str().size()); | ||
} | ||
}; | ||
|
||
/*! \brief Target equality function based on the string value of Target | ||
This will be removed when maps from Targets to IRModules are removed from the | ||
codebase.*/ | ||
struct TargetStrEqual { | ||
/*! | ||
* \brief Check if the two Targets are equal | ||
* \param target One Target | ||
* \param other_target The other Target | ||
* \return String equality of the targets | ||
*/ | ||
const bool operator()(const Target& target, const Target& other_target) const { | ||
TargetStrHash target_hash = TargetStrHash(); | ||
return target_hash(target) == target_hash(other_target); | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash, | ||
* TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since | ||
* we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a | ||
* version that is keyed based on string value of the Target instead. Note that once we remove | ||
* Map<Target, IRModule>, this function will be removed. | ||
* \param input_map The map to convert | ||
* \return The converted map | ||
*/ | ||
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> | ||
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map); | ||
|
||
/*! | ||
* \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to | ||
* Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that | ||
* once we remove Map<Target, IRModule>, this function will be removed. | ||
* \param input_map The map to convert | ||
* \return The converted map | ||
*/ | ||
Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap( | ||
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map); | ||
|
||
} // namespace tvm | ||
#endif // TVM_TARGET_TARGET_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,11 @@ namespace { | |
struct PairHash { | ||
template <typename T1, typename T2> | ||
std::size_t operator()(const std::pair<T1, T2>& k) const { | ||
return std::hash<T1>()(k.first) ^ std::hash<T2>()(k.second); | ||
return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second)); | ||
} | ||
template <typename T2> | ||
std::size_t operator()(const std::pair<Target, T2>& k) const { | ||
return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second)); | ||
} | ||
}; | ||
|
||
|
@@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> { | ||
public: | ||
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule. | ||
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target) | ||
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target) | ||
: mod_(mod), | ||
per_target_module_(per_target_module), | ||
device_(device), | ||
|
@@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
*/ | ||
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars, | ||
Target target) { | ||
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint); | ||
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint); | ||
auto packed_itr = compiled_packed_funcs_.find(packed_func_key); | ||
if (packed_itr != compiled_packed_funcs_.end()) { | ||
// Already compiled. | ||
|
@@ -382,8 +386,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
|
||
// Project out just the function(s) we need. | ||
IRModule lowered_projected_mod; | ||
auto mod_itr = per_target_module_.find(target->str()); | ||
ICHECK(mod_itr != per_target_module_.end()) | ||
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> per_target_module_std_map_ = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't append a _ for local vars since the convention is it indicates a member var. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done! |
||
TargetModuleMapToTargetStrModuleMap(per_target_module_); | ||
auto mod_itr = per_target_module_std_map_.find(target); | ||
ICHECK(mod_itr != per_target_module_std_map_.end()) | ||
<< "No target module for target '" << target->str() << "'"; | ||
const IRModule& target_module = (*mod_itr).second; | ||
for (const auto& var : all_tir_fn_vars) { | ||
|
@@ -407,7 +413,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint); | ||
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint | ||
<< "' in compiled module for target '" << target->str() << "'"; | ||
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func); | ||
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func); | ||
} | ||
|
||
// Return just what we need for this call. | ||
|
@@ -874,11 +880,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
// Map from target key to lowered TIR functions derived from mod_. | ||
// Note that primitives are implicitly executed on target_, while shape functions are implicitly | ||
// executed on the default 'cpu' host. Thus this map has at most two entries. | ||
Map<String, IRModule> per_target_module_; | ||
Map<Target, IRModule> per_target_module_; | ||
// Cached packed functions for the primitives and shape functions, keyed by target and | ||
// global var name. | ||
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash> | ||
compiled_packed_funcs_; | ||
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_; | ||
// Unique device on which primitives (but not shape functions) will be executed. | ||
// (For simplicity we only run the interpreter on a single device.) | ||
Device device_; | ||
|
@@ -895,7 +900,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>, | |
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive | ||
* functions needed by the rewritten module. | ||
*/ | ||
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) { | ||
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) { | ||
// Run minimal transforms on module to establish invariants needed by interpreter. | ||
transform::Sequential seq({transform::SimplifyInference(), | ||
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive' | ||
|
@@ -1014,7 +1019,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De | |
// and can just eval it directly. | ||
expr_to_eval = expr; | ||
} | ||
std::pair<IRModule, Map<String, IRModule>> main_and_lowered = | ||
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered = | ||
Prepare(mod_with_expr, device, target); | ||
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>( | ||
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, | ||
|
@@ -1057,7 +1062,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions, | |
std::unordered_set<String> import_set, Device device, Target target) { | ||
std::pair<IRModule, GlobalVar> mod_and_global = | ||
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set); | ||
std::pair<IRModule, Map<String, IRModule>> main_and_lowered = | ||
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered = | ||
Prepare(mod_and_global.first, device, target); | ||
Interpreter intrp( | ||
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably remove this as it's not used in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will send a follow up P that does this just for the sake of forward progress. Thanks!