Skip to content

Commit

Permalink
add nemotron5 conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
JRD971000 authored and Ali Taghibakhshi committed Nov 12, 2024
1 parent 5c5b023 commit 0d9bb4f
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ model:
post_process: True # add pooler
megatron_legacy: False
persist_layer_norm: True

squared_relu_activation: True
params_dtype: bf16
tokenizer:
library: 'huggingface'
type: 'EleutherAI/gpt-neox-20b'
Expand All @@ -87,7 +88,7 @@ model:
use_fast: True

# Distributed checkpoint setup
dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_format: 'torch_dist' # Set to 'torch_dist' to use PyTorch distributed checkpoint format.
dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU
dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import torch
import torch.nn.functional as F
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

Expand All @@ -38,6 +39,8 @@

HAVE_MEGATRON_CORE = False

def squared_relu(x):
return torch.pow(F.relu(x), 2)

class MegatronMambaModel(MegatronGPTModel):
"""
Expand All @@ -62,6 +65,15 @@ def model_provider_func(self, pre_process, post_process):
self.transformer_config.add_bias_linear = self.cfg.get('add_bias_linear', False)
self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False)
self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5)
if self.cfg.get('params_dtype'):
self.transformer_config.params_dtype = torch.bfloat16
else:
self.transformer_config.params_dtype = torch.float32
self.transformer_config.params_dtype=torch.bfloat16
if self.cfg.get('kv_channels'):
self.transformer_config.kv_channels = self.cfg.get('kv_channels')
if self.cfg.get('squared_relu_activation'):
self.transformer_config.activation_func = squared_relu

model = MambaModel(
config=self.transformer_config,
Expand Down
Loading

0 comments on commit 0d9bb4f

Please sign in to comment.