Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sbert to IR #8445

Merged
merged 7 commits into from
Feb 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 91 additions & 5 deletions docs/source/nlp/information_retrieval.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,96 @@
.. _information_retrieval:

Information Retrieval
=====================
Sentence-BERT
=============

We recommend you try the Information Retrieval model in a Jupyter notebook (can run on `Google's Colab <https://colab.research.google.com/notebooks/intro.ipynb>`_): `NeMo/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb <https://github.com/NVIDIA/NeMo/blob/stable/tutorials/nlp/Information_Retrieval_MSMARCO.ipynb>`__.
Sentence-BERT (SBERT) is a modification of the BERT model that is specifically trained to generate semantically meaningful sentence embeddings.
The model architecture and pre-training process are detailed in the `Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks <https://aclanthology.org/D19-1410.pdf>`__ paper. Similar to BERT,
Sentence-BERT utilizes a BERT-based architecture, but it is trained using a siamese and triplet network structure to derive fixed-sized sentence embeddings that capture semantic information.
Sentence-BERT is commonly used to generate high-quality sentence embeddings for various downstream natural language processing tasks, such as semantic textual similarity, clustering, and information retrieval

Connect to an instance with a GPU (**Runtime** -> **Change runtime type** -> select **GPU** for hardware the accelerator),
Data Input for the Senntence-BERT model
---------------------------------------

An example script on how to train the model can be found here: `NeMo/examples/nlp/information_retrieval <https://github.com/NVIDIA/NeMo/tree/stable/examples/nlp/information_retrieval>`__.
The fine-tuning data for the Sentence-BERT (SBERT) model should consist of data instances,
each comprising a query, a positive document, and a list of negative documents. Negative mining is
not supported in NeMo yet; therefore, data preprocessing should be performed offline before training.
The dataset should be in JSON format. For instance, the dataset should have the following structure:

.. code-block:: python

[
{
"question": "Query",
"pos_doc": ["Positive"],
"neg_doc": ["Negative_1", "Negative_2", ..., "Negative_n"]
},
{
// Next data instance
},
...,
{
// Subsequent data instance
}
]

This format ensures that the fine-tuning data is appropriately structured for training the Sentence-BERT model.


Fine-tuning the Sentence-BERT model
-----------------------------------

For fine-tuning Sentence-BERT model, you need to initialze the Sentence-BERT model with BERT model
checkpoint. To do so, you should either have a ``.nemo`` checkpoint or need to convert a HuggingFace
BERT checkpoint to NeMo using the following:

.. code-block:: python

python NeMo/scripts/nlp_language_modeling/convert_bert_hf_to_nemo.py \
--input_name_or_path "intfloat/e5-large-unsupervised" \
--output_path /path/to/output/nemo/file.nemo

Then you can fine-tune the sentence-BERT model using the following script:

.. code-block:: python


#!/bin/bash

PROJECT= # wandb project name
NAME= # wandb run name
export WANDB_API_KEY= # your_wandb_key


NUM_DEVICES=1 # number of gpus to train on


CONFIG_PATH="/NeMo/examples/nlp/information_retrieval/conf/"
CONFIG_NAME="megatron_bert_config"
PATH_TO_NEMO_MODEL= # Path to conveted nemo model from hf
DATASET_PATH= # Path to json dataset
SAVE_DIR= # where the checkpoint and logs are saved
mkdir -p $SAVE_DIR


python /NeMo/examples/nlp/language_modeling/megatron_sbert_pretraining.py \
--config-path=${CONFIG_PATH} \
--config-name=${CONFIG_NAME} \
restore_from_path=${PATH_TO_NEMO_MODEL} \
trainer.devices=${NUM_DEVICES} \
trainer.val_check_interval=100 \
trainer.max_epochs=1 \
+trainer.num_sanity_val_steps=0 \
model.global_batch_size=8 \ # should be NUM_DEVICES * model.micro_batch_size
model.micro_batch_size=8 \
model.tokenizer.library="huggingface" \
model.tokenizer.type="intfloat/e5-large-unsupervised" \
++model.data.data_prefix=${DATASET_PATH} \
++model.tokenizer.do_lower_case=False \
++model.data.evaluation_sample_size=100 \
++model.data.hard_negatives_to_train=4 \
++model.data.evaluation_steps=100 \
exp_manager.explicit_log_dir=${SAVE_DIR} \
exp_manager.create_wandb_logger=True \
exp_manager.resume_if_exists=True \
exp_manager.wandb_logger_kwargs.name=${NAME} \
exp_manager.wandb_logger_kwargs.project=${PROJECT}
160 changes: 160 additions & 0 deletions examples/nlp/information_retrieval/conf/megatron_sbert_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
name: megatron_bert
restore_from_path: null # used when starting from a .nemo file

trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0
benchmark: False

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_bert
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# model parallelism
mcore_bert: False
micro_batch_size: 4
global_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null

