Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BYO-FT support, with some LoRA support #224

Closed
wants to merge 33 commits into from

Conversation

Lunderberg
Copy link
Member

Commits that have been PRed to upstream mlc-ai/mlc-llm are named with the corresponding PR number.

Previously, the `transform_dequantize` pass would output a function
where all parameters were named `f"param_{i}"`, which could be quite
difficult to read.  This commit updates `transform_dequantize` to
propagate the parameter names.

If a parameter is not quantized, its name remains the same.  If a
parameter is quantized and produces a single output tensor, the new
name is `f"{old_name}.{new_dtype}"`.  If a parameter is quantized and
produces multiple output tensors, they are named
`f"{old_name}.{new_dtype}.{i}"`.
Prior to this commit, the `loaded_idx_set`, `loaded_torch_bins`,
`cached_relax_params`, and `cached_torch_params` objects were passed
by the calling scope into `ParamManager.get_param_loading_functions`.
These objects are implementation details for caching purposes, and are
not required by the calling scope.

This commit updates `ParamManager.get_param_loading_functions` to
construct the cache object internally.  The closures `get_item` and
`set_item` returned from `get_param_loading_functions` have reference
to the internal cache objects.
Prior to this commit, `get_param_loading_functions` generated and
returned both `get_item` and `set_item`.  The code paths for these two
were entirely independent, and did not depend on each other.  This
commit splits it into `get_param_get_item`, which returns the
`get_item` function, and `get_param_set_item`, which returns the
`set_item` function.
A relax function may have `None` for `function.attrs`.  If this
occurs, checking `"Composite" in function.attrs` would raise an error.
…erAlloc

Prior to this commit, the implementation of
`mlc_llm.transform.LiftTIRGlobalBufferAlloc` assumed that every
PrimFunc in a module was schedulable.  If the module contained a
non-schedulable PrimFunc, the `assert isinstance(func.body,
tir.BlockRealize)` would produce an error.

This commit updates `LiftTIRGlobalBufferAlloc` to leave unrecognized
PrimFunc instances, and their callsites, unmodified.
Prior to this commit, the `weight.shard_dim` and
`weight.shard_strategy` fields were defined when
`combine_matmul=True`, but were left undefined in some locations for
`combine_matmul=False`.  This commit adds definitions for these cases.
A pass must not mutate the `IRModule` that it receives as input.
Unlike most functions exposed through the python API, the
`IRModule.__setitem__` method mutates the underlying `IRModuleNode`.
This can impact other users of the shared `IRModule` object, which
expect mutation to be done using copy-on-write.

See apache/tvm#11372 for more details.
Prior to this commit, a model name with multiple path
components (e.g. `dist/models/group_name/model_name`) would have
duplicated path components
(e.g. `dist/group_name/artifact_path/group_name/libname.so`).
This commit resolves the duplication.
Instead of a python function that returns an updated `IRModule`, the
new `optimize_mod_pipeline` function returns a `tvm.ir.transform.Pass`
which can be applied to an `IRModule`.
Sets the entry functions for a module.
This allows it to be used as part of a optimization pipeline specified
as a `tvm.ir.transform.Sequential`.
…manager

Prior to this commit, the `ReorderTransformFunc` required several
components of the `ParamManager` to use.  The functionality it
provides, reordering dataflow blocks to minimize the liveset, is
useful outside of the context of the `ParamManager`.  This commit
makes the following changes, allowing it to be used independently of
the `ParamManager`.

- Generate the `pidx2binname` dictionary outside of `ReorderTransformFunc`

- Allow parameters to be separate `func.params`, rather than a single
  bundled tuple parameter.
Prior to this commit, `copy_tokenizer` provided the
`$ARTIFACT_DIR/params` directory as the destination argument to
`shutil.copy`.  If the directory already exists, then this would work
correctly.  If the directory did not already exist, then it would
create a file named `params` within the artifact directory.  This code
path could be triggered by running `build.py` first with
`--build-model-only`, and later with `--convert-weights-only`.

This commit updates `copy_tokenizer` to create the `params` directory
if it does not already exist, and to provide a full output path to
`shutil.copy` instead of an output directory.

