Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integration-torchtune #653

Merged
merged 1 commit into from
Jul 25, 2024
Merged

integration-torchtune #653

merged 1 commit into from
Jul 25, 2024

Conversation

Zeyi-Lin
Copy link
Member

@Zeyi-Lin Zeyi-Lin commented Jul 25, 2024

Description

与torchtune微调框架的集成,已在Gemma-2b模型微调上试验成功。

image

使用方法:

  1. 安装torchtune:
pip install torchtune
  1. 下载gemma-2b预训练模型到本地:
from modelscope import snapshot_download
model_dir = snapshot_download('AI-ModelScope/gemma-2b', cache_dir="./",)
  1. 创建2B_qlora_single_device.yaml
# Config for multi-device QLoRA finetuning in lora_finetune_single_device.py
# using a gemma 2B model
#
# This config assumes that you've run the following command before launching
# this run:
#   tune download google/gemma-2b --hf-token <HF_TOKEN> --ignore-patterns ""
#
# To launch on a single device, run the following command from root:
#   tune run lora_finetune_single_device --config gemma/2B_qlora_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
#   tune run lora_finetune_single_device --config gemma/2B_qlora_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Tokenizer
tokenizer:
  _component_: torchtune.models.gemma.gemma_tokenizer
  path: ./AI-ModelScope/gemma-2b/tokenizer.model

# Dataset
dataset:
  _component_: torchtune.datasets.alpaca_dataset
seed: null
shuffle: True

# Model Arguments
model:
  _component_: torchtune.models.gemma.qlora_gemma_2b
  lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
  apply_lora_to_mlp: True
  lora_rank: 64
  lora_alpha: 16

checkpointer:
  _component_: torchtune.utils.FullModelHFCheckpointer
  checkpoint_dir: ./AI-ModelScope/gemma-2b/
  checkpoint_files: [
    model-00001-of-00002.safetensors,
    model-00002-of-00002.safetensors,
  ]
  recipe_checkpoint: null
  output_dir: ./output/gemma-2b
  model_type: GEMMA
resume_from_checkpoint: False

optimizer:
  _component_: torch.optim.AdamW
  lr: 2e-5

lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torch.nn.CrossEntropyLoss

# Fine-tuning arguments
batch_size: 4
epochs: 3
max_steps_per_epoch: null
gradient_accumulation_steps: 4
compile: False

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Logging
metric_logger:
  _component_: swanlab.integration.torchtune.SwanLabLogger
  project: "gemma-fintune"
  experiment_name: "gemma-2b"
  log_dir: ${output_dir}
output_dir: ./output/alpaca-gemma-lora
log_every_n_steps: 1
log_peak_memory_stats: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
  _component_: torchtune.utils.setup_torch_profiler
  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 5
  active_steps: 2
  num_cycles: 1
  1. 开始训练:
tune run lora_finetune_single_device --config 2B_qlora_single_device.yaml

Closes: #617

@Zeyi-Lin Zeyi-Lin requested a review from SAKURA-CAT July 25, 2024 15:06
@Zeyi-Lin Zeyi-Lin self-assigned this Jul 25, 2024
@SAKURA-CAT SAKURA-CAT merged commit 57e33d6 into main Jul 25, 2024
@SAKURA-CAT SAKURA-CAT deleted the feat/integration-torchtune2 branch July 25, 2024 15:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[REQUEST] 集成torchtune
2 participants