Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ffi/include/tvm/ffi/any.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ struct AnyHash {
uint64_t val_hash = [&]() -> uint64_t {
if (src.data_.type_index == TypeIndex::kTVMFFIStr ||
src.data_.type_index == TypeIndex::kTVMFFIBytes) {
const BytesObjBase* src_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
const details::BytesObjBase* src_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
return details::StableHashBytes(src_str->data, src_str->size);
} else {
return src.data_.v_uint64;
Expand All @@ -572,10 +572,10 @@ struct AnyEqual {
// specialy handle string hash
if (lhs.data_.type_index == TypeIndex::kTVMFFIStr ||
lhs.data_.type_index == TypeIndex::kTVMFFIBytes) {
const BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
const BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
const details::BytesObjBase* lhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
details::AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memequal(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size);
}
return false;
Expand Down
26 changes: 12 additions & 14 deletions ffi/include/tvm/ffi/string.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

namespace tvm {
namespace ffi {

namespace details {
/*! \brief Base class for bytes and string. */
class BytesObjBase : public Object, public TVMFFIByteArray {};

Expand All @@ -73,8 +73,6 @@ class StringObj : public BytesObjBase {
TVM_FFI_DECLARE_STATIC_OBJECT_INFO(StringObj, Object);
};

namespace details {

// String moved from std::string
// without having to trigger a copy
template <typename Base>
Expand Down Expand Up @@ -115,21 +113,21 @@ class Bytes : public ObjectRef {
* \param other a char array.
*/
Bytes(const char* data, size_t size) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<BytesObj>(data, size)) {}
: ObjectRef(details::MakeInplaceBytes<details::BytesObj>(data, size)) {}
/*!
* \brief constructor from char [N]
*
* \param other a char array.
*/
Bytes(TVMFFIByteArray bytes) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<BytesObj>(bytes.data, bytes.size)) {}
: ObjectRef(details::MakeInplaceBytes<details::BytesObj>(bytes.data, bytes.size)) {}
/*!
* \brief constructor from char [N]
*
* \param other a char array.
*/
Bytes(std::string other) // NOLINT(*)
: ObjectRef(make_object<details::BytesObjStdImpl<BytesObj>>(std::move(other))) {}
: ObjectRef(make_object<details::BytesObjStdImpl<details::BytesObj>>(std::move(other))) {}
/*!
* \brief Swap this String with another string
* \param other The other string
Expand Down Expand Up @@ -163,7 +161,7 @@ class Bytes : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }

TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, BytesObj);
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bytes, ObjectRef, details::BytesObj);

/*!
* \brief Compare two char sequence
Expand Down Expand Up @@ -245,7 +243,7 @@ class String : public ObjectRef {
*/
template <size_t N>
String(const char other[N]) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, N)) {}
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, N)) {}

/*!
* \brief constructor
Expand All @@ -258,37 +256,37 @@ class String : public ObjectRef {
* \param other a char array.
*/
String(const char* other) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, std::strlen(other))) {}
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, std::strlen(other))) {}

/*!
* \brief constructor from raw string
*
* \param other a char array.
*/
String(const char* other, size_t size) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other, size)) {}
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other, size)) {}

/*!
* \brief Construct a new string object
* \param other The std::string object to be copied
*/
String(const std::string& other) // NOLINT(*)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data(), other.size())) {}
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data(), other.size())) {}

/*!
* \brief Construct a new string object
* \param other The std::string object to be moved
*/
String(std::string&& other) // NOLINT(*)
: ObjectRef(make_object<details::BytesObjStdImpl<StringObj>>(std::move(other))) {}
: ObjectRef(make_object<details::BytesObjStdImpl<details::StringObj>>(std::move(other))) {}

/*!
* \brief constructor from TVMFFIByteArray
*
* \param other a TVMFFIByteArray.
*/
explicit String(TVMFFIByteArray other)
: ObjectRef(details::MakeInplaceBytes<StringObj>(other.data, other.size)) {}
: ObjectRef(details::MakeInplaceBytes<details::StringObj>(other.data, other.size)) {}

/*!
* \brief Swap this String with another string
Expand Down Expand Up @@ -423,7 +421,7 @@ class String : public ObjectRef {
*/
operator std::string() const { return std::string{get()->data, size()}; }

TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);
TVM_FFI_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, details::StringObj);