In addition, the `utils.save_params` method has an explicit check to
ensure that any previously existing `$ARTIFACT_DIR/params` is a
directory, to throw an error earlier if not.  (This was the location
of the first observed error when running with `--convert-weights-only`
after running `--build-model-only`.)
@masahi
Copy link
Member

masahi commented Feb 28, 2024

What's the plan for migrating this to the new repo? cc @sunggg

@Lunderberg
Copy link
Member Author

The commits e523125 and 5289e47 ("[BYO-FT] Support execution of transform_params during initialization" and "[Debug] Add LOG.debug statements for safetensor loading") should be the only two that impact the runtime environment. The remainder would all be compile-time changes, and can stay within the mlc-llm repo.

Inside a `T.block`, loop variables may not be used, and access to them
must be done through the corresponding `T.axis.remap` output.
In addition to the inference functions, produce a `transform_params`
in the compiled module.  This function allows the weight conversion to
be performed at a later point.
Loading fine-tunes requires that the model matches the variable names
in the initial model definition.
The `T.reads` and `T.writes` annotations may not reference variables
declared within the body of the block.
For debug purposes, symlink all `*.safetensors` files into the
`original_params` directory of the output.
Validates that the output of `transform_params` is the same as the
input weights required by the inference functions.
This is useful for debugging LoRA implementations, if the `"decode"`
function doesn't yet exist.

The `num_input` attribute should now be generated for all relax to TIR
kernels, and should not need to be added afterward.
Previously, only apply this normalization if the `QuantizationSpec`
provides `float16` where the model had been initially `bfloat16`.
Now, apply this normalization if the model is defined as either
`float16` or `bfloat16`.
@Lunderberg Lunderberg force-pushed the lunderberg/byo_ft_and_lora_support branch from 13755c8 to 1b5620e Compare March 1, 2024 21:48
@masahi
Copy link
Member

masahi commented Mar 1, 2024

Do you have the changes you made to examples/python/run_llama_batched_vllm.py to test exercise a lora-ed model end to end?

Also I need an example lora file that is known to work.

@@ -1516,7 +1553,8 @@ def get_model(args, hf_config):
**hf_config,
dtype=dtype,
position_embedding_base=position_embedding_base,
combine_matmul=True,
# TODO: Re-enable with CombineParallelMatmul
combine_matmul=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is CombineParallelMatmul with lora supported now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, no. It will require improvements to LiftTransformParams, in order to lift out a parameter transformation that can be used for every function in an IRModule.

"""
param_path = pathlib.Path(param_path)

safetensors = LazySafetensorDir(param_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it safe to assume that safetensors are always available? If this reads from the HF model cards, I've seen the cases that this is not true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conversion into safetensors format will be part of the asset ingestion step, so we should always have them available. This is also on the fallback path if the params directory generated with tvm.contrib.tvmjs isn't available, so the normal workflow should still work as intended.

(Side-note: I like the safetensors format a lot more than the pytorch format, because we can read a single tensor without needing to unpack the entire file. This avoids a large amount of caching that is needed when reading from pytorch files, as

@@ -67,36 +132,144 @@ def broadcast_from_worker_0(sess: di.Session, src, shape, dtype):
return dst


def load_tvm_weights(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path seems untested. The function is broken.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are correct, looks like this was a merge error on my part.

@masahi
Copy link
Member

masahi commented Mar 11, 2024

@Lunderberg Have you tested Mixtral for BYO-FT? Even without matmul combining, MLC and PT params don't map one to one. The logic to stack 2D MoE weights into a single 3D weight tensor is needed.

@Lunderberg
Copy link
Member Author

@masahi I have, and am currently running into two issues for Mixtral.

  1. The easiest location to inject the R.concat for MoE weights ends up being part of the runtime function, which causes a ~10x decrease in performance. This should be resolved in the long-term by the LiftTransformParams that @zxybazh is working on, but am about to make a workaround in the meantime.

  2. The Mixtral model is giving results that are irrelevant to the prompt. I'm seeing the same behavior for both the BYO-FT flow, and for the tvmjs flow, so either the issue predates the changes I've made or my testing script is incorrect. Either way, focusing on (1) for Mixtral for now.

@masahi
Copy link
Member

masahi commented Mar 11, 2024

The Mixtral model is giving results that are irrelevant to the prompt.

That's odd. With or without BYO-FT, I'm getting good outputs from Mixtral in the new repo. In the old repo, you should get the same output if you run serve/tests/test_engine.py without BYO-FT.

Prompt = 'Hello, my name is'
Generated 0-th sample = ' Katie and I am a 20-something year old living in the beautiful city of Vancouver'

Prompt = 'The capital of France is'
Generated 0-th sample = ' a city that is known for its beauty, culture, and history. It is a city that is'

Prompt = 'The president of the United States is a powerful man. But he can also be'
Generated 0-th sample = ' a very petty man.

On Tuesday, President Donald Trump took to Twitter to attack the'

Prompt = 'The future of AI is full of promise. But we need to carefully'
Generated 0-th sample = ' consider the ethical implications of the technology.

Artificial intelligence (AI) is a powerful tool'

)

module = sess.load_vm_module(lib_path.as_posix())
params = module["transform_params"](worker_id)
Copy link
Member

@masahi masahi Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After moving matmul weight combining to transform_params, I met a different OOM problem. Param loading works fine but the OOM happens after memory profiling. I realized that this is a fairly nasty issue, and that you will also probably encounter this when you enable matmul combining for Mixtral.

By running transform_params on GPU, params and temporaries are allocated inside the pooled allocator. During memory profiling, the allocator gets the "biggest" input and is supposed to make allocations for all activations derived from that input. But if there are already freed memory in the pool before memory profiling, the allocator may reclaim them to serve an allocation request, rather than doing a fresh allocation. This breaks the contracts of memory profiling, and leads to an under-estimation of required memory to run an inference on the biggest input, or a failure to find a sutiable memory block in the pool during serving.

So we need to make sure that the memory pool in the pooled allocator is empty before memory profiling. The easiest way I found is to use a different allocator in transform_params. My commit in the new repo https://github.com/octoml/mlc-serve/commit/63fd90f648a9bb596e4b9a3dc31a9a2598cce492 does that. It depends on a TVM commit https://github.com/octoml/tvm/commit/44a0d80b8d4a2a12f1fd5d8b08c027bd249481f2 which is now in mlc-serve-old-final.

This assumes that the peak VRAM footprint when running transform_params with the naive allocator doesn't exceed the total bytes of params (otherwise memory profiling becomes inaccurate). It seems to work fine according to my experiment: Running a benchmark script against Mixtral, with or without transform_params, shows a very similar VRAM usage.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable for the short-term, and is a good workaround for the memory profiler limitations. For the long-term, I'm wondering if we should make a transform that generates a memory footprint calculator for each function. This transform would determine the liveset of variables at each step in the Relax compute graph, in terms of any symbolic shapes in the function parameters, and would generate a function that returns the maximum footprint across all livesets in a function.

What I'm picturing for the output:

@R.function
def max_bytes_memory(
    func_name: str,
    vocab_size: R.Prim(value="tir_vocab_size"),
    seq_len: R.Prim(value="tir_seq_len"),
) -> R.Prim("int64"):
    tir_vocab_size = T.int64()
    tir_seq_len = T.int64()

    if func_name == "prefill":
        output = R.prim_value(...)
    elif func_name == "decode":
        output = R.prim_value(...)
    else:
        R.assert_op(
            False,
            func_name,
            format="Function name must in ['prefill','decode'], but was {}",
        )
        output = R.prim_value(-1)
    return output

That way, instead of needing to perform a memory-profiling call with a specific setup for the memory pool, we could call the generated function and use the precomputed memory footprint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, as long as this approach computes the same value as the current dynamic one, I agree that we should switch to this approach. Feel free to work on it if you are interested (might be necessary for LoRA as discussed).

@masahi masahi closed this Mar 18, 2024
Lunderberg pushed a commit to Lunderberg/mlc-llm that referenced this pull request Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants