Skip to content

Commit

Permalink
fix(core): Correct Error Checking (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
potatomashed authored Feb 10, 2025
1 parent cd145d9 commit 4f414eb
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 26 deletions.
4 changes: 2 additions & 2 deletions cpp/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ thread_local Any last_error;
} // namespace

MLC_API MLCAny MLCGetLastError() {
MLCAny ret;
static_cast<Any &>(ret) = std::move(last_error);
MLCAny ret = static_cast<MLCAny &>(last_error);
static_cast<MLCAny &>(last_error) = MLCAny();
return ret;
}

Expand Down
16 changes: 8 additions & 8 deletions include/mlc/base/lib.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct VTable {
this->Swap(other);
return *this;
}
~VTable() { MLC_CHECK_ERR(::MLCVTableDelete(self), nullptr); }
~VTable() { MLC_CHECK_ERR(::MLCVTableDelete(self)); }

template <typename R, typename... Args> R operator()(Args... args) const;
template <typename Obj> VTable &Set(Func func);
Expand Down Expand Up @@ -41,17 +41,17 @@ struct Lib {
static FuncObj *_init(int32_t type_index) { return VTableGetFunc(init, type_index, "__init__"); }
static VTable MakeVTable(const char *name) {
MLCVTableHandle vtable = nullptr;
MLC_CHECK_ERR(::MLCVTableCreate(_lib, name, &vtable), nullptr);
MLC_CHECK_ERR(::MLCVTableCreate(_lib, name, &vtable));
return VTable(vtable);
}
MLC_INLINE static MLCTypeInfo *GetTypeInfo(int32_t type_index) {
MLCTypeInfo *type_info = nullptr;
MLC_CHECK_ERR(::MLCTypeIndex2Info(_lib, type_index, &type_info), nullptr);
MLC_CHECK_ERR(::MLCTypeIndex2Info(_lib, type_index, &type_info));
return type_info;
}
MLC_INLINE static MLCTypeInfo *GetTypeInfo(const char *type_key) {
MLCTypeInfo *type_info = nullptr;
MLC_CHECK_ERR(::MLCTypeKey2Info(_lib, type_key, &type_info), nullptr);
MLC_CHECK_ERR(::MLCTypeKey2Info(_lib, type_key, &type_info));
return type_info;
}
MLC_INLINE static const char *GetTypeKey(int32_t type_index) {
Expand All @@ -77,14 +77,14 @@ struct Lib {
}
MLC_INLINE static MLCTypeInfo *TypeRegister(int32_t parent_type_index, int32_t type_index, const char *type_key) {
MLCTypeInfo *info = nullptr;
MLC_CHECK_ERR(::MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info), nullptr);
MLC_CHECK_ERR(::MLCTypeRegister(_lib, parent_type_index, type_key, type_index, &info));
return info;
}

private:
static FuncObj *VTableGetFunc(MLCVTableHandle vtable, int32_t type_index, const char *vtable_name) {
MLCAny func{};
MLC_CHECK_ERR(::MLCVTableGetFunc(vtable, type_index, true, &func), &func);
MLC_CHECK_ERR(::MLCVTableGetFunc(vtable, type_index, true, &func));
if (!::mlc::base::IsTypeIndexPOD(func.type_index)) {
::mlc::base::DecRef(func.v.v_obj);
}
Expand All @@ -100,12 +100,12 @@ struct Lib {
}
static MLCVTableHandle VTableGetGlobal(const char *name) {
MLCVTableHandle ret = nullptr;
MLC_CHECK_ERR(::MLCVTableGetGlobal(_lib, name, &ret), nullptr);
MLC_CHECK_ERR(::MLCVTableGetGlobal(_lib, name, &ret));
return ret;
}
static MLC_SYMBOL_HIDE inline MLCTypeTableHandle _lib = []() {
MLCTypeTableHandle ret = nullptr;
MLC_CHECK_ERR(::MLCHandleGetGlobal(&ret), nullptr);
MLC_CHECK_ERR(::MLCHandleGetGlobal(&ret));
return ret;
}();
static MLC_SYMBOL_HIDE inline MLCVTableHandle cxx_str = VTableGetGlobal("__cxx_str__");
Expand Down
4 changes: 2 additions & 2 deletions include/mlc/base/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@
} \
MLC_UNREACHABLE()

#define MLC_CHECK_ERR(Call, Ret) \
#define MLC_CHECK_ERR(Call) \
if (int32_t err_code = (Call)) { \
::mlc::base::FuncCallCheckError(err_code, (Ret)); \
::mlc::base::FuncCallCheckError(err_code, nullptr); \
}

namespace mlc {
Expand Down
10 changes: 5 additions & 5 deletions include/mlc/core/all.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,12 @@ inline Any Lib::IRPrint(AnyView obj, AnyView printer, AnyView path) {
return ret;
}
inline int32_t Lib::FuncSetGlobal(const char *name, FuncObj *func, bool allow_override) {
MLC_CHECK_ERR(::MLCFuncSetGlobal(_lib, name, Any(func), allow_override), nullptr);
MLC_CHECK_ERR(::MLCFuncSetGlobal(_lib, name, Any(func), allow_override));
return 0;
}
inline FuncObj *Lib::FuncGetGlobal(const char *name, bool allow_missing) {
Any ret;
MLC_CHECK_ERR(::MLCFuncGetGlobal(_lib, name, &ret), &ret);
MLC_CHECK_ERR(::MLCFuncGetGlobal(_lib, name, &ret));
if (!ret.defined() && !allow_missing) {
MLC_THROW(KeyError) << "Missing global function: " << name;
}
Expand Down Expand Up @@ -205,12 +205,12 @@ template <typename R, typename... Args> inline R VTable::operator()(Args... args
AnyViewArray<N> stack_args;
Any ret;
stack_args.Fill(std::forward<Args>(args)...);
MLC_CHECK_ERR(::MLCVTableCall(self, N, stack_args.v, &ret), &ret);
MLC_CHECK_ERR(::MLCVTableCall(self, N, stack_args.v, &ret));
return ret;
}
template <typename Obj> inline VTable &VTable::Set(Func func) {
constexpr bool override_mode = false;
int32_t type_index = Obj::_type_index;
MLC_CHECK_ERR(::MLCVTableSetFunc(this->self, type_index, func.get(), override_mode), nullptr);
MLC_CHECK_ERR(::MLCVTableSetFunc(this->self, Obj::_type_index, func.get(), override_mode));
return *this;
}

Expand Down
4 changes: 2 additions & 2 deletions include/mlc/core/func.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ inline void FuncCall(const void *self, int32_t num_args, const MLCAny *args, MLC
const MLCFunc *func = static_cast<const MLCFunc *>(self);
if (func->call && reinterpret_cast<void *>(func->safe_call) == reinterpret_cast<void *>(FuncObj::SafeCallImpl)) {
func->call(func, num_args, args, ret);
} else {
MLC_CHECK_ERR(func->safe_call(func, num_args, args, ret), ret);
} else if (int32_t err_code = func->safe_call(func, num_args, args, ret)) {
FuncCallCheckError(err_code, ret);
}
}
template <int32_t num_args> inline auto GetGlobalFuncCall(const char *name) {
Expand Down
8 changes: 6 additions & 2 deletions include/mlc/core/func_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,16 @@ template <typename FuncType, typename> MLC_INLINE FuncObj *FuncObj::Allocator::N
inline Ref<FuncObj> FuncObj::FromForeign(void *self, MLCDeleterType deleter, MLCFuncSafeCallType safe_call) {
if (deleter == nullptr) {
return Ref<FuncObj>::New([self, safe_call](int32_t num_args, const MLCAny *args, MLCAny *ret) {
MLC_CHECK_ERR(safe_call(self, num_args, args, ret), ret);
if (int32_t err_code = safe_call(self, num_args, args, ret)) {
::mlc::base::FuncCallCheckError(err_code, ret);
}
});
} else {
return Ref<FuncObj>::New(
[self = std::shared_ptr<void>(self, deleter), safe_call](int32_t num_args, const MLCAny *args, MLCAny *ret) {
MLC_CHECK_ERR(safe_call(self.get(), num_args, args, ret), ret);
if (int32_t err_code = safe_call(self.get(), num_args, args, ret)) {
::mlc::base::FuncCallCheckError(err_code, ret);
}
});
}
}
Expand Down
8 changes: 3 additions & 5 deletions include/mlc/core/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,12 @@ struct _Reflect {
reinterpret_cast<MLCFunc *>(func_any_to_ref.v.v_obj), //
kStaticFn});
}
MLC_CHECK_ERR(::MLCTypeRegisterFields(nullptr, this->type_index, this->fields.size(), this->fields.data()),
nullptr);
MLC_CHECK_ERR(::MLCTypeRegisterFields(nullptr, this->type_index, this->fields.size(), this->fields.data()));
MLC_CHECK_ERR(::MLCTypeRegisterStructure(nullptr, this->type_index, static_cast<int32_t>(this->structure_kind),
this->sub_structure_indices.size(), this->sub_structure_indices.data(),
this->sub_structure_kinds.data()),
nullptr);
this->sub_structure_kinds.data()));
for (const MLCTypeMethod &method : this->methods) {
MLC_CHECK_ERR(::MLCTypeAddMethod(nullptr, this->type_index, method), nullptr);
MLC_CHECK_ERR(::MLCTypeAddMethod(nullptr, this->type_index, method));
}
}
return 0;
Expand Down

0 comments on commit 4f414eb

Please sign in to comment.