Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL only multi-gpu multi-node training without MPI #426

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

chinthysl
Copy link
Contributor

@chinthysl chinthysl commented May 17, 2024

Scheduling jobs using Slurm seems much easier in a multi-node training setup compared to setting up MPI for the cluster.
This draft contains the changes to use mpirun for single-node training and Slurm for multi-node training.

PyTorch uses one of the backends from Gloo, MPI, and NCCL for DDP. Maybe we don't need to use both MPI and NCCL together. It should be either CUDA-aware MPI or NCCL. Something to discuss further.

I got some interesting performance numbers in a large-scale training setup using llm.c.
I used the Ahrefs DGX H100 (80GB) Superpod. The cluster uses NVLINK for intra-node communications and InfiniBand RDMA for inter-node communications.

Number are taken just before #421 from @ngc92. Current master should have higher tokens/sec with less latency.

Without cuDNN we can reach up to batch_size=12

NumNodes    NumProcess  Latency(ms) tokens/sec(k)
1           8           395         249
2           16          674         296
4           32          817         484
8           64          1056        748
16          128         1087        1455
24          192         1320        1794
48          384         2054        2303

oreo-multinode-llmc

With cuDNN we can reach up to batch_size=24

NumNodes    NumProcess  Latency(ms) tokens/sec(k)
1           8           539         366
2           16          798         498
4           32          991         798
8           64          1158        1365
16          128         1247        2531
24          192         1478        3198

oreo-multinode-llmc-cudnn

@karpathy
Copy link
Owner

Very cool! I'll take a look and also see if I can find a slurm cluster to play with this on.
Do you by any chance have a PyTorch baseline at this scale?
Are you careful to use gradient accumulation? Example command I've been using for gpt2-xl on my box so far:

mpirun -np 4 ./train_gpt2cu -e gpt2_1558M_bf16.bin -i data/TinyStories -v 250 -s 250 -g 144 -o stories.log -z 1 -b 12 -d 294912 -l 1e-4

Notice in particular the use of -d which specifies the total batch size, and which we'd want to be at about 0.5M, following the GPT-3 paper. The above example has it at ~300K

@chinthysl
Copy link
Contributor Author

Thanks for the suggestions. I didn’t take in grad accumulation changes yet. Waiting for the updated numbers.
Also planning to get multinode number from the torch training. Hopefully we’ll have a better perf comparison soon.

@chinthysl
Copy link
Contributor Author

chinthysl commented May 20, 2024

@karpathy I feel like the total num tokens should be more than 0.5M. GPT2 claims they used 256 GPUs, context len 1024.
It's already 256K total tokens. Then the batch size per worker becomes 2. If that it the case actual GPT2 training might have not fully used the 32GB mem of V100 GPUs. May be (1024*2) tokens made the 100% utilization of the SMs in GPU for computations. So OpenAI increased # gpus to gain more FLOPS!

@PeterZhizhin
Copy link
Contributor

I thought OpenMPI supports Slurm: https://docs.open-mpi.org/en/main/launching-apps/slurm.html

Can you give some insight on why it didn't work for you?

@chinthysl
Copy link
Contributor Author

chinthysl commented May 22, 2024

@PeterZhizhin These are the issues I'm facing when I try to dispatch MPI dependent jobs using Slurm.

  • Can't use MPI’s full-features mpirun launcher (srun --ntasks=N --ntasks-per-node=8 --nodelist=node[000-00X] mpirun ./train)
    PMIx is not installed, MPI is not installed --with-pmix support.

  • Can't use Slurm’s “direct launch” capability. (srun --ntasks=N --ntasks-per-node=8 --nodelist=node[000-00X] --mpi=pmix ./train)
    PMIx is not installed, both MPI and Slurm is not installed --with-pmix support.

Generally in cluster setup, slurm is used as nothing but a tool to dispatch process across nodes. PMIx support is an additional support afaiu.

