Skip to content

Commit

Permalink
added mlp and attn bias option to flash and paged llama models (IBM#85)
Browse files Browse the repository at this point in the history
#### Motivation

The `Calico` models currently set the mlp and attention bias to true,
which was hard-coded to false in flash and paged llama implementations.
This will use the config params set in
huggingface/transformers#30031 to set those
values properly.

#### Modifications

- added attention_bias, mlp_bias to config for Flash and Paged Llama
implementations (default is False)
- set bias in attention and mlp to the config value

#### Result

Models should be able to load properly if containing attention and mlp
bias

---------

Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
  • Loading branch information
2 people authored and Xaenalt committed Aug 12, 2024
1 parent 6e68de5 commit 5953686
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 10 deletions.
5 changes: 5 additions & 0 deletions server/text_generation_server/inference_engine/tgis_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ def __init__(
model_class = FlashRWForCausalLM

elif model_type == "llama":
# See: https://github.com/ibm-granite/vllm_granite/blob/main/vllm/model_executor/models/llama.py#L353-L354
if self._config.tie_word_embeddings:
aliases = {
"lm_head.weight": ["model.embed_tokens.weight"]
}
if PAGED_ATTENTION:
from text_generation_server.models.custom_modeling.paged_llama_modeling import PagedLlamaForCausalLM
model_class = PagedLlamaForCausalLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
attention_bias=False,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -84,6 +86,8 @@ def __init__(
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias

super().__init__(
pad_token_id=pad_token_id,
Expand Down Expand Up @@ -168,7 +172,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))


class FlashLlamaAttention(torch.nn.Module):
Expand Down Expand Up @@ -209,13 +213,13 @@ def __init__(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=config.attention_bias,
)

def forward(
Expand Down Expand Up @@ -298,13 +302,13 @@ def __init__(self, prefix, config, weights):
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
bias=config.mlp_bias,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
bias=config.mlp_bias,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(
tie_word_embeddings=False,
rope_scaling=None,
rope_theta=10000.0,
attention_bias=False,
mlp_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -85,6 +87,8 @@ def __init__(
self.use_cache = use_cache
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias

super().__init__(
pad_token_id=pad_token_id,
Expand Down Expand Up @@ -169,7 +173,7 @@ def _load_gqa(config, prefix: str, weights):
config.hidden_size,
], f"{list(weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

return TensorParallelColumnLinear(get_linear(weight, bias=None, quantize=config.quantize))
return TensorParallelColumnLinear(get_linear(weight, bias=config.attention_bias, quantize=config.quantize))


class PagedLlamaAttention(torch.nn.Module):
Expand Down Expand Up @@ -207,13 +211,13 @@ def __init__(
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=False,
bias=config.attention_bias,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
bias=config.attention_bias,
)

def forward(
Expand Down Expand Up @@ -280,13 +284,13 @@ def __init__(self, prefix, config, weights):
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
bias=config.mlp_bias,
)
self.down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
bias=config.mlp_bias,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
Expand Down

0 comments on commit 5953686

Please sign in to comment.