Skip to content

Commit

Permalink
Sample pack trust remote code v2 (#1873)
Browse files Browse the repository at this point in the history
* fix the multipack patch for remote code models

* add deepseek v2 lite example w fsdp
  • Loading branch information
winglian authored Aug 27, 2024
1 parent f6362d2 commit 1e43660
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 2 deletions.
67 changes: 67 additions & 0 deletions examples/deepseek-v2/fft-fsdp-16b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
base_model: deepseek-ai/DeepSeek-V2-Lite
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 8
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-5

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 100
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
special_tokens:
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DeepseekV2DecoderLayer
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_sharding_strategy: FULL_SHARD
2 changes: 2 additions & 0 deletions src/axolotl/monkeypatch/multipack.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,5 @@ def patch_remote(model_name, config_name, modeling_name):
module_name = model_config.__class__.__module__.replace(config_name, modeling_name)
modeling_arch = importlib.import_module(module_name)
modeling_arch._get_unpad_data = get_unpad_data # pylint: disable=protected-access
# workaround to make the patch stick
modeling_arch._axolotl_multipack_patch = True # pylint: disable=protected-access
2 changes: 0 additions & 2 deletions src/axolotl/monkeypatch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ def get_max_seqlen_in_batch(attention_mask: torch.Tensor) -> torch.Tensor:
max_num = int(torch.max(attention_mask).item())
batch_size, _ = attention_mask.shape
counts = torch.zeros((batch_size, max_num), dtype=torch.int32)

for i in range(1, max_num + 1):
mask = attention_mask == i
counts[:, i - 1] = torch.sum(mask, dim=-1).to(dtype=torch.int32)

result = counts.flatten()
nonzero_indices = torch.nonzero(result).squeeze(-1)
return result[nonzero_indices]
Expand Down

0 comments on commit 1e43660

Please sign in to comment.