Skip to content

Commit 7a9b071

Browse files
dhiaEddineRhaiemyounesbelkadaArthurZuckervasquydshieh
authored
[Falcon H1] Fix slow path forward pass (#38320)
* Create push-important-models.yml * feat: add falcon-h1 * fixup * address comment * fix * fix copies * fix copies * fix * fix * fix * fix * fix copies * fix * fix copies * fix test import to at least trigget the cis * yups * update * fix make fix copies * fix inits? * fix style * skip annoying test * add integration test for Falcon H1 * fix copies * fix * fix typo * make style * fix slow path generations * clean debug traces * debug * remove debug traces final confirmation * clean debug traces final * fix format and lineup * make style * debug * Update src/transformers/models/falcon_h1/modular_falcon_h1.py Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> * adress comments * fix fix-copies * fix integration test * Merge pull request #7 from ydshieh/fix-slow-path update * another update (#8) * update * update --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Younes Belkada <younesbelkada@gmail.com> Co-authored-by: younesbelkada <younes.belkada@tii.ae> Co-authored-by: Arthur Zucker <arthur.zucker@gmail.com> Co-authored-by: Anton Vlasjuk <73884904+vasqu@users.noreply.github.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent b5b76b5 commit 7a9b071

File tree

3 files changed

+41
-30
lines changed

3 files changed

+41
-30
lines changed

src/transformers/models/falcon_h1/modeling_falcon_h1.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -604,9 +604,10 @@ def cuda_kernels_forward(
604604
):
605605
# 1. Gated MLP's linear projection
606606
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
607+
# Add Multipliers
607608
hidden_states = hidden_states * self.ssm_in_multiplier
608609
projected_states = self.in_proj(hidden_states)
609-
projected_states = projected_states * self.mup_vector
610+
projected_states = projected_states * self.mup_vector # ADD Mup Multipliers
610611
d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
611612

612613
# Set up dimensions for reshapes later
@@ -800,10 +801,13 @@ def torch_forward(
800801

801802
# 1. Gated MLP's linear projection
802803
input_states = apply_mask_to_padding_states(input_states, attention_mask)
804+
# Add Multipliers
805+
input_states = input_states * self.ssm_in_multiplier
803806
projected_states = self.in_proj(input_states)
804-
gate, hidden_states_B_C, dt = projected_states.split(
805-
[self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
806-
)
807+
projected_states = projected_states * self.mup_vector # ADD Mup Multipliers
808+
gate, hidden_states_B_C, dt = projected_states.split([
809+
self.intermediate_size, self.conv_dim, self.num_heads
810+
], dim=-1)
807811

808812
use_precomputed_states = (
809813
cache_params is not None
@@ -914,8 +918,8 @@ def torch_forward(
914918
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
915919
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
916920
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
917-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
918-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
921+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
922+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
919923
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
920924

921925
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

src/transformers/models/falcon_h1/modular_falcon_h1.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,9 +415,10 @@ def cuda_kernels_forward(
415415
):
416416
# 1. Gated MLP's linear projection
417417
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
418+
# Add Multipliers
418419
hidden_states = hidden_states * self.ssm_in_multiplier
419420
projected_states = self.in_proj(hidden_states)
420-
projected_states = projected_states * self.mup_vector
421+
projected_states = projected_states * self.mup_vector # ADD Mup Multipliers
421422
d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
422423

423424
# Set up dimensions for reshapes later
@@ -611,10 +612,13 @@ def torch_forward(
611612

612613
# 1. Gated MLP's linear projection
613614
input_states = apply_mask_to_padding_states(input_states, attention_mask)
615+
# Add Multipliers
616+
input_states = input_states * self.ssm_in_multiplier
614617
projected_states = self.in_proj(input_states)
615-
gate, hidden_states_B_C, dt = projected_states.split(
616-
[self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
617-
)
618+
projected_states = projected_states * self.mup_vector # ADD Mup Multipliers
619+
gate, hidden_states_B_C, dt = projected_states.split([
620+
self.intermediate_size, self.conv_dim, self.num_heads
621+
], dim=-1)
618622

619623
use_precomputed_states = (
620624
cache_params is not None
@@ -725,8 +729,8 @@ def torch_forward(
725729
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
726730
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
727731
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
728-
B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
729-
C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
732+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
733+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
730734
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
731735

732736
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)

tests/models/falcon_h1/test_modeling_falcon_h1.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -484,24 +484,27 @@ def test_falcon_h1_hard(self):
484484
"""
485485
An integration test for Falcon-H1.
486486
"""
487-
EXPECTED_TEXT = (
488-
"Tell me about the french revolution.\n"
489-
"The French Revolution (1789–1799) was a period of radical social and political upheaval in France that "
490-
"fundamentally transformed the nation and had profound effects on the rest of Europe and the world. Here are the key aspects of the revolution:\n\n"
491-
"### **Causes**\n"
492-
"1. **Economic Crisis**: France was in severe financial trouble due to costly wars (particularly the American Revolution), extravagant spending by the monarchy, and inefficient taxation.\n"
493-
"2. **Social Inequality**: The rigid class system (the Ancien Régime) divided society into the privileged nobility and clergy (First Estate) and the common people (Third Estate), who bore the brunt of taxation and had few rights.\n"
494-
"3. **Enlightenment Ideas**: Philosophers like Rousseau, Voltaire, and Montesquieu inspired ideas of liberty, equality, and popular sovereignty.\n"
495-
"4. **Settlement of 1789**: The Estates-General convened to address the financial crisis, leading to the Third Estate's assertion of its rights and the eventual formation of the National Assembly.\n\n"
496-
"### **Key Events**\n"
497-
"1. **Opening of the Revolution (1789)**:\n"
498-
"- **Storming of the Bastille**: Symbolic of the fall of royal tyranny.\n"
499-
"- **Declaration of the Rights of Man and of the Citizen**: Proclaimed universal rights to liberty, property, and security.\n"
500-
"- **Creation of the National Assembly**: The Third Estate declared itself the representative body of France.\n\n"
501-
"2. **Radical Phase (1792–1794)**:\n"
502-
"- **Reign of Terror**: Led by Maximilien Robespierre, the Committee of Public Safety enforced radical egalitarianism through the guillotine, executing thousands of perceived enemies of the revolution (monarchists, clergy, aristocrats, and counter-revolutionaries).\n"
503-
"- **Execution of Louis XVI**: The king was guillotined in June 1793, symbolizing the end of the monarchy.\n"
504-
)
487+
EXPECTED_TEXT = """
488+
user
489+
Tell me about the french revolution.
490+
assistant
491+
The French Revolution (1789–1799) was a period of radical social and political upheaval in France that fundamentally transformed the nation and had profound effects on the rest of Europe and the world. Here are the key aspects of the revolution:
492+
493+
### **Causes**
494+
1. **Economic Crisis**: France was in severe financial trouble due to costly wars (particularly the American Revolution), extravagant spending by the monarchy, and inefficient taxation.
495+
2. **Social Inequality**: The rigid class system (the Ancien Régime) divided society into the privileged nobility and clergy (First Estate) and the commoners (Third Estate), who bore the brunt of taxation and had few rights.
496+
3. **Enlightenment Ideas**: Philosophers like Voltaire, Rousseau, and Montesquieu inspired ideas of liberty, equality, and popular sovereignty.
497+
4. **Settlement of 1789**: The Estates-General convened to address the financial crisis, leading to the Third Estate's assertion of its rights and the eventual abolition of the feudal system.
498+
499+
### **Key Events**
500+
1. **Storming of the Bastille (July 14, 1789)**: A symbol of royal tyranny, the Bastille fortress was stormed by revolutionaries, sparking widespread rebellion.
501+
2. **Declaration of the Rights of Man and of the Citizen (August 1789)**: A foundational document proclaiming liberty, equality, and fraternity.
502+
3. **National Assembly and King’s Trial (1791–1792)**: King Louis XVI and his ministers were tried and executed (King Louis was guillotined, Marie Antoinette was banished), marking the end of the monarchy.
503+
4. **Rise of the Jacobins and Reign of Terror (1793–1794)**: Radical leaders like Maximilien Robespierre sought to purge France of counter-revolutionaries, leading to mass executions and widespread fear.
504+
5. **Thermidorian Reaction
505+
"""
506+
# Remove the first char (`\n`) and the consecutive whitespaces caused by the formatting.
507+
EXPECTED_TEXT = EXPECTED_TEXT.strip().replace(" " * 12, "")
505508

506509
model_id = "tiiuae/Falcon-H1-1.5B-Deep-Instruct"
507510
tokenizer = AutoTokenizer.from_pretrained(model_id)

0 commit comments

Comments
 (0)