Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Migrate data structure of TargetNode #5960

Merged
merged 2 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/ir/transform.h>
#include <tvm/node/container.h>
#include <tvm/support/with.h>
#include <tvm/target/target_id.h>

#include <string>
#include <unordered_set>
Expand All @@ -42,52 +43,58 @@ namespace tvm {
*/
class TargetNode : public Object {
public:
/*! \brief The name of the target device */
std::string target_name;
/*! \brief The name of the target device */
std::string device_name;
/*! \brief The type of the target device */
int device_type;
/*! \brief The maximum threads that a schedule should use for this device */
int max_num_threads = 1;
/*! \brief The warp size that should be used by the LowerThreadAllreduce pass */
int thread_warp_size = 1;
/*! \brief The id of the target device */
TargetId id;
/*! \brief Tag of the the target, can be empty */
String tag;
/*! \brief Keys for this target */
Array<runtime::String> keys_array;
/*! \brief Options for this target */
Array<runtime::String> options_array;
/*! \brief Collection of imported libs */
Array<runtime::String> libs_array;
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("target_name", &target_name);
v->Visit("device_name", &device_name);
v->Visit("device_type", &device_type);
v->Visit("max_num_threads", &max_num_threads);
v->Visit("thread_warp_size", &thread_warp_size);
v->Visit("keys_array", &keys_array);
v->Visit("options_array", &options_array);
v->Visit("libs_array", &libs_array);
v->Visit("id", &id);
v->Visit("tag", &tag);
v->Visit("keys_", &keys);
v->Visit("attrs", &attrs);
v->Visit("_str_repr_", &str_repr_);
}

/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> keys() const;
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
auto it = attrs.find(attr_key);
if (it != attrs.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}

/*! \brief Get the options for this target as a vector of string */
TVM_DLL std::vector<std::string> options() const;
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*! \brief Get the keys for this target as a vector of string */
TVM_DLL std::vector<std::string> GetKeys() const;

/*! \brief Get the keys for this target as an unordered_set of string */
TVM_DLL std::unordered_set<std::string> libs() const;
TVM_DLL std::unordered_set<std::string> GetLibs() const;

static constexpr const char* _type_key = "Target";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetNode, Object);

private:
/*! \brief Internal string repr. */
mutable std::string str_repr_;
friend class Target;
};

/*!
Expand All @@ -102,7 +109,17 @@ class Target : public ObjectRef {
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL static Target Create(const std::string& target_str);
TVM_DLL static Target Create(const String& target_str);
/*!
* \brief Construct a Target node from the given name and options.
* \param name The major target name. Should be one of
* {"aocl", "aocl_sw_emu", "c", "cuda", "ext_dev", "hexagon", "hybrid", "llvm",
* "metal", "nvptx", "opencl", "rocm", "sdaccel", "stackvm", "vulkan"}
* \param options Additional options appended to the target
* \return The constructed Target
*/
TVM_DLL static Target CreateTarget(const std::string& name,
const std::vector<std::string>& options);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
Expand Down
55 changes: 55 additions & 0 deletions include/tvm/target/target_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ template <typename, typename, typename>
struct ValueTypeInfoMaker;
}

class Target;

/*! \brief Perform schema validation */
TVM_DLL void TargetValidateSchema(const Map<String, ObjectRef>& config);

Expand All @@ -54,6 +56,10 @@ class TargetIdNode : public Object {
public:
/*! \brief Name of the target id */
String name;
/*! \brief Device type of target id */
int device_type;
/*! \brief Default keys of the target */
Array<String> default_keys;
/*! \brief Stores the required type_key and type_index of a specific attr of a target */
struct ValueTypeInfo {
String type_key;
Expand All @@ -62,6 +68,14 @@ class TargetIdNode : public Object {
std::unique_ptr<ValueTypeInfo> val;
};

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("device_type", &device_type);
v->Visit("default_keys", &default_keys);
}

Map<String, ObjectRef> ParseAttrsFromRawString(const std::vector<std::string>& options);

static constexpr const char* _type_key = "TargetId";
TVM_DECLARE_FINAL_OBJECT_INFO(TargetIdNode, Object);

Expand All @@ -72,9 +86,12 @@ class TargetIdNode : public Object {
void ValidateSchema(const Map<String, ObjectRef>& config) const;
/*! \brief A hash table that stores the type information of each attr of the target key */
std::unordered_map<String, ValueTypeInfo> key2vtype_;
/*! \brief A hash table that stores the default value of each attr of the target key */
std::unordered_map<String, ObjectRef> key2default_;
/*! \brief Index used for internal lookup of attribute registry */
uint32_t index_;
friend void TargetValidateSchema(const Map<String, ObjectRef>&);
friend class Target;
friend class TargetId;
template <typename, typename>
friend class AttrRegistry;
Expand All @@ -91,6 +108,7 @@ class TargetIdNode : public Object {
*/
class TargetId : public ObjectRef {
public:
TargetId() = default;
/*! \brief Get the attribute map given the attribute name */
template <typename ValueType>
static inline TargetIdAttrMap<ValueType> GetAttrMap(const String& attr_name);
Expand All @@ -110,6 +128,7 @@ class TargetId : public ObjectRef {
template <typename, typename>
friend class AttrRegistry;
friend class TargetIdRegEntry;
friend class Target;
};

/*!
Expand Down Expand Up @@ -148,13 +167,31 @@ class TargetIdRegEntry {
template <typename ValueType>
inline TargetIdRegEntry& set_attr(const String& attr_name, const ValueType& value,
int plevel = 10);
/*!
* \brief Set DLPack's device_type the target
* \param device_type Device type
*/
inline TargetIdRegEntry& set_device_type(int device_type);
/*!
* \brief Set DLPack's device_type the target
* \param keys The default keys
*/
inline TargetIdRegEntry& set_default_keys(std::vector<String> keys);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
* \tparam ValueType The value type to be registered
*/
template <typename ValueType>
inline TargetIdRegEntry& add_attr_option(const String& key);
/*!
* \brief Register a valid configuration option and its ValueType for validation
* \param key The configuration key
* \param default_value The default value of the key
* \tparam ValueType The value type to be registered
*/
template <typename ValueType>
inline TargetIdRegEntry& add_attr_option(const String& key, ObjectRef default_value);
/*! \brief Set name of the TargetId to be the same as registry if it is empty */
inline TargetIdRegEntry& set_name();
/*!
Expand Down Expand Up @@ -286,6 +323,16 @@ inline TargetIdRegEntry& TargetIdRegEntry::set_attr(const String& attr_name, con
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_device_type(int device_type) {
id_->device_type = device_type;
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_default_keys(std::vector<String> keys) {
id_->default_keys = keys;
return *this;
}

template <typename ValueType>
inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
CHECK(!id_->key2vtype_.count(key))
Expand All @@ -294,6 +341,14 @@ inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key) {
return *this;
}

template <typename ValueType>
inline TargetIdRegEntry& TargetIdRegEntry::add_attr_option(const String& key,
ObjectRef default_value) {
add_attr_option<ValueType>(key);
id_->key2default_[key] = default_value;
return *this;
}

inline TargetIdRegEntry& TargetIdRegEntry::set_name() {
if (id_->name.empty()) {
id_->name = name;
Expand Down
9 changes: 4 additions & 5 deletions python/tvm/autotvm/tophub.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ def context(target, extra_files=None):
tgt = _target.create(tgt)

possible_names = []
for opt in tgt.options:
if opt.startswith("-device"):
device = _alias(opt[8:])
possible_names.append(device)
possible_names.append(tgt.target_name)
device = tgt.attrs.get("device", "")
if device != "":
possible_names.append(_alias(device))
possible_names.append(tgt.id.name)

all_packages = list(PACKAGE_VERSION.keys())
for name in possible_names:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _build_for_device(input_mod, target, target_host):
"""
target = _target.create(target)
target_host = _target.create(target_host)
device_type = ndarray.context(target.target_name, 0).device_type
device_type = ndarray.context(target.id.name, 0).device_type

mod_mixed = input_mod
mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed)
Expand Down Expand Up @@ -402,7 +402,7 @@ def build(inputs,
if not target_host:
for tar, _ in target_input_mod.items():
tar = _target.create(tar)
device_type = ndarray.context(tar.target_name, 0).device_type
device_type = ndarray.context(tar.id.name, 0).device_type
if device_type == ndarray.cpu(0).device_type:
target_host = tar
break
Expand Down
22 changes: 11 additions & 11 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def softmax_strategy_cuda(attrs, inputs, out_type, target):
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.cuda.schedule_softmax),
name="softmax.cuda")
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(
wrap_compute_softmax(topi.cuda.softmax_cudnn),
wrap_topi_schedule(topi.cuda.schedule_softmax_cudnn),
Expand Down Expand Up @@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
dilation_h, dilation_w,
pre_flag=False)
if judge_winograd_shape:
if target.target_name == "cuda" and \
if target.id.name == "cuda" and \
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
judge_winograd_tensorcore:
strategy.add_implementation(
Expand All @@ -162,7 +162,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
topi.cuda.schedule_conv2d_nhwc_winograd_direct),
name="conv2d_nhwc_winograd_direct.cuda",
plevel=5)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
Expand All @@ -181,7 +181,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add cudnn implementation
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
Expand Down Expand Up @@ -209,7 +209,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
else: # group_conv2d
# add cudnn implementation, if any
cudnn_impl = False
if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"] and padding[0] == padding[2] and \
padding[1] == padding[3]:
strategy.add_implementation(
Expand Down Expand Up @@ -264,7 +264,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
padding, stride_h, stride_w,
dilation_h, dilation_w,
pre_flag=True)
if target.target_name == "cuda" and \
if target.id.name == "cuda" and \
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
judge_winograd_tensorcore:
strategy.add_implementation(
Expand Down Expand Up @@ -362,7 +362,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
plevel=10)
N, _, _, _, _ = get_const_tuple(data.shape)
_, _, _, CI, CO = get_const_tuple(kernel.shape)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
Expand All @@ -373,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
name="conv3d_ndhwc_tensorcore.cuda",
plevel=20)

if target.target_name == "cuda" and "cudnn" in target.libs:
if target.id.name == "cuda" and "cudnn" in target.libs:
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_cudnn, True),
wrap_topi_schedule(topi.cuda.schedule_conv3d_cudnn),
name="conv3d_cudnn.cuda",
Expand Down Expand Up @@ -458,7 +458,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_dense_large_batch),
name="dense_large_batch.cuda",
plevel=5)
if target.target_name == "cuda":
if target.id.name == "cuda":
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
Expand All @@ -468,7 +468,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_dense_tensorcore),
name="dense_tensorcore.cuda",
plevel=20)
if target.target_name == "cuda" and "cublas" in target.libs:
if target.id.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_dense(topi.cuda.dense_cublas),
wrap_topi_schedule(topi.cuda.schedule_dense_cublas),
Expand All @@ -485,7 +485,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10)
if target.target_name == "cuda" and "cublas" in target.libs:
if target.id.name == "cuda" and "cublas" in target.libs:
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_cublas),
wrap_topi_schedule(topi.generic.schedule_extern),
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
wrap_compute_dense(topi.rocm.dense),
wrap_topi_schedule(topi.rocm.schedule_dense),
name="dense.rocm")
if target.target_name == "rocm" and "rocblas" in target.libs:
if target.id.name == "rocm" and "rocblas" in target.libs:
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
strategy.add_implementation(
wrap_compute_dense(topi.rocm.dense_rocblas),
Expand Down
Loading