Skip to content

Commit

Permalink
feat(dataclass): Introduce mlcd.prototype (#19)
Browse files Browse the repository at this point in the history
This PR introduces `mlcd.prototype` that aims to help easier migration
between C++ and Python.

The method `prototype` prints, for example, when `lang="c++"`, the C++
definition of existing Python dataclasses defined in MLC API. More
specifically, the command below:

```
python -c "import mlc.dataclasses as mlcd; print(mlcd.prototype(\"mlc.sym.*\", lang=\"c++\" export_macro=\"MLC_SYM_EXPORTS\"))"
```

prints all C++ definition code for types with prefix `mlc.sym.*`, and
the export macro is `MLC_SYM_EXPORTS`.
  • Loading branch information
potatomashed authored Feb 3, 2025
1 parent b584462 commit a8017ce
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 23 deletions.
1 change: 1 addition & 0 deletions python/mlc/_cython/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
type_key2py_type_info,
type_register_fields,
type_register_structure,
type_table,
)

LIB: _ctypes.CDLL = _core.LIB
Expand Down
4 changes: 4 additions & 0 deletions python/mlc/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,10 @@ def type_create(int32_t parent_type_index, str type_key):
return type_info


cpdef list type_table():
return list(TYPE_INDEX_TO_INFO)


cdef const char* _DLPACK_CAPSULE_NAME = "dltensor"
cdef const char* _DLPACK_CAPSULE_NAME_USED = "used_dltensor"
cdef const char* _DLPACK_CAPSULE_NAME_VER = "dltensor_versioned"
Expand Down
8 changes: 7 additions & 1 deletion python/mlc/dataclasses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
from .c_class import c_class
from .py_class import PyClass, py_class
from .utils import Structure, add_vtable_method, field, prototype_cxx, prototype_py, vtable_method
from .utils import (
Structure,
add_vtable_method,
field,
prototype,
vtable_method,
)
4 changes: 2 additions & 2 deletions python/mlc/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
get_parent_type,
inspect_dataclass_fields,
method_init,
prototype_py,
prototype,
)

ClsType = typing.TypeVar("ClsType")
Expand Down Expand Up @@ -117,5 +117,5 @@ def _check_c_class(
if warned:
warnings.warn(
f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n"
+ prototype_py(type_info)
+ prototype(type_info, lang="py")
)
63 changes: 46 additions & 17 deletions python/mlc/dataclasses/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import ctypes
import dataclasses
import functools
import inspect
import re
import typing
from collections.abc import Callable
from io import StringIO
Expand All @@ -16,6 +18,7 @@
TypeMethod,
type_add_method,
type_index2type_methods,
type_table,
)
from mlc.core import typing as mlc_typing

Expand Down Expand Up @@ -325,13 +328,9 @@ def add_vtable_methods_for_type_cls(type_cls: type, type_index: int) -> None:
type_add_method(type_index, name, func, kind=0)


def prototype_py(type_info: type | TypeInfo) -> str:
if not isinstance(type_info, TypeInfo):
if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment]
raise ValueError(f"Invalid type: {type_info}")
def _prototype_py(type_info: TypeInfo) -> str:
assert isinstance(type_info, TypeInfo)
cls_name = type_info.type_key.rsplit(".", maxsplit=1)[-1]

io = StringIO()
print(f"@mlc.dataclasses.c_class({type_info.type_key!r})", file=io)
print(f"class {cls_name}:", file=io)
Expand All @@ -352,12 +351,11 @@ def prototype_py(type_info: type | TypeInfo) -> str:
return io.getvalue().rstrip()


def prototype_cxx(type_info: type | TypeInfo) -> str:
if not isinstance(type_info, TypeInfo):
if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment]
raise ValueError(f"Invalid type: {type_info}")
def _prototype_cxx(
type_info: TypeInfo,
export_macro: str = "_EXPORTS",
) -> str:
assert isinstance(type_info, TypeInfo)

