-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BYO-FT support, with some LoRA support #224
Conversation
Previously, the `transform_dequantize` pass would output a function where all parameters were named `f"param_{i}"`, which could be quite difficult to read. This commit updates `transform_dequantize` to propagate the parameter names. If a parameter is not quantized, its name remains the same. If a parameter is quantized and produces a single output tensor, the new name is `f"{old_name}.{new_dtype}"`. If a parameter is quantized and produces multiple output tensors, they are named `f"{old_name}.{new_dtype}.{i}"`.
Prior to this commit, the `loaded_idx_set`, `loaded_torch_bins`, `cached_relax_params`, and `cached_torch_params` objects were passed by the calling scope into `ParamManager.get_param_loading_functions`. These objects are implementation details for caching purposes, and are not required by the calling scope. This commit updates `ParamManager.get_param_loading_functions` to construct the cache object internally. The closures `get_item` and `set_item` returned from `get_param_loading_functions` have reference to the internal cache objects.
Prior to this commit, `get_param_loading_functions` generated and returned both `get_item` and `set_item`. The code paths for these two were entirely independent, and did not depend on each other. This commit splits it into `get_param_get_item`, which returns the `get_item` function, and `get_param_set_item`, which returns the `set_item` function.
A relax function may have `None` for `function.attrs`. If this occurs, checking `"Composite" in function.attrs` would raise an error.
…erAlloc Prior to this commit, the implementation of `mlc_llm.transform.LiftTIRGlobalBufferAlloc` assumed that every PrimFunc in a module was schedulable. If the module contained a non-schedulable PrimFunc, the `assert isinstance(func.body, tir.BlockRealize)` would produce an error. This commit updates `LiftTIRGlobalBufferAlloc` to leave unrecognized PrimFunc instances, and their callsites, unmodified.
Prior to this commit, the `weight.shard_dim` and `weight.shard_strategy` fields were defined when `combine_matmul=True`, but were left undefined in some locations for `combine_matmul=False`. This commit adds definitions for these cases.
A pass must not mutate the `IRModule` that it receives as input. Unlike most functions exposed through the python API, the `IRModule.__setitem__` method mutates the underlying `IRModuleNode`. This can impact other users of the shared `IRModule` object, which expect mutation to be done using copy-on-write. See apache/tvm#11372 for more details.
Prior to this commit, a model name with multiple path components (e.g. `dist/models/group_name/model_name`) would have duplicated path components (e.g. `dist/group_name/artifact_path/group_name/libname.so`). This commit resolves the duplication.
Instead of a python function that returns an updated `IRModule`, the new `optimize_mod_pipeline` function returns a `tvm.ir.transform.Pass` which can be applied to an `IRModule`.
Sets the entry functions for a module.
This allows it to be used as part of a optimization pipeline specified as a `tvm.ir.transform.Sequential`.
…manager Prior to this commit, the `ReorderTransformFunc` required several components of the `ParamManager` to use. The functionality it provides, reordering dataflow blocks to minimize the liveset, is useful outside of the context of the `ParamManager`. This commit makes the following changes, allowing it to be used independently of the `ParamManager`. - Generate the `pidx2binname` dictionary outside of `ReorderTransformFunc` - Allow parameters to be separate `func.params`, rather than a single bundled tuple parameter.
Prior to this commit, `copy_tokenizer` provided the `$ARTIFACT_DIR/params` directory as the destination argument to `shutil.copy`. If the directory already exists, then this would work correctly. If the directory did not already exist, then it would create a file named `params` within the artifact directory. This code path could be triggered by running `build.py` first with `--build-model-only`, and later with `--convert-weights-only`. This commit updates `copy_tokenizer` to create the `params` directory if it does not already exist, and to provide a full output path to `shutil.copy` instead of an output directory. In addition, the `utils.save_params` method has an explicit check to ensure that any previously existing `$ARTIFACT_DIR/params` is a directory, to throw an error earlier if not. (This was the location of the first observed error when running with `--convert-weights-only` after running `--build-model-only`.)
What's the plan for migrating this to the new repo? cc @sunggg |
The commits |
Inside a `T.block`, loop variables may not be used, and access to them must be done through the corresponding `T.axis.remap` output.
In addition to the inference functions, produce a `transform_params` in the compiled module. This function allows the weight conversion to be performed at a later point.
Loading fine-tunes requires that the model matches the variable names in the initial model definition.
The `T.reads` and `T.writes` annotations may not reference variables declared within the body of the block.
For debug purposes, symlink all `*.safetensors` files into the `original_params` directory of the output.
Validates that the output of `transform_params` is the same as the input weights required by the inference functions.
This is useful for debugging LoRA implementations, if the `"decode"` function doesn't yet exist. The `num_input` attribute should now be generated for all relax to TIR kernels, and should not need to be added afterward.
Previously, only apply this normalization if the `QuantizationSpec` provides `float16` where the model had been initially `bfloat16`. Now, apply this normalization if the model is defined as either `float16` or `bfloat16`.
13755c8
to
1b5620e
Compare
Do you have the changes you made to Also I need an example lora file that is known to work. |
@@ -1516,7 +1553,8 @@ def get_model(args, hf_config): | |||
**hf_config, | |||
dtype=dtype, | |||
position_embedding_base=position_embedding_base, | |||
combine_matmul=True, | |||
# TODO: Re-enable with CombineParallelMatmul | |||
combine_matmul=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is CombineParallelMatmul
with lora supported now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, no. It will require improvements to LiftTransformParams
, in order to lift out a parameter transformation that can be used for every function in an IRModule
.
""" | ||
param_path = pathlib.Path(param_path) | ||
|
||
safetensors = LazySafetensorDir(param_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it safe to assume that safetensors are always available? If this reads from the HF model cards, I've seen the cases that this is not true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conversion into safetensors format will be part of the asset ingestion step, so we should always have them available. This is also on the fallback path if the params
directory generated with tvm.contrib.tvmjs
isn't available, so the normal workflow should still work as intended.
(Side-note: I like the safetensors format a lot more than the pytorch format, because we can read a single tensor without needing to unpack the entire file. This avoids a large amount of caching that is needed when reading from pytorch files, as
@@ -67,36 +132,144 @@ def broadcast_from_worker_0(sess: di.Session, src, shape, dtype): | |||
return dst | |||
|
|||
|
|||
def load_tvm_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This code path seems untested. The function is broken.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are correct, looks like this was a merge error on my part.
@Lunderberg Have you tested Mixtral for BYO-FT? Even without matmul combining, MLC and PT params don't map one to one. The logic to stack 2D MoE weights into a single 3D weight tensor is needed. |
@masahi I have, and am currently running into two issues for Mixtral.
|
That's odd. With or without BYO-FT, I'm getting good outputs from Mixtral in the new repo. In the old repo, you should get the same output if you run
|
) | ||
|
||
module = sess.load_vm_module(lib_path.as_posix()) | ||
params = module["transform_params"](worker_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After moving matmul weight combining to transform_params
, I met a different OOM problem. Param loading works fine but the OOM happens after memory profiling. I realized that this is a fairly nasty issue, and that you will also probably encounter this when you enable matmul combining for Mixtral.
By running transform_params
on GPU, params and temporaries are allocated inside the pooled allocator. During memory profiling, the allocator gets the "biggest" input and is supposed to make allocations for all activations derived from that input. But if there are already freed memory in the pool before memory profiling, the allocator may reclaim them to serve an allocation request, rather than doing a fresh allocation. This breaks the contracts of memory profiling, and leads to an under-estimation of required memory to run an inference on the biggest input, or a failure to find a sutiable memory block in the pool during serving.
So we need to make sure that the memory pool in the pooled allocator is empty before memory profiling. The easiest way I found is to use a different allocator in transform_params
. My commit in the new repo https://github.com/octoml/mlc-serve/commit/63fd90f648a9bb596e4b9a3dc31a9a2598cce492 does that. It depends on a TVM commit https://github.com/octoml/tvm/commit/44a0d80b8d4a2a12f1fd5d8b08c027bd249481f2 which is now in mlc-serve-old-final
.
This assumes that the peak VRAM footprint when running transform_params
with the naive allocator doesn't exceed the total bytes of params (otherwise memory profiling becomes inaccurate). It seems to work fine according to my experiment: Running a benchmark script against Mixtral, with or without transform_params
, shows a very similar VRAM usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable for the short-term, and is a good workaround for the memory profiler limitations. For the long-term, I'm wondering if we should make a transform that generates a memory footprint calculator for each function. This transform would determine the liveset of variables at each step in the Relax compute graph, in terms of any symbolic shapes in the function parameters, and would generate a function that returns the maximum footprint across all livesets in a function.
What I'm picturing for the output:
@R.function
def max_bytes_memory(
func_name: str,
vocab_size: R.Prim(value="tir_vocab_size"),
seq_len: R.Prim(value="tir_seq_len"),
) -> R.Prim("int64"):
tir_vocab_size = T.int64()
tir_seq_len = T.int64()
if func_name == "prefill":
output = R.prim_value(...)
elif func_name == "decode":
output = R.prim_value(...)
else:
R.assert_op(
False,
func_name,
format="Function name must in ['prefill','decode'], but was {}",
)
output = R.prim_value(-1)
return output
That way, instead of needing to perform a memory-profiling call with a specific setup for the memory pool, we could call the generated function and use the precomputed memory footprint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, as long as this approach computes the same value as the current dynamic one, I agree that we should switch to this approach. Feel free to work on it if you are interested (might be necessary for LoRA as discussed).
Commits that have been PRed to upstream
mlc-ai/mlc-llm
are named with the corresponding PR number.