From a8017ced44f2c319f53051fc0af1cf7d5afa7a7a Mon Sep 17 00:00:00 2001 From: Mashed Potato <38517644+potatomashed@users.noreply.github.com> Date: Sun, 2 Feb 2025 17:34:44 -0800 Subject: [PATCH] feat(dataclass): Introduce `mlcd.prototype` (#19) This PR introduces `mlcd.prototype` that aims to help easier migration between C++ and Python. The method `prototype` prints, for example, when `lang="c++"`, the C++ definition of existing Python dataclasses defined in MLC API. More specifically, the command below: ``` python -c "import mlc.dataclasses as mlcd; print(mlcd.prototype(\"mlc.sym.*\", lang=\"c++\" export_macro=\"MLC_SYM_EXPORTS\"))" ``` prints all C++ definition code for types with prefix `mlc.sym.*`, and the export macro is `MLC_SYM_EXPORTS`. --- python/mlc/_cython/__init__.py | 1 + python/mlc/_cython/core.pyx | 4 ++ python/mlc/dataclasses/__init__.py | 8 ++- python/mlc/dataclasses/c_class.py | 4 +- python/mlc/dataclasses/utils.py | 63 ++++++++++++++++------ tests/python/test_dataclasses_prototype.py | 6 +-- 6 files changed, 63 insertions(+), 23 deletions(-) diff --git a/python/mlc/_cython/__init__.py b/python/mlc/_cython/__init__.py index 9fc0cbae..4a0246bf 100644 --- a/python/mlc/_cython/__init__.py +++ b/python/mlc/_cython/__init__.py @@ -58,6 +58,7 @@ type_key2py_type_info, type_register_fields, type_register_structure, + type_table, ) LIB: _ctypes.CDLL = _core.LIB diff --git a/python/mlc/_cython/core.pyx b/python/mlc/_cython/core.pyx index a99a896b..1a9eea8e 100644 --- a/python/mlc/_cython/core.pyx +++ b/python/mlc/_cython/core.pyx @@ -1656,6 +1656,10 @@ def type_create(int32_t parent_type_index, str type_key): return type_info +cpdef list type_table(): + return list(TYPE_INDEX_TO_INFO) + + cdef const char* _DLPACK_CAPSULE_NAME = "dltensor" cdef const char* _DLPACK_CAPSULE_NAME_USED = "used_dltensor" cdef const char* _DLPACK_CAPSULE_NAME_VER = "dltensor_versioned" diff --git a/python/mlc/dataclasses/__init__.py b/python/mlc/dataclasses/__init__.py index 99f1de90..7a90826c 100644 --- a/python/mlc/dataclasses/__init__.py +++ b/python/mlc/dataclasses/__init__.py @@ -1,3 +1,9 @@ from .c_class import c_class from .py_class import PyClass, py_class -from .utils import Structure, add_vtable_method, field, prototype_cxx, prototype_py, vtable_method +from .utils import ( + Structure, + add_vtable_method, + field, + prototype, + vtable_method, +) diff --git a/python/mlc/dataclasses/c_class.py b/python/mlc/dataclasses/c_class.py index 2e573518..e13caac9 100644 --- a/python/mlc/dataclasses/c_class.py +++ b/python/mlc/dataclasses/c_class.py @@ -18,7 +18,7 @@ get_parent_type, inspect_dataclass_fields, method_init, - prototype_py, + prototype, ) ClsType = typing.TypeVar("ClsType") @@ -117,5 +117,5 @@ def _check_c_class( if warned: warnings.warn( f"One or multiple warnings in `{type_cls.__module__}.{type_cls.__qualname__}`. Its prototype is:\n" - + prototype_py(type_info) + + prototype(type_info, lang="py") ) diff --git a/python/mlc/dataclasses/utils.py b/python/mlc/dataclasses/utils.py index 6ff3a257..2271dc0b 100644 --- a/python/mlc/dataclasses/utils.py +++ b/python/mlc/dataclasses/utils.py @@ -2,7 +2,9 @@ import ctypes import dataclasses +import functools import inspect +import re import typing from collections.abc import Callable from io import StringIO @@ -16,6 +18,7 @@ TypeMethod, type_add_method, type_index2type_methods, + type_table, ) from mlc.core import typing as mlc_typing @@ -325,13 +328,9 @@ def add_vtable_methods_for_type_cls(type_cls: type, type_index: int) -> None: type_add_method(type_index, name, func, kind=0) -def prototype_py(type_info: type | TypeInfo) -> str: - if not isinstance(type_info, TypeInfo): - if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment] - raise ValueError(f"Invalid type: {type_info}") +def _prototype_py(type_info: TypeInfo) -> str: assert isinstance(type_info, TypeInfo) cls_name = type_info.type_key.rsplit(".", maxsplit=1)[-1] - io = StringIO() print(f"@mlc.dataclasses.c_class({type_info.type_key!r})", file=io) print(f"class {cls_name}:", file=io) @@ -352,12 +351,11 @@ def prototype_py(type_info: type | TypeInfo) -> str: return io.getvalue().rstrip() -def prototype_cxx(type_info: type | TypeInfo) -> str: - if not isinstance(type_info, TypeInfo): - if (type_info := getattr(type_info, "_mlc_type_info", None)) is None: # type: ignore[assignment] - raise ValueError(f"Invalid type: {type_info}") +def _prototype_cxx( + type_info: TypeInfo, + export_macro: str = "_EXPORTS", +) -> str: assert isinstance(type_info, TypeInfo) - parent_type_info = type_info.get_parent() namespaces = type_info.type_key.split(".") cls_name = namespaces[-1] @@ -388,22 +386,22 @@ def prototype_cxx(type_info: type | TypeInfo) -> str: if i != 0: print(", ", file=io, end="") print(f"{ty} {name}", file=io, end="") - print("): ", file=io, end="") - for i, (name, _) in enumerate(fields): - if i != 0: - print(", ", file=io, end="") - print(f"{name}({name})", file=io, end="") + print("): _mlc_header{}", file=io, end="") + for name, _ in fields: + print(f", {name}({name})", file=io, end="") print(" {}", file=io) # Step 2.3. Macro to define object type print( - f' MLC_DEF_DYN_TYPE(_EXPORTS, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");', + f' MLC_DEF_DYN_TYPE({export_macro}, {cls_name}Obj, {parent_obj_name}, "{type_info.type_key}");', file=io, ) print(f"}}; // struct {cls_name}Obj\n", file=io) # Step 3. Object reference class print(f"struct {cls_name} : public {parent_ref_name} {{", file=io) # Step 3.1. Define fields for reflection - print(f" MLC_DEF_OBJ_REF(_EXPORTS, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io) + print( + f" MLC_DEF_OBJ_REF({export_macro}, {cls_name}, {cls_name}Obj, {parent_ref_name})", file=io + ) for name, _ in fields: print(f' .Field("{name}", &{cls_name}Obj::{name})', file=io) # Step 3.2. Define `__init__` method for reflection @@ -416,3 +414,34 @@ def prototype_cxx(type_info: type | TypeInfo) -> str: for ns in reversed(namespaces[:-1]): print(f"}} // namespace {ns}", file=io) return io.getvalue().rstrip() + + +def prototype( + match: str | type | TypeInfo | Callable[[TypeInfo], bool], + lang: Literal["c++", "py"] = "c++", + export_macro: str = "_EXPORTS", +) -> str: + type_info_list: list[TypeInfo] + if ( + isinstance(match, type) + and (type_info := getattr(match, "_mlc_type_info", None)) is not None + ): + assert isinstance(type_info, TypeInfo) + type_info_list = [type_info] + elif isinstance(match, TypeInfo): + type_info_list = [match] + elif isinstance(match, str): + pattern = re.compile(match) + type_info_list = [i for i in type_table() if i and pattern.fullmatch(i.type_key)] + elif callable(match): + type_info_list = [i for i in type_table() if i and match(i)] + else: + raise ValueError(f"Invalid `match`: {match}") + fn: Callable[[TypeInfo], str] + if lang == "c++": + fn = functools.partial(_prototype_cxx, export_macro=export_macro) + elif lang == "py": + fn = _prototype_py + else: + raise ValueError(f"Invalid `lang`: {lang}") + return "\n\n".join(fn(i) for i in type_info_list) diff --git a/tests/python/test_dataclasses_prototype.py b/tests/python/test_dataclasses_prototype.py index ab01b292..546d26ec 100644 --- a/tests/python/test_dataclasses_prototype.py +++ b/tests/python/test_dataclasses_prototype.py @@ -45,7 +45,7 @@ class py_class: opt_dict_any_str: dict[Any, str] | None opt_dict_str_list_int: dict[str, list[int]] | None """.strip() - actual = mlcd.prototype_py(PyClassForTest).strip() + actual = mlcd.prototype(PyClassForTest, lang="py").strip() assert actual == expected @@ -93,7 +93,7 @@ def test_prototype_cxx() -> None: ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any; ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str; ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List>> opt_dict_str_list_int; - explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List> dict_str_list_int, ::mlc::Optional opt_bool, ::mlc::Optional opt_i64, ::mlc::Optional opt_f64, ::mlc::Optional opt_raw_ptr, ::mlc::Optional opt_dtype, ::mlc::Optional opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List>> opt_dict_str_list_int): bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {} + explicit py_classObj(bool bool_, int64_t i8, int64_t i16, int64_t i32, int64_t i64, double f32, double f64, void* raw_ptr, DLDataType dtype, DLDevice device, ::mlc::Any any, ::mlc::Func func, ::mlc::List<::mlc::Any> ulist, ::mlc::Dict<::mlc::Any, ::mlc::Any> udict, ::mlc::Str str_, ::mlc::Str str_readonly, ::mlc::List<::mlc::Any> list_any, ::mlc::List<::mlc::List> list_list_int, ::mlc::Dict<::mlc::Any, ::mlc::Any> dict_any_any, ::mlc::Dict<::mlc::Str, ::mlc::Any> dict_str_any, ::mlc::Dict<::mlc::Any, ::mlc::Str> dict_any_str, ::mlc::Dict<::mlc::Str, ::mlc::List> dict_str_list_int, ::mlc::Optional opt_bool, ::mlc::Optional opt_i64, ::mlc::Optional opt_f64, ::mlc::Optional opt_raw_ptr, ::mlc::Optional opt_dtype, ::mlc::Optional opt_device, ::mlc::Optional<::mlc::Func> opt_func, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_ulist, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_udict, ::mlc::Optional<::mlc::Str> opt_str, ::mlc::Optional<::mlc::List<::mlc::Any>> opt_list_any, ::mlc::Optional<::mlc::List<::mlc::List>> opt_list_list_int, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Any>> opt_dict_any_any, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::Any>> opt_dict_str_any, ::mlc::Optional<::mlc::Dict<::mlc::Any, ::mlc::Str>> opt_dict_any_str, ::mlc::Optional<::mlc::Dict<::mlc::Str, ::mlc::List>> opt_dict_str_list_int): _mlc_header{}, bool_(bool_), i8(i8), i16(i16), i32(i32), i64(i64), f32(f32), f64(f64), raw_ptr(raw_ptr), dtype(dtype), device(device), any(any), func(func), ulist(ulist), udict(udict), str_(str_), str_readonly(str_readonly), list_any(list_any), list_list_int(list_list_int), dict_any_any(dict_any_any), dict_str_any(dict_str_any), dict_any_str(dict_any_str), dict_str_list_int(dict_str_list_int), opt_bool(opt_bool), opt_i64(opt_i64), opt_f64(opt_f64), opt_raw_ptr(opt_raw_ptr), opt_dtype(opt_dtype), opt_device(opt_device), opt_func(opt_func), opt_ulist(opt_ulist), opt_udict(opt_udict), opt_str(opt_str), opt_list_any(opt_list_any), opt_list_list_int(opt_list_list_int), opt_dict_any_any(opt_dict_any_any), opt_dict_str_any(opt_dict_str_any), opt_dict_any_str(opt_dict_any_str), opt_dict_str_list_int(opt_dict_str_list_int) {} MLC_DEF_DYN_TYPE(_EXPORTS, py_classObj, ::mlc::Object, "mlc.testing.py_class"); }; // struct py_classObj @@ -142,5 +142,5 @@ def test_prototype_cxx() -> None: } // namespace testing } // namespace mlc """.strip() - actual = mlcd.prototype_cxx(PyClassForTest).strip() + actual = mlcd.prototype(PyClassForTest, lang="c++").strip() assert actual == expected