Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYO-FT support, with some LoRA support #224

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
b8b9280
[PR-1685][ParamManager] Preserve variable names in transform_dequantize
Lunderberg Jan 11, 2024
29c49fc
[PR-1686][ParamManager] Simplify get_param_loading_functions signature
Lunderberg Jan 11, 2024
1ded3c2
[PR-1687][ParamManager] Separate get_param_loading_functions for get/set
Lunderberg Jan 11, 2024
8573cc4
[PR-1756][Transform] Add check for function.attrs
Lunderberg Feb 14, 2024
18a6d1c
[PR-1757][Transform] Handle non-schedulable func in LiftTIRGlobalBuff…
Lunderberg Feb 14, 2024
76cb787
[PR-1758][Models] Define sharding strategy when combine_matmul=False
Lunderberg Feb 2, 2024
de2b289
[PR-1760][Bugfix] Remove mutation of IRModule in ReorderTransformFunc
Lunderberg Feb 8, 2024
aadff29
[PR-1851][Bugfix] Handle model names with multiple path components
Lunderberg Feb 27, 2024
f1fde2a
[PR-1852][Build] Replace mod_transform_before_build with IRModule pass
Lunderberg Jan 11, 2024
1db8a2b
[PR-1855][Utils][Transform] Added SetEntryFuncs transform
Lunderberg Feb 8, 2024
c136c11
[PR-1856][Build] Update transform_params_for_each_rank to IRModule pass
Lunderberg Feb 1, 2024
ec531cf
[PR-1857][Utils] Allow ReorderTransformFunc to be used without param …
Lunderberg Feb 15, 2024
f5330d4
[Utils][Bugfix] Provide full path to shutil.copy in copy_tokenizer
Lunderberg Jan 11, 2024
035c915
[Model] Update Mixtral to have well-formed TIR
Lunderberg Feb 29, 2024
1069f69
Apply black auto-format
Lunderberg Feb 27, 2024
7e826c1
[BYO-FT] Generate a `transform_params` function in compiled module
Lunderberg Jan 30, 2024
ec048b0
[BYO-FT] Set combine_matmul=False for llama.py, VLLM-llama, mistral
Lunderberg Feb 6, 2024
de36ed9
[BYO-FT] Support execution of transform_params during initialization
Lunderberg Feb 1, 2024
ad51336
[Llama] Produce well-formed TIR for PagedAttention
Lunderberg Jan 24, 2024
03082e8
[Debug] Output `original_params` directory
Lunderberg Feb 6, 2024
a4376c3
[Debug] Implement validate_transform_params
Lunderberg Feb 8, 2024
234e777
[Debug] Add verify_well_formed calls
Lunderberg Feb 6, 2024
6f4aad1
[Debug] Print optimized model
Lunderberg Feb 8, 2024
71e5e93
[Debug] Added assert in ParamManager indicating required call sequence
Lunderberg Jan 17, 2024
febc355
[Debug] Add LOG.debug statements for safetensor loading
Lunderberg Feb 22, 2024
8b51544
[LoRA] Add --lora=path/to/lora.safetensors argument
Lunderberg Jan 10, 2024
4e09fa7
[LoRA] Implement utility functions to get the rank of each LoRA
Lunderberg Jan 17, 2024
2b62e5a
[LoRA] Assert that `num_input` is present instead of redefining it
Lunderberg Jan 17, 2024
d202e7b
[LoRA] Implement optimization passes for LoRA models
Lunderberg Feb 8, 2024
c15ed34
[LoRA] Add transforms to inject/optimize LoRA
Lunderberg Feb 8, 2024
4741afb
Handle bfloat16 -> float16 conversions for any dimension of tensor
Lunderberg Feb 29, 2024
677fe40
Normalize from (bfloat16 | float16) to float16
Lunderberg Mar 1, 2024
1b5620e
Black auto-format
Lunderberg Mar 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
822 changes: 725 additions & 97 deletions mlc_llm/core.py

Large diffs are not rendered by default.

52 changes: 49 additions & 3 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,27 @@ def shard_k_weight_scale(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_axis_0(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((red, spatial), dtype=dtype)
w = topi.reshape(a, (num_shards, red // num_shards, spatial))
func = te.create_prim_func([a, w])
return func

def shard_axis_1(weight: relax.TensorStructInfo):
(spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((spatial, red), dtype=dtype)
w = topi.reshape(a, (spatial, num_shards, red // num_shards))
w = topi.transpose(w, (1, 0, 2))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
(spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
Expand Down Expand Up @@ -135,6 +156,8 @@ def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
"shard_mlp_k": shard_k_weight_scale,
"shard_o_proj_k": shard_k_weight_scale,
"shard_gate_up": shard_gate_up_weight_scale,
"shard_axis_0": shard_axis_0,
"shard_axis_1": shard_axis_1,
"moe_shard_mlp_k": moe_shard_k_weight_scale,
"moe_shard_gate_up": moe_shard_gate_up_weight_scale,
}
Expand Down Expand Up @@ -176,6 +199,27 @@ def shard_k_weight(weight: relax.TensorStructInfo):
func = te.create_prim_func([a, w])
return func

def shard_axis_0(weight: relax.TensorStructInfo):
(red, spatial), dtype = weight.shape, weight.dtype
red, spatial = int(red), int(spatial)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((red, spatial), dtype=dtype)
w = topi.reshape(a, (num_shards, red // num_shards, spatial))
func = te.create_prim_func([a, w])
return func

def shard_axis_1(weight: relax.TensorStructInfo):
(spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((spatial, red), dtype=dtype)
w = topi.reshape(a, (spatial, num_shards, red // num_shards))
w = topi.transpose(w, (1, 0, 2))
func = te.create_prim_func([a, w])
return func

def shard_gate_up_weight_scale(x: relax.TensorStructInfo):
(red, spatial), dtype = x.shape, x.dtype
red, spatial = int(red), int(spatial)
Expand All @@ -197,6 +241,8 @@ def shard_gate_up_weight_scale(x: relax.TensorStructInfo):
"shard_mlp_k": shard_k_weight,
"shard_o_proj_k": shard_k_weight,
"shard_gate_up": shard_gate_up_weight_scale,
"shard_axis_0": shard_axis_0,
"shard_axis_1": shard_axis_1,
}


Expand All @@ -221,7 +267,7 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]):

shard_info_dict[param_name] = shard_info

q_params = param_manager.get_quantized_param_info("prefill").fields
q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")]
for _, param in param_manager.params.items():
if param.shard_strategy is None:
pass
Expand Down Expand Up @@ -272,7 +318,7 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I
param_shape_is_already_sharded=args.build_model_only,
)

q_params = param_manager.get_quantized_param_info("prefill").fields
q_params = [param.struct_info for param in param_manager.get_quantized_params("prefill")]

# The order of the quantized parameters must be preserved.
# Therefore, we need to loop over q_params and look up information
Expand All @@ -289,7 +335,7 @@ def create_shard_transformation_func(param_manager, args, model_config) -> tvm.I
)

bb = relax.BlockBuilder() # pylint: disable=invalid-name
with bb.function("transform_params"):
with bb.function("transform_params", attrs={"num_input": 1}):
rank = tir.SizeVar("rank", "int64")
# TODO(Lunderberg): Support primitive inputs to relax
# functions. Currently, using a PrimStructInfo as the
Expand Down
61 changes: 50 additions & 11 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ def __init__(self, config: LlamaConfig):
self.gate_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)
self.down_proj = Linear(intermediate_size, hidden_size, dtype=dtype, bias=False)
self.up_proj = Linear(hidden_size, intermediate_size, dtype=dtype, bias=False)
self.gate_proj.weight.shard_dim = 0
self.gate_proj.weight.shard_strategy = "shard_axis_0"
self.down_proj.weight.shard_dim = 1
self.down_proj.weight.shard_strategy = "shard_axis_1"
self.up_proj.weight.shard_dim = 0
self.up_proj.weight.shard_strategy = "shard_axis_0"

self.act = {"silu": relax.op.nn.silu, "gelu": relax.op.nn.gelu}[config.hidden_act]

Expand Down Expand Up @@ -375,6 +381,9 @@ def __init__(self, config: LlamaConfig):
self.q_proj.weight.shard_dim = 0
self.k_proj.weight.shard_dim = 0
self.v_proj.weight.shard_dim = 0
self.q_proj.weight.shard_strategy = "shard_axis_0"
self.k_proj.weight.shard_strategy = "shard_axis_0"
self.v_proj.weight.shard_strategy = "shard_axis_0"

self.o_proj = Linear(
self.head_dim * self.num_query_heads,
Expand Down Expand Up @@ -1250,7 +1259,6 @@ def emit_paged_kv_cache_op(bb: relax.BlockBuilder, config: LlamaConfig) -> None:
num_heads = config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads

# fmt: off
@T.prim_func
def kv_cache_transpose_append(
var_pages: T.handle,
Expand All @@ -1269,7 +1277,11 @@ def kv_cache_transpose_append(
page_size = T.SizeVar("page_size", "int64")
num_pages = T.int64()

pages = T.match_buffer(var_pages, (num_pages, num_layers, 2, num_heads, page_size, head_dim), config.dtype)
pages = T.match_buffer(
var_pages,
(num_pages, num_layers, 2, num_heads, page_size, head_dim),
config.dtype,
)
k_data = T.match_buffer(var_k_data, (ntoken, num_heads, head_dim), config.dtype)
v_data = T.match_buffer(var_v_data, (ntoken, num_heads, head_dim), config.dtype)
last_page_offset = T.match_buffer(var_last_page_offset, (nseq,), "int32")
Expand All @@ -1281,10 +1293,23 @@ def kv_cache_transpose_append(
for global_pos, h, f in T.grid(ntoken, num_heads, head_dim):
with T.block("k_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos])
seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx])

seq_idx = T.meta_var(pos2seqidx[vgpos].astype("int64"))
seqlen = T.meta_var(
(
(page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1)
* page_size
+ last_page_offset[seq_idx]
).astype("int64")
)

pages[
page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
page_table_values[
page_table_indptr[seq_idx]
+ T.floordiv(
seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size
)
],
layer_id,
0,
vh,
Expand All @@ -1293,17 +1318,29 @@ def kv_cache_transpose_append(
] = k_data[vgpos, vh, vf]
with T.block("v_transpose_append"):
vgpos, vh, vf = T.axis.remap("SSS", [global_pos, h, f])
seq_idx: T.int64 = T.Cast("int64", pos2seqidx[vgpos])
seqlen: T.int64 = T.Cast("int64", (page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1) * page_size + last_page_offset[seq_idx])

seq_idx = T.meta_var(pos2seqidx[vgpos].astype("int64"))
seqlen = T.meta_var(
(
(page_table_indptr[seq_idx + 1] - page_table_indptr[seq_idx] - 1)
* page_size
+ last_page_offset[seq_idx]
).astype("int64")
)

pages[
page_table_values[page_table_indptr[seq_idx] + T.floordiv(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size)],
page_table_values[
page_table_indptr[seq_idx]
+ T.floordiv(
seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size
)
],
layer_id,
1,
vh,
T.floormod(seqlen - (append_length_indptr[seq_idx + 1] - vgpos), page_size),
vf,
] = v_data[vgpos, vh, vf]
# fmt: on

bb.add_func(kv_cache_transpose_append, "kv_cache_transpose_append")
bb.add_func(relax.extern("paged_kv_cache.attention_kernel_prefill"), "attention_prefill")
Expand Down Expand Up @@ -1516,7 +1553,8 @@ def get_model(args, hf_config):
**hf_config,
dtype=dtype,
position_embedding_base=position_embedding_base,
combine_matmul=True,
# TODO: Re-enable with CombineParallelMatmul
combine_matmul=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CombineParallelMatmul with lora supported now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no. It will require improvements to LiftTransformParams, in order to lift out a parameter transformation that can be used for every function in an IRModule.

num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
Expand All @@ -1526,7 +1564,8 @@ def get_model(args, hf_config):
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
# TODO: Re-enable with CombineParallelMatmul
combine_matmul=False,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
Expand Down
9 changes: 6 additions & 3 deletions mlc_llm/relax_model/llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,8 @@ def get_model(args, hf_config):
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
# combine_matmul=True,
combine_matmul=False,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
quantization_scheme=args.quantization,
Expand All @@ -1072,7 +1073,8 @@ def get_model(args, hf_config):
**hf_config,
dtype=dtype,
position_embedding_base=position_embedding_base,
combine_matmul=True,
# combine_matmul=True,
combine_matmul=False,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
Expand All @@ -1082,7 +1084,8 @@ def get_model(args, hf_config):
dtype=dtype,
max_sequence_length=hf_config["max_position_embeddings"],
position_embedding_base=position_embedding_base,
combine_matmul=True,
# combine_matmul=True,
combine_matmul=False,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
Expand Down
27 changes: 16 additions & 11 deletions mlc_llm/relax_model/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,9 @@ def forward(self, input_ids: relax.Expr):


class MistralModel(nn.Module):
def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False):
def __init__(
self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False
):
self.num_shards = config.num_shards
self.padding_idx = config.pad_token_id
self.embed_tokens = None
Expand Down Expand Up @@ -730,7 +732,9 @@ def forward(


class MistralForCausalLM(nn.Module):
def __init__(self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False):
def __init__(
self, config: MistralConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False
):
self.model = MistralModel(config, vocab_size_var, sep_embed)
self.lm_head = Linear(config.hidden_size, vocab_size_var, dtype=config.dtype, bias=False)

Expand Down Expand Up @@ -827,13 +831,13 @@ def create_encoding_func(

bsz = 1
seq_len = tvm.tir.SizeVar("n", "int64") # number of tokens for the input
rolling_cache_len = tvm.tir.SizeVar("c", "int64") # rolling_cache_len captures number of elements in the cache
rolling_cache_len = tvm.tir.SizeVar(
"c", "int64"
) # rolling_cache_len captures number of elements in the cache
kv_seq_len = tvm.tir.SizeVar(
"k", "int64"
) # kv_seq_len captures number of elements in cache + seq_len
cache_offset = tvm.tir.SizeVar(
"o", "int64"
) # slidinf window kv cache offset
cache_offset = tvm.tir.SizeVar("o", "int64") # slidinf window kv cache offset

hidden_size = config.hidden_size
with bb.function(func_name):
Expand Down Expand Up @@ -888,13 +892,13 @@ def create_decoding_func(
func_name = "decode"

bsz = 1
rolling_cache_len = tvm.tir.SizeVar("c", "int64") # rolling_cache_len captures number of elements in the cache
rolling_cache_len = tvm.tir.SizeVar(
"c", "int64"
) # rolling_cache_len captures number of elements in the cache
kv_seq_len = tvm.tir.SizeVar(
"k", "int64"
) # kv_seq_len captures number of elements in cache + seq_len
cache_offset = tvm.tir.SizeVar(
"o", "int64"
) # sliding window kv cache offset
cache_offset = tvm.tir.SizeVar("o", "int64") # sliding window kv cache offset

with bb.function(func_name):
model = MistralForCausalLM(config, tvm.tir.SizeVar("vocab_size", "int64"))
Expand Down Expand Up @@ -992,7 +996,8 @@ def get_model(args, hf_config):
config = MistralConfig(
**hf_config,
dtype=dtype,
combine_matmul=True,
# combine_matmul=True,
combine_matmul=False,
num_shards=args.num_shards,
build_model_only=args.build_model_only,
)
Expand Down
4 changes: 2 additions & 2 deletions mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,14 @@ def top2_softmax_func(
for j in T.unroll(2):
with T.block("cast"):
vj = T.axis.remap("S", [j])
local_top_k_f32[vj] = T.cast(local_top_k[j], "float32")
local_top_k_f32[vj] = T.cast(local_top_k[vj], "float32")
with T.block("max"):
local_top_k_max[0] = T.max(local_top_k_f32[0], local_top_k_f32[1])
for j in T.unroll(2):
with T.block("output"):
vj = T.axis.remap("S", [j])
out[vi, vj] = T.cast(
T.exp(local_top_k_f32[j] - local_top_k_max[0])
T.exp(local_top_k_f32[vj] - local_top_k_max[0])
/ (
T.exp(local_top_k_f32[0] - local_top_k_max[0])
+ T.exp(local_top_k_f32[1] - local_top_k_max[0])
Expand Down
Loading
Loading