Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update nemo args for mcore flash decode arg change #11138

Merged
merged 5 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T
# Install NeMo requirements
ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea
ARG MODELOPT_VERSION=0.19.0
ARG MCORE_TAG=213c8a23fa9fe95d19eff0932a1e6e71767f0962
ARG MCORE_TAG=441cb9250101cf2cc406f0439b802f34f923f251

ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
RUN \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class FalconTransformerLayer(TransformerLayer):

Transformer layer takes input with size [s, b, h] and returns an
output of the same size.

"""

def __init__(
Expand Down Expand Up @@ -106,6 +106,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ def forward(
context=None,
context_mask=None,
rotary_pos_emb=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
inference_params=None,
packed_seq_params=None, # TODO: handle this
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def forward(
rotary_pos_emb = None
self.decoder.input_tensor = None
if self.position_embedding_type == 'rope':
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(None, self.decoder, hidden_states, self.config)
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(
None, self.decoder, hidden_states, self.config, None
)
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len)

hidden_states = self.decoder(hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,21 @@ def forward(
context: Tensor = None,
context_mask: Tensor = None,
rotary_pos_emb: Tensor = None,
rotary_pos_cos: Tensor = None,
rotary_pos_sin: Tensor = None,
inference_params: InferenceParams = None,
packed_seq_params: PackedSeqParams = None,
):
hidden_states = super().forward(
hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb,
rotary_pos_cos,
rotary_pos_sin,
inference_params,
packed_seq_params,
)

mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER)
Expand Down Expand Up @@ -220,6 +230,8 @@ def forward(
inference_params=None,
rotary_pos_emb=None,
packed_seq_params=None,
rotary_pos_cos=None,
rotary_pos_sin=None,
):
# hidden_states: [sq, b, h]

Expand All @@ -237,8 +249,8 @@ def forward(
# ===================================================
# Adjust key, value, and rotary_pos_emb for inference
# ===================================================
key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, key, value, rotary_pos_emb
query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference(
inference_params, query, key, value, rotary_pos_emb
)

if packed_seq_params is not None:
Expand Down
Loading