Skip to content

Jinhyeong-Lim/Domain-Adaptive-Pre-train-for-Dialogue-Summarization

Repository files navigation

Dialogue Summarization

Installation

pip install -r requirements.txt

Note: The current code base has been tested with 2 Nvidia 3090 GPUs. We also use wandb cloud-based logger which you may need to register and login first.

Data generation

Download aihub dialogue summary dataset.

python data_generator.py --path aihub_dialogue_summary --output data 

Update tokenizer for dialogue summary dataset

P01, P02와 같은 speaker id를 하나의 token으로 처리함.

  1. Download gogamza/kobart-base-v1 tokenizer
  2. Replace <unusedX> tokens into others
python update_tokenizer.py

Post-training

torchrun --nproc_per_node=2 run_post_train.py \
        --model_path "kobart-dialogue" \
        --model_name "gogamza/kobart-base-v1" \
        --run_name "kobart-post_train" \
        --do_train \
        --report_to "none" \
        --dialogue_max_seq_length 1024 \
        --train_file data/dialogue.json \
        --output_dir "checkpoints/kobart-post_train" \
        --learning_rate 0.0005 \
        --warmup_steps 50000 \
        --per_device_train_batch_size 8 \
        --max_steps 500000 \
        --save_strategy steps \
        --dataloader_num_workers 16 \
        --save_steps 10000 \
        --save_total_limit 3 \
        --gradient_accumulation_steps 1 \
        --logging_steps 100 

Fine-tuning

torchrun --nproc_per_node=1 run_summarization.py \
  --model_path "kobart-dialogue" \
  --model_name "gogamza/kobart-base-v1" \
  --run_name "kobart-dialsumm" \
  --do_train \
  --do_eval \
  --do_predict \
  --report_to "none" \
  --train_file data/train.csv \
  --valididation_file data/valid.csv \
  --predict_file data/predict.csv \
  --max_source_length 512 \
  --max_target_length 128 \
  --output_dir "checkpoints/kobart-dialsumm" \
  --learning_rate 5e-5 \
  --warmup_steps 50 \
  --per_device_train_batch_size 32 \
  --per_device_eval_batch_size 8 \
  --max_steps 1000 \
  --save_strategy steps \
  --evaluation_strategy steps \
  --dataloader_num_workers 10 \
  --save_steps 10 \
  --eval_steps 10 \
  --logging_steps 10 \
  --save_total_limit 3 \
  --load_best_model_at_end \
  --label_smoothing_factor 0.1 \
  --gradient_accumulation_steps 2 \
  --overwrite_cache \
  --fp16 \
  --predict_with_generate

Loss Graph

Untitled (2)

  • Post train 과정
    • 400 ~ 600 step 범위에서 loss 수렴 (해당 부분에서 early stopping), 이 후 loss 상승 후 다시 수렴
    • Post-training 데이터 대략 28만개, 약 2 epoch 정도 최저점 부근에서 loss 수렴 하고, 이 후 loss 상승 (Overfitting 발생 추정)

Performance

Model R1 R2 RL
Naive Fine tuning 31.3613 17.6152 28.2803
Post training then Fine tuning 32.5910 18.5439 29.4671

About

대화 요약 모델 성능 향상을 위한 사후 학습

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages