Skip to content

Commit

Permalink
Fix documentation for Numba (#7065) (#7077)
Browse files Browse the repository at this point in the history
* Fix documentation for Numba



* Update force float32 flag dynamically



* Update force float32 flag dynamically



* Fix nemo version



---------

Signed-off-by: smajumdar <titu1994@gmail.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Signed-off-by: jubick1337 <mattyson.so@gmail.com>
  • Loading branch information
3 people authored and jubick1337 committed Aug 8, 2023
1 parent f3a7ba0 commit c3a55f9
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ Built for speed, NeMo can utilize NVIDIA's Tensor Cores and scale out training t
Requirements
------------

1) Python 3.8 or above
2) Pytorch 1.10.0 or above
1) Python 3.9 or above
2) Pytorch 1.13.1 or above
3) NVIDIA GPU for training

Documentation
Expand Down
2 changes: 1 addition & 1 deletion docs/source/nlp/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ Datasets
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTDataset
:show-inheritance:

.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset.GPTSFTChatDataset
.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_chat_dataset.GPTSFTChatDataset
:show-inheritance:

.. autoclass:: nemo.collections.nlp.data.language_modeling.megatron.retro_dataset.RETRODataset
Expand Down
4 changes: 2 additions & 2 deletions docs/source/starthere/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Prerequisites

Before you begin using NeMo, it's assumed you meet the following prerequisites.

#. You have Python version 3.6, 3.7 or 3.8.
#. You have Python version 3.9, 3.10.

#. You have Pytorch version 1.8.1.
#. You have Pytorch version 1.13.1 or 2.0+.

#. You have access to an NVIDIA GPU for training.

Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/losses/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class RNNTLossConfig:
min_version='0.53.0',
is_available=NUMBA_RNNT_AVAILABLE,
installation_msg=NUMBA_INSTALLATION_MESSAGE,
force_float32=not numba_utils.NUMBA_FP16_SUPPORTED,
force_float32=False, # This is only temporarily false, will be dynamically updated during resolution
),
"pytorch": RNNTLossConfig(
loss_name="pytorch",
Expand Down Expand Up @@ -258,6 +258,9 @@ def resolve_rnnt_loss(loss_name: str, blank_idx: int, loss_kwargs: dict = None)
_warn_unused_additional_kwargs(loss_name, loss_kwargs)

elif loss_name == 'warprnnt_numba':
# Update loss config's forced float32 flag if set to None
loss_config.force_float32 = not numba_utils.is_numba_cuda_fp16_supported()

fastemit_lambda = loss_kwargs.pop('fastemit_lambda', 0.0)
clamp = loss_kwargs.pop('clamp', -1.0)
loss_func = RNNTLossNumba(blank=blank_idx, reduction='none', fastemit_lambda=fastemit_lambda, clamp=clamp)
Expand Down Expand Up @@ -444,7 +447,7 @@ def forward(self, log_probs, targets, input_lengths, target_lengths):
max_targets_len = target_lengths.max()

# Force cast joint to float32
if not self._force_float32 and numba_utils.NUMBA_FP16_SUPPORTED:
if not self._force_float32 and numba_utils.is_numba_cuda_fp16_supported():
# Execute the kernel in fp16
pass
elif self._force_float32 and log_probs.dtype != torch.float32:
Expand Down
11 changes: 6 additions & 5 deletions nemo/core/utils/numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
__NUMBA_MINIMUM_VERSION__ = os.environ.get("NEMO_NUMBA_MINVER", __NUMBA_DEFAULT_MINIMUM_VERSION__)

__NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__ = "0.57.0"
NUMBA_FP16_SUPPORTED = model_utils.check_lib_version(
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
)[0]


NUMBA_INSTALLATION_MESSAGE = (
Expand Down Expand Up @@ -171,12 +168,16 @@ def is_numba_cuda_fp16_supported(return_reason: bool = False) -> Union[bool, Tup
use_nvidia_binding = False
reason += "Env variable `NUMBA_CUDA_USE_NVIDIA_BINDING` is not available or has not set to `1`."

if NUMBA_FP16_SUPPORTED:
numba_fp16_version_correct = model_utils.check_lib_version(
'numba', __NUMBA_MINIMUM_VERSION_FP16_SUPPORTED__, operator=operator.ge
)[0]

if numba_fp16_version_correct:
reason += f"Numba CUDA FP16 is supported in installed numba version."
else:
reason += f"Numba CUDA FP16 is not supported in installed numba version."

result = use_nvidia_binding and NUMBA_FP16_SUPPORTED
result = use_nvidia_binding and numba_fp16_version_correct

if return_reason:
return result, reason
Expand Down
3 changes: 2 additions & 1 deletion nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import copy
import importlib
import os
from dataclasses import dataclass, is_dataclass
from enum import Enum
Expand Down Expand Up @@ -554,7 +555,7 @@ def check_lib_version(lib_name: str, checked_version: str, operator) -> Tuple[Op
if '.' in lib_name:
mod = import_class_by_path(lib_name)
else:
mod = __import__(lib_name)
mod = importlib.import_module(lib_name)

if hasattr(mod, '__version__'):
lib_ver = version.Version(mod.__version__)
Expand Down

0 comments on commit c3a55f9

Please sign in to comment.