private:
/*!
Expand Down
8 changes: 4 additions & 4 deletions ffi/src/ffi/extra/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class StructEqualHandler {
case TypeIndex::kTVMFFIStr:
case TypeIndex::kTVMFFIBytes: {
// compare bytes
const BytesObjBase* lhs_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(lhs);
const BytesObjBase* rhs_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(rhs);
const details::BytesObjBase* lhs_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(lhs);
const details::BytesObjBase* rhs_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(rhs);
return Bytes::memncmp(lhs_str->data, rhs_str->data, lhs_str->size, rhs_str->size) == 0;
}
case TypeIndex::kTVMFFIArray: {
Expand Down
8 changes: 4 additions & 4 deletions ffi/src/ffi/extra/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class StructuralHashHandler {
case TypeIndex::kTVMFFIStr:
case TypeIndex::kTVMFFIBytes: {
// return same hash as AnyHash
const BytesObjBase* src_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
const details::BytesObjBase* src_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
return details::StableHashCombine(src_data->type_index,
details::StableHashBytes(src_str->data, src_str->size));
}
Expand Down Expand Up @@ -196,8 +196,8 @@ class StructuralHashHandler {
} else {
if (src_data->type_index == TypeIndex::kTVMFFIStr ||
src_data->type_index == TypeIndex::kTVMFFIBytes) {
const BytesObjBase* src_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const BytesObjBase*>(src);
const details::BytesObjBase* src_str =
AnyUnsafe::CopyFromAnyViewAfterCheck<const details::BytesObjBase*>(src);
// return same hash as AnyHash
return details::StableHashCombine(src_data->type_index,
details::StableHashBytes(src_str->data, src_str->size));
Expand Down
46 changes: 30 additions & 16 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,26 +242,40 @@ class PassContext : public ObjectRef {
template <typename ValueType>
static int32_t RegisterConfigOption(const char* key) {
// NOTE: we could further update the function later.
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
auto* reflection = ReflectionVTable::Global();
auto type_key = ffi::TypeIndexToTypeKey(tindex);

auto legalization = [=](ffi::Any value) -> ffi::Any {
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
return reflection->CreateObject(type_key, opt_map.value());
} else {
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
auto* reflection = ReflectionVTable::Global();
auto type_key = ffi::TypeIndexToTypeKey(tindex);
auto legalization = [=](ffi::Any value) -> ffi::Any {
if (auto opt_map = value.try_cast<Map<String, ffi::Any>>()) {
return reflection->CreateObject(type_key, opt_map.value());
} else {
auto opt_val = value.try_cast<ValueType>();
if (!opt_val.has_value()) {
TVM_FFI_THROW(AttributeError)
<< "Expect config " << key << " to have type " << type_key << ", but instead get "
<< ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
}
return *opt_val;
}
};
RegisterConfigOption(key, type_key, legalization);
} else {
// non-object type, do not support implicit conversion from map
std::string type_str = ffi::TypeTraits<ValueType>::TypeStr();
auto legalization = [=](ffi::Any value) -> ffi::Any {
auto opt_val = value.try_cast<ValueType>();
if (!opt_val.has_value()) {
TVM_FFI_THROW(AttributeError)
<< "Expect config " << key << " to have type " << type_key << ", but instead get "
<< "Expect config " << key << " to have type " << type_str << ", but instead get "
<< ffi::details::AnyUnsafe::GetMismatchTypeInfo<ValueType>(value);
} else {
return *opt_val;
}
return value;
}
};

RegisterConfigOption(key, tindex, legalization);
return tindex;
};
RegisterConfigOption(key, type_str, legalization);
}
return 0;
}

// accessor.
Expand All @@ -274,7 +288,7 @@ class PassContext : public ObjectRef {
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
TVM_DLL static void RegisterConfigOption(const char* key, String value_type_str,
std::function<ffi::Any(ffi::Any)> legalization);

// Classes to get the Python `with` like syntax.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ class MetricCollectorNode : public Object {
/*! \brief Stop collecting metrics.
* \param obj The object created by the corresponding `Start` call.
* \returns A set of metric names and the associated values. Values must be
* one of DurationNode, PercentNode, CountNode, or StringObj.
* one of DurationNode, PercentNode, CountNode, or String.
*/
virtual Map<String, ffi::Any> Stop(ffi::ObjectRef obj) = 0;