parent_type_info = type_info.get_parent()
namespaces = type_info.type_key.split(".")
cls_name = namespaces[-1]
Expand Down Expand Up @@ -388,22 +386,22 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
if i != 0:
print(", ", file=io, end="")
print(f"{ty} {name}", file=io, end="")
print("): ", file=io, end="")
for i, (name, _) in enumerate(fields):
if i != 0:
print(", ", file=io, end="")
print(f"{name}({name})", file=io, end="")
print("): _mlc_header{}", file=io, end="")
for name, _ in fields:
print(f", {name}({name})", file=io, end="")
print(" {}", file=io)
# Step 2.3. Macro to define object type
print(
f' MLC_DEF_DYN_TYPE(_EXPORTS, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");',
f' MLC_DEF_DYN_TYPE({export_macro}, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");',
file=io,
)
print(f"}}; // struct {cls_name}Obj\n", file=io)
# Step 3. Object reference class
print(f"struct {cls_name} : public {parent_ref_name} {{", file=io)
# Step 3.1. Define fields for reflection
print(f" MLC_DEF_OBJ_REF(_EXPORTS, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io)
print(
f" MLC_DEF_OBJ_REF({export_macro}, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io
)
for name, _ in fields:
print(f' .Field("{name}", &{cls_name}Obj::{name})', file=io)
# Step 3.2. Define `__init__` method for reflection
Expand All @@ -416,3 +414,34 @@ def prototype_cxx(type_info: type | TypeInfo) -> str:
for ns in reversed(namespaces[:-1]):
print(f"}} // namespace {ns}", file=io)
return io.getvalue().rstrip()


def prototype(
match: str | type | TypeInfo | Callable[[TypeInfo], bool],
lang: Literal["c++", "py"] = "c++",
export_macro: str = "_EXPORTS",
) -> str:
type_info_list: list[TypeInfo]
if (
isinstance(match, type)
and (type_info := getattr(match, "_mlc_type_info", None)) is not None
):
assert isinstance(type_info, TypeInfo)
type_info_list = [type_info]
elif isinstance(match, TypeInfo):
type_info_list = [match]
elif isinstance(match, str):
pattern = re.compile(match)
type_info_list = [i for i in type_table() if i and pattern.fullmatch(i.type_key)]
elif callable(match):
type_info_list = [i for i in type_table() if i and match(i)]
else:
raise ValueError(f"Invalid `match`: {match}")
fn: Callable[[TypeInfo], str]
if lang == "c++":
fn = functools.partial(_prototype_cxx, export_macro=export_macro)
elif lang == "py":
fn = _prototype_py
else:
raise ValueError(f"Invalid `lang`: {lang}")
return "\n\n".join(fn(i) for i in type_info_list)
6 changes: 3 additions & 3 deletions tests/python/test_dataclasses_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class py_class:
opt_dict_any_str: dict[Any, str] | None
opt_dict_str_list_int: dict[str, list[int]] | None
""".strip()
actual = mlcd.prototype_py(PyClassForTest).strip()
actual = mlcd.prototype(PyClassForTest, lang="py").strip()
assert actual == expected


Expand Down Expand Up @@ -93,7 +93,7 @@ def test_prototype_cxx() -> None:
::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any;
::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str;
::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int;
explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List<int64_t>> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>> dict_str_list_int, ::mlc::Optional<bool> opt_bool, ::mlc::Optional<int64_t> opt_i64, ::mlc::Optional<double> opt_f64, ::mlc::Optional<void*> opt_raw_ptr, ::mlc::Optional<DLDataType> opt_dtype, ::mlc::Optional<DLDevice> opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List<int64_t>>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int): bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List<int64_t>> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>> dict_str_list_int, ::mlc::Optional<bool> opt_bool, ::mlc::Optional<int64_t> opt_i64, ::mlc::Optional<double> opt_f64, ::mlc::Optional<void*> opt_raw_ptr, ::mlc::Optional<DLDataType> opt_dtype, ::mlc::Optional<DLDevice> opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List<int64_t>>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List<int64_t>>> opt_dict_str_list_int): _mlc_header{}, bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {}
MLC_DEF_DYN_TYPE(_EXPORTS, py_classObj, ::mlc::Object, "mlc.testing.py_class");
}; // struct py_classObj
Expand Down Expand Up @@ -142,5 +142,5 @@ def test_prototype_cxx() -> None:
} // namespace testing
} // namespace mlc
""".strip()
actual = mlcd.prototype_cxx(PyClassForTest).strip()
actual = mlcd.prototype(PyClassForTest, lang="c++").strip()
assert actual == expected

0 comments on commit a8017ce

Please sign in to comment.