Skip to content

Commit

Permalink
Allocate kv-cache on demand
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Mar 12, 2024
1 parent 4289555 commit 815544d
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 9 deletions.
18 changes: 16 additions & 2 deletions chatglm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,7 @@ void BaseModelForCausalLM::sampling_softmax_inplace(TokenIdScore *first, TokenId
std::vector<int> BaseModelForCausalLM::generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
BaseStreamer *streamer) {
CHATGLM_CHECK(gen_config.max_length <= config.max_length)
<< "requested max_length (" << gen_config.max_length << ") is larger than model's max_length ("
<< "Requested max_length (" << gen_config.max_length << ") exceeds pre-configured model max_length ("
<< config.max_length << ")";

std::vector<int> output_ids;
Expand Down Expand Up @@ -1700,7 +1700,16 @@ StateDict InternLMForCausalLM::state_dict() const {

// ===== pipeline =====

Pipeline::Pipeline(const std::string &path) {
Pipeline::Pipeline(const std::string &path, int max_length) {
auto _update_config_max_length = [](ModelConfig &config, int max_length) {
if (max_length > 0) {
CHATGLM_CHECK(max_length <= config.max_length)
<< "Requested max_length (" << max_length << ") exceeds the max possible model sequence length ("
<< config.max_length;
config.max_length = max_length;
}
};

mapped_file = std::make_unique<MappedFile>(path);
ModelLoader loader(mapped_file->data, mapped_file->size);

Expand All @@ -1718,6 +1727,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-5f, ActivationType::GELU, true, true,
true, false, RopeType::CHATGLM, -1, AttentionMaskType::CHATGLM);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1734,6 +1744,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV2>(), 1e-5f, ActivationType::SILU, true, false,
false, false, RopeType::GPTJ, 2, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1758,6 +1769,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-6f, ActivationType::SILU, false, false,
false, false, RopeType::NEOX, 1, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1774,6 +1786,7 @@ Pipeline::Pipeline(const std::string &path) {
// load config
ModelConfig config(model_type, loader.read_basic<ConfigRecordV1>(), 1e-6f, ActivationType::SILU, false, false,
false, true, RopeType::DISABLED, -1, AttentionMaskType::CAUSAL);
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand All @@ -1797,6 +1810,7 @@ Pipeline::Pipeline(const std::string &path) {
config = ModelConfig(model_type, rec, 1e-6f, ActivationType::SILU, false, false, false, false,
RopeType::NEOX, 1, AttentionMaskType::CAUSAL);
}
_update_config_max_length(config, max_length);

// load tokenizer
int proto_size = loader.read_basic<int>();
Expand Down
2 changes: 1 addition & 1 deletion chatglm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ class InternLMForCausalLM : public BasicModelForCausalLM<InternLMModel> {

class Pipeline {
public:
Pipeline(const std::string &path);
Pipeline(const std::string &path, int max_length = -1);

std::vector<int> generate(const std::vector<int> &input_ids, const GenerationConfig &gen_config,
BaseStreamer *streamer = nullptr) const;
Expand Down
2 changes: 1 addition & 1 deletion chatglm_cpp/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class ModelType:
def value(self) -> int:
...
class Pipeline:
def __init__(self, path: str) -> None:
def __init__(self, path: str, max_length: int = -1) -> None:
...
@property
def model(self) -> BaseModelForCausalLM:
Expand Down
10 changes: 7 additions & 3 deletions chatglm_cpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ def _ensure_chat_message(message: Union[ChatMessage, Dict[str, Any]]) -> ChatMes


class Pipeline(_C.Pipeline):
def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None:
def __init__(self, model_path: str, *, max_length: Optional[int] = None, dtype: Optional[str] = None) -> None:
kwargs = {}
if max_length is not None:
kwargs.update(max_length=max_length)

if Path(model_path).is_file():
# load ggml model
super().__init__(str(model_path))
super().__init__(str(model_path), **kwargs)
else:
# convert hf model to ggml format
from chatglm_cpp.convert import convert
Expand All @@ -40,7 +44,7 @@ def __init__(self, model_path: str, *, dtype: Optional[str] = None) -> None:

with tempfile.NamedTemporaryFile("wb") as f:
convert(f, model_path, dtype=dtype)
super().__init__(f.name)
super().__init__(f.name, **kwargs)

def chat(
self,
Expand Down
2 changes: 1 addition & 1 deletion chatglm_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ PYBIND11_MODULE(_C, m) {
// ===== Pipeline ====

py::class_<Pipeline>(m, "Pipeline")
.def(py::init<const std::string &>(), "path"_a)
.def(py::init<const std::string &, int>(), "path"_a, "max_length"_a = -1)
.def_property_readonly("model", [](const Pipeline &self) { return self.model.get(); })
.def_property_readonly("tokenizer", [](const Pipeline &self) { return self.tokenizer.get(); });
}
Expand Down
2 changes: 1 addition & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ static inline void print_message(const chatglm::ChatMessage &message) {
static void chat(Args &args) {
ggml_time_init();
int64_t start_load_us = ggml_time_us();
chatglm::Pipeline pipeline(args.model_path);
chatglm::Pipeline pipeline(args.model_path, args.max_length);
int64_t end_load_us = ggml_time_us();

std::string model_name = pipeline.model->config.model_type_name();
Expand Down
13 changes: 13 additions & 0 deletions tests/test_chatglm_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def check_pipeline(model_path, prompt, target, gen_kwargs={}):
assert stream_output == target


@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found")
def test_pipeline_options():
# check max_length option
pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH)
assert pipeline.model.config.max_length == 2048
pipeline = chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH, max_length=234)
assert pipeline.model.config.max_length == 234

# check if resources are properly released
for _ in range(100):
chatglm_cpp.Pipeline(CHATGLM_MODEL_PATH)


@pytest.mark.skipif(not CHATGLM_MODEL_PATH.exists(), reason="model file not found")
def test_chatglm_pipeline():
check_pipeline(
Expand Down

0 comments on commit 815544d

Please sign in to comment.