Expand Down
22 changes: 18 additions & 4 deletions include/tvm/target/target_kind.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,27 @@ struct ValueTypeInfoMaker<ValueType, std::false_type, std::false_type> {
using ValueTypeInfo = TargetKindNode::ValueTypeInfo;

ValueTypeInfo operator()() const {
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
ValueTypeInfo info;
info.type_index = tindex;
info.type_key = runtime::Object::TypeIndex2Key(tindex);
info.key = nullptr;
info.val = nullptr;
return info;
if constexpr (std::is_base_of_v<ObjectRef, ValueType>) {
int32_t tindex = ffi::TypeToRuntimeTypeIndex<ValueType>::v();
info.type_index = tindex;
info.type_key = runtime::Object::TypeIndex2Key(tindex);
return info;
} else if constexpr (std::is_same_v<ValueType, String>) {
// special handle string since it can be backed by multiple types.
info.type_index = ffi::TypeIndex::kTVMFFIStr;
info.type_key = ffi::TypeTraits<ValueType>::TypeStr();
return info;
} else {
// TODO(tqchen) consider upgrade to leverage any system to support union type
constexpr int32_t tindex = ffi::TypeToFieldStaticTypeIndex<ValueType>::value;
static_assert(tindex != ffi::TypeIndex::kTVMFFIAny, "Do not support union type for now");
info.type_index = tindex;
info.type_key = runtime::Object::TypeIndex2Key(tindex);
return info;
}
}
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/exec/disco_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def _str_func(x: str):


@register_func("tests.disco.str_obj", override=True)
def _str_obj_func(x: String):
assert isinstance(x, String)
def _str_obj_func(x: str):
assert isinstance(x, str)
return String(x + "_suffix")


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ cdef inline int make_args(tuple py_args, TVMFFIAny* out, list temp_args) except
out[i].type_index = kTVMFFINDArray
out[i].v_ptr = (<NDArray>arg).chandle
temp_args.append(arg)
elif isinstance(arg, PyNativeObject):
elif isinstance(arg, PyNativeObject) and arg.__tvm_ffi_object__ is not None:
arg = arg.__tvm_ffi_object__
out[i].type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
out[i].v_ptr = (<Object>arg).chandle
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/ffi/cython/string.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class String(str, PyNativeObject):
"""
def __new__(cls, value):
val = str.__new__(cls, value)
val.__init_tvm_ffi_object_by_constructor__(_STR_CONSTRUCTOR, value)
val.__tvm_ffi_object__ = None
return val

# pylint: disable=no-self-argument
Expand All @@ -65,7 +65,7 @@ class Bytes(bytes, PyNativeObject):
"""
def __new__(cls, value):
val = bytes.__new__(cls, value)
val.__init_tvm_ffi_object_by_constructor__(_BYTES_CONSTRUCTOR, value)
val.__tvm_ffi_object__ = None
return val

# pylint: disable=no-self-argument
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.error import TVMError
from tvm.ir import Array, IRModule, Map
from tvm.rpc import RPCSession
from tvm.runtime import PackedFunc, String
from tvm.runtime import PackedFunc
from tvm.tir import FloatImm, IntImm


Expand Down Expand Up @@ -352,7 +352,7 @@ def _json_de_tvm(obj: Any) -> Any:
return obj
if isinstance(obj, (IntImm, FloatImm)):
return obj.value
if isinstance(obj, (str, String)):
if isinstance(obj, (str,)):
return str(obj)
if isinstance(obj, Array):
return [_json_de_tvm(i) for i in obj]
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import tvm
from .. import tir
from ..tir import PrimExpr
from ..runtime import String
from . import _ffi_api
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
Expand Down Expand Up @@ -114,7 +113,7 @@ def convert_to_expr(value: Any) -> Expr:
if isinstance(tvm_value, PrimExpr):
return PrimValue(value)
# Case 3
if isinstance(tvm_value, (str, String)):
if isinstance(tvm_value, (str,)):
return StringImm(value)
# Case 4
if isinstance(value, (tuple, list)):
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/tir/schedule/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from tvm.runtime import Object

from ...ir import Array, Map, save_json
from ...runtime import String
from ..expr import FloatImm, IntImm
from ..function import IndexMap
from . import _ffi_api
Expand All @@ -45,7 +44,7 @@ def _json_from_tvm(obj):
return [_json_from_tvm(i) for i in obj]
elif isinstance(obj, Map):
return {_json_from_tvm(k): _json_from_tvm(v) for k, v in obj.items()}
elif isinstance(obj, String):
elif isinstance(obj, str):
return str(obj)
elif isinstance(obj, (IntImm, FloatImm)):
return obj
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/msc/core/printer/msc_base_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) {
} else {
output_ << float_imm->value;
}
} else if (const auto* string_obj = value.as<ffi::StringObj>()) {
output_ << "\"" << tvm::support::StrEscape(string_obj->data, string_obj->size) << "\"";
} else if (auto opt_str = value.as<ffi::String>()) {
output_ << "\"" << tvm::support::StrEscape((*opt_str).data(), (*opt_str).size()) << "\"";
} else {
LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey();
}
Expand Down
4 changes: 2 additions & 2 deletions src/contrib/msc/core/printer/prototxt_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace contrib {
namespace msc {

LiteralDoc PrototxtPrinter::ToLiteralDoc(const ffi::Any& obj) {
if (obj.as<ffi::StringObj>()) {
return LiteralDoc::Str(Downcast<String>(obj), std::nullopt);
if (auto opt_str = obj.as<ffi::String>()) {
return LiteralDoc::Str(*opt_str, std::nullopt);
} else if (obj.as<IntImmNode>()) {
return LiteralDoc::Int(Downcast<IntImm>(obj)->value, std::nullopt);
} else if (obj.as<FloatImmNode>()) {
Expand Down
Loading
Loading