On AllNLI, training w/ MPNRL has higher training throughput and better memory utilization than training w/ MNRL. They're on par in terms of task performance, but I need to run more experiments.
No-duplicates sampling causes batch sizes to decay if there's high skewnewss in the number of positives per anchors.
Plot for AllNLI, no-duplicates sampling
Reproduce by running:
python compare_dataloaders.py \
--dataset_name "sentence-transformers/all-nli" \
--dataset_config "triplet" \
--dataset_split "train" \
--batch_size 128 \
--dataset_size_train 10000 \
--seed 42
Here are CUDA memory snapshots across time for MNRL + AllNLI (first 10k triplets, inputted batch size of 200):
The drops in memory are caused by drops in the batch size. There is a long tail of under-utilization. Peak usage is determined by the first few batches, which is a small portion of time.
It's simpler to use a loss which seamlessley handles multiple positives. As a result, training throughput is higher, and GPU utilization (in terms of % memory and % time) is more stable. Data loading itself is also 15x faster, as there's no de-duplication.
Here are CUDA memory snapshots across time for MPNRL:
Here's a comparison of time-based GPU utilization:
The small experiment in ./demos/train_allnli.ipynb
demonstrates that task/statistical performance is on par with MNRL.
In an experiment on the first 100k triplets in AllNLI and an inputted batch size of 200, MNRL took ~33 minutes while MPNRL took ~20 minutes. Statistical performance was similar.
python -m pip install git+https://github.com/kddubey/mpnrl.git
To run ./run.py
, clone the repo and then:
python -m pip install ".[demos]"
NOTE: this isn't meant to be a stable Python package. There are many TODOs.
Make sure to not use the no-duplicates sampler for MPNRL.
from sentence_transformers.sampler import BatchSamplers
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
)
import mpnrl
model = SentenceTransformer("your-model")
train_dataset = ...
# records of {"anchor": ..., "positive": ..., "negative"(s): ...}
trainer = SentenceTransformerTrainer(
model=model
train_dataset=train_dataset,
args=SentenceTransformerTrainingArguments(
...
batch_sampler=BatchSamplers.BATCH_SAMPLER,
),
loss=mpnrl.losses.MultiplePositivesNegativesRankingLoss(model),
data_collator=mpnrl.data_collator.GroupingDataCollator(
train_dataset, tokenize_fn=model.tokenize
),
)
trainer.train()
There's a small demo in ./demos/train_allnli.ipynb
.
-
mpnrl.collator
TODO
s. -
mpnrl.loss
TODO
s. - Measure how long it takes for MNRL vs MPRNL to get to a good model (pearson/spearman correlation on validation data).
- Repeat for a few datasets and study how the level of data duplication affects these outcomes.