Junfeng Jiang1, Qiang Zhang2, Akiko Aizawa1, Renjing Xu2
University of Tokyo1 The Hong Kong University of Science and Technology2
BioMed-LLaMA-7b is a large language model (LLM) having 7 billion parameters pretrained continuously from MetaAI's LLaMA-7b checkpoint on biomedical abstracts and papers from The Pile, namely, the PubMed-abstract and PubMed-central subsets.
In this repository, we also provide the codes for continuous pretraining, finetuning, and evaluation. Hope that this work can be beneficial to the biomedical NLP community.
The Pile is a large-scale high-quality dataset of diverse text sources that is designed to be used for pretraining large language models. It contains 825 GiB of text from 22 diverse sources, including Wikipedia, PubMed abstracts, PubMed Central papers, etc. We extracted the PubMed-abstract and PubMed-central subsets from The Pile as our pretraining resources, which contain approximately 30M abstracts and 5M papers.
After extraction, we obtained 213 GiB of text containing about 63B tokens. We trained the LLaMA-7b model on these data for 1 epoch to avoid overfitting to the pretraining data.
Since it is a continuous pretraining, we mainly follow the hyperparameters of LLaMA-7b as shown below.
max_seq_length | 2048 |
lr | 3e-5 |
batch size | 2048 |
betas | [0.9, 0.95] |
weight decay | 0.1 |
gradient clipping | 1.0 |
The model was trained on an 8-node HPC cluster containing 32 NVIDIA A100-80GB GPUs in total lasting about a week.
We conducted several optimization strategies to speed up training and reduce memory consumption.
- We used PyTorch FSDP to enable model parallelism. However, since the network bandwidth across nodes in our cluster is limited, we adopted hybrid sharing strategy to reduce node-wise communication cost. If you need this feature for your project, you can install the modified version of transformers from here: Coldog2333/transformers (based on transformers v4.28.1).
- Gradient accumulation is also applied to reduce GPU-wise communication cost.
- We also used xformers to conduct effective attention computation to reduce memory consumption and speed up training.
- Mixed precision training (bf16+tf32) is also used to reduce memory consumption and speed up training. Though the data type of LLaMA's model weights is float16, we didn't observe any difference between fp16 and bf16 training in our preliminary experiments.
Here below is the curve of training loss, where running average smoothing is applied for visualization.
We conducted comparison mainly with vanilla LLaMA-7B, PMC-LLaMA, and BioMedLM. Some other models are also included for some of the downstream tasks. Evaluating language models on some downstream tasks (e.g., QA) is not trivial since they tend to generate free-form answers. Therefore, we show the potential accuracy of them by computing the perplexity of each option given the question (and the abstract for PubMedQA) using lm-evaluation-harness. The option with lowest perplexity is chosen as the final answer.
Since MedQA and MedMCQA are not implemented by EleutherAI, we implemented them by ourselves. So please use the version of lm-evaluation-harness in this repository to evaluate them.
Note that BioMedLM was trained on the same pretraining resources but more epochs (6 epochs in total containing 300B tokens), and PMC-LLaMA-7B was trained on 4.8M PubMedCentral papers for 5 epochs.
Model | Strategy | PubMed-A | PubMed-C | USMLE (4/5) | MedMCQA | PubMedQA |
---|---|---|---|---|---|---|
Random | - | - | - | 0.25 / 0.5 | 0.25 | 0.33 |
GPT-Neo (2.7B) | 0-shot | 19.1207 | 20.8701 | 0.2781 / 0.2412 | 0.2570 | 0.5640 |
BioMedLM (2.7B) | 0-shot | 15.6959 | 18.6799 | 0.2993 / 0.2624 | 0.2744 | 0.5520 |
LLaMA-7B | 0-shot | 20.1107 | 29.0583 | 0.3339 / 0.2742 | 0.2933 | 0.7520 |
PMC-LLaMA-7B | 0-shot | 36.8191 | 39.5381 | 0.3441 / 0.2883 | 0.2850 | 0.6640 |
BioMed-LLaMA-7B | 0-shot | 15.7774 | 20.9322 | 0.3535 / 0.3032 | 0.2921 | 0.6160 |
LLaMA-7B | few-shot | - | - | 0.3661 (3) / 0.3174(3) | 0.2991 (10) | 0.713 (1) |
BioMed-LLaMA-7B | few-shot | - | - | 0.3668 (3) / 0.3229 (3) | 0.3007 (10) | 0.702 (1) |
LLaMA-7B | fine-tune | - | - | 0.3946±0.008 | 0.4994 | 0.764 |
BioMed-LLaMA-7B | fine-tune | - | - | 0.4072±0.012 | 0.5357 | 0.763 |
*PubMed-A: Pile/PubMed-Abstracts, PubMed-C: Pile/PubMed-Central, USMLE: MedQA-USMLEQA
Existing commercial LLMs achieve an excellent performance on medical tasks like USMLE-QA, especially when performing few-shot inference. However, they usually have tremendous number of parameters, so the inference requires many computation resources and time, especially when adding few-shot demonstrations to the inputting prompt. Finetuning on these demonstrations is also impossible. However, our model is quite smaller and we have many downstream tasks to be evaluated, so we conducted instruction tuning with these few-shot examples instead of performing in-context prompting.
We collected diverse instruction tuning data from various resources:
Source | #Sample | MixtureP | Domain |
---|---|---|---|
MedQA-USMLE/train | 10178 | 21.45% | Medical |
MedMCQA/train | 182822 | 25.69% | Medical |
PubMedQA/train | 211269 | 14.84% | Medical |
AlpacaDataCleaned | 51760 | 18.18% | Open |
visual-med-alpaca | 54412 | 19.11% | Medical |
medpalm | 24 | 0.05% | Medical |
medpalm-cot | 19 | 0.04% | Medical |
medpalm2-cot | 19 | 0.04% | Medical |
mmlu-cot | 282 | 0.6% | Science |
codex-cot | 3 | 0.006% | Medical |
After instruction tuning, we can find that BioMed-LLaMA can benefit more than vanilla LLaMA from the instruction tuning, especially on MedQA-USMLE. However, the performances on MedMCQA and PubMedQA are not improved comparing to finetuning. We think that there are three possible reasons:
- During instruction tuning, even though we have a large number of training samples for MedMCQA and PubMedQA, these data only contain part of the original training data. So the models may not be able to learn the full distribution of the training data, and therefore perform worse than finetuning with the whole training datasets.
- The questions of MedMCQA are quite short, whereas other instruction tuning data generally has longer input.
- The answers of PubMedQA are quite short (Yes/No/Maybe), making them more difficult to optimize during jointly training.
Model | Strategy | USMLE (4) | MedMCQA | PubMedQA |
---|---|---|---|---|
LLaMA-7B | instructed | 0.4391 | 0.4236 | 0.744 |
BioMed-LLaMA-7B | instructed | 0.487 | 0.4475 | 0.757 |
This research is supported by HKUST and JST SPRING on the computing resources and funding. Thanks for the MetaAI for sharing the LLaMA models. Thanks for other researchers for sharing the data and codes. Specially, thanks @anchen1011 for providing us valuable suggestions on this project.
Please cite this repo if you find the codes or contents are useful for your research.
@misc{biomedllama,
author = {Junfeng Jiang, Qiang Zhang, Akiko Aizawa, and Renjing Xu},
title = {BioMed-LLaMA: Continuous Pretraining LLaMA with Biomedical Abstracts and Papers},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/Coldog2333/BioMed-LLaMA}},
}
@article{touvron2023llama,
title={LLaMA: Open and Efficient Foundation Language Models},
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and Rodriguez, Aurelien and Joulin, Armand and Grave, Edouard and Lample, Guillaume},
journal={arXiv preprint arXiv:2302.13971},
year={2023}
}