Skip to content

Commit

Permalink
update nemo args for mcore flash decode arg change (#11138)
Browse files Browse the repository at this point in the history
* update mcore transformer layer, block, attention forward args for flash decode

Signed-off-by: Huiying Li <willwin.lee@gmail.com>

* Apply isort and black reformatting

Signed-off-by: HuiyingLi <HuiyingLi@users.noreply.github.com>

* update mcore tag in Dockerfile.ci

Signed-off-by: HuiyingLi <willwin.lee@gmail.com>

---------

Signed-off-by: Huiying Li <willwin.lee@gmail.com>
Signed-off-by: HuiyingLi <HuiyingLi@users.noreply.github.com>
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: HuiyingLi <HuiyingLi@users.noreply.github.com>
  • Loading branch information
2 people authored and yashaswikarnati committed Nov 21, 2024
1 parent 911ebfb commit 8e6e438
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.19.0
ARG MCORE_TAG=213c8a23fa9fe95d19eff0932a1e6e71767f0962
ARG MCORE_TAG=441cb9250101cf2cc406f0439b802f34f923f251

ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class FalconTransformerLayer(TransformerLayer):
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""

def __init__(
Expand Down Expand Up @@ -106,6 +106,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def forward(
rotary_pos_emb = None
self.decoder.input_tensor = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(None, self.decoder, hidden_states, self.config)
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
None, self.decoder, hidden_states, self.config, None
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)

hidden_states = self.decoder(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,21 @@ def forward(
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
hidden_states = super().forward(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
inference_params,
packed_seq_params,
)

mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER)
Expand Down Expand Up @@ -220,6 +230,8 @@ def forward(
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
):
# hidden_states: [sq, b, h]

Expand All @@ -237,8 +249,8 @@ def forward(
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb
)

if packed_seq_params is not None:
Expand Down

0 comments on commit 8e6e438

Please sign in to comment.