diff --git a/python/rwkv_cpp/rwkv_cpp_model.py b/python/rwkv_cpp/rwkv_cpp_model.py index 80304db..1f79972 100644 --- a/python/rwkv_cpp/rwkv_cpp_model.py +++ b/python/rwkv_cpp/rwkv_cpp_model.py @@ -70,6 +70,14 @@ def __init__( self._valid: bool = True + @property + def arch_version_major(self) -> int: + return self._library.rwkv_get_arch_version_major(self._ctx) + + @property + def arch_version_minor(self) -> int: + return self._library.rwkv_get_arch_version_minor(self._ctx) + @property def n_vocab(self) -> int: return self._library.rwkv_get_n_vocab(self._ctx) diff --git a/python/rwkv_cpp/rwkv_cpp_shared_library.py b/python/rwkv_cpp/rwkv_cpp_shared_library.py index a42dec5..0c2630e 100644 --- a/python/rwkv_cpp/rwkv_cpp_shared_library.py +++ b/python/rwkv_cpp/rwkv_cpp_shared_library.py @@ -79,6 +79,12 @@ def __init__(self, shared_library_path: str) -> None: ] self.library.rwkv_eval_sequence_in_chunks.restype = ctypes.c_bool + self.library.rwkv_get_arch_version_major.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_arch_version_major.restype = ctypes.c_uint32 + + self.library.rwkv_get_arch_version_minor.argtypes = [ctypes.c_void_p] + self.library.rwkv_get_arch_version_minor.restype = ctypes.c_uint32 + self.library.rwkv_get_n_vocab.argtypes = [ctypes.c_void_p] self.library.rwkv_get_n_vocab.restype = ctypes.c_size_t @@ -261,6 +267,30 @@ def rwkv_eval_sequence_in_chunks( ): raise ValueError('rwkv_eval_sequence_in_chunks failed, check stderr') + def rwkv_get_arch_version_major(self, ctx: RWKVContext) -> int: + """ + Returns the major version used by the given model. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_arch_version_major(ctx.ptr) + + def rwkv_get_arch_version_minor(self, ctx: RWKVContext) -> int: + """ + Returns the minor version used by the given model. + + Parameters + ---------- + ctx : RWKVContext + RWKV context obtained from rwkv_init_from_file. + """ + + return self.library.rwkv_get_arch_version_minor(ctx.ptr) + def rwkv_get_n_vocab(self, ctx: RWKVContext) -> int: """ Returns the number of tokens in the given model's vocabulary. diff --git a/rwkv.cpp b/rwkv.cpp index 19a78c4..c5bad1c 100644 --- a/rwkv.cpp +++ b/rwkv.cpp @@ -152,6 +152,16 @@ extern "C" RWKV_API uint32_t rwkv_get_logits_buffer_element_count(const struct r return rwkv_get_logits_len(ctx); } +// API function. +size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx) { + return (size_t) ctx->model->arch_version_major; +} + +// API function. +size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx) { + return (size_t) ctx->model->arch_version_minor; +} + // API function. size_t rwkv_get_n_vocab(const struct rwkv_context * ctx) { return (size_t) ctx->model->header.n_vocab; diff --git a/rwkv.h b/rwkv.h index 978ed45..455a35a 100644 --- a/rwkv.h +++ b/rwkv.h @@ -172,6 +172,12 @@ extern "C" { float * logits_out ); + // Returns the major version used by the given model. + RWKV_API size_t rwkv_get_arch_version_major(const struct rwkv_context * ctx); + + // Returns the minor version used by the given model. + RWKV_API size_t rwkv_get_arch_version_minor(const struct rwkv_context * ctx); + // Returns the number of tokens in the given model's vocabulary. // Useful for telling 20B_tokenizer models (n_vocab = 50277) apart from World models (n_vocab = 65536). RWKV_API size_t rwkv_get_n_vocab(const struct rwkv_context * ctx);