Specially when we look at torchrun, lightning, and accelerate, per node one main process get launched. Using python mp, main process spawn N number of child process per GPU. Afaiu these frameworks run without any job scheduler dependencies.
Instead they configure one node as a server(tcp:socket) with MASTER_ADDRESS and MASTER_PORT. Additionally we can set WORLD_SIZE and RANK per process. Then nccl, mpi, or gloo works only as the data communications backend.

Ideally we should remove MPI and SLURM dependencies. Then we should implement a similar socket or some other, server client interface to synchronize the process group independently. We will be able to implement additional features like fault tolerance if we do that. (If one gpu or nodes goes down, current training hangs forever)

But for us, since we are not there yet, we need to have MPI or SLURM dependencies.
Therefor I thought of letting MPI handle single node training and SLURM handle multi node.

Upto discussions.

@PeterZhizhin
Copy link
Contributor

@chinthysl thank you! Something that we can also do is instead of relying on MPI for single-node is to just spawn processes via fork(). Then, for Slurm we could ask it to lauch one process per node.

train_gpt2.cu Outdated

if (multi_gpu_config->slurm_managed) { //If the process is managed by slurm we don't need to use MPI
if (d_buffer == NULL) cudaCheck(cudaMalloc(&d_buffer, sizeof(float)));
cudaCheck(cudaMemcpy(d_buffer, &value, sizeof(float), cudaMemcpyHostToDevice));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I was asking if we could use MPI on Slurm clusters is because of this.

This feels a bit ugly, while a single MPI_AllReduce is a lot cleaner.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this one looks not great. Also thinking about alternatives.

Copy link
Contributor Author

@chinthysl chinthysl Jun 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@karpathy @PeterZhizhin I introduced cuda unified mem buffer with cudaPrefetchAsync. But still internally additional host to device and device to host memcopies happen. So MPI is cleaner in that sense.

On the other hand I'm thinking about a cleaner way to get rid of MPI and SLURM dependencies by introducing an independent server client interface to manage distributed stuff.

train_gpt2.cu Outdated Show resolved Hide resolved
@chinthysl chinthysl force-pushed the multinode-slurm branch 2 times, most recently from d5e4477 to c00738c Compare June 3, 2024 11:00
@ghost
Copy link

ghost commented Jun 5, 2024

Currently doing multi-node testing on a slurm cluster using fineweb 10b for now; training is being done on two gpu2 nodes (2xA100 80gb per node). I am getting this error, stopping the job:

*** The MPI_Barrier() function was called before MPI_INIT was invoked.
*** This is disallowed by the MPI standard.
*** Your MPI job will now abort.

A complete log file for a job can be found below.

[alita@bee ~]$ cat fail.out
Allocated nodes: 
qbd487
qbd488
srun: lua: Submitted job 12165
Started Multi Node training managed by Slurm: ProcessID:3, NumProcess::4, DeviceId:1
Started Multi Node training managed by Slurm: ProcessID:2, NumProcess::4, DeviceId:0
Started Multi Node training managed by Slurm: ProcessID:1, NumProcess::4, DeviceId:1
Started Multi Node training managed by Slurm: ProcessID:0, NumProcess::4, DeviceId:0
+-----------------------+----------------------------------------------------+
| Parameter             | Value                                              |
+-----------------------+----------------------------------------------------+
| train data pattern    | /home/alita/work/10b/fineweb_train_*.bin           |
| val data pattern      | /home/alita/work/10b/fineweb_val_*.bin             |
| output log dir        | /work/alita/gpu2/                                  |
| checkpoint_every      | 10                                                 |
| resume                | 0                                                  |
| micro batch size B    | 64                                                 |
| sequence length T     | 1024                                               |
| total batch size      | 1048576                                            |
| learning rate (LR)    | 6.000000e-04                                       |
| warmup iterations     | 700                                                |
| final LR fraction     | 0.000000e+00                                       |
| weight decay          | 1.000000e-01                                       |
| max_steps             | -1                                                 |
| val_loss_every        | 250                                                |
| val_max_steps         | 20                                                 |
| sample_every          | 20000                                              |
| genT                  | 64                                                 |
| overfit_single_batch  | 0                                                  |
| use_master_weights    | enabled                                            |
| recompute             | 1                                                  |
+-----------------------+----------------------------------------------------+
| device                | NVIDIA A100 80GB PCIe                              |
| TFlops                | -1.0                                               |
| precision             | BF16                                               |
+-----------------------+----------------------------------------------------+
| load_filename         | d12                                                |
| max_sequence_length T | 1024                                               |
| vocab_size V          | 50257                                              |
| padded_vocab_size Vp  | 50304                                              |
| num_layers L          | 12                                                 |
| num_heads NH          | 12                                                 |
| channels C            | 768                                                |
| num_parameters        | 124475904                                          |
+-----------------------+----------------------------------------------------+
| train_num_batches     | 9780                                               |
| val_num_batches       | 20                                                 |
+-----------------------+----------------------------------------------------+
| run hellaswag         | yes                                                |
+-----------------------+----------------------------------------------------+
| Zero Stage1 is enabled                                                     |
| num_processes         | 4                                                  |
| zero_stage            | 1                                                  |
+-----------------------+----------------------------------------------------+
num_parameters: 124475904 => bytes: 248951808
allocated 237 MiB for model parameters
batch_size B=64 * seq_len T=1024 * num_processes=4 and total_batch_size=1048576
=> setting grad_accum_steps=4
allocating 23034 MiB for activations
val loss 2.752490
allocating 237 MiB for parameter gradients
allocating 480 MiB for activation gradients
allocating 118 MiB for AdamW optimizer state m
allocating 118 MiB for AdamW optimizer state v
allocating 118 MiB for master copy of params
step    1/9780 | train loss 2.752462 | norm 1.6601 | lr 8.57e-07 | 1857.39 ms | -100.0% bf16 MFU | 564544 tok/s
step    2/9780 | train loss 2.739455 | norm 1.6413 | lr 1.71e-06 | 1364.95 ms | -100.0% bf16 MFU | 768215 tok/s
step    3/9780 | train loss 2.714078 | norm 1.6619 | lr 2.57e-06 | 1362.65 ms | -100.0% bf16 MFU | 768881 tok/s
step    4/9780 | train loss 2.678862 | norm 1.7498 | lr 3.43e-06 | 1361.10 ms | -100.0% bf16 MFU | 769409 tok/s
step    5/9780 | train loss 2.642340 | norm 1.7980 | lr 4.29e-06 | 1362.95 ms | -100.0% bf16 MFU | 769391 tok/s
step    6/9780 | train loss 2.606456 | norm 1.8861 | lr 5.14e-06 | 1362.69 ms | -100.0% bf16 MFU | 769412 tok/s
step    7/9780 | train loss 2.576031 | norm 1.9981 | lr 6.00e-06 | 1360.37 ms | -100.0% bf16 MFU | 769674 tok/s
step    8/9780 | train loss 2.547013 | norm 2.0359 | lr 6.86e-06 | 1360.54 ms | -100.0% bf16 MFU | 769845 tok/s
step    9/9780 | train loss 2.519169 | norm 2.0552 | lr 7.71e-06 | 1360.91 ms | -100.0% bf16 MFU | 769941 tok/s
step   10/9780 | train loss 2.500124 | norm 2.0364 | lr 8.57e-06 | 1361.36 ms | -100.0% bf16 MFU | 769982 tok/s
Writing model to /work/alita/gpu2//model_00000010.bin
Writing state to /work/alita/gpu2//state_00000010_00001.bin
Writing state to /work/alita/gpu2//state_00000010_00003.bin
Writing state to /work/alita/gpu2//state_00000010_00002.bin
*** The MPI_Barrier() function was called before MPI_INIT was invoked.
*** This is disallowed by the MPI standard.
*** Your MPI job will now abort.
[qbd487:398785] Local abort before MPI_INIT completed completed successfully, but am not able to aggregate error messages, and not able to guarantee that all other processes were killed!
*** The MPI_Barrier() function was called before MPI_INIT was invoked.
*** This is disallowed by the MPI standard.
*** Your MPI job will now abort.
*** The MPI_Barrier() function was called before MPI_INIT was invoked.
*** This is disallowed by the MPI standard.
*** Your MPI job will now abort.
[qbd488:615259] Local abort before MPI_INIT completed completed successfully, but am not able to aggregate error messages, and not able to guarantee that all other processes were killed!
[qbd488:615260] Local abort before MPI_INIT completed completed successfully, but am not able to aggregate error messages, and not able to guarantee that all other processes were killed!
Writing state to /work/alita/gpu2//state_00000010_00000.bin
*** The MPI_Barrier() function was called before MPI_INIT was invoked.
*** This is disallowed by the MPI standard.
*** Your MPI job will now abort.
[qbd487:398786] Local abort before MPI_INIT completed completed successfully, but am not able to aggregate error messages, and not able to guarantee that all other processes were killed!
srun: error: qbd487: task 1: Exited with exit code 1
srun: error: qbd488: tasks 2-3: Exited with exit code 1
srun: error: qbd487: task 0: Exited with exit code 1

I added mpiCheck(MPI_Init(argc, argv)); to line 376 of train_gpt2.cu and remade the training file, seems to fix things!

[alita@bee ~]$ cat gpt2-gpu2.out 
Allocated nodes: 
qbd487
qbd488
srun: lua: Submitted job 12168
[qbd487:400139] PMIX ERROR: ERROR in file ../../../../../../src/mca/gds/ds12/gds_ds12_lock_pthread.c at line 169
[qbd488:616078] PMIX ERROR: ERROR in file ../../../../../../src/mca/gds/ds12/gds_ds12_lock_pthread.c at line 169
[qbd488:616077] PMIX ERROR: ERROR in file ../../../../../../src/mca/gds/ds12/gds_ds12_lock_pthread.c at line 169
[qbd487:400138] PMIX ERROR: ERROR in file ../../../../../../src/mca/gds/ds12/gds_ds12_lock_pthread.c at line 169
Started Multi Node training managed by Slurm: ProcessID:0, NumProcess::4, DeviceId:0
+-----------------------+----------------------------------------------------+
| Parameter             | Value                                              |
+-----------------------+----------------------------------------------------+
| train data pattern    | /home/alita/work/10b/fineweb_train_*.bin           |
| val data pattern      | /home/alita/work/10b/fineweb_val_*.bin             |
| output log dir        | /work/alita/gpu2/                                  |
| checkpoint_every      | 10                                                 |
| resume                | 0                                                  |
| micro batch size B    | 64                                                 |
| sequence length T     | 1024                                               |
| total batch size      | 1048576                                            |
| learning rate (LR)    | 6.000000e-04                                       |
| warmup iterations     | 700                                                |
| final LR fraction     | 0.000000e+00                                       |
| weight decay          | 1.000000e-01                                       |
| max_steps             | -1                                                 |
| val_loss_every        | 250                                                |
| val_max_steps         | 20                                                 |
| sample_every          | 20000                                              |
| genT                  | 64                                                 |
| overfit_single_batch  | 0                                                  |
| use_master_weights    | enabled                                            |
| recompute             | 1                                                  |
+-----------------------+----------------------------------------------------+
| device                | NVIDIA A100 80GB PCIe                              |
| TFlops                | -1.0                                               |
| precision             | BF16                                               |
+-----------------------+----------------------------------------------------+
| load_filename         | d12                                                |
| max_sequence_length T | 1024                                               |
| vocab_size V          | 50257                                              |
| padded_vocab_size Vp  | 50304                                              |
| num_layers L          | 12                                                 |
| num_heads NH          | 12                                                 |
| channels C            | 768                                                |
| num_parameters        | 124475904                                          |
+-----------------------+----------------------------------------------------+
| train_num_batches     | 9780                                               |
| val_num_batches       | 20                                                 |
+-----------------------+----------------------------------------------------+
| run hellaswag         | yes                                                |
+-----------------------+----------------------------------------------------+
| Zero Stage1 is enabled                                                     |
| num_processes         | 4                                                  |
| zero_stage            | 1                                                  |
+-----------------------+----------------------------------------------------+
num_parameters: 124475904 => bytes: 248951808
allocated 237 MiB for model parameters
batch_size B=64 * seq_len T=1024 * num_processes=4 and total_batch_size=1048576
=> setting grad_accum_steps=4
allocating 23034 MiB for activations
Started Multi Node training managed by Slurm: ProcessID:2, NumProcess::4, DeviceId:0
Started Multi Node training managed by Slurm: ProcessID:3, NumProcess::4, DeviceId:1
Started Multi Node training managed by Slurm: ProcessID:1, NumProcess::4, DeviceId:1
val loss 2.752490
allocating 237 MiB for parameter gradients
allocating 480 MiB for activation gradients
allocating 118 MiB for AdamW optimizer state m
allocating 118 MiB for AdamW optimizer state v
allocating 118 MiB for master copy of params
step    1/9780 | train loss 2.752463 | norm 1.6601 | lr 8.57e-07 | 1956.23 ms | -100.0% bf16 MFU | 536020 tok/s
step    2/9780 | train loss 2.739456 | norm 1.6413 | lr 1.71e-06 | 1359.70 ms | -100.0% bf16 MFU | 771182 tok/s
step    3/9780 | train loss 2.714082 | norm 1.6619 | lr 2.57e-06 | 1359.34 ms | -100.0% bf16 MFU | 771288 tok/s
step    4/9780 | train loss 2.678853 | norm 1.7499 | lr 3.43e-06 | 1358.38 ms | -100.0% bf16 MFU | 771513 tok/s
step    5/9780 | train loss 2.642331 | norm 1.7979 | lr 4.29e-06 | 1362.87 ms | -100.0% bf16 MFU | 770940 tok/s
step    6/9780 | train loss 2.606469 | norm 1.8861 | lr 5.14e-06 | 1362.00 ms | -100.0% bf16 MFU | 770705 tok/s
step    7/9780 | train loss 2.576045 | norm 7.2377 | lr 6.00e-06 | 1362.61 ms | -100.0% bf16 MFU | 770484 tok/s
step    8/9780 | train loss 2.547086 | norm 2.0359 | lr 6.86e-06 | 1360.29 ms | -100.0% bf16 MFU | 770544 tok/s
step    9/9780 | train loss 2.519274 | norm 2.0550 | lr 7.71e-06 | 1362.20 ms | -100.0% bf16 MFU | 770429 tok/s
step   10/9780 | train loss 2.500306 | norm 2.0366 | lr 8.57e-06 | 1363.29 ms | -100.0% bf16 MFU | 770256 tok/s
Writing model to /work/alita/gpu2//model_00000010.bin
Writing state to /work/alita/gpu2//state_00000010_00002.bin
Writing state to /work/alita/gpu2//state_00000010_00003.bin
Writing state to /work/alita/gpu2//state_00000010_00001.bin
Writing state to /work/alita/gpu2//state_00000010_00000.bin

Note: I set -n of the training script to ten for debugging purposes and have since reverted back to five thousand.

A sbatch script for the two node job can be found below. I plan to eventually do a twelve node (2xA100 80gb) job using fineweb 100b.

#!/bin/bash

#SBATCH -N 2
#SBATCH -n 4
#SBATCH --cpus-per-task=32
#SBATCH -t 24:00:00
#SBATCH -A loni_test001
#SBATCH -p gpu2
#SBATCH --gres=gpu:2
#SBATCH --job-name=gpt2-gpu2
#SBATCH -o gpt2-gpu2.out

echo "Allocated nodes: "
scontrol show hostname \$SLURM_NODELIST

module load openmpi

export DFS_PATH='/home/alita/work/10b'
cd /work/alita/multi

srun --mpi=pmix ./train_gpt2cu \
    -i '/home/alita/work/10b/fineweb_train_*.bin' \
    -j '/home/alita/work/10b/fineweb_val_*.bin' \
    -o /work/alita/gpu2/ \
    -e 'd12' \
    -b 64 -t 1024 \
    -d 1048576 \
    -r 1 \
    -z 1 \
    -c 0.1 \
    -n 5 \
    -l 0.0006 \
    -q 0.0 \
    -u 700 \
    -v 250 -s 20000 \
    -h 1 \
    -n 10

@chinthysl chinthysl force-pushed the multinode-slurm branch 2 times, most recently from 233f11d to b4fd422 Compare June 5, 2024 08:11
@chinthysl
Copy link
Contributor Author

@0xAlita I added the alternative to MPI_Barrier issue. It should work now. Please try without --mpi=pmix if you are using this branch.
What happens when you use this slurm script with a master branch built binary? I see you have pmix support for mpi and slurm in your cluster. You might not need these changes!

@chinthysl
Copy link
Contributor Author

chinthysl commented Jun 5, 2024

@karpathy Interesting observation here. I was able get some numbers from multi-node training for the d48 model. Training was done in DGX H100's.
I tried to set a constant batch size per process to keep the gradient accumulation steps to 1.
According to the following results llmc has higher performance compared to pytorch training in scale.
Its more than 0.5M tok/sec diff with 128 GPUs (16 nodes of 8xH100).

output (1)

Sample slurm run - llmc

n_nodes=16
n_proc=$((8 * n_nodes))
micro_batch_size=8
total_batch_size=$((1024 * n_proc * micro_batch_size))

srun --ntasks-per-node=8 \
  --gres=gpu:8 \
  --cpus-per-task=128 
  --ntasks=$n_proc \
  --nodelist=node[000-015] \
  ./train_gpt2cu \
  -i "dev/data/fineweb350B/fineweb_train_*.bin" \
  -j "dev/data/fineweb350B/fineweb_val_*.bin" \
  -o "logs/llmc_d48_n16_b8" \
  -v 500 -s 500 -g 144 \
  -h 1 \
  -t 1024 \
  -b $micro_batch_size \ 
  -d $total_batch_size \
  -r 0 \
  -z 1 \
  -c 0.1 \
  -l 0.0006 \
  -q 0.0 \
  -u 700 \
  -y 1 \
  -x 500 \
  -e "d48"

Sample slurm run - pytorch

srun --ntasks-per-node=8 \
     --gres=gpu:8 \
     --cpus-per-task=16 \ 
     --ntasks=$n_proc \
     --nodelist=node[000-015] \
     --output=logs/llmc_d48_n16_b8/%j.log \     
     ./train_gpt2cu \
     -i "dev/data/fineweb350B/fineweb_train_*.bin" \
     -j "dev/data/fineweb350B/fineweb_val_*.bin" \
     -o "logs/llmc_d48_n16_b8" \
     -v 500 -s 500 -g 144 \
     -h 1 \
     -t 1024 \
     -b $micro_batch_size \ 
     -d $total_batch_size \
     -r 0 \
     -z 1 \
     -c 0.1 \
     -l 0.0006 \
     -q 0.0 \
     -u 700 \
     -y 1 \
     -x 500 \
     -e "d48"

@chinthysl chinthysl changed the title NCCL only multi gpu training for slurm enabled cluster NCCL only multi-gpu multi-node training without MPI Jun 12, 2024
@chinthysl
Copy link
Contributor Author

I removed MPI dependencies and tested for performance changes. There's no visible performance changes between master and this branch.

@chinthysl
Copy link
Contributor Author

chinthysl commented Jun 18, 2024

@karpathy I was able to train 1.5B model using 59 nodes of DGX H100 using FineWeb350B.
I’ve set the batch size to 1024*16 tokens per gpu (process). So 1024*16*8*59 = 7733248 (7.73M) tokens per iteration over 100K steps. (I think it consumed 770B tokens, around 2 epochs over the dataset)
It took around 12-15hrs, ran smoothly without any issues. As you mentioned in discussion #580, 1.5B model also shows better training performance.
fineweb_100k

@eliebak
Copy link
Contributor

eliebak commented Jun 18, 2024

hey @chinthysl do you have the ckpt somewhere? :)

