Skip to content

Commit

Permalink
[Model] Initial batching support for Llama (mlc-ai#1048)
Browse files Browse the repository at this point in the history
This PR introduces the initial batched input support for llama
models. To make the code managable, we keep both the single-sequence
handling flow and the batching handling flow in the Llama modeling.

Now, with `--enable-batching` as a build argument, we build Llama
for the batched version.

NOTE: The paged attention kernel/TIR func are not included in this PR,
so currently the built library with batching enabled is not runnable.
We will follow up with the attention kernel in the future.

This PR guarantees that the existing single-sequence inference (Python
API, CLI, etc.) is not broken.

P.S.. The batching flow is subject to bug fixes as we integrate with
the attention function and run the e2e flow in the future.
  • Loading branch information
MasterJH5574 authored Oct 14, 2023
1 parent edab9b5 commit d854105
Show file tree
Hide file tree
Showing 6 changed files with 614 additions and 88 deletions.
15 changes: 7 additions & 8 deletions cpp/llm_chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ class LLMChat {
// Step 6. KV cache creation.
this->kv_cache_ = ft_.create_kv_cache_func_();
// Step 7. Pre-allocate fixed size ndarray
this->temperature_arr_ = NDArray::Empty({}, DataType::Float(32), device_);
this->temperature_arr_ = NDArray::Empty({1}, DataType::Float(32), device_);
float temperature = static_cast<float>(this->temperature_);
this->temperature_arr_.CopyFromBytes(&temperature, sizeof(float));
if (ft_.use_disco) {
Expand Down Expand Up @@ -947,19 +947,18 @@ class LLMChat {
// the generation_config will not override the original config
// since is only used for this generation
double gen_temperature;
NDArray gen_temperature_arr;
double gen_repetition_penalty;
double gen_top_p;
if (generation_config.count("temperature")) {
CHECK(generation_config["temperature"].is<double>());
gen_temperature = generation_config["temperature"].get<double>();

gen_temperature_arr = NDArray::Empty({}, DataType::Float(32), device_);
float temperature_cast = static_cast<float>(gen_temperature);
gen_temperature_arr.CopyFromBytes(&temperature_cast, sizeof(float));
if (gen_temperature != this->temperature_) {
this->temperature_ = gen_temperature;
float temperature_cast = static_cast<float>(gen_temperature);
this->temperature_arr_.CopyFromBytes(&temperature_cast, sizeof(float));
}
} else {
gen_temperature = this->temperature_;
gen_temperature_arr = this->temperature_arr_;
}
if (generation_config.count("repetition_penalty")) {
CHECK(generation_config["repetition_penalty"].is<double>());
Expand All @@ -979,7 +978,7 @@ class LLMChat {
if (gen_temperature < 1e-6f) {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
} else {
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, gen_temperature_arr));
this->UpdateLogitsOrProbOnCPUSync(this->Softmax(logits_on_device, this->temperature_arr_));
}
} else {
this->UpdateLogitsOrProbOnCPUSync(logits_on_device);
Expand Down
35 changes: 25 additions & 10 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class BuildArgs:
Build with separated embedding layer, only applicable to LlaMa. This
feature is in testing stage, and will be formally replaced after massive
overhaul of embedding feature for all models and use cases.
enable_batching: bool
Build the model for batched inference.
This is a temporary flag used to control the model execution flow in single-
sequence and batching settings for now. We will eventually merge two flows
in the future and remove this flag then.
"""
model: str = field(
default="auto",
Expand Down Expand Up @@ -180,21 +185,29 @@ class BuildArgs:
"action": "store_true",
},
)
no_cutlass_attn: bool = field(
enable_batching: bool = field(
default=False,
metadata={
"help": (
"Disable offloading attention operations to CUTLASS."
"Build the model for batched inference."
"This is a temporary flag used to control the model execution flow in single-"
"sequence and batching settings for now. We will eventually merge two flows"
"in the future and remove this flag then."
),
"action": "store_true",
},
)
no_cutlass_attn: bool = field(
default=False,
metadata={
"help": ("Disable offloading attention operations to CUTLASS."),
"action": "store_true",
},
)
no_cutlass_norm: bool = field(
default=False,
metadata={
"help": (
"Disable offloading layer and RMS norm operations to CUTLASS."
),
"help": ("Disable offloading layer and RMS norm operations to CUTLASS."),
"action": "store_true",
},
)
Expand Down Expand Up @@ -231,9 +244,7 @@ class BuildArgs:
use_flash_attn_mqa: bool = field(
default=False,
metadata={
"help": (
"Offload multi-query attention workload to Flash Attention."
),
"help": ("Offload multi-query attention workload to Flash Attention."),
"action": "store_true",
},
)
Expand Down Expand Up @@ -380,6 +391,8 @@ def mod_transform_before_build(
]
if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
if args.enable_batching:
model_names[2] = "decode_with_embed"
if args.model.lower().startswith("rwkv-"):
model_names += ["reset_kv_cache"]

Expand Down Expand Up @@ -458,7 +471,7 @@ def mod_transform_before_build(
),
annotate_workspace,
relax.transform.AllocateWorkspace(),
relax.transform.RunCodegen(options, entry_functions=model_names)
relax.transform.RunCodegen(options, entry_functions=model_names),
]
)(mod)

Expand Down Expand Up @@ -558,7 +571,9 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": use_cuda_graph}):
# The num_input attribute is needed to capture transformed weights passed as input
# into a cuda graph.
mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3})
# NOTE: CUDA graph for batching is not enabled and is left as a TODO item.
if not args.enable_batching:
mod_deploy["decode"] = mod_deploy["decode"].with_attr({"num_input": 3})
ex = relax.build(mod_deploy, args.target, system_lib=args.system_lib)

output_filename = f"{args.model}-{args.quantization.name}-{target_kind}.{args.lib_format}"
Expand Down
Loading

0 comments on commit d854105

Please sign in to comment.