Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions ffi/include/tvm/ffi/reflection/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,8 @@ inline Function GetMethod(std::string_view type_key, const char* method_name) {
*/
template <typename Callback>
inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
using ResultType = decltype(callback(type_info->fields));
static_assert(std::is_same_v<ResultType, void>, "Callback must return void");
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
Expand All @@ -417,6 +419,34 @@ inline void ForEachFieldInfo(const TypeInfo* type_info, Callback callback) {
}
}

/*!
* \brief Visit each field info of the type info and run callback which returns bool for early stop.
*
* \tparam Callback The callback function type, which returns bool for early stop.
*
* \param type_info The type info.
* \param callback_with_early_stop The callback function.
* \return true if any of early stop is triggered.
*
* \note This function calls both the child and parent type info and can be used for searching.
*/
template <typename Callback>
inline bool ForEachFieldInfoWithEarlyStop(const TypeInfo* type_info,
Callback callback_with_early_stop) {
// iterate through acenstors in parent to child order
// skip the first one since it is always the root object
for (int i = 1; i < type_info->type_depth; ++i) {
const TVMFFITypeInfo* parent_info = type_info->type_acenstors[i];
for (int j = 0; j < parent_info->num_fields; ++j) {
if (callback_with_early_stop(parent_info->fields + j)) return true;
}
}
for (int i = 0; i < type_info->num_fields; ++i) {
if (callback_with_early_stop(type_info->fields + i)) return true;
}
return false;
}

} // namespace reflection
} // namespace ffi
} // namespace tvm
Expand Down
61 changes: 61 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include <dmlc/common.h>
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/ir/expr.h>
#include <tvm/node/structural_equal.h>
#include <tvm/node/structural_hash.h>
Expand Down Expand Up @@ -970,5 +971,65 @@ inline void BaseAttrsNode::PrintDocString(std::ostream& os) const { // NOLINT(*
}
}

/*!
* \brief Adapter for AttrsNode with the new reflection API.
*
* We will phaseout the old AttrsNode in future in favor of the new reflection API.
* This adapter allows us to gradually migrate to the new reflection API.
*
* \tparam DerivedType The final attribute type.
*/
template <typename DerivedType>
class AttrsNodeReflAdapter : public BaseAttrsNode {
public:
void InitByPackedArgs(const ffi::PackedArgs& args, bool allow_unknown) final {
LOG(FATAL) << "`" << DerivedType::_type_key << "` uses new reflection mechanism for init";
}
void VisitNonDefaultAttrs(AttrVisitor* v) final {
LOG(FATAL) << "`" << DerivedType::_type_key
<< "` uses new reflection mechanism for visit non default attrs";
}
void VisitAttrs(AttrVisitor* v) final {
LOG(FATAL) << "`" << DerivedType::_type_key
<< "` uses new reflection mechanism for visit attrs";
}

bool SEqualReduce(const DerivedType* other, SEqualReducer equal) const {
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
bool success = true;
ffi::reflection::ForEachFieldInfoWithEarlyStop(
type_info, [&](const TVMFFIFieldInfo* field_info) {
ffi::reflection::FieldGetter field_getter(field_info);
ffi::Any field_value = field_getter(self());
ffi::Any other_field_value = field_getter(other);
if (!equal.AnyEqual(field_value, other_field_value)) {
success = false;
return true;
}
return false;
});
return success;
}

void SHashReduce(SHashReducer hash_reducer) const {
const TVMFFITypeInfo* type_info = TVMFFIGetTypeInfo(DerivedType::RuntimeTypeIndex());
ffi::reflection::ForEachFieldInfo(type_info, [&](const TVMFFIFieldInfo* field_info) {
ffi::reflection::FieldGetter field_getter(field_info);
ffi::Any field_value = field_getter(self());
hash_reducer(field_value);
});
}

Array<AttrFieldInfo> ListFieldInfo() const final {
// use the new reflection to list field info
return Array<AttrFieldInfo>();
}

private:
DerivedType* self() const {
return const_cast<DerivedType*>(static_cast<const DerivedType*>(this));
}
};

} // namespace tvm
#endif // TVM_IR_ATTRS_H_
64 changes: 40 additions & 24 deletions include/tvm/relax/attrs/ccl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,70 @@
#ifndef TVM_RELAX_ATTRS_CCL_H_
#define TVM_RELAX_ATTRS_CCL_H_