@karpathy
Copy link
Owner

Thank you for posting @chinthysl , very cool. We had a small discussion about it on our Discord with the core devs, please join us sometime on the CUDA MODE Discord, #llmdotc for higher-throughput chatter. We don't have 500 GPUs, but I'm gearing up to do a repro on at least one 8X H100 80GB GPU node, which I estimate to run for a few days. I also noticed your run is not exactly configured to repro GPT-2. For that you'd want to do 100B tokens (1 epoch), so something like this: https://github.com/karpathy/llm.c/blob/master/scripts/run_gpt2_1558M.sh

cheers!

ncclCheck(ncclGetUniqueId(&nccl_id));
idFile = fopen(filename, "wb");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fopenCheck

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally, all file ops here should use the checked versions instead

#ifdef MULTI_GPU
if (multi_gpu_config->num_processes > 1) {
mpiCheck(MPI_Barrier(MPI_COMM_WORLD));
if (unified_buffer == NULL) cudaCheck(cudaMallocManaged(&unified_buffer, sizeof(float)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if a "barrier" function should have any business conditionally allocating memory that needs to be manually freed later on

cudaStream_t nccl_stream; // CUDA Stream to perform NCCL operations.
cudaEvent_t compute_nccl_sync; // Event used to synchronize NCCL with the compute
#endif
} MultiGpuConfig;

MultiGpuConfig multi_gpu_config_init(int num_processes, int process_rank, int gpus_per_node, char *dfs_path) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could the other code handle different number of GPUs per node? Not sure how often that is realistic, but if so, it should at least be a conscious decision to remove that functionality

result.num_processes = num_processes;
result.device_idx = process_rank % gpus_per_node;

FILE* idFile;
Copy link
Contributor

@ngc92 ngc92 Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this file get cleaned up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No!

@@ -21,7 +21,7 @@ Long story short, try `-r 1` (recompute GeLU, trading off speed and memory) to c
It might be that you only have one GPU and not a whole box of them. Every script is fairly easy to change for just a single GPU. For llm.c, simply change line 1 to line 2 and leave everything else the same:

```bash
mpirun -np 8 ./train_gpt2cu \
mpirun -np 8 bach -c './train_gpt2cu -pn 8 -pr $OMPI_COMM_WORLD_RANK'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bach? bash?


if (unified_buffer == NULL) cudaCheck(cudaMallocManaged(&unified_buffer, sizeof(float)));
*unified_buffer = value;
cudaCheck(cudaMemPrefetchAsync(unified_buffer, sizeof(float), multi_gpu_config->device_idx, 0));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this prefetch actually give any advantage, given that the access already happens on the next line?

for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
if (!(strlen(argv[i]) == 2 || strlen(argv[i]) == 3)) { error_usage(); } // must be -x (one dash, one letter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment is outdated

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also argparse docs above would be)

@@ -36,7 +37,9 @@ while true; do
-u 700 \
-n 5000 \
-y 1 \
-e "d12"
-e "d12" \
-pn 8 \
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OMPI_COMM_WORLD_SIZE maybe?

@@ -325,6 +325,7 @@ typedef struct {
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
float* unified_buffer; // GPU buffer to avg loss across process
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't seem like unified buffer should be part of model, possible for it to maybe live "outside" inside int main or as global inside zero.cuh file or so?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants