-
Notifications
You must be signed in to change notification settings - Fork 65
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #190 from X-LANCE/dev-slam-omni
Add reproduction for SLAM-Omni
- Loading branch information
Showing
219 changed files
with
81,661 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# SLAM-Omni | ||
[![Python 3.10](https://img.shields.io/badge/Python-3.10-blue.svg)](https://www.python.org/downloads/release/python-3100/) [![arXiv](https://img.shields.io/badge/arXiv-2412.15649-B31B1B.svg)](https://arxiv.org/abs/2412.15649) [![GitHub Demo Page](https://img.shields.io/badge/Github-Demo%20Page-orange.svg)](https://slam-omni.github.io/) [![License](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) | ||
|
||
(*Reproduction of the [paper](https://arxiv.org/abs/2412.15649) SLAM-Omni: Timbre-Controllable Voice Interaction System with Single-Stage Training.*) | ||
|
||
## Environment Setup | ||
Set up the environment using the following commands after preparing the SLAM-LLM environment: | ||
```bash | ||
pip install -r ./examples/s2s/requirements.txt | ||
``` | ||
|
||
Alternatively, you can use our provided Docker image: | ||
```bash | ||
docker pull worstchan/slam-omni:v0 | ||
docker run -it --gpus all --name slam-omni worstchan/slam-omni:v0 /bin/bash | ||
``` | ||
|
||
## Data Preparation | ||
|
||
Our project supports two data formats: **Parquet** and **JSONL**. The open-source datasets are available on the Hugging Face Hub in **Parquet** format. Examples usage is provided in [this notebook](./demo/demo_data/demo.ipynb). | ||
|
||
### Supported Datasets | ||
We provide three re-synthesized datasets for SLAM-Omni training: | ||
- [VoiceAssistant-400K](https://huggingface.co/datasets/worstchan/VoiceAssistant-400K-SLAM-Omni): Single-round English dialogue dataset. | ||
- [UltraChat-300K](https://huggingface.co/datasets/worstchan/UltraChat-300K-SLAM-Omni): Multi-round English dialogue dataset. | ||
- [Belle_1.4M](https://huggingface.co/datasets/worstchan/Belle_1.4M-SLAM-Omni): Multi-round Chinese dialogue dataset. | ||
|
||
#### Usage | ||
You can load any of these datasets using the following code: | ||
```python | ||
from datasets import load_dataset | ||
|
||
# Replace "DATASET_NAME" with one of the following: | ||
# - "worstchan/VoiceAssistant-400K-SLAM-Omni" | ||
# - "worstchan/UltraChat-300K-SLAM-Omni" | ||
# - "worstchan/Belle_1.4M-SLAM-Omni" | ||
|
||
ds = load_dataset("DATASET_NAME") | ||
``` | ||
|
||
### JSONL | ||
We also support JSONL format for its concise structure. Below is an example: | ||
```jsonl | ||
{"key": "1", "source_wav": "/xxx/1.wav", "source_text": "Can you recommend some Chinese food for me?", "target_wav": "/xxx/1.wav", "target_text": "Sure! I recommend trying dumplings, Peking duck, and mapo tofu for a mix of flavors and textures in Chinese cuisine. These dishes offer a good balance of savory, spicy, and crispy elements."} | ||
``` | ||
|
||
## Checkpoints | ||
We reproduced the single-stage fine-tuning results of SLAM-Omni with a group size of **3**. The following checkpoints are available for download: | ||
- [Single-Round Dialogue (English)](https://drive.google.com/drive/folders/1ZmM1h5ZTvS-piuN-msmctmZdi51GWLAu?usp=sharing): Trained on VoiceAssistant-400K. | ||
- [Multi-Round Dialogue (English)](https://drive.google.com/drive/folders/1xBNrqR2LWC0uEjezjx4aUgdsbstisboS?usp=sharing): Trained on VoiceAssistant-400K and UltraChat-300K. | ||
- [Multi-Round Dialogue (Chinese)](https://drive.google.com/drive/folders/1sExIp-UDdL37gb-mh9YlhuDIib0-wUVP?usp=sharing): Trained on Belle_1.4M. | ||
|
||
|
||
## Training | ||
|
||
You can pre-train the S2S model using TTS or ASR tasks with our provided scripts, though we recommend proceeding directly to fine-tuning. Alternatively, you may directly train a TTS or ASR model under the SLAM-Omni framework. For detailed instructions, refer to the [pre-training README](./scripts/pretrain/README.md). | ||
|
||
### Fine-tuning | ||
We provide two primary fine-tuning options for **SLAM-Omni** modeling: | ||
```bash | ||
# Fine-tune with grouping strategy (Recommended) | ||
bash ./examples/s2s/scripts/finetune/finetune_s2s_group.sh | ||
|
||
# Fine-tune without grouping | ||
bash ./examples/s2s/scripts/finetune/finetune_s2s.sh | ||
``` | ||
|
||
We also include scripts for reproducing [Mini-Omni](https://github.com/gpt-omni/mini-omni). Note that this requires the original [VoiceAssistant-400K](https://huggingface.co/datasets/gpt-omni/VoiceAssistant-400K) dataset for training: | ||
```bash | ||
bash ./examples/s2s/scripts/finetune/mini-omni/finetune_s2s.sh | ||
``` | ||
|
||
#### Note💫 | ||
Our framework theoretically supports **all codec-based spoken dialogue model training**. Simply re-synthesize the target tokens (e.g., CosyVoice2 tokens) during training for compatibility. | ||
|
||
## Inference | ||
We provide scripts for both **online** and **batch** inference. You can use the trained model or the provided checkpoints for inference. For detailed guidance, refer to [inference README](./scripts/inference/README.md). | ||
|
||
|
||
|
||
### Online Inference | ||
Run the following commands for real-time inference: | ||
|
||
```bash | ||
# Multi-turn (Recommended) | ||
bash ./examples/s2s/scripts/inference/inference_s2s_online_multi-round.sh | ||
|
||
# Single-turn | ||
bash ./examples/s2s/scripts/inference/inference_s2s_online.sh | ||
``` | ||
|
||
For Mini-Omni modeling, use the following commands: | ||
```bash | ||
# Single-turn non-streaming | ||
bash ./examples/s2s/scripts/inference/mini-omni/inference_s2s_online.sh | ||
|
||
# Single-turn streaming | ||
bash ./examples/s2s/scripts/inference/mini-omni/inference_s2s_online_stream.sh | ||
``` | ||
|
||
|
||
### Batch Inference | ||
|
||
For batch inference, ensure the data format matches the training format (**Parquet** or **JSONL**). Use the following commands: | ||
|
||
```bash | ||
# SLAM-Omni framework | ||
bash ./examples/s2s/scripts/inference/inference_s2s_batch.sh | ||
|
||
# Mini-Omni framework | ||
bash ./examples/s2s/scripts/inference/mini-omni/inference_s2s_batch.sh | ||
``` | ||
|
||
## TODO | ||
- [ ] Add evaluation scripts. | ||
- [ ] Add streaming inference scripts for SLAM-Omni. | ||
|
||
|
||
<!-- ## Gradio Demo --> | ||
|
||
## Citation | ||
SLAM-Omni: | ||
```bibtex | ||
@article{chen2024slam, | ||
title={SLAM-Omni: Timbre-Controllable Voice Interaction System with Single-Stage Training}, | ||
author={Chen, Wenxi and Ma, Ziyang and Yan, Ruiqi and Liang, Yuzhe and Li, Xiquan and Xu, Ruiyang and Niu, Zhikang and Zhu, Yanqiao and Yang, Yifan and Liu, Zhanxun and others}, | ||
journal={arXiv preprint arXiv:2412.15649}, | ||
year={2024} | ||
} | ||
``` | ||
Mini-Omni: | ||
```bibtex | ||
@article{xie2024mini, | ||
title={Mini-omni: Language models can hear, talk while thinking in streaming}, | ||
author={Xie, Zhifei and Wu, Changqiao}, | ||
journal={arXiv preprint arXiv:2408.16725}, | ||
year={2024} | ||
} | ||
``` | ||
|
||
## Acknowledgement | ||
- We borrow some code from [Mini-Omni](https://github.com/gpt-omni/mini-omni) for SNAC-based modeling. | ||
- We borrow some code from [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for the vocoder. | ||
|
||
|
||
## License | ||
Our code is released under MIT License. The Chinese dialogue model is licensed under GPL-3.0 due to its use of Belle data and is intended for research purposes only. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"train_micro_batch_size_per_gpu": 4, | ||
"gradient_accumulation_steps": 1, | ||
"optimizer": { | ||
"type": "Adam", | ||
"params": { | ||
"lr": 1e-4 | ||
} | ||
}, | ||
"fp16": { | ||
"enabled": true | ||
}, | ||
"zero_optimization": { | ||
"stage": 3, | ||
"offload_optimizer": { | ||
"device": "cpu" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
prompt: "Conduct a spoken conversation with the user. " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
dataset_config: | ||
prompt: "Transcribe the provided audio into accurate text. " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
dataset_config: | ||
# we put prompt here, because the hydra override in shell script only support a small subset of chars | ||
# prompt: "Transcribe speech to text. Output the transcription directly without redundant content. Ensure that the output is not duplicated. " | ||
prompt: "Generate a natural and expressive spoken version of the given text. " |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from slam_llm.pipeline.finetune_deepspeed import main as train | ||
from slam_llm.utils.deepspeed_utils import deepspeed_main_wrapper | ||
|
||
import logging | ||
from dataclasses import dataclass, field | ||
from omegaconf import DictConfig, ListConfig, OmegaConf | ||
from s2s_config import ModelConfig, TrainConfig, DataConfig, LogConfig | ||
|
||
|
||
@dataclass | ||
class RunConfig: | ||
dataset_config: DataConfig = field(default_factory=DataConfig) | ||
model_config: ModelConfig = field(default_factory=ModelConfig) | ||
train_config: TrainConfig = field(default_factory=TrainConfig) | ||
log_config: LogConfig = field(default_factory=LogConfig) | ||
debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) | ||
metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) | ||
deepspeed_config: str = field(default="examples/asr_librispeech/conf/ds_config.json", metadata={"help": "The metric for evaluation"}) | ||
|
||
|
||
@deepspeed_main_wrapper(config_name=None, version_base=None) | ||
def main_hydra(cfg: DictConfig): | ||
run_config = RunConfig() | ||
cfg = OmegaConf.merge(run_config, cfg) | ||
def to_plain_list(cfg_item): | ||
if isinstance(cfg_item, ListConfig): | ||
return OmegaConf.to_container(cfg_item, resolve=True) | ||
elif isinstance(cfg_item, DictConfig): | ||
return {k: to_plain_list(v) for k, v in cfg_item.items()} | ||
else: | ||
return cfg_item | ||
|
||
# kwargs = to_plain_list(cfg) | ||
kwargs = cfg | ||
log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) | ||
|
||
logging.basicConfig(level=log_level) | ||
|
||
if kwargs.get("debug", False): | ||
import pdb; | ||
pdb.set_trace() | ||
|
||
train(kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
main_hydra() |
Oops, something went wrong.