Skip to content

Commit

Permalink
Add missing import guards for causal_conv1d and mamba_ssm dependencies (
Browse files Browse the repository at this point in the history
#10429) (#10506)

* Add causal_conv1d import guard



* Add mamba_ssm import guard



* Apply isort and black reformatting



---------

Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
Signed-off-by: janekl <janekl@users.noreply.github.com>
Co-authored-by: Jan Lasek <janek.lasek@gmail.com>
Co-authored-by: janekl <janekl@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
4 people authored Sep 26, 2024
1 parent d401c9d commit c5381e2
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch._dynamo
from accelerated_scan.triton import scan
from causal_conv1d import causal_conv1d_fn
from einops import rearrange
from torch import nn

Expand All @@ -40,6 +39,13 @@
TransformerConfig = ApexGuardDefaults
HAVE_MEGATRON_CORE = False

try:
from causal_conv1d import causal_conv1d_fn

HAVE_CAUSAL_CONV1D = True
except (ImportError, ModuleNotFoundError):
HAVE_CAUSAL_CONV1D = False

torch._dynamo.config.suppress_errors = True


Expand Down Expand Up @@ -277,6 +283,8 @@ def __call__(

class Conv1D(MegatronModule):
def __init__(self, config, width, temporal_width):
if not HAVE_CAUSAL_CONV1D:
raise ImportError("Package causal_conv1d is required to use Conv1D")
super().__init__(config=config)
self.config = config
self.width = width
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,49 @@
# limitations under the License.

import torch
from megatron.core.models.mamba import MambaModel
from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
from nemo.utils import logging

try:
import mamba_ssm

HAVE_MAMBA_SSM = True

except ModuleNotFoundError:

HAVE_MAMBA_SSM = False

try:
from megatron.core.models.mamba import MambaModel
from megatron.core.models.mamba.mamba_layer_specs import mamba_stack_spec

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):

HAVE_MEGATRON_CORE = False


class MegatronMambaModel(MegatronGPTModel):
"""
Megatron Mamba pretraining.
"""

def __init__(self, cfg: DictConfig, trainer: Trainer):

if not HAVE_MEGATRON_CORE or not HAVE_MAMBA_SSM:
raise ImportError("Both megatron.core and mamba_ssm packages are required to use MegatronMambaModel")
self.vocab_size = cfg.get('vocab_size', 65536)
self.cfg = cfg
super().__init__(cfg=cfg, trainer=trainer)
logging.warning("Overriding mcore_gpt=True")
self.mcore_gpt = True

def model_provider_func(self, pre_process, post_process):

if not HAVE_MEGATRON_CORE or not HAVE_MAMBA_SSM:
raise ImportError("Both megatron.core and mamba_ssm packages are required to use MegatronMambaModel")
self.hybrid_override_pattern = self.cfg.get(
'hybrid_override_pattern', "M" * self.transformer_config.num_layers
)
Expand Down

0 comments on commit c5381e2

Please sign in to comment.