-
-
Notifications
You must be signed in to change notification settings - Fork 123
4. Adding a model
Aphrodite supports a variety of HuggingFace model types already. You can see the full list here. You can easily add a custom model if you want to.
All the supported models in Aphrodite have been adapted from the HuggingFace modeling code. For example, modeling_llama.py. You will then need to modify the code.
You will need to make a few changes. Namely:
- Remove any unnecessary code related to training.
- Change the input parameters:
def forward(
self,
input_ids: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
-) -> Union[Tuple, CausalLMOutputWithPast]:
+ positions: torch.Tensor,
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+) -> Optional[SamplerOutput]:
- Update the code by considering that
input_ids
andpositions
are now flattened tensors. - Replace the attention ops with either
PagedAttention
,PagedAttentionWithRoPE
orPagedAttentionWithALiBi
, depending on the model's architecture.
Aphrodite supports the basic multi-head attention mechanisms (MHA, MQA, GQA, VGQA) and its variant with rotary positional embeddings and Alibi. If your model employrs a different attention mechanism, you will need to implement a new attention layer.
If your model is too large to fit into a single GPU, you can use tensor parallelism. To do this, substitute your model's linear and embedding layers with their tensor-parallel versions. For the embedding layer, you can simply replace nn.Embedding
with VocabParallelEmbedding
. The output lm_head, you can use ParallelLMHead
. When it comes to linear layers, we provide the following options to parallelize them:
-
ReplicatedLinear
: Replicates the input and weights across multiple GPUs. No memory saving. -
RowParallelLinear
: The input tensor is partitioned along the hidden dim. The weight matrix is partitioned along the rows (input dim). An all-reduce operation is performed after the matmul to reduce the results. Typically used for the second FFN layer and the output linear transformation of the attention layr. -
ColumnParallelLinear
: The input tensor is replicated. The weight matrix is partitioned along the columns (output dimension). The result is partitioned along the column dimension. Typically used for the first FFN layer and the separated QKV transformation of the attention layer in the original Transformer. -
MergedColumnParallelLinear
: Column-parallel linear that merges multiple ColumnParallelLinear operators. Typically used for the first FFN layer with weighted activation functions (e.g. SiLU). This class handles the sharded weight loading logic for multiple weight matrices. -
QKVParallelLinear
: Parallel linear layer for the query, key, and value projections of the multi-head and grouped-query attention mechanisms. When number of key/value heads are less than the world size, this class replicates the KV heads properly. This class handles the weight loading and replication of the weight matrices.
Note that all the linear layers above take linear_method as input. Aphrodite will set this parameter according to different quantization schemes to support weight quantization.
You now need to implement the load_weights
method in your *ForCausalLM
class. This method should load the weights from the HuggingFace's checkpoint file(s) and assign them to the corresponding layers in your model. Specifically, for MergedColumnParallelLinear
and QKVParallelLinear
layers, if the original model has separated weight matrices, you need to load the different parts separately.
Finally, include your *ForCausalLM
class in aphrodite/modeling/models/init.py and register it to the _MODEL_REGISTERY
in aphrodite/modeling/loader.py.