# model architecture
encoder_seq_length: 512
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: 'learned_absolute' # Position embedding type. Options ['learned_absolute', 'rope', 'alibi', 'kerple' , 'xpos', 'sandwich'] xpos and sandwich are experimental.
num_layers: 24
hidden_size: 1024
ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 16
skip_head: True
transformer_block_type: post_ln
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
hidden_dropout: 0.1 # Dropout probability for hidden state transformer.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: False # scale Q * K^T by 1 / layer-number.
normalization: layernorm
layernorm_epsilon: 1e-12
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.
pre_process: True # add embedding
post_process: True # add pooler
bert_binary_head: True # BERT binary head
megatron_legacy: True

tokenizer:
library: 'huggingface'
type: 'intfloat/e5-large-unsupervised'
model: null
vocab_file: null
merge_file: null

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125
grad_div_ar_fusion: False

# miscellaneous
seed: 1234
use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

## Activation Checkpointing
# NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed.
# These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+).
# See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details.
# 'full' will checkpoint the entire transformer layer.
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block'
# 'uniform' divides the total number of transformer layers and checkpoints the input activation
# of each chunk at the specified granularity. When used with 'selective', 'uniform' checkpoints all attention blocks in the model.
# 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity
activations_checkpoint_num_layers: null
# when using 'uniform' this creates groups of transformer layers to checkpoint. Usually set to 1. Increase to save more memory.
# when using 'block' this this will checkpoint the first activations_checkpoint_num_layers per pipeline stage.
num_micro_batches_with_partial_activation_checkpoints: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value is provided, it sets the number of micro-batches where only a partial number of Transformer layers get checkpointed
# and recomputed within a window of micro-batches. The rest of micro-batches in the window checkpoint all Transformer layers. The size of window is
# set by the maximum outstanding micro-batch backpropagations, which varies at different pipeline stages. The number of partial layers to checkpoint
# per micro-batch is set by 'activations_checkpoint_num_layers' with 'activations_checkpoint_method' of 'block'.
# This feature enables using activation checkpoint at a fraction of micro-batches up to the point of full GPU memory usage.
activations_checkpoint_layers_per_pipeline: null
# This feature is valid only when used with pipeline-model-parallelism.
# When an integer value (rounded down when float is given) is provided, it sets the number of Transformer layers to skip checkpointing at later
# pipeline stages. For example, 'activations_checkpoint_layers_per_pipeline' of 3 makes pipeline stage 1 to checkpoint 3 layers less than
# stage 0 and stage 2 to checkpoint 6 layers less stage 0, and so on. This is possible because later pipeline stage
# uses less GPU memory with fewer outstanding micro-batch backpropagations. Used with 'num_micro_batches_with_partial_activation_checkpoints',
# this feature removes most of activation checkpoints at the last pipeline stage, which is the critical execution path.
sequence_parallel: False

data:
# Path to data must be specified by the user.
# can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]",
# Or see example below:
# data_prefix:
# - .5
# - /raid/data/pile/my-gpt3_00_text_document
# - .5
# - /raid/data/pile/my-gpt3_01_text_document
data_prefix: [1.0, /path/to/data]
index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix
data_impl: mmap
splits_string: 900,50,50
seq_length: ${model.encoder_seq_length}
skip_warmup: True
num_workers: 0
dataloader_type: single # cyclic, LDDL
reset_position_ids: False # Reset position ids after end-of-document token
reset_attention_mask: False # Reset attention mask after end-of-document token
eod_mask_loss: False # Mask loss for the end of document tokens
masked_lm_prob: 0.15 # Probability of replacing a token with mask.
short_seq_prob: 0.1 # Probability of producing a short sequence.

optim:
name: fused_adam
lr: 2e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2e-5
59 changes: 59 additions & 0 deletions examples/nlp/information_retrieval/megatron_sbert_finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict

from nemo.collections.nlp.models.information_retrieval.megatron_sbert_model import MegatronSBertModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronBertTrainerBuilder
from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="megatron_bert_config")
def main(cfg) -> None:
if cfg.model.data.dataloader_type != "LDDL":
mp.set_start_method("spawn", force=True)

logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')

trainer = MegatronBertTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

model_cfg = MegatronSBertModel.merge_cfg_with(cfg.restore_from_path, cfg)

assert (
model_cfg.micro_batch_size * cfg.trainer.devices == model_cfg.global_batch_size
), "Gradiant accumulation is not supported for contrastive learning yet"

OmegaConf.set_struct(model_cfg, True)
with open_dict(model_cfg):
model_cfg.precision = trainer.precision

model = MegatronSBertModel.restore_from(
restore_path=cfg.restore_from_path,
trainer=trainer,
save_restore_connector=NLPSaveRestoreConnector(),
override_config_path=model_cfg,
strict=True,
)

trainer.fit(model)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ def __init__(
if token is not None and token not in self.tokenizer.get_vocab():
new_tokens_in_vocab.append(token)

# value is required for megatron-core
self.unique_identifiers = OrderedDict()

if len(new_tokens_in_vocab) > 0:
"""
Special tokens that were not previously included in the tokenizer's vocabulary file will be added to
Expand Down
Loading
Loading