Skip to content

Commit

Permalink
add sbert to IR (#8445)
Browse files Browse the repository at this point in the history
* add  sbert to IR

Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>

* add doc

Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>

* fix the  auto_tokenizer property method reset bug

Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>

* addressed bot comments

Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: ataghibakhsh <ataghibakhsh@nvidia.com>
Co-authored-by: Eric Harper <complex451@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Feb 17, 2024
1 parent 5c1b8d1 commit 8222634
Show file tree
Hide file tree
Showing 6 changed files with 1,198 additions and 8 deletions.
96 changes: 91 additions & 5 deletions docs/source/nlp/information_retrieval.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,96 @@
.. _information_retrieval:

Information Retrieval
=====================
Sentence-BERT
=============

We recommend you try the Information Retrieval model in a Jupyter notebook (can run on `Google's Colab <https://colab.research.google.com/notebooks/intro.ipynb>`_): `NeMo/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb <https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb>`__.
Sentence-BERT (SBERT) is a modification of the BERT model that is specifically trained to generate semantically meaningful sentence embeddings.
The model architecture and pre-training process are detailed in the `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://aclanthology.org/D19-1410.pdf>`__ paper. Similar to BERT,
Sentence-BERT utilizes a BERT-based architecture, but it is trained using a siamese and triplet network structure to derive fixed-sized sentence embeddings that capture semantic information.
Sentence-BERT is commonly used to generate high-quality sentence embeddings for various downstream natural language processing tasks, such as semantic textual similarity, clustering, and information retrieval

Connect to an instance with a GPU (**Runtime** -> **Change runtime type** -> select **GPU** for hardware the accelerator),
Data Input for the Senntence-BERT model
---------------------------------------

An example script on how to train the model can be found here: `NeMo/examples/nlp/information_retrieval <https://github.com/NVIDIA/NeMo/tree/stable/examples/nlp/information_retrieval>`__.
The fine-tuning data for the Sentence-BERT (SBERT) model should consist of data instances,
each comprising a query, a positive document, and a list of negative documents. Negative mining is
not supported in NeMo yet; therefore, data preprocessing should be performed offline before training.
The dataset should be in JSON format. For instance, the dataset should have the following structure:

.. code-block:: python
[
{
"question": "Query",
"pos_doc": ["Positive"],
"neg_doc": ["Negative_1", "Negative_2", ..., "Negative_n"]
},
{
// Next data instance
},
...,
{
// Subsequent data instance
}
]
This format ensures that the fine-tuning data is appropriately structured for training the Sentence-BERT model.


Fine-tuning the Sentence-BERT model
-----------------------------------

For fine-tuning Sentence-BERT model, you need to initialze the Sentence-BERT model with BERT model
checkpoint. To do so, you should either have a ``.nemo`` checkpoint or need to convert a HuggingFace
BERT checkpoint to NeMo using the following:

.. code-block:: python
python NeMo/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py \
--input_name_or_path "intfloat/e5-large-unsupervised" \
--output_path /path/to/output/nemo/file.nemo
Then you can fine-tune the sentence-BERT model using the following script:

.. code-block:: python
#!/bin/bash
PROJECT= # wandb project name
NAME= # wandb run name
export WANDB_API_KEY= # your_wandb_key
NUM_DEVICES=1 # number of gpus to train on
CONFIG_PATH="/NeMo/examples/nlp/information_retrieval/conf/"
CONFIG_NAME="megatron_bert_config"
PATH_TO_NEMO_MODEL= # Path to conveted nemo model from hf
DATASET_PATH= # Path to json dataset
SAVE_DIR= # where the checkpoint and logs are saved
mkdir -p $SAVE_DIR
python /NeMo/examples/nlp/language_modeling/megatron_sbert_pretraining.py \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
restore_from_path=${PATH_TO_NEMO_MODEL} \
trainer.devices=${NUM_DEVICES} \
trainer.val_check_interval=100 \
trainer.max_epochs=1 \
+trainer.num_sanity_val_steps=0 \
model.global_batch_size=8 \ # should be NUM_DEVICES * model.micro_batch_size
model.micro_batch_size=8 \
model.tokenizer.library="huggingface" \
model.tokenizer.type="intfloat/e5-large-unsupervised" \
++model.data.data_prefix=${DATASET_PATH} \
++model.tokenizer.do_lower_case=False \
++model.data.evaluation_sample_size=100 \
++model.data.hard_negatives_to_train=4 \
++model.data.evaluation_steps=100 \
exp_manager.explicit_log_dir=${SAVE_DIR} \
exp_manager.create_wandb_logger=True \
exp_manager.resume_if_exists=True \
exp_manager.wandb_logger_kwargs.name=${NAME} \
exp_manager.wandb_logger_kwargs.project=${PROJECT}
160 changes: 160 additions & 0 deletions examples/nlp/information_retrieval/conf/megatron_sbert_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
name: megatron_bert
restore_from_path: null # used when starting from a .nemo file

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0
benchmark: False

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_bert
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# model parallelism
mcore_bert: False
micro_batch_size: 4
global_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null

# model architecture
encoder_seq_length: 512
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental.
num_layers: 24
hidden_size: 1024
ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 16
skip_head: True
transformer_block_type: post_ln
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: False # scale Q * K^T by 1 / layer-number.
normalization: layernorm
layernorm_epsilon: 1e-12
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
bert_binary_head: True # BERT binary head
megatron_legacy: True

tokenizer:
library: 'huggingface'
type: 'intfloat/e5-large-unsupervised'
model: null
vocab_file: null
merge_file: null

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125
grad_div_ar_fusion: False

# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null
# when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory.
# when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage.
num_micro_batches_with_partial_activation_checkpoints: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed
# and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is
# set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint
# per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'.
# This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage.
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later
# pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than
# stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage
# uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints',
# this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path.
sequence_parallel: False

data:
# Path to data must be specified by the user.
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-gpt3_00_text_document
# - .5
# - /raid/data/pile/my-gpt3_01_text_document
data_prefix: [1.0, /path/to/data]
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic, LDDL
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
masked_lm_prob: 0.15 # Probability of replacing a token with mask.
short_seq_prob: 0.1 # Probability of producing a short sequence.

optim:
name: fused_adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5
59 changes: 59 additions & 0 deletions examples/nlp/information_retrieval/megatron_sbert_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.models.information_retrieval.megatron_sbert_model import MegatronSBertModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronBertTrainerBuilder
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="megatron_bert_config")
def main(cfg) -> None:
if cfg.model.data.dataloader_type != "LDDL":
mp.set_start_method("spawn", force=True)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

trainer = MegatronBertTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

model_cfg = MegatronSBertModel.merge_cfg_with(cfg.restore_from_path, cfg)

assert (
model_cfg.micro_batch_size * cfg.trainer.devices == model_cfg.global_batch_size
), "Gradiant accumulation is not supported for contrastive learning yet"

OmegaConf.set_struct(model_cfg, True)
with open_dict(model_cfg):
model_cfg.precision = trainer.precision

model = MegatronSBertModel.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
save_restore_connector=NLPSaveRestoreConnector(),
override_config_path=model_cfg,
strict=True,
)

trainer.fit(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def __init__(
if token is not None and token not in self.tokenizer.get_vocab():
new_tokens_in_vocab.append(token)

# value is required for megatron-core
self.unique_identifiers = OrderedDict()

if len(new_tokens_in_vocab) > 0:
"""
Special tokens that were not previously included in the tokenizer's vocabulary file will be added to
Expand Down
Loading

0 comments on commit 8222634

Please sign in to comment.