diff --git a/README.rst b/README.rst
index f7374641d66d..41e326cef44c 100644
--- a/README.rst
+++ b/README.rst
@@ -41,17 +41,43 @@
Latest News
-----------
-- 2023/12/06 `New NVIDIA NeMo Framework Features and NVIDIA H200 `_
+.. raw:: html
-.. image:: https://github.com/sbhavani/TransformerEngine/blob/main/docs/examples/H200-NeMo-performance.png
- :target: https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility
- :alt: H200-NeMo-performance
- :width: 600
+
+ Large Language Models and Multimodal
+
+ Accelerate your generative AI journey with NVIDIA NeMo framework on GKE (2024/03/16)
-NeMo Framework has been updated with state-of-the-art features,
-such as FSDP, Mixture-of-Experts, and RLHF with TensorRT-LLM to provide speedups up to 4.2x for Llama-2 pre-training on H200.
-**All of these features will be available in an upcoming release.**
+ An end-to-end walkthrough to train generative AI models on the Google Kubernetes Engine (GKE) using the NVIDIA NeMo Framework is available at https://github.com/GoogleCloudPlatform/nvidia-nemo-on-gke. The walkthrough includes detailed instructions on how to set up a Google Cloud Project and pre-train a GPT model using the NeMo Framework.
+
+
+
+ Bria Builds Responsible Generative AI for Enterprises Using NVIDIA NeMo, Picasso (2024/03/06)
+
+ Bria, a Tel Aviv startup at the forefront of visual generative AI for enterprises now leverages the NVIDIA NeMo Framework. The Bria.ai platform uses reference implementations from the NeMo Multimodal collection, trained on NVIDIA Tensor Core GPUs, to enable high-throughput and low-latency image generation. Bria has also adopted NVIDIA Picasso, a foundry for visual generative AI models, to run inference.
+
+
+
+
+ New NVIDIA NeMo Framework Features and NVIDIA H200 (2023/12/06)
+
+ NVIDIA NeMo Framework now includes several optimizations and enhancements, including: 1) Fully Sharded Data Parallelism (FSDP) to improve the efficiency of training large-scale AI models, 2) Mix of Experts (MoE)-based LLM architectures with expert parallelism for efficient LLM training at scale, 3) Reinforcement Learning from Human Feedback (RLHF) with TensorRT-LLM for inference stage acceleration, and 4) up to 4.2x speedups for Llama 2 pre-training on NVIDIA H200 Tensor Core GPUs.
+
+
+
+
+
+
+ NVIDIA now powers training for Amazon Titan Foundation models (2023/11/28)
+
+ NVIDIA NeMo framework now empowers the Amazon Titan foundation models (FM) with efficient training of large language models (LLMs). The Titan FMs form the basis of Amazon’s generative AI service, Amazon Bedrock. The NeMo Framework provides a versatile framework for building, customizing, and running LLMs.
+
+
+
+
+
+
Introduction
diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
index 5562e7a34f1f..b136446d97fb 100644
--- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
+++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
@@ -554,7 +554,7 @@ def _graph_reinitialize(
logits_dim=self.joint.num_classes_with_blank,
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
- include_duration_confidence=self.include_duration_confidence
+ include_duration_confidence=self.include_duration_confidence,
)
self.state.all_durations = self.durations.to(self.state.device)
diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
index a660af46f13d..e5e48cdc10da 100644
--- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
+++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
@@ -367,9 +367,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
self.loss_broadcast_src_rank = None
- self.return_output_tensors = cfg.data.get('return_output_tensors', False)
- self.validation_drop_last = cfg.data.get('validation_drop_last', True)
- self.sample_weight = cfg.data.get('sample_weight', 'token')
+ data_cfg = cfg.get('data', {})
+ self.return_output_tensors = data_cfg.get('return_output_tensors', False)
+ self.validation_drop_last = data_cfg.get('validation_drop_last', True)
+ self.sample_weight = data_cfg.get('sample_weight', 'token')
self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False)
self.inference_params = None
diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
index 64e8fe44e1e8..2aeb014c1b40 100644
--- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
+++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py
@@ -86,29 +86,27 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
linear_qkv_output, _ = self.linear_qkv(hidden_states)
layernorm_output = None
- # In megatron/core/models/gpt/gpt_layer_specs.py TELayerNormColumnParallelLinear is used for linear_qkv.
- # TELayerNormColumnParallelLinear fused LN and linear, both will be returned.
- # In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear is used for linear_qkv,
+ # In megatron/core/models/gpt/gpt_layer_specs.py when fused module is used(e.g. TELayerNormColumnParallelLinear)
+ # both LN and qkv will be returned.
+ # In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear(non-fused) is used for linear_qkv,
# which only returns linear.
- if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear):
- mixed_qkv, layernorm_output = linear_qkv_output
- elif isinstance(self.linear_qkv, TEColumnParallelLinear): # only mixed_qkv
+ if isinstance(linear_qkv_output, tuple):
+ if len(linear_qkv_output) == 2: # fused module, qkv&LN
+ mixed_qkv, layernorm_output = linear_qkv_output
+ else:
+ raise ValueError(f"Unexpected number of outputs from linear_qkv output: {len(linear_qkv_output)}")
+ else: # for qkv&LN not fused only mixed_qkv
mixed_qkv = linear_qkv_output
- else:
- raise ValueError(
- f"Unrecognized module type '{type(self.linear_qkv)}' when getting query, key, value tensors for mcore mixins. "
- )
# LoRA logic
if self.is_adapter_available():
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
- if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear):
+ if layernorm_output is not None:
lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
- elif isinstance(self.linear_qkv, TEColumnParallelLinear):
- lora_mixed_qkv = lora_kqv_adapter(hidden_states)
else:
- raise ValueError(f"Unrecognized module type '{type(self.linear_qkv)}' when applying lora.")
+ lora_mixed_qkv = lora_kqv_adapter(hidden_states)
+
mixed_qkv = mixed_qkv + lora_mixed_qkv
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
diff --git a/tools/ctc_segmentation/scripts/prepare_data.py b/tools/ctc_segmentation/scripts/prepare_data.py
index c6ea024273fb..476e719eb51b 100644
--- a/tools/ctc_segmentation/scripts/prepare_data.py
+++ b/tools/ctc_segmentation/scripts/prepare_data.py
@@ -26,6 +26,8 @@
from tqdm import tqdm
from nemo.collections.asr.models import ASRModel
+from nemo.collections.asr.models.ctc_models import EncDecCTCModel
+from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.utils import model_utils
try:
@@ -354,7 +356,19 @@ def _split(sentences, delimiter):
asr_model = ASRModel.from_pretrained(model_name=args.model) # type: ASRModel
model_name = args.model
- vocabulary = asr_model.cfg.decoder.vocabulary
+ if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)):
+ raise NotImplementedError(
+ f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
+ " Currently only instances of these models are supported"
+ )
+
+ # get vocabulary list
+ if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based
+ vocabulary = asr_model.tokenizer.vocab
+ elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based
+ vocabulary = asr_model.cfg.decoder.vocabulary
+ else:
+ raise ValueError("Unexpected model type. Vocabulary list not found.")
if os.path.isdir(args.in_text):
text_files = glob(f"{args.in_text}/*.txt")
diff --git a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py
index dddeb9a42dc2..c9d9ed2d8731 100644
--- a/tools/ctc_segmentation/scripts/run_ctc_segmentation.py
+++ b/tools/ctc_segmentation/scripts/run_ctc_segmentation.py
@@ -27,6 +27,8 @@
from utils import get_segments
import nemo.collections.asr as nemo_asr
+from nemo.collections.asr.models.ctc_models import EncDecCTCModel
+from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
parser = argparse.ArgumentParser(description="CTC Segmentation")
parser.add_argument("--output_dir", default="output", type=str, help="Path to output directory")
@@ -72,18 +74,19 @@
logging.basicConfig(handlers=handlers, level=level)
if os.path.exists(args.model):
- asr_model = nemo_asr.models.EncDecCTCModel.restore_from(args.model)
- elif args.model in nemo_asr.models.EncDecCTCModel.get_available_model_names():
- asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(args.model, strict=False)
+ asr_model = nemo_asr.models.ASRModel.restore_from(args.model)
else:
- try:
- asr_model = nemo_asr.models.EncDecCTCModelBPE.from_pretrained(args.model)
- except:
- raise ValueError(
- f"Provide path to the pretrained checkpoint or choose from {nemo_asr.models.EncDecCTCModel.get_available_model_names()}"
- )
+ asr_model = nemo_asr.models.ASRModel.from_pretrained(args.model, strict=False)
+
+ if not (isinstance(asr_model, EncDecCTCModel) or isinstance(asr_model, EncDecHybridRNNTCTCModel)):
+ raise NotImplementedError(
+ f"Model is not an instance of NeMo EncDecCTCModel or ENCDecHybridRNNTCTCModel."
+ " Currently only instances of these models are supported"
+ )
- bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE)
+ bpe_model = isinstance(asr_model, nemo_asr.models.EncDecCTCModelBPE) or isinstance(
+ asr_model, nemo_asr.models.EncDecHybridRNNTCTCBPEModel
+ )
# get tokenizer used during training, None for char based models
if bpe_model:
@@ -91,8 +94,18 @@
else:
tokenizer = None
+ if isinstance(asr_model, EncDecHybridRNNTCTCModel):
+ asr_model.change_decoding_strategy(decoder_type="ctc")
+
# extract ASR vocabulary and add blank symbol
- vocabulary = ["ε"] + list(asr_model.cfg.decoder.vocabulary)
+ if hasattr(asr_model, 'tokenizer'): # i.e. tokenization is BPE-based
+ vocabulary = asr_model.tokenizer.vocab
+ elif hasattr(asr_model.decoder, "vocabulary"): # i.e. tokenization is character-based
+ vocabulary = asr_model.cfg.decoder.vocabulary
+ else:
+ raise ValueError("Unexpected model type. Vocabulary list not found.")
+
+ vocabulary = ["ε"] + list(vocabulary)
logging.debug(f"ASR Model vocabulary: {vocabulary}")
data = Path(args.data)
@@ -136,9 +149,14 @@
logging.debug(f"len(signal): {len(signal)}, sr: {sample_rate}")
logging.debug(f"Duration: {original_duration}s, file_name: {path_audio}")
- log_probs = asr_model.transcribe(audio=[str(path_audio)], batch_size=1, return_hypotheses=True)[
+ hypotheses = asr_model.transcribe([str(path_audio)], batch_size=1, return_hypotheses=True)
+ # if hypotheses form a tuple (from Hybrid model), extract just "best" hypothesis
+ if type(hypotheses) == tuple and len(hypotheses) == 2:
+ hypotheses = hypotheses[0]
+ log_probs = hypotheses[
0
- ].alignments
+ ].alignments # note: "[0]" is for batch dimension unpacking (and here batch size=1)
+
# move blank values to the first column (ctc-package compatibility)
blank_col = log_probs[:, -1].reshape((log_probs.shape[0], 1))
log_probs = np.concatenate((blank_col, log_probs[:, :-1]), axis=1)