Skip to content

Commit

Permalink
Merge pull request NVlabs#88 from XueFuzhao/main
Browse files Browse the repository at this point in the history
Fix SP Grad Flow
  • Loading branch information
XueFuzhao authored Jun 6, 2024
2 parents 903fd6b + 80ff76e commit 4507918
Show file tree
Hide file tree
Showing 35 changed files with 159 additions and 92 deletions.
105 changes: 33 additions & 72 deletions llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,86 +620,47 @@ def repack_multimodal_data(


# Gather inputs_embeds and reshard it evenly
inputs_embeds_list = [torch.zeros((bs, ulysess_seq_len[i], inputs_embeds.shape[-1]), dtype=inputs_embeds.dtype, device=inputs_embeds.device, requires_grad=True) for i in range(sp_degree)]
dist.all_gather(inputs_embeds_list, inputs_embeds, group=sp_group)
global_inputs_embeds_list = []
for i in range(bs):
global_inputs_embeds_batch_list = []
for j in range(sp_degree):
global_inputs_embeds_batch_list.append(inputs_embeds_list[j][i, :effective_seqlen_batch_list[i][j]])
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(global_inputs_embeds_list, batch_first=True, padding_value=0)
new_inputs_embeds = torch.narrow(global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)

# # Gather inputs_embeds and reshard it evenly
# global_inputs_embeds = torch.zeros(
# (bs, global_seq_len, inputs_embeds.shape[-1]),
# dtype=inputs_embeds.dtype,
# device=inputs_embeds.device,
# )
# if sp_rank == 0:
# start_idx = 0
# else:
# start_idx = torch.sum(torch.cat(ulysess_seq_len[:sp_rank], dim=0)).item()
# global_inputs_embeds[:, start_idx:start_idx+shard_seqlen] += inputs_embeds
# dist.all_reduce(global_inputs_embeds, group=sp_group)
# TODO: Fix the non-enough images.
# inputs_embeds_list = [torch.zeros((bs, ulysess_seq_len[i], inputs_embeds.shape[-1]), dtype=inputs_embeds.dtype, device=inputs_embeds.device, requires_grad=True) for i in range(sp_degree)]
# dist.all_gather(inputs_embeds_list, inputs_embeds, group=sp_group)
# global_inputs_embeds_list = []
# for i in range(bs):
# global_inputs_embeds_batch_list = []
# for j in range(sp_degree):
# local_start_idx = shard_seqlen * j
# local_end_idx = local_start_idx + effective_seqlen_batch_list[i][j] if j < sp_degree - 1 else global_seq_len
# global_inputs_embeds_batch_list.append(global_inputs_embeds[i, local_start_idx, :local_end_idx])
# global_inputs_embeds_batch_list.append(inputs_embeds_list[j][i, :effective_seqlen_batch_list[i][j]])
# global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
# global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(global_inputs_embeds_list, batch_first=True, padding_value=0)




# new_inputs_embeds = torch.narrow(global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)



# global_attention_mask = torch.cat(attention_mask_list, dim=1)
# global_attention_mask[:, start_idx:start_idx+shard_seqlen] += attention_mask
# dist.all_reduce(global_attention_mask, group=sp_group)
# new_attention_mask = torch.narrow(global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)



# # Gather attention_mask and reshard it evenly
# global_attention_mask = torch.zeros((bs, global_seq_len), dtype=attention_mask.dtype, device=attention_mask.device)
# global_attention_mask[:, start_idx:start_idx+shard_seqlen] += attention_mask
# dist.all_reduce(global_attention_mask, group=sp_group)
# new_attention_mask = torch.narrow(global_attention_mask, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)


# print(f"reshard: from {start_idx_reshard}, to {end_idx_reshard}. Total {global_seq_len}. Old len {shard_seqlen}")
# Gather position_ids and reshard it evenly
# global_position_ids = torch.zeros((bs, global_seq_len), dtype=position_ids.dtype, device=position_ids.device)
# global_position_ids[:, start_idx:start_idx+shard_seqlen] += position_ids
# dist.all_reduce(global_position_ids, group=sp_group)
# new_position_ids = torch.narrow(global_position_ids, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)



# # Gather inputs_embeds and reshard it evenly
# global_inputs_embeds = torch.zeros(
# (bs, global_seq_len, inputs_embeds.shape[-1]),
# dtype=inputs_embeds.dtype,
# device=inputs_embeds.device,
# )
# global_inputs_embeds[:, start_idx:start_idx+shard_seqlen] += inputs_embeds
# dist.all_reduce(global_inputs_embeds, group=sp_group)
# new_inputs_embeds = torch.narrow(global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)

# # Gather labels and reshard it evenly
# global_labels = torch.zeros((bs, global_seq_len), dtype=labels.dtype, device=labels.device)
# global_labels[:, start_idx:start_idx+shard_seqlen] = labels
# dist.all_reduce(global_labels, group=sp_group)
# dist.barrier(group=sp_group)
# new_labels = torch.narrow(global_labels, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)

# Gather all hidden states and flaten them
ulysess_seq_len_cat = torch.cat(ulysess_seq_len, dim=0)
global_inputs_embeds_list = []
if sp_rank == 0:
original_start_id = 0
original_end_id = torch.sum(ulysess_seq_len_cat[:sp_rank+1]).item()
elif sp_rank == sp_degree - 1:
original_start_id = torch.sum(ulysess_seq_len_cat[:sp_rank]).item()
original_end_id = torch.sum(ulysess_seq_len_cat[:sp_rank+1]).item()
else:
original_start_id = torch.sum(ulysess_seq_len_cat[:sp_rank]).item()
original_end_id = torch.sum(ulysess_seq_len_cat[:sp_rank+1]).item()
all_inputs_embeds = torch.zeros(bs, torch.sum(ulysess_seq_len_cat), inputs_embeds.shape[-1], dtype=inputs_embeds.dtype, device=inputs_embeds.device).contiguous()
all_inputs_embeds[: , original_start_id: original_end_id, :] += inputs_embeds
dist.barrier(group=sp_group)
dist.all_reduce(all_inputs_embeds, group=sp_group)
dist.barrier(group=sp_group)
for i in range(bs):
global_inputs_embeds_batch_list = []
for j in range(sp_degree):
prev_len = torch.sum(ulysess_seq_len_cat[:j]).item() if j > 0 else 0
start_id = prev_len
end_id = prev_len + effective_seqlen_batch_list[i][j]
global_inputs_embeds_batch_list.append(all_inputs_embeds[i, start_id: end_id])
global_inputs_embeds_list.append(torch.cat(global_inputs_embeds_batch_list, dim=0))
global_inputs_embeds = torch.nn.utils.rnn.pad_sequence(global_inputs_embeds_list, batch_first=True, padding_value=0)
new_inputs_embeds = torch.narrow(global_inputs_embeds, 1, start_idx_reshard, end_idx_reshard - start_idx_reshard)


return (
None,
new_position_ids,
Expand Down
3 changes: 0 additions & 3 deletions llava/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,12 @@ def reshard_hiddne_states_and_labels(hidden_states, labels):
dist.barrier(group=sp_group)
dist.all_reduce(all_hidden_states, group=sp_group)
dist.barrier(group=sp_group)
# flatten_global_hidden_states = torch.cat(all_hidden_states, dim=1)[:, :-1, :].view(-1, hidden_states.shape[-1])
# flatten_global_hidden_states = torch.cat(all_hidden_states, dim=1)[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1])
flatten_global_hidden_states = all_hidden_states[:, :-1, :].contiguous().view(-1, hidden_states.shape[-1])
# Get the local hidden states
effective_flatten_global_hidden_states = flatten_global_hidden_states[flatten_effective_label_index]
if repeat_num > 1:
effective_flatten_global_hidden_states = effective_flatten_global_hidden_states.repeat(repeat_num, 1)
effective_local_hidden_states = torch.narrow(effective_flatten_global_hidden_states, 0, start_id, end_id - start_id)
# print("hidden_states.shape: ", effective_local_hidden_states.shape, "labels.shape: ", effective_local_labels.shape, "sp_rank: ", sp_rank, "flatten_label_mask: ", torch.sum(flatten_label_mask),)

return effective_local_hidden_states, effective_local_labels

Expand Down
8 changes: 2 additions & 6 deletions scripts/batch_launch_train_siglipso400m.sh
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
#!/bin/bash
#SBATCH --job-name=vila-7b-internvid10m-pretraining:nvr_lpr_aiagent
#SBATCH --nodes=32
#SBATCH --nodes=16
#SBATCH --gres=gpu:8
#SBATCH --time=4:00:00
#SBATCH -A nvr_lpr_aiagent
<<<<<<< HEAD
#SBATCH --partition=grizzly,polar,polar2,polar3,polar4
=======
#SBATCH --partition=grizzly,polar,,polar2,polar3,polar4
>>>>>>> 4f906a9425c2318c89084b2a7eda650bc32f05ac
#SBATCH --exclusive
#SBATCH --dependency=singleton
#SBATCH --output=7b-internvid-10m-ego1m-training.out


# srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/1_train_projector_vicuna13b_siglipso400m.sh
srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/2_train_mmc4_coyo_sharegpt4v_internvid10m_ego4d1m_vicuna_siglipso400m_video.sh
srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/1_mm_align.sh
# srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/3_sft_videov2_siglipso400m_2.sh

4 changes: 2 additions & 2 deletions scripts/batch_launch_train_siglipso400m_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
#SBATCH --gres=gpu:8
#SBATCH --time=1:00:00
#SBATCH -A nvr_lpr_aiagent
#SBATCH --partition=grizzly,polar,polar2,polar3,polar4
#SBATCH --partition=interactive,grizzly,polar,polar2,polar3,polar4
#SBATCH --exclusive
#SBATCH --dependency=singleton
#SBATCH --output=training-test-3-01.out


srun --label ~/workspace/VILA-Internal/scripts/v1_5/video/1_train_projector_vilavideo7b_v03_test.sh
srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/1_mm_align_test.sh
# srun --label bash ~/workspace/VILA-Internal/scripts/v1_5/video/1_train_projector_vicuna13b_siglipso400m_video_test.sh


56 changes: 56 additions & 0 deletions scripts/v1_5/video/1_mm_align.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
source /lustre/fsw/portfolios/nvr/users/${USER}/anaconda3/bin/activate
conda init
source ~/.bashrc
conda activate vila
which python

cd ~/workspace/VILA-Internal

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

n_node=$SLURM_JOB_NUM_NODES
bs=$((128 * 4 / n_node))
echo "number of nodes:" $n_node
echo "per device batch size:" $bs
echo "node rank:" $SLURM_PROCID

torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \
--master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \
llava/train/train_hybrid.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path /home/ligengz/downloads/Meta-Llama-3-8B-Instruct \
--version plain \
--data_mixture llava_1_5_mm_align \
--vision_tower google/siglip-so400m-patch14-384 \
--mm_vision_select_feature cls_patch \
--mm_projector mlp_downsample \
--tune_mm_projector True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio resize \
--bf16 True \
--output_dir ./checkpoints/vilavideo-llama3-8b-spdebug \
--num_train_epochs 1 \
--per_device_train_batch_size $bs \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 20000 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--lazy_preprocess True \
--report_to wandb \
--seq_parallel_size 4
56 changes: 56 additions & 0 deletions scripts/v1_5/video/1_mm_align_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash
source /lustre/fsw/portfolios/nvr/users/${USER}/anaconda3/bin/activate
conda init
source ~/.bashrc
conda activate vila
which python

cd ~/workspace/VILA-Internal

master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
export MASTER_ADDR=$master_addr
echo "MASTER_ADDR="$MASTER_ADDR

n_node=$SLURM_JOB_NUM_NODES
bs=$((4 * 4 / n_node))
echo "number of nodes:" $n_node
echo "per device batch size:" $bs
echo "node rank:" $SLURM_PROCID

torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \
--master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \
llava/train/train_hybrid.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path /home/ligengz/downloads/Meta-Llama-3-8B-Instruct \
--version plain \
--data_mixture llava_1_5_mm_align \
--vision_tower google/siglip-so400m-patch14-384 \
--mm_vision_select_feature cls_patch \
--mm_projector mlp_downsample \
--tune_mm_projector True \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio resize \
--bf16 True \
--output_dir ./checkpoints/vilavideo-llama3-8b-spdebug-test \
--num_train_epochs 1 \
--per_device_train_batch_size $bs \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 20000 \
--save_total_limit 1 \
--learning_rate 1e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length 4096 \
--gradient_checkpointing True \
--dataloader_num_workers 8 \
--lazy_preprocess True \
--report_to wandb \
--seq_parallel_size 4
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ echo "node rank:" $SLURM_PROCID

torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=25001 \
--master_addr $MASTER_ADDR --node_rank=$SLURM_PROCID \
llava/train/train_mem.py \
llava/train/train_hybrid.py \
--deepspeed ./scripts/zero2.json \
--model_name_or_path /home/jasonlu/models/vicuna-1.5/vicuna-7b-v1.5 \
--version plain \
Expand Down
File renamed without changes.
File renamed without changes.
7 changes: 4 additions & 3 deletions scripts/v1_5/video_osmo/1_vilavideo_aligh_test_v3.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ echo "MASTER_ADDR="$MASTER_ADDR

n_node=$WORLD_SIZE
seq_parallel_size=8
bs=$((4 * seq_parallel_size / n_node))
# bs=$((128 * seq_parallel_size / n_node))
bs=2
echo "number of nodes:" $n_node
echo "per device batch size:" $bs
echo "node rank:" $NODE_RANK
Expand All @@ -32,11 +33,11 @@ torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=$MASTER_PORT \
--mm_use_im_patch_token False \
--image_aspect_ratio resize \
--bf16 True \
--output_dir ./checkpoints/vilavideo8b_align_test_v3_ga8_sp8_2nodes_loss_debug \
--output_dir ./checkpoints/vilavideo8b_align_test_v3_ga8_sp8_2nodes_loss_debug_v3 \
--num_train_epochs 1 \
--per_device_train_batch_size $bs \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 8 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200000 \
Expand Down
10 changes: 5 additions & 5 deletions scripts/v1_5/video_osmo/1_vilavideo_align_v2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ cd ~/VILA
echo "MASTER_ADDR="$MASTER_ADDR

n_node=$WORLD_SIZE
seq_parallel_size=8
bs=$((128 * seq_parallel_size / n_node))
seq_parallel_size=4
bs=$((512 * seq_parallel_size / n_node))
echo "number of nodes:" $n_node
echo "per device batch size:" $bs
echo "node rank:" $NODE_RANK

torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=$MASTER_PORT \
--master_addr $MASTER_ADDR --node_rank=$NODE_RANK \
llava/train/train_hybrid.py \
--deepspeed ./scripts/zero3_mics_mini.json \
--deepspeed ./scripts/zero3_mics_mini_fixed.json \
--model_name_or_path /mnt/amlfs-01/home/fuzhaox/checkpoints/Meta-Llama-3-8B-Instruct \
--version llama_3 \
--data_mixture osmo_ccs_recaptioned+osmo_internvid_1300K \
Expand All @@ -32,11 +32,11 @@ torchrun --nnodes=$n_node --nproc_per_node=8 --master_port=$MASTER_PORT \
--mm_use_im_patch_token False \
--image_aspect_ratio resize \
--bf16 True \
--output_dir ./checkpoints/vilavideo8b_align_v041 \
--output_dir ./checkpoints/vilavideo8b_align_v042 \
--num_train_epochs 1 \
--per_device_train_batch_size $bs \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 200 \
Expand Down

0 comments on commit 4507918

Please sign in to comment.