#include <tvm/ffi/reflection/reflection.h>
#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in allreduce operators */
struct AllReduceAttrs : public tvm::AttrsNode<AllReduceAttrs> {
struct AllReduceAttrs : public tvm::AttrsNodeReflAdapter<AllReduceAttrs> {
String op_type;
bool in_group;

TVM_DECLARE_ATTRS(AllReduceAttrs, "relax.attrs.AllReduceAttrs") {
TVM_ATTR_FIELD(op_type).describe(
"The type of reduction operation to be applied to the input data. Now only sum is "
"supported.");
TVM_ATTR_FIELD(in_group).describe(
"Whether the reduction operation performs in group or globally or in group as default.");
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AllReduceAttrs>()
.def_ro("op_type", &AllReduceAttrs::op_type,
"The type of reduction operation to be applied to the input data. Now only sum is "
"supported.")
.def_ro("in_group", &AllReduceAttrs::in_group,
"Whether the reduction operation performs in group or globally or in group as "
"default.");
}

static constexpr const char* _type_key = "relax.attrs.AllReduceAttrs";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllReduceAttrs, BaseAttrsNode);
}; // struct AllReduceAttrs

/*! \brief Attributes used in allgather operators */
struct AllGatherAttrs : public tvm::AttrsNode<AllGatherAttrs> {
struct AllGatherAttrs : public tvm::AttrsNodeReflAdapter<AllGatherAttrs> {
int num_workers;
bool in_group;

TVM_DECLARE_ATTRS(AllGatherAttrs, "relax.attrs.AllGatherAttrs") {
TVM_ATTR_FIELD(num_workers)
.describe(
"The number of workers, also the number of parts the given buffer should be chunked "
"into.");
TVM_ATTR_FIELD(in_group).describe(
"Whether the allgather operation performs in group or globally or in group as default.");
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<AllGatherAttrs>()
.def_ro("num_workers", &AllGatherAttrs::num_workers,
"The number of workers, also the number of parts the given buffer should be "
"chunked into.")
.def_ro("in_group", &AllGatherAttrs::in_group,
"Whether the allgather operation performs in group or globally or in group as "
"default.");
}

static constexpr const char* _type_key = "relax.attrs.AllGatherAttrs";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(AllGatherAttrs, BaseAttrsNode);
}; // struct AllGatherAttrs

/*! \brief Attributes used in scatter operators */
struct ScatterCollectiveAttrs : public tvm::AttrsNode<ScatterCollectiveAttrs> {
struct ScatterCollectiveAttrs : public tvm::AttrsNodeReflAdapter<ScatterCollectiveAttrs> {
int num_workers;
int axis;

TVM_DECLARE_ATTRS(ScatterCollectiveAttrs, "relax.attrs.ScatterCollectiveAttrs") {
TVM_ATTR_FIELD(num_workers)
.describe(
"The number of workers, also the number of parts the given buffer should be chunked "
"into.");
TVM_ATTR_FIELD(axis).describe(
"The axis of the tensor to be scattered. The tensor will be chunked along "
"this axis.");
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<ScatterCollectiveAttrs>()
.def_ro("num_workers", &ScatterCollectiveAttrs::num_workers,
"The number of workers, also the number of parts the given buffer should be "
"chunked into.")
.def_ro("axis", &ScatterCollectiveAttrs::axis,
"The axis of the tensor to be scattered. The tensor will be chunked along "
"this axis.");
}

static constexpr const char* _type_key = "relax.attrs.ScatterCollectiveAttrs";
TVM_FFI_DECLARE_FINAL_OBJECT_INFO(ScatterCollectiveAttrs, BaseAttrsNode);
}; // struct ScatterCollectiveAttrs

} // namespace relax
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/msc/core/ir/graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void FuncAttrGetter::VisitExpr_(const CallNode* op) {
if (op->attrs.defined()) {
Map<String, String> attrs;
AttrGetter getter(&attrs);
const_cast<BaseAttrsNode*>(op->attrs.get())->VisitAttrs(&getter);
getter(op->attrs);
for (const auto& pair : attrs) {
if (attrs_.count(pair.first)) {
int cnt = 1;
Expand Down Expand Up @@ -350,7 +350,7 @@ const MSCJoint GraphBuilder::AddNode(const Expr& expr, const Optional<Expr>& bin
attrs = FuncAttrGetter().GetAttrs(call_node->op);
} else if (call_node->attrs.defined()) {
AttrGetter getter(&attrs);
const_cast<BaseAttrsNode*>(call_node->attrs.get())->VisitAttrs(&getter);
getter(call_node->attrs);
}
} else if (const auto* const_node = expr.as<ConstantNode>()) {
if (const_node->is_scalar()) {
Expand Down
54 changes: 53 additions & 1 deletion src/contrib/msc/core/ir/graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_CONTRIB_MSC_CORE_IR_GRAPH_BUILDER_H_

#include <dmlc/json.h>
#include <tvm/ffi/reflection/reflection.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/runtime/ndarray.h>
Expand Down Expand Up @@ -106,14 +107,65 @@ struct MSCRBuildConfig {
}
};

class AttrGetter : public AttrVisitor {
class AttrGetter : private AttrVisitor {
public:
/*!
* \brief Get the attributes as Map<String, String>
* \param attrs the attributes.
*/
explicit AttrGetter(Map<String, String>* attrs) : attrs_(attrs) {}

void operator()(const Attrs& attrs) {
// dispatch between new reflection and old reflection
const TVMFFITypeInfo* attrs_tinfo = TVMFFIGetTypeInfo(attrs->type_index());
if (attrs_tinfo->extra_info != nullptr) {
tvm::ffi::reflection::ForEachFieldInfo(attrs_tinfo, [&](const TVMFFIFieldInfo* field_info) {
Any field_value = tvm::ffi::reflection::FieldGetter(field_info)(attrs);
this->VisitAny(String(field_info->name), field_value);
});
} else {
// TODO(tvm-team): remove this once all objects are transitioned to the new reflection
const_cast<BaseAttrsNode*>(attrs.get())->VisitAttrs(this);
}
}

private:
void VisitAny(String key, Any value) {
switch (value.type_index()) {
case kTVMFFINone: {
attrs_->Set(key, "");
break;
}
case kTVMFFIBool: {
attrs_->Set(key, std::to_string(value.cast<bool>()));
break;
}
case kTVMFFIInt: {
attrs_->Set(key, std::to_string(value.cast<int64_t>()));
break;
}
case kTVMFFIFloat: {
attrs_->Set(key, std::to_string(value.cast<double>()));
break;
}
case kTVMFFIDataType: {
attrs_->Set(key, runtime::DLDataTypeToString(value.cast<DLDataType>()));
}
case kTVMFFIStr: {
attrs_->Set(key, value.cast<String>());
break;
}
default: {
if (value.type_index() >= kTVMFFIStaticObjectBegin) {
attrs_->Set(key, StringUtils::ToString(value.cast<ObjectRef>()));
} else {
LOG(FATAL) << "Unsupported type: " << value.type_index();
}
break;
}
}
}

void Visit(const char* key, double* value) final { attrs_->Set(key, std::to_string(*value)); }

void Visit(const char* key, int64_t* value) final { attrs_->Set(key, std::to_string(*value)); }
Expand Down
Loading
Loading