-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SLM] Support BERT architecture. Implement a text embedding module
- Loading branch information
Showing
9 changed files
with
630 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
"""The Python API for MLC Embeddings.""" | ||
|
||
import json | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import numpy as np | ||
import tvm | ||
from tvm import relax | ||
from tvm.contrib import tvmjs | ||
from tvm.runtime import Device, Module | ||
from tvm.runtime.relax_vm import VirtualMachine | ||
|
||
from mlc_llm.chat_module import _get_model_path | ||
from mlc_llm.serve import engine_utils | ||
from mlc_llm.support.auto_device import detect_device | ||
from mlc_llm.tokenizer import Tokenizer | ||
|
||
|
||
def _extract_metadata(mod: Module): | ||
return json.loads(VirtualMachine(mod, tvm.runtime.device("cpu"))["_metadata"]()) | ||
|
||
|
||
def _load_params( | ||
model_weight_path: str, device: Device, model_metadata: Dict[str, Any] | ||
) -> List[tvm.nd.NDArray]: | ||
params, meta = tvmjs.load_ndarray_cache(model_weight_path, device) | ||
param_names = [param["name"] for param in model_metadata["params"]] | ||
assert len(param_names) == meta["ParamSize"] | ||
|
||
plist = [] | ||
for param_name in param_names: | ||
plist.append(params[param_name]) | ||
return plist | ||
|
||
|
||
def _get_tvm_module( | ||
model_weight_path: str, lib_path: str, device: Device, instrument: tvm.runtime.PackedFunc = None | ||
): | ||
ex = tvm.runtime.load_module(lib_path) | ||
vm = relax.VirtualMachine(ex, device) | ||
if instrument: | ||
vm.set_instrument(instrument) | ||
metadata = _extract_metadata(ex) | ||
params = _load_params(model_weight_path, device, metadata) | ||
return vm.module, params, metadata | ||
|
||
|
||
class DefaultDebugInstrument: | ||
"""The default debug instrument to use if users don't specify | ||
a customized one. | ||
This debug instrument will dump the arguments and output of each | ||
VM Call instruction into a .npz file. It will also alert the user | ||
if any function outputs are NaN or INF. | ||
""" | ||
|
||
def __init__(self, debug_out: Path): | ||
"""Constructor | ||
Parameters | ||
---------- | ||
debug_out : Path | ||
the directory to dump the .npz files | ||
""" | ||
self.counter = 0 | ||
self.first_nan_occurred = False | ||
self.first_inf_occurred = False | ||
self.debug_out = debug_out | ||
debug_out.mkdir(exist_ok=True, parents=True) | ||
|
||
def reset(self, debug_out: Path): | ||
"""Reset the state of the Instrument class | ||
Parameters | ||
---------- | ||
debug_out : Path | ||
the directory to dump the .npz files | ||
""" | ||
self.counter = 0 | ||
self.first_nan_occurred = False | ||
self.first_inf_occurred = False | ||
self.debug_out = debug_out | ||
debug_out.mkdir(exist_ok=True, parents=True) | ||
|
||
def __call__(self, func, name, before_run, ret_val, *args): | ||
# Determine what functions to look at | ||
if before_run: # Whether before the function is called or after | ||
return | ||
if name.startswith("vm.builtin.") and "attention_with_fused_qkv" not in name: | ||
return | ||
|
||
# Decide what to print or save about the function's arguments (where args[-1] is the | ||
# buffer we write the result to) | ||
func_name = f"f{self.counter}_{name}" | ||
|
||
# Save the arguments to npz | ||
arg_dict = {} | ||
for i, arg in enumerate(args): | ||
if isinstance(arg, tvm.nd.NDArray): | ||
arg_dict[f"arg_{i}"] = arg.numpy() | ||
|
||
np.savez(self.debug_out / f"{func_name}.npz", **arg_dict) | ||
|
||
self.counter += 1 | ||
|
||
|
||
class MLCEmbeddings: # pylint: disable=too-few-public-methods | ||
"""A class to embed queries using MLC LLM encoder models. | ||
Parameters | ||
---------- | ||
model: str | ||
The model folder after compiling with MLC-LLM build process. The parameter | ||
can either be the model name with its quantization scheme | ||
(e.g. ``Llama-2-7b-chat-hf-q4f16_1``), or a full path to the model | ||
folder. In the former case, we will use the provided name to search | ||
for the model folder over possible paths. | ||
model_lib_path : str | ||
The full path to the model library file to use (e.g. a ``.so`` file). | ||
device : Optional[str] | ||
The description of the device to run on. User should provide a string in the | ||
form of 'device_name:device_id' or 'device_name', where 'device_name' is one of | ||
'cuda', 'metal', 'vulkan', 'rocm', 'opencl', 'auto' (automatically detect the | ||
local device), and 'device_id' is the device id to run on. If no 'device_id' | ||
is provided, it will be set to 0 by default. | ||
debug_dir: Path | ||
The output folder to store the dumped debug files. If None, will not dump any debug files. | ||
""" | ||
|
||
def __init__( # pylint: disable=too-many-arguments | ||
self, | ||
model: str, | ||
model_lib_path: str, | ||
device: Optional[str] = "auto", | ||
debug_dir: Optional[str] = None, | ||
): | ||
|
||
self.device = detect_device(device) | ||
instrument = DefaultDebugInstrument(Path(debug_dir)) if debug_dir else None | ||
self.mod, self.params, self.metadata = _get_tvm_module( | ||
model, model_lib_path, self.device, instrument | ||
) | ||
self.model_path, _ = _get_model_path(model) | ||
self.tokenizer = Tokenizer(self.model_path) | ||
self.prefill_func = self.mod["prefill"] | ||
|
||
def embed(self, queries: List[str]) -> tvm.runtime.NDArray: | ||
""" | ||
Embeds a list of queries in a single batch. | ||
Parameters | ||
---------- | ||
queries : List[str] | ||
A list of queries to embed. | ||
Returns | ||
------- | ||
List[float] | ||
A list of embeddings for the queries. | ||
""" | ||
tokens, attention_mask = self._tokenize_queries(queries) | ||
tokens_tvm = tvm.nd.array(tokens.astype("int32"), device=self.device) | ||
attention_mask_tvm = tvm.nd.array(attention_mask.astype("int32"), device=self.device) | ||
output = self.prefill_func(tokens_tvm, attention_mask_tvm, self.params) | ||
return output | ||
|
||
def _tokenize_queries(self, queries: List[str]) -> Tuple[np.ndarray, np.ndarray]: | ||
tokens = engine_utils.process_prompts(queries, self.tokenizer.encode) | ||
max_query_length = max(len(token_seq) for token_seq in tokens) | ||
|
||
token_inputs = np.zeros((len(tokens), max_query_length), dtype=np.int32) | ||
attention_mask = np.zeros((len(tokens), max_query_length), dtype=np.int32) | ||
|
||
for i, token_seq in enumerate(tokens): | ||
token_inputs[i, : len(token_seq)] = token_seq | ||
attention_mask[i, : len(token_seq)] = 1 | ||
|
||
return token_inputs, attention_mask |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
""" | ||
This file specifies how MLC's BERT parameter maps from other formats, for example HuggingFace | ||
PyTorch, HuggingFace safetensors. | ||
""" | ||
import functools | ||
|
||
import numpy as np | ||
|
||
from mlc_llm.loader import ExternMapping | ||
from mlc_llm.quantization import Quantization | ||
|
||
from .bert_model import BertConfig, BertModel | ||
|
||
|
||
def huggingface(model_config: BertConfig, quantization: Quantization) -> ExternMapping: | ||
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to | ||
the names of HuggingFace PyTorch parameters. | ||
Parameters | ||
---------- | ||
model_config : BertConfig | ||
The configuration of the BERT model. | ||
quantization : Quantization | ||
The quantization configuration. | ||
Returns | ||
------- | ||
param_map : ExternMapping | ||
The parameter mapping from MLC to HuggingFace PyTorch. | ||
""" | ||
model = BertModel(model_config) | ||
if quantization is not None: | ||
model.to(quantization.model_dtype) | ||
_, _named_params, _ = model.export_tvm( # type: ignore[misc] | ||
spec=model.get_default_spec(), | ||
allow_extern=True, | ||
) | ||
named_parameters = dict(_named_params) | ||
|
||
mapping = ExternMapping() | ||
|
||
for i in range(model_config.num_hidden_layers): | ||
attn = f"encoder.layer.{i}.attention.self" | ||
mlc_name = f"{attn}.qkv.weight" | ||
mlc_param = named_parameters[mlc_name] | ||
mapping.add_mapping( | ||
mlc_name, | ||
[ | ||
f"{attn}.query.weight", | ||
f"{attn}.key.weight", | ||
f"{attn}.value.weight", | ||
], | ||
functools.partial( | ||
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), | ||
dtype=mlc_param.dtype, | ||
), | ||
) | ||
|
||
mlc_name = f"{attn}.qkv.bias" | ||
mlc_param = named_parameters[mlc_name] | ||
mapping.add_mapping( | ||
mlc_name, | ||
[ | ||
f"{attn}.query.bias", | ||
f"{attn}.key.bias", | ||
f"{attn}.value.bias", | ||
], | ||
functools.partial( | ||
lambda q, k, v, dtype: np.concatenate([q, k, v], axis=0).astype(dtype), | ||
dtype=mlc_param.dtype, | ||
), | ||
) | ||
|
||
for mlc_name, mlc_param in named_parameters.items(): | ||
if mlc_name not in mapping.param_map: | ||
mapping.add_mapping( | ||
mlc_name, | ||
[mlc_name], | ||
functools.partial( | ||
lambda x, dtype: x.astype(dtype), | ||
dtype=mlc_param.dtype, | ||
), | ||
) | ||
|
||
return mapping |
Oops, something went wrong.