diff --git a/.github/workflows/monthly_test.yaml b/.github/workflows/monthly_test.yaml index 32fc9374..1fc432fb 100644 --- a/.github/workflows/monthly_test.yaml +++ b/.github/workflows/monthly_test.yaml @@ -23,7 +23,7 @@ jobs: run: | source activate ${evo_env_torch21_flash2} jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 8 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_msp" ./tests/test_training/test_norm_weight.py + srun -p ${SLURM_PARTITION} --exclusive --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_msp" ./tests/test_training/test_norm_weight.py exit_code=$? sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname @@ -42,7 +42,7 @@ jobs: run: | source activate ${evo_env_torch21_flash2} jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 8 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_fsp" ./tests/test_training/test_norm_weight.py + srun -p ${SLURM_PARTITION} --exclusive --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_fsp" ./tests/test_training/test_norm_weight.py exit_code=$? sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname @@ -61,7 +61,7 @@ jobs: run: | source activate ${evo_env_torch21_flash2} jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} - srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 8 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_isp" ./tests/test_training/test_norm_weight.py + srun -p ${SLURM_PARTITION} --exclusive --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:8 pytest -s -v --color=yes -m "check_norm_isp" ./tests/test_training/test_norm_weight.py exit_code=$? sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname diff --git a/.github/workflows/upload_to_pypi.yaml b/.github/workflows/upload_to_pypi.yaml index 8ff0818e..78a49e3a 100644 --- a/.github/workflows/upload_to_pypi.yaml +++ b/.github/workflows/upload_to_pypi.yaml @@ -27,17 +27,31 @@ jobs: run: | pip install setuptools wheel twine - - name: get latest tag - run: | - latest_tag=$(git describe --tags --abbrev=0) - echo "$latest_tag" > version.txt - - name: build and upload package run: | source activate ${evo_env_torch21_flash2} + python_path=$(which python) && echo "Python executable is at: $python_path" + latest_tag=$(git describe --tags --abbrev=0) + echo "$latest_tag" > version.txt export PYTHONPATH=$PWD:$PYTHONPATH + export LLMPLATFORM=/mnt/petrelfs/share_data/llm_env + export CUDA_PATH=${LLMPLATFORM}/dep/cuda-11.8 + export GCC_HOME=${LLMPLATFORM}/dep/gcc-10.2.0 + export MPFR_HOME=${LLMPLATFORM}/dep/mpfr-4.1.0 + export LD_LIBRARY_PATH=${CUDA_PATH}/lib64:${CUDA_PATH}/extras/CUPTI/lib64/:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=${GCC_HOME}/lib64:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=${MPFR_HOME}/lib:$LD_LIBRARY_PATH + export CC=${GCC_HOME}/bin/gcc + export CXX=${GCC_HOME}/bin/c++ jobname=${GITHUB_RUN_ID}-${GITHUB_JOB}-${GITHUB_RUN_ATTEMPT} srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:1 python setup.py sdist bdist_wheel + cd csrc/rotary/ + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:1 python setup.py sdist bdist_wheel + cd ../xentropy/ + srun -p ${SLURM_PARTITION} --kill-on-bad-exit=1 --job-name=$jobname -N 1 -n 1 --gres=gpu:1 python setup.py sdist bdist_wheel + cd ../../ exit_code=$? twine upload -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} dist/* + twine upload -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} csrc/rotary/dist/* + twine upload -u __token__ -p ${{ secrets.PYPI_API_TOKEN }} csrc/xentropy/dist/* sh ./ci_scripts/common/check_slurm_cancled.sh $exit_code $jobname diff --git a/.github/workflows/weekly_test.yaml b/.github/workflows/weekly_test.yaml index 2303f9af..2d737caa 100644 --- a/.github/workflows/weekly_test.yaml +++ b/.github/workflows/weekly_test.yaml @@ -177,6 +177,8 @@ jobs: echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_8GPU_ISP run: | @@ -195,6 +197,8 @@ jobs: echo "::add-mask::${{env.WORKSPACE_PREFIX}}" echo "::add-mask::$path_prefix" - uses: actions/checkout@v3 + with: + ref: ${{ github.event_name == 'schedule' && 'develop' || github.event_name == 'workflow_dispatch' && '' }} - name: training_8GPU_ISP_CKPT run: | diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 891e8ee3..ef20dc60 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -146,6 +146,14 @@ norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. num_experts=4, moe_use_residual=False, diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 9e0fc91d..891885c3 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_internlm2_train" -model_type="INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 92544 @@ -144,6 +144,14 @@ layer_norm_epsilon=1e-5, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ @@ -197,3 +205,18 @@ # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/configs/7B_isp_sft.py b/configs/7B_isp_sft.py index dc6408cd..7e88772f 100644 --- a/configs/7B_isp_sft.py +++ b/configs/7B_isp_sft.py @@ -146,6 +146,14 @@ layer_norm_epsilon=1e-5, use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ zero1 parallel (dict): diff --git a/configs/7B_llama2.py b/configs/7B_llama2.py index 9f464164..baacee63 100644 --- a/configs/7B_llama2.py +++ b/configs/7B_llama2.py @@ -6,8 +6,8 @@ SEQ_LEN = 2048 HIDDEN_SIZE = 4096 NUM_ATTENTION_HEAD = 32 -NUM_KV_ATTENTION_HEAD = 8 -MLP_RATIO = 3.5 +NUM_KV_ATTENTION_HEAD = 32 +MLP_RATIO = 2.6875 NUM_LAYER = 32 @@ -144,6 +144,14 @@ layer_norm_epsilon=1e-5, num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, ) """ diff --git a/configs/7B_sft.py b/configs/7B_sft.py index c2ae7078..eba87bcd 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -71,6 +71,10 @@ valid_folder=VALID_FOLDER, empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, + # whether use shared memory to load meta files + use_shm=False, + # when use shm, the default shm_path is "/dev/shm/metacache" + # shm_path="/dev/shm/metacache" ) grad_scaler = dict( @@ -100,6 +104,11 @@ reduce_bucket_size=512 * 1024 * 1024, # grad clipping clip_grad_norm=1.0, + # whether use new optm + use_split_tensor_optim=False, + # when use split tensor optm + # Perform all gather with a set of parameters of all_gather_size + all_gather_size=512 * 1024 * 1024, ) loss = dict( @@ -145,6 +154,14 @@ norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) """ diff --git a/configs/_base_/models/internlm2_1B.py b/configs/_base_/models/internlm2_1B.py index ff0569d3..7d063919 100644 --- a/configs/_base_/models/internlm2_1B.py +++ b/configs/_base_/models/internlm2_1B.py @@ -25,7 +25,7 @@ mlp_ratio=MLP_RATIO, multiple_of=MULTIPLE_OF, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_20B.py b/configs/_base_/models/internlm2_20B.py index 82b06249..1347b98f 100644 --- a/configs/_base_/models/internlm2_20B.py +++ b/configs/_base_/models/internlm2_20B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=True, + qk_interleaved=False, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/configs/_base_/models/internlm2_7B.py b/configs/_base_/models/internlm2_7B.py index 81f5acd4..94cae4b3 100644 --- a/configs/_base_/models/internlm2_7B.py +++ b/configs/_base_/models/internlm2_7B.py @@ -23,7 +23,7 @@ num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, mlp_ratio=MLP_RATIO, norm_type="rmsnorm", - adapt_hf=False, + qk_interleaved=True, apply_post_layer_norm=False, no_bias=True, layer_norm_epsilon=1e-5, diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp new file mode 100644 index 00000000..206fda39 --- /dev/null +++ b/csrc/rotary/rotary.cpp @@ -0,0 +1,37 @@ +#include +#include + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, + const torch::Tensor cos, const torch::Tensor sin, + torch::Tensor out1, torch::Tensor out2, + const bool conj); + +void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, + const torch::Tensor cos, const torch::Tensor sin, + torch::Tensor out1, torch::Tensor out2, + const bool conj) { + CHECK_DEVICE(x1); CHECK_DEVICE(x2); + CHECK_DEVICE(cos); CHECK_DEVICE(sin); + CHECK_DEVICE(out1); CHECK_DEVICE(out1); + TORCH_CHECK(x1.dtype() == x2.dtype()); + TORCH_CHECK(cos.dtype() == sin.dtype()); + TORCH_CHECK(out1.dtype() == out2.dtype()); + TORCH_CHECK(x1.dtype() == cos.dtype()); + TORCH_CHECK(x1.dtype() == out1.dtype()); + TORCH_CHECK(x1.sizes() == x2.sizes()); + TORCH_CHECK(cos.sizes() == sin.sizes()); + TORCH_CHECK(out1.sizes() == out2.sizes()); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x1.get_device()}; + + apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("apply_rotary", &apply_rotary, "Apply rotary embedding"); +} diff --git a/csrc/rotary/rotary_cuda.cu b/csrc/rotary/rotary_cuda.cu new file mode 100644 index 00000000..584b57c9 --- /dev/null +++ b/csrc/rotary/rotary_cuda.cu @@ -0,0 +1,41 @@ +#include +#include +#include + +void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, + const torch::Tensor cos, const torch::Tensor sin, + torch::Tensor out1, torch::Tensor out2, + const bool conj) { + auto iter = at::TensorIteratorConfig() + .add_output(out1) + .add_output(out2) + .add_input(x1) + .add_input(x2) + .add_input(cos) + .add_input(sin) + .check_all_same_dtype(false) + .promote_inputs_to_common_dtype(false) + .build(); + + if (!conj) { + AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { + at::native::gpu_kernel_multiple_outputs( + iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, + scalar_t sin) -> thrust::tuple { + scalar_t out1 = float(x1) * float(cos) - float(x2) * float(sin); + scalar_t out2 = float(x1) * float(sin) + float(x2) * float(cos); + return {out1, out2}; + }); + }); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2(at::kBFloat16, at::kHalf, x1.scalar_type(), "rotary_kernel", [&] { + at::native::gpu_kernel_multiple_outputs( + iter, [] GPU_LAMBDA (scalar_t x1, scalar_t x2, scalar_t cos, + scalar_t sin) -> thrust::tuple { + scalar_t out1 = float(x1) * float(cos) + float(x2) * float(sin); + scalar_t out2 = -float(x1) * float(sin) + float(x2) * float(cos); + return {out1, out2}; + }); + }); + } +} diff --git a/csrc/rotary/setup.py b/csrc/rotary/setup.py new file mode 100644 index 00000000..7809fc14 --- /dev/null +++ b/csrc/rotary/setup.py @@ -0,0 +1,131 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from setuptools import setup, find_packages +from wheel.bdist_wheel import bdist_wheel +import subprocess + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +raise_if_cuda_home_none("rotary_emb") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("rotary_emb is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +ext_modules.append( + CUDAExtension( + 'rotary_emb', [ + 'rotary.cpp', + 'rotary_cuda.cu', + ], + extra_compile_args={'cxx': ['-g', '-march=native', '-funroll-loops'], + 'nvcc': append_nvcc_threads([ + '-O3', '--use_fast_math', '--expt-extended-lambda' + ] + cc_flag) + } + ) +) + +class CustomBdistWheel(bdist_wheel): + def finalize_options(self): + bdist_wheel.finalize_options(self) + self.plat_name = 'manylinux2014_x86_64' + +setup( + name="rotary_emb", + version="0.5.2", + ext_modules=ext_modules, + cmdclass={ + "build_ext": BuildExtension, + "bdist_wheel": CustomBdistWheel + }, +) diff --git a/csrc/xentropy/README.md b/csrc/xentropy/README.md new file mode 100644 index 00000000..7970f393 --- /dev/null +++ b/csrc/xentropy/README.md @@ -0,0 +1,9 @@ +This CUDA extension implements optimized cross-entropy loss, adapted from Apex's +[Xentropy](https://github.com/NVIDIA/apex/tree/master/apex/contrib/xentropy). +We make it work for bfloat16 and support in-place backward to save memory. + +It has only been tested on A100s. + +```sh +cd csrc/xentropy && pip install . +``` diff --git a/csrc/xentropy/interface.cpp b/csrc/xentropy/interface.cpp new file mode 100644 index 00000000..41a783fd --- /dev/null +++ b/csrc/xentropy/interface.cpp @@ -0,0 +1,59 @@ +#include + +// CUDA forward declarations +std::vector softmax_xentropy_cuda( + const at::Tensor &input, + const at::Tensor &labels, + const float smoothing, + const int total_classes); + +at::Tensor softmax_xentropy_backward_cuda( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace, + const int total_classes); + +// C++ interface + +#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + +std::vector softmax_xentropy_forward( + const at::Tensor &input, + const at::Tensor &labels, + const float smoothing, + const int total_classes=-1) { + // For tensor parallel cross entropy with smoothing, we want to pass in the total number + // of classes so that smoothing can be applied correctly. If total_classes=-1, use the + // last dimension of the input tensor. + CHECK_INPUT(input); + CHECK_INPUT(labels); + + return softmax_xentropy_cuda(input, labels, smoothing, total_classes); +} + +at::Tensor softmax_xentropy_backward( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace, + const int total_classes=-1) { + CHECK_INPUT(grad_loss); + CHECK_INPUT(logits); + CHECK_INPUT(max_log_sum_exp); + CHECK_INPUT(labels); + + return softmax_xentropy_backward_cuda(grad_loss, logits, max_log_sum_exp, labels, + smoothing, inplace, total_classes); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &softmax_xentropy_forward, "Softmax cross entropy loss with label smoothing forward (CUDA)", py::arg("input"), py::arg("labels"), py::arg("smoothing"), py::arg("total_classes")=-1); + m.def("backward", &softmax_xentropy_backward, "Softmax cross entropy loss with label smoothing backward (CUDA)", py::arg("grad_loss"), py::arg("logits"), py::arg("max_log_sum_exp"), py::arg("labels"), py::arg("smoothing"), py::arg("inplace"), py::arg("total_classes")=-1); +} diff --git a/csrc/xentropy/setup.py b/csrc/xentropy/setup.py new file mode 100644 index 00000000..2c5174e9 --- /dev/null +++ b/csrc/xentropy/setup.py @@ -0,0 +1,147 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME +from setuptools import setup, find_packages +from wheel.bdist_wheel import bdist_wheel +import subprocess + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--xentropy") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("xentropy is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +ext_modules.append( + CUDAExtension( + name="xentropy_cuda_lib", + sources=[ + "interface.cpp", + "xentropy_kernel.cu" + ], + extra_compile_args={ + "cxx": ["-O3"] + generator_flag, + "nvcc": append_nvcc_threads( + ["-O3"] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +class CustomBdistWheel(bdist_wheel): + def finalize_options(self): + bdist_wheel.finalize_options(self) + self.plat_name = 'manylinux2014_x86_64' + +setup( + name="xentropy", + version="0.1.2", + description="Cross-entropy loss", + ext_modules=ext_modules, + cmdclass={ + "build_ext": BuildExtension, + "bdist_wheel": CustomBdistWheel + }, +) diff --git a/csrc/xentropy/xentropy_kernel.cu b/csrc/xentropy/xentropy_kernel.cu new file mode 100644 index 00000000..8d8836e6 --- /dev/null +++ b/csrc/xentropy/xentropy_kernel.cu @@ -0,0 +1,760 @@ +// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/xentropy/xentropy_kernel.cu +// TD [2022-09-17]: We make it work for bfloat16, and add an option to do the backward inplace (to save memory). +/** + * From PyTorch: + * + * Copyright (c) 2016- Facebook, Inc (Adam Paszke) + * Copyright (c) 2014- Facebook, Inc (Soumith Chintala) + * Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) + * Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) + * Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) + * Copyright (c) 2011-2013 NYU (Clement Farabet) + * Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) + * Copyright (c) 2006 Idiap Research Institute (Samy Bengio) + * Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + * + * From Caffe2: + * + * Copyright (c) 2016-present, Facebook Inc. All rights reserved. + * + * All contributions by Facebook: + * Copyright (c) 2016 Facebook Inc. + * + * All contributions by Google: + * Copyright (c) 2015 Google Inc. + * All rights reserved. + * + * All contributions by Yangqing Jia: + * Copyright (c) 2015 Yangqing Jia + * All rights reserved. + * + * All contributions from Caffe: + * Copyright(c) 2013, 2014, 2015, the respective contributors + * All rights reserved. + * + * All other contributions: + * Copyright(c) 2015, 2016 the respective contributors + * All rights reserved. + * + * Caffe2 uses a copyright model similar to Caffe: each contributor holds + * copyright over their contributions to Caffe2. The project versioning records + * all such contribution and copyright details. If a contributor wants to further + * mark their specific copyright on a particular contribution, they should + * indicate their copyright solely in the commit message of the change when it is + * committed. + * + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America + * and IDIAP Research Institute nor the names of its contributors may be + * used to endorse or promote products derived from this software without + * specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + */ +#include +#include +#include + +#include +#include + +// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h +// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ + switch(TYPE) \ + { \ + case at::ScalarType::Float: \ + { \ + using scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::Half: \ + { \ + using scalar_t_##LEVEL = at::Half; \ + __VA_ARGS__; \ + break; \ + } \ + case at::ScalarType::BFloat16: \ + { \ + using scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + break; \ + } \ + default: \ + AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ + } +// #else +// #define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, LEVEL, NAME, ...) \ +// switch(TYPE) \ +// { \ +// case at::ScalarType::Float: \ +// { \ +// using scalar_t_##LEVEL = float; \ +// __VA_ARGS__; \ +// break; \ +// } \ +// case at::ScalarType::Half: \ +// { \ +// using scalar_t_##LEVEL = at::Half; \ +// __VA_ARGS__; \ +// break; \ +// } \ +// default: \ +// AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ +// } +// #endif + +#define ALIGN_BYTES 16 + +using Tensor = at::Tensor; +using TensorList = at::TensorList; +using ScalarType = at::ScalarType; +using at::acc_type; + +template +struct LogSoftMaxForwardEpilogue { + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) + : logsum(max_input + std::log(sum)) {} + + __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_log_sum_exp) + : logsum(max_log_sum_exp) {} + + __device__ __forceinline__ OutT operator()(T input) const { + return static_cast(input - logsum); + } + + const AccumT logsum; +}; + +template +struct LogSoftMaxBackwardEpilogue { + __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) + : sum(sum) {} + + __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { + return static_cast(gradOutput - std::exp(static_cast(output)) * sum); + } + + const AccumT sum; +}; + + + +const int max_threads = 1024; + +inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { + uint64_t block_size = 1; + uint64_t max_block_size = std::min(dim_size / ILP, static_cast(max_threads)); + while (block_size < (max_block_size/2)) block_size *= 2; + // Launch at least a single warp - the kernel assumes that. + block_size = std::max(block_size, static_cast(32)); + return dim3(block_size); +} + +template +struct Add { + __device__ __forceinline__ T operator()(T a, T b) const { + return a + b; + } +}; + +template +struct Max { + __device__ __forceinline__ T operator()(T a, T b) const { + return a < b ? b : a; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// +// Regular kernel (fast when dim_size is large; requires inner_size == 1) +//////////////////////////////////////////////////////////////////////////////// + + +template +struct MaxFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { + return ::max(max, (AccumT)v); + } +}; + +template +struct AddFloat +{ + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + v; + } +}; + +template +struct SumExpFloat +{ + __device__ __forceinline__ SumExpFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + std::exp(v - max_k); + } + + const AccumT max_k; +}; + +template class Reduction, typename AccumT> +__device__ __forceinline__ AccumT +blockReduce(AccumT* smem, AccumT val, + const Reduction& r, + AccumT defaultVal) +{ + // To avoid RaW races from chaining blockReduce calls together, we need a sync here + __syncthreads(); + + smem[threadIdx.x] = val; + + __syncthreads(); + + AccumT warpVal = defaultVal; + + // First warp will perform per-warp reductions for the remaining warps + uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; + if (threadIdx.x < 32) { + int lane = threadIdx.x % 32; + if (lane < blockDim.x / 32) { +#pragma unroll + for (int i = 0; i < 32; ++i) { + warpVal = r(warpVal, smem[lane * 32 + i]); + } + __syncwarp(mask); + smem[lane] = warpVal; + } + } + + __syncthreads(); + + // First thread will perform a reduction of the above per-warp reductions + AccumT blockVal = defaultVal; + + if (threadIdx.x == 0) { + for (int i = 0; i < blockDim.x / 32; ++i) { + blockVal = r(blockVal, smem[i]); + } + smem[0] = blockVal; + } + + // Sync and broadcast + __syncthreads(); + return smem[0]; +} + +template class Reduction1, template class Reduction2, typename AccumT> +__device__ __forceinline__ void +blockReduce(AccumT* smem, + AccumT* reducVal1, + AccumT val1, + const Reduction1& r1, + AccumT defaultVal1, + AccumT* reducVal2, + AccumT val2, + const Reduction2& r2, + AccumT defaultVal2) +{ + // To avoid RaW races from chaining blockReduce calls together, we need a sync here + __syncthreads(); + + smem[threadIdx.x] = val1; + smem[blockDim.x + threadIdx.x] = val2; + + __syncthreads(); + + AccumT warpVal1 = defaultVal1; + AccumT warpVal2 = defaultVal2; + + // First warp will perform per-warp reductions for the remaining warps + uint32_t mask = (((uint64_t)1) << (blockDim.x / 32)) - 1; + if (threadIdx.x < 32) { + int lane = threadIdx.x % 32; + if (lane < blockDim.x / 32) { +#pragma unroll + for (int i = 0; i < 32; ++i) { + warpVal1 = r1(warpVal1, smem[lane * 32 + i]); + warpVal2 = r2(warpVal2, smem[lane * 32 + i + blockDim.x]); + } + __syncwarp(mask); + smem[lane] = warpVal1; + smem[lane + blockDim.x] = warpVal2; + } + } + + __syncthreads(); + + // First thread will perform a reduction of the above per-warp reductions + AccumT blockVal1 = defaultVal1; + AccumT blockVal2 = defaultVal2; + + if (threadIdx.x == 0) { + for (int i = 0; i < blockDim.x / 32; ++i) { + blockVal1 = r1(blockVal1, smem[i]); + blockVal2 = r2(blockVal2, smem[i + blockDim.x]); + } + smem[0] = blockVal1; + smem[blockDim.x] = blockVal2; + } + + // Sync and broadcast + __syncthreads(); + *reducVal1 = smem[0]; + *reducVal2 = smem[blockDim.x]; + __syncthreads(); +} + +template class Reduction, int ILP, typename T, typename AccumT> +__device__ __forceinline__ AccumT +ilpReduce(int shift, + T* data, + int size, + const Reduction& r, + AccumT defaultVal) +{ + typedef typename std::aligned_storage::type LoadT; + AccumT threadVal = defaultVal; + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + data -= shift; + size += shift; + if(threadIdx.x >= shift){ + threadVal = r(threadVal, data[offset]); + } + size -= blockDim.x; + data += blockDim.x; + } + int last = size % (ILP * blockDim.x); + + T v[ILP]; + LoadT* value = reinterpret_cast(&v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *value = reinterpret_cast(data)[offset]; + + for (int j = 0; j < ILP; ++j) { + threadVal = r(threadVal, v[j]); + } + } + + offset = size - last + threadIdx.x; + // Epilogue + for (; offset < size; offset += blockDim.x) + threadVal = r(threadVal, data[offset]); + + return threadVal; +} + +template class Reduction1, template class Reduction2, int ILP, typename T, typename AccumT> +__device__ __forceinline__ void +ilpReduce(int shift, + T* data, + int size, + AccumT* reducVal1, + const Reduction1& r1, + AccumT defaultVal1, + AccumT* reducVal2, + const Reduction2& r2, + AccumT defaultVal2) +{ + typedef typename std::aligned_storage::type LoadT; + + AccumT threadVal1 = defaultVal1; + AccumT threadVal2 = defaultVal2; + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + data -= shift; + size += shift; + if(threadIdx.x >= shift){ + threadVal1 = r1(threadVal1, data[offset]); + threadVal2 = r2(threadVal2, data[offset]); + } + size -= blockDim.x; + data += blockDim.x; + } + int last = size % (ILP * blockDim.x); + + T v[ILP]; + LoadT* value = reinterpret_cast(&v); + + for (; offset * ILP < (size - last); offset += blockDim.x) { + *value = reinterpret_cast(data)[offset]; + + for (int j = 0; j < ILP; ++j) { + threadVal1 = r1(threadVal1, v[j]); + threadVal2 = r2(threadVal2, v[j]); + } + } + + offset = size - last + threadIdx.x; + // Epilogue + for (; offset < size; offset += blockDim.x) { + threadVal1 = r1(threadVal1, data[offset]); + threadVal2 = r2(threadVal2, data[offset]); + } + + *reducVal1 = threadVal1; + *reducVal2 = threadVal2; +} + +template class Epilogue> +__global__ void +cunn_SoftMaxXEntropyForward( + accscalar_t *losses, + outscalar_t *max_log_sum_exp, + scalar_t *input, + int64_t *labels, + int64_t classes, + const float smoothing, + const int total_classes) +{ + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + // forward pointers to batch[blockIdx.x] + // each block handles a sample in the mini-batch + input += blockIdx.x * classes; + //output += blockIdx.x * classes; + const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); + + int64_t label = labels[blockIdx.x]; + + // find the max and sum + accscalar_t threadMax, threadSum, max_k, sum_k; + ilpReduce( + shift, input, classes, + &threadMax, MaxFloat(), + -at::numeric_limits::max(), + &threadSum, AddFloat(), + static_cast(0)); + + blockReduce( + sdata, + &max_k, threadMax, Max(), + -at::numeric_limits::max(), + &sum_k, threadSum, Add(), + static_cast(0)); + + accscalar_t threadExp = ilpReduce(shift, input, classes, SumExpFloat(max_k), static_cast(0)); + accscalar_t sumAll = blockReduce( + sdata, threadExp, Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + // calculate per element loss with label smoothing + // reserve max + log_sum_exp for bprop + if (threadIdx.x == 0) { + accscalar_t lse = max_k + std::log(sumAll); + accscalar_t log_prob = (label >= 0 && label < classes) ? epilogue(static_cast(input[label])) : 0.f; + losses[blockIdx.x] = (lse - sum_k / total_classes) * smoothing - log_prob * (1 - smoothing); + max_log_sum_exp[blockIdx.x] = lse; + } +} + +template +__device__ __forceinline__ void +apply(scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes, + const int total_classes) +{ + accscalar_t smooth_positives = 1.0 - smoothing; + accscalar_t smooth_negatives = smoothing / total_classes; + accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; + int64_t label = labels[blockIdx.x]; + accscalar_t coeff = max_log_sum_exp[blockIdx.x]; + + int offset = threadIdx.x; + int last = classes % (ILP * blockDim.x); + + for (; offset < classes - last; offset += blockDim.x * ILP) { + accscalar_t tmpLogits[ILP]; + +#pragma unroll + for (int j = 0; j < ILP; ++j) { + tmpLogits[j] = static_cast(logits[offset + j * blockDim.x]); + } + +#pragma unroll + for (int j = 0; j < ILP; ++j) + gradInput[offset + j * blockDim.x] = tmpGradOutput * ( + std::exp(tmpLogits[j] - coeff) - static_cast( + (offset + j * blockDim.x == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + + for (; offset < classes; offset += blockDim.x) + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast((offset == label) ? 1 : 0) * + smooth_positives - smooth_negatives); +} + + +template +__device__ __forceinline__ void +aligned_apply(int shift, + scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes, + const int total_classes) +{ + accscalar_t smooth_positives = 1.0 - smoothing; + accscalar_t smooth_negatives = smoothing / total_classes; + accscalar_t tmpGradOutput = gradOutput[blockIdx.x]; + int64_t label = labels[blockIdx.x]; + accscalar_t coeff = max_log_sum_exp[blockIdx.x]; + + int offset = threadIdx.x; + + // shift and do 1 + if(shift > 0){ + logits -= shift; + gradInput -= shift; + classes += shift; + if(threadIdx.x >= shift){ + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast(((offset - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + classes -= blockDim.x; + gradInput += blockDim.x; + logits += blockDim.x; + shift -= blockDim.x; + } + + int last = classes % (ILP * blockDim.x); + + typedef typename std::aligned_storage::type LoadT; + // input + scalar_t v[ILP]; + LoadT* value = reinterpret_cast(&v); + // output + scalar_t r[ILP]; + LoadT* result = reinterpret_cast(&r); + + for (; offset * ILP < (classes - last); offset += blockDim.x) { + *value = reinterpret_cast(logits)[offset]; + +#pragma unroll + for (int j = 0; j < ILP; ++j) { + r[j] = tmpGradOutput * (std::exp( + static_cast(v[j]) - coeff) - + static_cast(((ILP * offset + j - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + } + reinterpret_cast(gradInput)[offset] = *result; + } + + offset = classes - last + threadIdx.x; + for (; offset < classes; offset += blockDim.x) + gradInput[offset] = tmpGradOutput * (std::exp( + static_cast(logits[offset]) - coeff) - + static_cast(((offset - shift) == label) ? 1 : 0) * + smooth_positives - smooth_negatives); + +} + +template class Epilogue> +__global__ void +cunn_SoftMaxXEntropyBackward( + scalar_t *gradInput, + scalar_t *logits, + outscalar_t *max_log_sum_exp, + outscalar_t *gradOutput, + int64_t *labels, + const float smoothing, + int classes, + const int total_classes) +{ + gradInput += blockIdx.x * classes; + logits += blockIdx.x * classes; + + // Do vectorized load/store when input/output have same alignment + const int shift = ((uint64_t)logits) % ALIGN_BYTES / sizeof(scalar_t); + const int shift_ = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); + if (shift == shift_){ + aligned_apply(shift, gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); + } + else { + apply(gradInput, logits, max_log_sum_exp, gradOutput, labels, smoothing, classes, total_classes <= 0 ? classes : total_classes); + } + +} + +template class Epilogue> +std::vector host_softmax_xentropy( + const Tensor & input_, + const Tensor & labels_, + const float smoothing, + const int total_classes) { + // For tensor parallel cross entropy with smoothing, we want to pass in the total number + // of classes so that smoothing can be applied correctly. If total_classes=-1, use the + // last dimension of the input tensor. + AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long"); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)input_.get_device()}; + + auto input = input_.contiguous(); + Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float)); + Tensor losses = at::empty_like(labels_, input_.options().dtype(ScalarType::Float)); + + static_assert(std::is_same, float>::value || + std::is_same, double>::value, + "accscalar_t for half should be float or double"); + AT_ASSERTM(input.dim() == 2, "Currently only 2 dim input supported"); + AT_ASSERTM(labels_.dim() == 1, "Labels should be 1 dimensional"); + AT_ASSERTM(input.size(0) == labels_.size(0), "Input and label should have same number of examples"); + AT_ASSERTM(input.numel() > 0, "Number of classes in input should not be 0"); + + const int64_t dim = 1; + int64_t outer_size = 1; + int64_t dim_size = input.size(dim); + int64_t inner_size = 1; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + for (int64_t i = 0; i < dim; ++i) + outer_size *= input.size(i); + for (int64_t i = dim + 1; i < input.dim(); ++i) + inner_size *= input.size(i); + // This kernel spawns a block per each element in the batch. + // XXX: it assumes that inner_size == 1 + TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); + + dim3 grid(outer_size); + + using namespace at; + DISPATCH_FLOAT_AND_HALF_AND_BF16(input.scalar_type(), 0, "host_softmax_xentropy", + using accscalar_t = at::acc_type; + const int ILP = sizeof(float4)/sizeof(scalar_t_0); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxXEntropyForward + <<>>( + losses.data_ptr(), max_log_sum_exp.data_ptr(), + input.data_ptr(), labels_.data_ptr(), + dim_size, smoothing, total_classes <= 0 ? dim_size : total_classes + ); + ); + + C10_CUDA_CHECK(cudaGetLastError()); + + std::vector ret = {losses, max_log_sum_exp}; + return ret; +} + +template class Epilogue> +Tensor host_softmax_xentropy_backward( + const at::Tensor &grad_loss, + at::Tensor &logits_, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + bool inplace, + const int total_classes) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()}; + + const int64_t dim = 1; + Tensor gI = inplace ? logits_ : at::empty_like(logits_); + if (grad_loss.numel() == 0) { + return gI; + } + + auto grad = grad_loss.contiguous(); + auto logits = logits_.contiguous(); + + static_assert(std::is_same, float>::value || + std::is_same, double>::value, + "accscalar_t for half should be float or double"); + if (grad.dim() == 0) grad = grad.view(1); + + AT_ASSERTM(logits_.dim() == 2, "Currently only 2 dim input supported"); + AT_ASSERTM(labels.dim() == 1, "Labels should be 1 dimensional"); + AT_ASSERTM(logits_.numel() > 0, "Number of classes in input should not be 0"); + AT_ASSERTM(logits_.size(0) == labels.size(0), "Input and label should have same number of examples"); + AT_ASSERTM(labels.size(0) == grad.size(0), "Label and loss should have same number of examples"); + + int64_t outer_size = 1; + int64_t dim_size = logits.size(dim); + int64_t inner_size = 1; + for (int64_t i = 0; i < dim; ++i) + outer_size *= logits.size(i); + for (int64_t i = dim + 1; i < logits.dim(); ++i) + inner_size *= logits.size(i); + // See descriptions of kernels above. + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + TORCH_CHECK(inner_size == 1, "Currently only inner size 1 supported"); + + dim3 grid(outer_size); + + DISPATCH_FLOAT_AND_HALF_AND_BF16(gI.scalar_type(), 0, "host_softmax_xentropy_backward", + using accscalar_t = acc_type; + const int ILP = sizeof(float4)/sizeof(scalar_t_0); + dim3 block = SoftMax_getBlockSize(ILP, dim_size); + cunn_SoftMaxXEntropyBackward + <<>>( + gI.data_ptr(), logits.data_ptr(), + max_log_sum_exp.data_ptr(), + grad.data_ptr(), labels.data_ptr(), + smoothing, dim_size, total_classes + ); + ); + + C10_CUDA_CHECK(cudaGetLastError()); + return gI; +} + +std::vector softmax_xentropy_cuda(const Tensor &input, const Tensor &labels, const float smoothing, const int total_classes){ + return host_softmax_xentropy(input, labels, smoothing, total_classes); +} + +at::Tensor softmax_xentropy_backward_cuda( + const at::Tensor &grad_loss, + at::Tensor &logits, + const at::Tensor &max_log_sum_exp, + const at::Tensor &labels, + const float smoothing, + const bool inplace, + const int total_classes) { + AT_ASSERTM((grad_loss.scalar_type() == ScalarType::Float), "expected grad types to be at::Float"); + return host_softmax_xentropy_backward(grad_loss, logits, max_log_sum_exp, labels, smoothing, inplace, total_classes); +} diff --git a/doc/code-docs/source/checkpoint.rst b/doc/code-docs/source/checkpoint.rst index aab161e9..c01c6950 100644 --- a/doc/code-docs/source/checkpoint.rst +++ b/doc/code-docs/source/checkpoint.rst @@ -16,7 +16,7 @@ CheckpointManager - ``checkpoint_every``: 检查点存储频率,参数类型 ``int``,默认为: ``50``。 -- ``load_ckpt_folder``: 初始化检查点/权重加载路径。参数类型 ``str``,默认为: ``None``,详见 :ref:`load-ckpt-folder`。 +- ``load_ckpt_info``: 初始化检查点/权重加载信息。参数类型 ``dict``,默认为: ``None``,详见 :ref:`load-ckpt-info`。 - ``async_upload``: 是否开启异步上传,默认值为:``False``,详见 :ref:`asyncupload`。 @@ -36,8 +36,8 @@ CheckpointManager ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), - auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'. + load_ckpt_info=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"), + auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_info'. checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3, volc and oss2 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. @@ -52,7 +52,7 @@ CheckpointManager 加载与存储格式约定 -------------------------- -.. _load-ckpt-folder: +.. _load-ckpt-info: (1) 路径格式约定 ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -75,10 +75,10 @@ InternEvo对config中出现的所有存储路径都遵循以下的路径格式 -(2) 模型加载(load_ckpt_folder)格式约定 +(2) 模型加载(load_ckpt_info)格式约定 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -load_ckpt_folder 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type`` 。 +load_ckpt_info 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_type`` 。 - ``path``:给出了检查点/初始化模型权重的加载路径(path的格式见下小节) @@ -92,17 +92,23 @@ load_ckpt_folder 由三个字段组成, ``path`` 、 ``content`` 和 ``ckpt_ty - ``ckpt_type``:表示加载的模型权重类型,目前支持的字段包括: - - ``internlm``:internevo约定的checkpoint存储格式。 + - ``internevo``:internevo约定的checkpoint存储格式。 + - ``llama``:llama约定的checkpoint存储格式。 + - ``hf_llama``:huggingface llama约定的checkpoint存储格式。 + - ``hf_model``:适用于加载huggingface所有模型的checkpoint存储格式。 下面给出两个例子: .. code-block:: python # 从文件存储相对路径 ckpt_model 中加载已有模型权重初始化模型,适合 sft 等训练初始化 - load_ckpt_folder= dict(path="local:ckpt_model", content=["model",], ckpt_type="internlm") + load_ckpt_info = dict(path="local:ckpt_model", content=("model",), ckpt_type="internevo") # 从文件存储相对路径 ckpt_model 中加载所有的状态,适合断点续训的场景 - load_ckpt_folder= dict(path="local:ckpt_model", content=["all",], ckpt_type="internlm") + load_ckpt_info = dict(path="local:ckpt_model", content=("all",), ckpt_type="internevo") + + # 从 huggingface 下载指定模型,加载checkpoint + load_ckpt_info = dict(path="internlm/internlm-7b", content=("model",), ckpt_type="hf_model") .. _asyncupload: @@ -144,13 +150,13 @@ config.ckpt 中相关的参数: 检查点自动加载功能的目的是在resume训练时,自动加载 ``save_ckpt_folder`` 路径下最新的检查点(包括snapshot检查点)。配合上自动重启机制,可以实现无人干预的任务自动恢复。 -该功能默认开启,所以要注意如果需要加载 ``load_ckpt_folder`` 路径下的模型权重,要将 ``auto_resume`` 设置为 False,否则可能会产生预期外的行为。 +该功能默认开启,所以要注意如果需要加载 ``load_ckpt_info`` 路径下的模型权重,要将 ``auto_resume`` 设置为 False,否则可能会产生预期外的行为。 config.ckpt 中相关的参数: - ``auto_resume``: 是否开启检查点自动恢复。参数类型 ``bool``,默认为 ``True``。 -``auto_resume`` 如果为True,则尝试从 ``save_ckpt_folder`` 路径中自动加载最新的ckpt,如果找不到,则从step 0开始训练。如果为False,则尝试从 ``load_ckpt_folder`` 中加载模型参数。 +``auto_resume`` 如果为True,则尝试从 ``save_ckpt_folder`` 路径中自动加载最新的ckpt,如果找不到,则从step 0开始训练。如果为False,则尝试从 ``load_ckpt_info`` 中加载模型参数。 .. _stopfile: diff --git a/doc/code-docs/source/initialize.rst b/doc/code-docs/source/initialize.rst index d1e7511b..ff3985ee 100644 --- a/doc/code-docs/source/initialize.rst +++ b/doc/code-docs/source/initialize.rst @@ -77,17 +77,19 @@ InternEvo 在配置文件中使用字段 ``model_type`` 和 ``model`` 来控制 - 字段 ``model_type`` 指明了要初始化的模型类型 - 字段 ``model`` 中的参数指定了在模型初始化过程中的参数设置 -值得注意的是,用户可以定义新的模型类型,并使用装饰器 ``@MODEL_INITIALIZER.register_module`` 注册模型的初始化函数,其中 ``MODEL_INITIALIZER`` 是类 ``internlm.util.registry.Registry`` 的一个实例化对象,示例如下所示: +值得注意的是,用户可以定义新的模型类型,并通过 ``register_module`` 注册模型的初始化函数,示例如下所示: .. code-block:: python - MODEL_TYPE = "NEW_MODEL" + model_initializer = Registry("model_initializer") - @MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) - def build_new_model_with_cfg(*args, **kwargs): + def register_model_initializer() -> None: + model_initializer.register_module("INTERNLM", InternLM1) .. _InternLM-optim-init: +其中,"INTERNLM"为新的模型类型,InternLM1为新模型的入口函数。 + 优化器初始化 ------------------------- diff --git a/doc/en/install.md b/doc/en/install.md index 8890c982..eae4a12c 100644 --- a/doc/en/install.md +++ b/doc/en/install.md @@ -1,17 +1,43 @@ ## Installation - ### Environment Preparation -The required packages and corresponding version are shown as follows: - Python == 3.10 +- GPU with Ampere or Hopper architecture (such as H100, A100) +- Linux OS + +### Install through pip +It is recommended to build a Python-3.10 virtual environment using conda, command is as follows: +```bash +conda create --name internevo python=3.10 -y +conda activate internevo +``` + +First, install the specified versions of torch, torchvision, torchaudio, and torch-scatter: +```bash +pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0+cu118 +pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu118.html +``` + +Install InternEvo: +```bash +pip install InternEvo +``` + +Install flash-attention (version v2.2.1): +```bash +pip install flash-attn==2.2.1 +``` + +Install Apex (version 23.05): +Apex is an optional package; If you choose to install it, follow the instructions in Install through source code. + +### Install through source code +#### Required Packages +The required packages and corresponding version are shown as follows: - GCC == 10.2.0 - MPFR == 4.1.0 - CUDA >= 11.8 - Pytorch >= 2.1.0 - Transformers >= 4.28.0 -- Flash-Attention >= v2.2.1 -- Apex == 23.05 -- GPU with Ampere or Hopper architecture (such as H100, A100) -- Linux OS After installing the above dependencies, some system environment variables need to be updated: ```bash @@ -24,15 +50,7 @@ export CC=${GCC_HOME}/bin/gcc export CXX=${GCC_HOME}/bin/c++ ``` -### Environment Installation -Install through pip command: -```bash -pip install InternEvo==xxx (xxx is the version you want to install) -``` -This installs only InternEvo project, do not involve the required packages or submodules. - -Or install through source code: - +#### Install Procedure Clone the project `InternEvo` and its dependent submodules from the github repository, as follows: ```bash git clone git@github.com:InternLM/InternEvo.git --recurse-submodules @@ -52,10 +70,8 @@ Install flash-attention (version v2.2.1): cd ./third_party/flash-attention python setup.py install cd ./csrc -cd fused_dense_lib && pip install -v . -cd ../xentropy && pip install -v . +cd xentropy && pip install -v . cd ../rotary && pip install -v . -cd ../layer_norm && pip install -v . cd ../../../../ ``` @@ -69,6 +85,11 @@ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation - cd ../../ ``` +### Additional Installation +```bash +pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE need +``` + ### Environment Image Users can use the provided dockerfile combined with docker.Makefile to build their own images, or obtain images with InternEvo runtime environment installed from https://hub.docker.com/r/internlm/internevo/tags. @@ -100,3 +121,26 @@ The default directory in the container is `/InternEvo`, please start training ac ```bash torchrun --nproc_per_node=8 --nnodes=1 train.py --config configs/7B_sft.py --launcher torch ``` + +## Environment Installation (NPU) +For machines with NPU, the version of the installation environment can refer to that of GPU. Use Ascend's torch_npu instead of torch on NPU machines. Additionally, Flash-Attention and Apex are no longer supported for installation on NPU. The corresponding functionalities have been internally implemented in the InternEvo codebase. The following tutorial is only for installing torch_npu. + +Official documentation for torch_npu: https://gitee.com/ascend/pytorch + +### Example Installation of Environment +- Linux OS +- torch_npu: v2.1.0-6.0.rc1 +- NPU card: 910B + +#### Installing torch_run +Refer to the documentation: https://gitee.com/ascend/pytorch/tree/v2.1.0-6.0.rc1/ + +You can try installing according to the methods in the documentation or download the specified version of torch_npu from https://gitee.com/ascend/pytorch/releases for installation, as shown below: + +```bash +pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu +pip3 install pyyaml +pip3 install setuptools +wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc1-pytorch2.1.0/torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +pip install torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +``` diff --git a/doc/install.md b/doc/install.md index 8201d19b..f1493472 100644 --- a/doc/install.md +++ b/doc/install.md @@ -1,17 +1,43 @@ ## 环境安装 - ### 环境准备 -首先,需要安装的依赖包及对应版本列表如下: - Python == 3.10 +- Ampere或者Hopper架构的GPU (例如H100, A100) +- Linux OS + +### pip方式安装 +推荐使用 conda 构建一个 Python-3.10 的虚拟环境,命令如下: +```bash +conda create --name internevo-env python=3.10 -y +conda activate internevo-env +``` + +首先,安装指定版本的torch, torchvision, torchaudio以及torch-scatter: +```bash +pip install --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.1.0+cu118 torchvision==0.16.0+cu118 torchaudio==2.1.0+cu118 +pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu118.html +``` + +安装InternEvo: +```bash +pip install InternEvo +``` + +安装 flash-attention (version v2.2.1): +```bash +pip install flash-attn==2.2.1 +``` + +安装 Apex (version 23.05): +apex为非必须安装包,如果安装,参考下述源码方式安装。 + +### 源码方式安装 +#### 依赖包 +首先,需要安装的依赖包及对应版本列表如下: - GCC == 10.2.0 - MPFR == 4.1.0 - CUDA >= 11.8 - Pytorch >= 2.1.0 - Transformers >= 4.28.0 -- Flash-Attention >= v2.2.1 -- Apex == 23.05 -- Ampere或者Hopper架构的GPU (例如H100, A100) -- Linux OS 以上依赖包安装完成后,需要更新配置系统环境变量: ```bash @@ -24,15 +50,8 @@ export CC=${GCC_HOME}/bin/gcc export CXX=${GCC_HOME}/bin/c++ ``` -### 环境安装 -可以通过pip命令直接安装,命令如下: -```bash - -pip install InternEvo==xxx (xxx是需要安装的版本号信息) -``` -这种方式仅安装了InternEvo项目,其依赖的软件包及子模块尚未安装。 - -也可以通过源码安装,将项目`InternEvo`及其依赖子模块,从 github 仓库中 clone 下来,命令如下: +#### 安装过程 +将项目`InternEvo`及其依赖子模块,从 github 仓库中 clone 下来,命令如下: ```bash git clone git@github.com:InternLM/InternEvo.git --recurse-submodules ``` @@ -51,10 +70,8 @@ pip install -r requirements/runtime.txt cd ./third_party/flash-attention python setup.py install cd ./csrc -cd fused_dense_lib && pip install -v . -cd ../xentropy && pip install -v . +cd xentropy && pip install -v . cd ../rotary && pip install -v . -cd ../layer_norm && pip install -v . cd ../../../../ ``` @@ -68,6 +85,11 @@ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation - cd ../../ ``` +### 额外安装 +```bash +pip install git+https://github.com/databricks/megablocks@v0.3.2 # MOE相关 +``` + ### 环境镜像 用户可以使用提供的 dockerfile 结合 docker.Makefile 来构建自己的镜像,或者也可以从 https://hub.docker.com/r/internlm/internevo/tags 获取安装了 InternEvo 运行环境的镜像。 @@ -99,3 +121,28 @@ docker run --gpus all -it -m 500g --cap-add=SYS_PTRACE --cap-add=IPC_LOCK --shm- ```bash torchrun --nproc_per_node=8 --nnodes=1 train.py --config configs/7B_sft.py --launcher torch ``` + +## 环境安装(NPU) +在搭载NPU的机器上安装环境的版本可参考GPU,在NPU上使用昇腾torch_npu代替torch,同时Flash-Attention和Apex不再支持安装,相应功能已由InternEvo代码内部实现。以下教程仅为torch_npu安装。 + +torch_npu官方文档:https://gitee.com/ascend/pytorch + +### 环境安装样例 +- Linux OS +- torch_npu: v2.1.0-6.0.rc1 +- NPU显卡:910B + + +#### 安装torch_run + +参考文档:https://gitee.com/ascend/pytorch/tree/v2.1.0-6.0.rc1/ + +安装时可尝试根据文档内方式安装,或者从 https://gitee.com/ascend/pytorch/releases 下载指定版本torch_npu进行安装,如下所示: + +```bash +pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu +pip3 install pyyaml +pip3 install setuptools +wget https://gitee.com/ascend/pytorch/releases/download/v6.0.rc1-pytorch2.1.0/torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +pip install torch_npu-2.1.0.post3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl +``` diff --git a/doc/usage.md b/doc/usage.md index a1dcef62..ad78fe2e 100644 --- a/doc/usage.md +++ b/doc/usage.md @@ -83,32 +83,36 @@ MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. - # load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"), - load_ckpt_folder="local:llm_ckpts/", # 'load_ckpt_info' setting guide: # 1. the 'path' indicate ckpt path, # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" - # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported. - load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "llama", "hf_llama", "hf_model". + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internevo"), + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=True, checkpoint_every=CHECKPOINT_EVERY, async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. ) -TRAIN_FOLDER = "/path/to/dataset" -VALID_FOLDER = "/path/to/dataset" +TRAIN_FOLDER = None # "/path/to/dataset" +VALID_FOLDER = None # "/path/to/dataset" data = dict( seq_len=SEQ_LEN, # micro_num means the number of micro_batch contained in one gradient update @@ -122,13 +126,22 @@ data = dict( pack_sample_into_one=False, total_steps=50000, skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. rampup_batch_size="", # Datasets with less than 50 rows will be discarded min_length=50, - # train_folder=TRAIN_FOLDER, - # valid_folder=VALID_FOLDER, - empty_cache_and_diag_interval=10, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, diag_outlier_ratio=1.1, + # whether use shared memory to load meta files + use_shm=False, + # when use shm, the default shm_path is "/dev/shm/metacache" + # shm_path="/dev/shm/metacache" ) grad_scaler = dict( @@ -153,11 +166,16 @@ grad_scaler = dict( hybrid_zero_optimizer = dict( # Enable low_level_optimzer overlap_communication overlap_sync_grad=True, - overlap_sync_param=True, + overlap_sync_param=False, # bucket size for nccl communication params reduce_bucket_size=512 * 1024 * 1024, # grad clipping clip_grad_norm=1.0, + # whether use new optm + use_split_tensor_optim=False, + # when use split tensor optm + # Perform all gather with a set of parameters of all_gather_size + all_gather_size=512 * 1024 * 1024, ) loss = dict( @@ -187,6 +205,7 @@ beta2_scheduler = dict( cur_iter=-1, ) +use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] num_attention_heads=NUM_ATTENTION_HEAD, @@ -198,28 +217,50 @@ model = dict( num_layers=NUM_LAYER, mlp_ratio=MLP_RATIO, apply_post_layer_norm=False, - dtype="torch.float16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" + dtype="torch.bfloat16", # Support: "torch.float16", "torch.half", "torch.bfloat16", "torch.float32", "torch.tf32" norm_type="rmsnorm", layer_norm_epsilon=1e-5, use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) """ -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. pipeline parallel (dict): 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. """ parallel = dict( - zero1=8, + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + weight=dict(size=1, overlap=True, memory_pool=True), ) cudnn_deterministic = False @@ -231,6 +272,10 @@ monitor = dict( enable_feishu_alert=DO_ALERT, feishu_alert_address=None, # feishu webhook to send alert message light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, ), ) ``` @@ -264,23 +309,56 @@ data = dict( ``` 数据集的详细内容可参考``数据准备``模块相关的介绍。 +同时,也支持huggingface格式的数据集处理。 +train_folder设置为huggingface上可以通过load_dataset直接下载的数据集路径,如:"roneneldan/TinyStories" +在data中,需要新增type及tokenizer_path字段,标示数据集是huggingface格式,并指定tokenizer路径,如: +```python +TRAIN_FOLDER = "roneneldan/TinyStories" +SEQ_LEN = 2048 +data = dict( + type="hf", + tokenizer_path="internlm/internlm-7b", + seq_len=SEQ_LEN, # 数据样本长度,默认值为 2048 + micro_num=1, # micro_num 是指在一次模型参数更新中会处理的 micro_batch 的数目,默认值为 1 + micro_bsz=1, # packed_length = micro_bsz * SEQ_LEN,为一次处理的 micro_batch 的数据大小,默认值为 1 + total_steps=50000, # 总的所需执行的 step 的数目,默认值为 50000 + min_length=50, # 若数据集文件中,数据行数少于50,将会被废弃 + train_folder=TRAIN_FOLDER, # 数据集文件路径,默认值为 None;若 train_folder 为空,则以自动生成的随机数据集 +进行训练测试 + pack_sample_into_one=False, # 数据整理的逻辑,决定是按照 seq_len 维度或者是 sequence 的真实长度来进行attention计算 +) +``` + #### 模型配置 如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置: ```python SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt" -LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt" +# MODEL_ONLY_FOLDER = "internlm/internlm-7b" +MODEL_ONLY_FOLDER = "local:/path/to/load/resume/ckpt" ckpt = dict( + enable_save_ckpt=True, # 是否开启保存 checkpoint 功能 save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径 checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练 # content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all" - # ckpt_type 表示加载的模型类型,目前支持: "internlm" + # ckpt_type 表示加载的模型类型,目前支持: "internevo", "llama", "hf_llama", "hf_model" + # 其中,"hf_model"类型表示从huggingface上下载模型加载ckpt,MODEL_ONLY_FOLDER需要设置为可以 + # 通过AutoModel直接加载的模型路径,如:"internlm/internlm-7b" load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"), + # 'auto_resume' 旨在在遇到由硬件故障引起的训练中断/挂起时,自动从 'save_ckpt_folder' 加载最新的检查点, + # 使用调度系统(例如 k8s/slurm)在训练重启时自动重启机制。 + # 请注意,如果未设置 auto_resume(其默认值为 True),它将不会默认加载 load_ckpt_info 中指定的检查点路径。 + # 如果你想从另一个模型初始化你的模型权重,必须将 auto_resume 设置为 False。 + # 如果你想从头开始训练,请将 auto_resume 设置为 False 并将 'load_ckpt_info' 设置为 None。 + auto_resume=False, + async_upload=True, # 异步检查点上传。(仅适用于 boto3 检查点) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # 异步上传期间临时文件的路径。 + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # 快照检查点保存频率。 ) ``` 注意: -- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上 +- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上;若无前缀,为huggingface上可以直接下载的模型路径。 模型相关关键参数配置如下所示: ```python @@ -306,7 +384,7 @@ model = dict( layer_norm_epsilon=1e-5, ) ``` -注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`utils/registry.py`下的`MODEL_INITIALIZER`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。 +注意:用户可自定义模型类型名和模型结构,并配置相对应的模型参数。通过`internlm/model/registry.py`下的`model_initializer`对象进行模型初始化函数接口注册,在训练主函数`train.py`中初始化模型时,可通过`model_type`配置获取指定的模型初始化接口函数。 *如果基于 InternLM 7B继续训练,可以参考 [ModelZoo](https://github.com/InternLM/InternLM/tree/main#model-zoo) 中 OpenXLab 链接下载权重* @@ -315,21 +393,32 @@ model = dict( 训练并行配置样例如下: ```python parallel = dict( - zero1=8, - tensor=1, + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), pipeline=dict(size=1, interleaved_overlap=True), - sequence_parallel=False, + weight=dict(size=1, overlap=True, memory_pool=True), ) ``` -- zero1:zero 并行策略,分如下三种情况,默认值为 -1 - - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 - - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 - - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 -- tensor:张量并行大小,通常是每个节点的 GPU 数量,默认值为 1 -- pipeline:流水线并行策略 - - size:流水线并行大小,默认值为 1 - - interleaved_overlap:bool 类型,交错式调度时,开启或关闭通信优化,默认值为关闭 -- sequence_parallel:是否开启序列化并行,默认值为 False +- zero1(字典): + 1. size: 整数 + - 当`zero1 <= 0`,则 zero1 进程组的大小等于数据并行进程组的大小,因此优化器状态参数将在数据并行范围内分配 + - 当`zero1 == 1`,则不使用 zero1 ,所有数据并行组保留完整的优化器状态参数 + - 当`zero1 > 1`且`zero1 <= data_parallel_world_size`,则 zero1 进程组是数据并行进程组的子集 + 2. fsdp: 布尔值,启用/禁用torch的完全分片数据并行,默认为False。 +- tensor(字典): + 1. size: 整数,张量并行的大小。 + 2. mode: 字符串,张量并行模式,应该是 ['mtp', 'msp', 'fsp', 'isp'] 中的一个, + - 默认为 'mtp',意味着没有序列并行的纯Megatron张量并行。 + - msp: 带序列并行的Megatron张量并行,序列并行大小 = 张量并行大小。 + - fsp: 通过flash-attn带序列并行的张量并行,序列并行大小 = 张量并行大小。 + - isp: 定制的内部序列并行,不带张量并行,可以与权重并行一起使用。 +- pipeline(字典): + 1. size: 整数,流水线并行的大小。 + 2. interleaved_overlap: 布尔值,启用/禁用在使用交错流水线调度器时的通信重叠,默认为False。 +- weight(字典): + 1. size: 整数,权重并行的大小。 + 2. overlap: 布尔值,启用/禁用all_gather/reduce_scatter通信重叠,默认为False。 + 3. memory_pool: 布尔值,启用/禁用内存池,默认为False。 注意:`数据并行大小 = 总的 GPU 数目 / 流水线并行大小 / 张量并行大小` @@ -370,6 +459,31 @@ $ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py - 2023-07-07 12:29:16,994 INFO train.py:323 in record_current_batch_training_metrics -- tflops=189.3109313713174,step=5,loss=9.822169303894043,tgs (tokens/gpu/second)=4262.67,lr=1.4000000000000001e-06,loss_scale=65536.0,grad_norm=47.10386835560855,micro_num=4,num_consumed_tokens=786432,inf_nan_skip_batches=0,num_samples_in_batch=17,largest_length=2048,largest_batch=6,smallest_batch=3,adam_beta2=0.95,fwd_bwd_time=3.69 ``` +### 加载训练的checkpoint并生成 + +若在 slurm 上启动分布式运行环境,多节点 16 卡的运行命令如下所示: +```bash +$ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python generate.py --config ./configs/7B_sft.py +``` + +在配置文件中添加`generation`配置 +``` +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) +``` + ### 长文本生成 在推理阶段,我们可以使用 Dynamic NTK RoPE 来代替原始的 RoPE,从而使得模型能够适应长文本的输入输出,达到 16K 的外推效果。 diff --git a/generate.py b/generate.py new file mode 100644 index 00000000..4ae76029 --- /dev/null +++ b/generate.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import gc +import json +import logging +import os +import shutil +import socket +import traceback +from pathlib import Path + +import numpy as np +import torch +from tqdm import tqdm + +from internlm.accelerator import get_accelerator +from internlm.apis.inference import SequenceGenerator +from internlm.core.context import global_context as gpc +from internlm.data import build_generation_loader_with_data_type +from internlm.initialize import initialize_distributed_env +from internlm.monitor import initialize_monitor_manager +from internlm.monitor.monitor import monitor_manager as mm +from internlm.train import initialize_model, initialize_parallel_communicator +from internlm.utils.common import ( + enable_pytorch_expandable_segments, + launch_time, + parse_args, +) +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.storage_manager import init_storage_manager +from tools.load_internlm2_model import get_model_device, merge_pp_within_tp + +# global llm logger +logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() + + +def get_latest_subdirectory(folder_path): + if ":" in folder_path: + prefix, folder_path = folder_path.split(":", 1) + prefix += ":" + else: + prefix = "" + subdirectories = [name for name in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, name))] + subdirectories_sorted = sorted( + subdirectories, key=lambda x: os.path.getctime(os.path.join(folder_path, x)), reverse=True + ) + if subdirectories_sorted: + return prefix + os.path.join(folder_path, subdirectories_sorted[0]) + else: + return None + + +def main(): + enable_pytorch_expandable_segments() + + generation_config = gpc.config["generation"] + + generation_config = type( + "", + (object,), + { + "output_folder": Path(generation_config["output_folder"]), + "ckpt_folder": generation_config["ckpt_folder"] + if "ckpt_folder" in generation_config + else get_latest_subdirectory(gpc.config.ckpt.save_ckpt_folder), + "data_folder": generation_config["data_folder"] if "data_folder" in generation_config else None, + "batch_size": generation_config.get("batch_size", None), + "eos_id": generation_config.get("eos_id", 2), + "bos_id": generation_config.get("bos_id", 1), + "pad_id": generation_config.get("bos_id", 1), + "additional_eos_token_list": generation_config.get("additional_eos_token_list", None), + "max_length": generation_config.get("max_length", 100), + "do_sample": generation_config.get("do_sample", True), + "temperature": generation_config.get("temperature", 1.0), + "num_beams": generation_config.get("num_beams", 1), + "top_k": generation_config.get("top_k", 50), + "top_p": generation_config.get("top_p", 1.0), + "repetition_penalty": generation_config.get("repetition_penalty", 1), + "length_penalty": generation_config.get("length_penalty", 1.0), + }, + ) + + if not os.path.exists(generation_config.output_folder.absolute()): + generation_config.output_folder.mkdir(exist_ok=True, parents=True) + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + torch.distributed.broadcast_object_list(objs, src=0) + current_time = objs[0].replace(":", ".") + global logger + logger = get_logger( + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + ) + + try: + init_storage_manager(False, None, None) + except AssertionError: + pass + except Exception as e: + raise e + + # initialize model + model = initialize_model() + _ = initialize_parallel_communicator(model) + model = model.model + + state_dict = merge_pp_within_tp(generation_config.ckpt_folder, del_model_prefix=True) + missing_k, unexpected_keys = model.load_state_dict(state_dict, strict=False) + if len(missing_k) != 0: + logger.warning(f"Warning: missing keys {missing_k}") + if len(unexpected_keys) != 0: + logger.warning(f"Warning: unexpected keys {unexpected_keys}") + + param_dtype = gpc.config.model.dtype + if isinstance(param_dtype, str): + try: + param_dtype = eval(param_dtype) # pylint: disable=W0123 + finally: + pass + if param_dtype == "torch.tf32": + param_dtype = torch.float32 + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + model.to(param_dtype) + model.eval() + torch.distributed.barrier() + + data_cfg = gpc.config.data + if generation_config.data_folder: + data_cfg.valid_folder = generation_config.data_folder + gene_dls = build_generation_loader_with_data_type(data_cfg, generation_config) + + sequenece_generator = SequenceGenerator( + decoder=model, + eos_token_id=generation_config.eos_id, + pad_token_id=generation_config.bos_id, + bos_token_id=generation_config.pad_id, + additional_eos_token_list=generation_config.additional_eos_token_list, + ) + + ds_count = 0 + gc.disable() + with torch.inference_mode(): + for ds_name, gene_dl in gene_dls.items(): + if len(gene_dl) == 0: + logger.info(f"Validation dataset: {ds_name} is empty") + continue + timer(f"dataset {ds_count}").start() + + # pylint: disable=forgotten-debug-statement + all_output_str = [] + # pylint: disable=unused-variable + for val_idx, (labels, input_ids) in tqdm( + enumerate(gene_dl), + desc="generate.", + total=len(gene_dl), + position=1, + leave=False, + ): + empty_cache_and_diag(val_idx, interval=gpc.config.data.empty_cache_and_diag_interval) + input_ids = torch.LongTensor(input_ids) + if input_ids.size(1) >= generation_config.max_length: + logger.warning( + f"Not generating for the {val_idx}'th batch, because the sequence " + f"length of the batch is {input_ids.size(1)} over the max generation" + f"length {generation_config.max_length}" + ) + output_ids = input_ids[:, : generation_config.max_length, ...] + else: + input_ids = input_ids.clamp(min=0, max=gpc.config.model.vocab_size).to(get_model_device(model)) + output_ids = sequenece_generator.generate( + tokens=input_ids, + max_length=generation_config.max_length, + do_sample=generation_config.do_sample, + temperature=generation_config.temperature, + num_beams=generation_config.num_beams, + top_k=generation_config.top_k, + top_p=generation_config.top_p, + repetition_penalty=generation_config.repetition_penalty, + length_penalty=generation_config.length_penalty, + ) + for output in output_ids: + not_pad_indices = torch.nonzero(output != generation_config.pad_id) + if not_pad_indices.nelement() != 0: + sequence = output[not_pad_indices[0] :] + else: + sequence = output + sequence = sequence.tolist() + line = str.encode(json.dumps({"tokens": sequence})) + all_output_str.append( + ( + line, + len(line), + ) + ) + + bin_meta, last_position = [], 0 + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin"), "wb") as file: + for line, token_num in all_output_str: + file.write(line) + bin_meta.append((last_position, token_num)) + last_position += len(line) + + with open(generation_config.output_folder.joinpath(f"{ds_name}.bin.meta"), "wb") as file: + np.save(file, bin_meta) + + timer(f"dataset {ds_count}").stop() + ds_count += 1 + + +if __name__ == "__main__": + args = parse_args() + hostname = socket.gethostname() + + # initialize distributed environment + initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) + assert hasattr(gpc, "config") and gpc.config is not None + assert "generation" in gpc.config, f"Please set `generation` config in `{args.config}` file" + assert ( + "output_folder" in gpc.config["generation"] + ), "Must set `output_folder` for the save folder of generation data" + + # initialize monitor manager context + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + try: + main() + except Exception: + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + mm.monitor_exception( + alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() + ) + + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + finally: + # local rank0 delete all files in shm_path, when use shm + devices_per_node = internlm_accelerator.device_count() + local_rank = gpc.get_global_rank() % devices_per_node + if gpc.config.data.use_shm and local_rank == 0: + if os.path.exists(gpc.config.data.shm_path): + shutil.rmtree(gpc.config.data.shm_path) diff --git a/internlm/apis/__init__.py b/internlm/apis/__init__.py index e69de29b..ba807b5e 100644 --- a/internlm/apis/__init__.py +++ b/internlm/apis/__init__.py @@ -0,0 +1,6 @@ +from .inference_utils import InferenceParams, process_parallel_output + +__all__ = [ + "InferenceParams", + "process_parallel_output", +] diff --git a/internlm/apis/inference.py b/internlm/apis/inference.py index 7a51e34d..d3b5de87 100644 --- a/internlm/apis/inference.py +++ b/internlm/apis/inference.py @@ -1,48 +1,18 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import List, Tuple +from typing import Dict, List, Tuple, Union import torch import torch.nn.functional as F from torch import nn -__all__ = ["SequenceGenerator"] - - -class InferenceParams: - """ - Intermediate cache objects for inference - """ +from internlm.apis import InferenceParams, process_parallel_output +from internlm.core.context import ParallelMode # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.trainer import Trainer - def __init__( - self, - max_sequence_len, - max_batch_size, - sequence_len_offset=0, - batch_size_offset=0, - key_value_memory_dict: dict = None, - lengths_per_sample=None, - attention_mask=None, - ) -> None: - - self.max_sequence_len: int = max_sequence_len - self.max_batch_size: int = max_batch_size - self.sequence_len_offset: int = sequence_len_offset - self.batch_size_offset: int = batch_size_offset - if key_value_memory_dict is None: - key_value_memory_dict = {} - self.key_value_memory_dict: dict = key_value_memory_dict - self.fused_ft_kernel: bool = False - self.lengths_per_sample = lengths_per_sample - self.attention_mask = attention_mask - - def reorder_state(self, indices): - if self.lengths_per_sample is not None: - self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0) - for key, value in list(self.key_value_memory_dict.items()): - value = value.index_select(index=indices, dim=0) - self.key_value_memory_dict[key] = value +__all__ = ["SequenceGenerator"] def _get_model_device(model): @@ -357,17 +327,8 @@ def _streaming_no_beam_search_generate( eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) + if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, @@ -379,7 +340,16 @@ def _streaming_no_beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = torch.cat(model_output, dim=0) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") if isinstance(scores, (list, tuple)): scores = scores[0] @@ -401,19 +371,20 @@ def _streaming_no_beam_search_generate( while cur_len < real_max_length: # batch_size x vocab_size - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) + + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = torch.cat(model_output, dim=0) else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) - inference_params.attention_mask = attention_mask - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") if isinstance(scores, (list, tuple)): scores = scores[0] @@ -502,17 +473,9 @@ def _no_beam_search_generate( eos_token_id = torch.LongTensor(eos_token_id).to(tokens.device) has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id) + if inference_params is None: inference_params = InferenceParams( max_sequence_len=max_length, @@ -524,75 +487,104 @@ def _no_beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() - inference_params.sequence_len_offset += tokens.size(1) - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if eos_token_id is not None: + scores[:, eos_token_id] = -1e12 - # The first token generated. - next_tokens = scores.argmax(dim=-1, keepdim=True) + # The first token generated. + next_tokens = scores.argmax(dim=-1, keepdim=True) + else: + next_tokens = tokens.new_zeros([batch_size, 1]) + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1) + inference_params.sequence_len_offset += tokens.size(1) + real_max_length = max_length max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) while cur_len < real_max_length: # batch_size x vocab_size - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) + + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) - inference_params.attention_mask = attention_mask - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() inference_params.sequence_len_offset += 1 - - if repetition_penalty != 1.0: - token_scores = scores.gather(dim=1, index=token_ids) - lt_zero_mask = token_scores.lt(0).float() - ge_zero_mask = lt_zero_mask.eq(0).float() - token_scores = ( - lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores - ) - scores.scatter_(dim=1, index=token_ids, src=token_scores) - # scores: [bsz, vocab_size] - if eos_token_id is not None and length_penalty != 1.0: - # batch_size x vocab_size - eos_token_scores = scores[:, eos_token_id].clone() - scores = scores / cur_len**length_penalty - scores[:, eos_token_id] = eos_token_scores - del eos_token_scores - - if do_sample: - if temperature > 0 and temperature != 1: - scores = scores / temperature - - scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) - # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 - probs = F.softmax(scores, dim=-1) + 1e-12 - - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = ( + lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + ) + scores.scatter_(dim=1, index=token_ids, src=token_scores) + # scores: [bsz, vocab_size] + if eos_token_id is not None and length_penalty != 1.0: + # batch_size x vocab_size + eos_token_scores = scores[:, eos_token_id].clone() + scores = scores / cur_len**length_penalty + scores[:, eos_token_id] = eos_token_scores + del eos_token_scores + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=2) + # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) # batch_size + else: + next_tokens = torch.argmax(scores, dim=-1) # batch_size else: - next_tokens = torch.argmax(scores, dim=-1) # batch_size - + next_tokens = tokens.new_zeros(batch_size) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) if eos_token_id is not None: # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len + 1), eos_token_id[0]) @@ -640,7 +632,7 @@ def _beam_search_generate( bos_token_id=1, ) -> torch.LongTensor: - device = _get_model_device(decoder) + device = tokens.device batch_size = tokens.size(0) if eos_token_id is not None: @@ -654,19 +646,7 @@ def _beam_search_generate( has_bos = torch.all(tokens[:, 0].eq(bos_token_id)) - if has_bos: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - else: - bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(tokens, has_bos, bos_token_id=bos_token_id) if inference_params is None: inference_params = InferenceParams( @@ -679,29 +659,56 @@ def _beam_search_generate( attention_mask=attention_mask, ) - scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + if isinstance(decoder, torch.nn.Module): + scores = decoder(**{"input_ids": tokens, "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + data = {"input_ids": tokens, "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True + ) + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() inference_params.sequence_len_offset += tokens.size(1) - if eos_token_id is not None: - scores[:, eos_token_id] = -1e12 - vocab_size = scores.size(1) - assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." - # The first token generated. - if do_sample: - probs = F.softmax(scores, dim=-1) + 1e-12 - # (batch_size, num_beams) - next_tokens = torch.multinomial(probs, num_samples=num_beams) - logits = probs.log() - # (batch_size, num_beams) - next_scores = logits.gather(dim=1, index=next_tokens) + if gpc.is_last_rank(ParallelMode.PIPELINE): + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if eos_token_id is not None: + scores[:, eos_token_id] = -1e12 + vocab_size = scores.size(1) + assert vocab_size >= num_beams, "num_beams should be smaller than " "the number of vocabulary size." + + # The first token generated. + if do_sample: + probs = F.softmax(scores, dim=-1) + 1e-12 + # (batch_size, num_beams) + next_tokens = torch.multinomial(probs, num_samples=num_beams) + logits = probs.log() + # (batch_size, num_beams) + next_scores = logits.gather(dim=1, index=next_tokens) + else: + scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) + # obtain (batch_size, num_beams), (batch_size, num_beams) + next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) else: - scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) - # obtain (batch_size, num_beams), (batch_size, num_beams) - next_scores, next_tokens = torch.topk(scores, num_beams, dim=1, largest=True, sorted=True) + next_tokens = tokens.new_zeros([batch_size, num_beams]) + next_scores = torch.zeros([batch_size, num_beams], dtype=torch.float32, device=next_tokens.device) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + next_scores, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) @@ -726,79 +733,102 @@ def _beam_search_generate( batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < real_max_length: - if has_bos: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - bos_sum = bos_pos.cumsum(dim=-1) - bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - else: - bos_pos = torch.where(token_ids.eq(bos_token_id), 1, 0) - to_atten_x = bos_pos[:, :, None] - to_atten_y = bos_pos[:, None, :] - # attention_mask = torch.einsum('bno,bom->bnm', to_atten_x, to_atten_y).eq(1) - attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + attention_mask = get_attention_mask(token_ids, has_bos, bos_token_id=bos_token_id) - inference_params.attention_mask = attention_mask # (bsz x num_beams, vocab_size) - scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) - - if isinstance(scores, (list, tuple)): - scores = scores[0] - scores = scores[:, -1].float() - inference_params.sequence_len_offset += 1 - if repetition_penalty != 1.0: - token_scores = scores.gather(dim=1, index=token_ids) - lt_zero_mask = token_scores.lt(0).float() - ge_zero_mask = lt_zero_mask.eq(0).float() - token_scores = ( - lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + if isinstance(decoder, torch.nn.Module): + inference_params.attention_mask = attention_mask + scores = decoder(**{"input_ids": token_ids[:, -1:], "inference_params": inference_params}) + elif isinstance(decoder, Trainer): + inference_params.set_attention_mask(attention_mask) + data = {"input_ids": token_ids[:, -1:], "inference_params": inference_params} + model_output, _, _ = decoder.execute_schedule( + (data, None), forward_only=True, return_loss=False, return_output_label=True ) - scores.scatter_(dim=1, index=token_ids, src=token_scores) - - if eos_token_id is not None: - max_len_eos_mask = max_lengths.eq(cur_len + 1) - # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. - eos_scores = scores[:, eos_token_id[0]] - scores[:, eos_token_id[0]] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores) - - if do_sample: - if temperature > 0 and temperature != 1: - scores = scores / temperature - - scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) - # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 - probs = F.softmax(scores, dim=-1) + 1e-12 + scores = process_parallel_output(model_output) + else: + raise NotImplementedError(f"Unsupported decoder type: {type(decoder)}") - # batch_size' x (num_beams+1) - _tokens = torch.multinomial(probs, num_samples=num_beams + 1) + inference_params.sequence_len_offset += 1 - logits = probs.log() - # batch_size' x (num_beams+1) - _scores = logits.gather(dim=1, index=_tokens) - # batch_size' x (num_beams+1) - _scores = _scores + beam_scores[:, None] - _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) - next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) - _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) - # (batch_size, 2*num_beams) - next_tokens = _tokens.gather(dim=1, index=ids) - # (batch_size, 2*num_beams) - from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() + if gpc.is_last_rank(ParallelMode.PIPELINE): + + if isinstance(scores, (list, tuple)): + scores = scores[0] + scores = scores[:, -1].float() + if repetition_penalty != 1.0: + token_scores = scores.gather(dim=1, index=token_ids) + lt_zero_mask = token_scores.lt(0).float() + ge_zero_mask = lt_zero_mask.eq(0).float() + token_scores = ( + lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores + ) + scores.scatter_(dim=1, index=token_ids, src=token_scores) + + if eos_token_id is not None: + max_len_eos_mask = max_lengths.eq(cur_len + 1) + # When the generated result exceeds the length, its eos_token_id is set to the most basic terminator. + eos_scores = scores[:, eos_token_id[0]] + scores[:, eos_token_id[0]] = torch.where(max_len_eos_mask, eos_scores + 1e32, eos_scores) + + if do_sample: + if temperature > 0 and temperature != 1: + scores = scores / temperature + + scores = top_k_top_p_filtering(scores, top_k, top_p, min_tokens_to_keep=num_beams + 1) + # add 1e-12 to avoid https://github.com/pytorch/pytorch/pull/27523 + probs = F.softmax(scores, dim=-1) + 1e-12 + + # batch_size' x (num_beams+1) + _tokens = torch.multinomial(probs, num_samples=num_beams + 1) + + logits = probs.log() + # batch_size' x (num_beams+1) + _scores = logits.gather(dim=1, index=_tokens) + # batch_size' x (num_beams+1) + _scores = _scores + beam_scores[:, None] + _scores = _scores.view(batch_size, num_beams * (num_beams + 1)) + next_scores, ids = _scores.topk(2 * num_beams, dim=1, largest=True, sorted=True) + _tokens = _tokens.view(batch_size, num_beams * (num_beams + 1)) + # (batch_size, 2*num_beams) + next_tokens = _tokens.gather(dim=1, index=ids) + # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / (num_beams + 1)).long() + else: + # (batch_size * num_beams, vocab_size) + scores = F.log_softmax(scores, dim=-1) + # (batch_size * num_beams, vocab_size) + _scores = scores + beam_scores[:, None] + # (batch_size, num_beams*vocab_size) + _scores = _scores.view(batch_size, -1) + # (bsz, 2*num_beams) + next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) + # (batch_size, 2*num_beams) + from_which_beam = torch.floor(ids.float() / vocab_size).long() + next_tokens = ids % vocab_size # (batch_size, 2*num_beams) else: - # (batch_size * num_beams, vocab_size) - scores = F.log_softmax(scores, dim=-1) - # (batch_size * num_beams, vocab_size) - _scores = scores + beam_scores[:, None] - # (batch_size, num_beams*vocab_size) - _scores = _scores.view(batch_size, -1) - # (bsz, 2*num_beams) - next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) - # (batch_size, 2*num_beams) - from_which_beam = torch.floor(ids.float() / vocab_size).long() - next_tokens = ids % vocab_size # (batch_size, 2*num_beams) + next_tokens = tokens.new_zeros([batch_size, 2 * num_beams]) + next_scores = torch.zeros([batch_size, 2 * num_beams], dtype=torch.float32, device=next_tokens.device) + from_which_beam = torch.zeros([batch_size, 2 * num_beams], dtype=torch.int64, device=next_tokens.device) + + if gpc.is_initialized(ParallelMode.PIPELINE): + # broadcast to other rank in PP group + torch.distributed.broadcast( + next_tokens, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + next_scores, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) + torch.distributed.broadcast( + from_which_beam, + src=gpc.get_ranks_in_group(ParallelMode.PIPELINE)[-1], + group=gpc.get_group(ParallelMode.PIPELINE), + ) not_eos_mask = torch.all(next_tokens[..., None].ne(eos_token_id), dim=-1) keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) @@ -964,3 +994,128 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("Inf") indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = filter_value return logits + + +@torch.no_grad() +def get_attention_mask(tokens, has_bos, bos_token_id=1): + if has_bos: + bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) + bos_sum = bos_pos.cumsum(dim=-1) + bos_pos = torch.where(bos_sum.eq(bos_sum[:, -1:]), 0, 1) + to_atten_x = bos_pos[:, :, None] + to_atten_y = bos_pos[:, None, :] + else: + bos_pos = torch.where(tokens.eq(bos_token_id), 1, 0) + to_atten_x = bos_pos[:, :, None] + to_atten_y = bos_pos[:, None, :] + # attention_mask = torch.logical_or(to_atten_x, to_atten_y).eq(1) + to_atten_y_new = to_atten_y.repeat(1, to_atten_x.shape[1], 1) + to_atten_x_new = to_atten_x.repeat(1, 1, to_atten_y.shape[2]) + attention_mask = torch.logical_or(to_atten_x_new, to_atten_y_new).eq(1) + + return attention_mask + + +def batch_tokenize_process_fn( + batch: Union[List[str], List[Dict], Dict], tokenizer, add_bos: bool = True, add_eos: bool = False +) -> Union[List, Dict]: + """Data post-processing function for tokenize. + + This function can be directly used in the map function of ``DatasetDict`` and supports batched=True. + + Args: + batch (Union[List[str], List[Dict], Dict]): Data used to tokenize which can be of the following + categories: + (a) A list whose content can be a string or a dictionary. If it is a dictionary, + it needs to contain the "content" field; + (b) A dictionary-like object, which should contain the "content" field. + tokenizer : Currently only sentencepiece is supported. + add_bos (bool, optional): Whether to add bos token. Defaults to True. + add_eos (bool, optional): Whether to add eos token. Defaults to False. + + Returns: + Union[List, Dict]: tokenized data. + """ + + def _tokenize(text): + tokens = [tokenizer.bos_id()] if add_bos else [] + tokens += tokenizer.encode(text) + if add_eos: + tokens.append(tokenizer.eos_id()) + return tokens + + if isinstance(batch, (List, Tuple)): + if len(batch) == 0: + return None + if isinstance(batch[0], str): + return [_tokenize(w) for w in batch] + if isinstance(batch[0], Dict): + for sample in batch: + sample["input_ids"] = _tokenize(sample["content"]) + return batch + elif isinstance(batch, str): + raise NotImplementedError("Do not support a single str as input.") + else: + try: + batch["input_ids"] = [_tokenize(w) for w in batch["content"]] + batch.pop("content") + return batch + except Exception as e: + print(f"The type of parameter ``batch`` is wrong, type:{type(batch)}, batch: {batch}.") + raise e + + +def pad_input_ids(batch: List[Dict], pad_token_id: int = 0, return_dict: bool = False) -> Union[Dict, torch.Tensor]: + """Tokenize a list of prompts with Left Padding. + + Args: + batch (List[Dict, List]): if batch[0] is a dict, then key 'input_ids' must exist, + and value must be a list of integers. + pad_token_id (int, optional): Defaults to 0. + return_dict (bool, optional): Defaults to False. + + Returns: + Union[Dict, torch.Tensor]: input_ids or dict(input_ids=input_ids) + """ + assert isinstance(batch, list), "batch must be a list" + + input_ids = [] + max_length = max([len(w["input_ids"] if isinstance(w, Dict) else w) for w in batch]) + for sample in batch: + cur_input_ids = sample["input_ids"] if isinstance(sample, Dict) else sample + assert len(cur_input_ids) > 0, "got empty list" + assert isinstance(cur_input_ids[0], int), f"only support a list of integers, but got {type(cur_input_ids[0])}" + cur_input_ids = torch.LongTensor(cur_input_ids) + # left padding for generation + input_ids.append( + torch.cat( + [ + cur_input_ids.new_full((max_length - len(cur_input_ids),), fill_value=pad_token_id), + cur_input_ids, + ] + ) + ) + input_ids = torch.stack(input_ids) + return input_ids if not return_dict else {"input_ids": input_ids} + + +def batch_tokenize( + prompts: List[str], tokenizer, return_dict: bool = False, pad_token_id: int = 1 +) -> Union[Dict, torch.Tensor]: + """Tokenize a list of prompts with Left Padding. Return the tokens. + + Args: + prompts (List[str]): a list of prompts + tokenizer : Currently only sentencepiece is supported. + return_dict (bool, optional): Defaults to False. + pad_token_id (int, optional): Defaults to 1. + + Returns: + Union[Dict, torch.Tensor]: input_ids or dict(input_ids=input_ids) + """ + + tokenizer_out = batch_tokenize_process_fn(prompts, tokenizer) + + tokens = pad_input_ids(tokenizer_out, return_dict=return_dict, pad_token_id=pad_token_id) + + return tokens diff --git a/internlm/apis/inference_utils.py b/internlm/apis/inference_utils.py new file mode 100644 index 00000000..423e7aaf --- /dev/null +++ b/internlm/apis/inference_utils.py @@ -0,0 +1,69 @@ +import torch + +from internlm.core.context import ParallelMode # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.parallel.comm.utils import _gather as gather + + +class InferenceParams: + """ + Intermediate cache objects for inference + """ + + def __init__( + self, + max_sequence_len, + max_batch_size, + sequence_len_offset=0, + batch_size_offset=0, + key_value_memory_dict: dict = None, + lengths_per_sample=None, + attention_mask=None, + window_size=None, + ) -> None: + + self.max_sequence_len: int = max_sequence_len + self.max_batch_size: int = max_batch_size + self.sequence_len_offset: int = sequence_len_offset + self.batch_size_offset: int = batch_size_offset + if key_value_memory_dict is None: + key_value_memory_dict = {} + self.key_value_memory_dict: dict = key_value_memory_dict + self.fused_ft_kernel: bool = False + self.lengths_per_sample = lengths_per_sample + self.attention_mask = attention_mask + self.full_attention_mask = attention_mask + self.window_size = window_size + + def reorder_state(self, indices): + if self.lengths_per_sample is not None: + self.lengths_per_sample = self.lengths_per_sample.index_select(index=indices, dim=0) + for key, value in list(self.key_value_memory_dict.items()): + value = value.index_select(index=indices, dim=0) + self.key_value_memory_dict[key] = value + + def set_batch_offset(self, offset, bsz): + """Called by `BaseScheduler._load_micro_batch`. + when micro-batch is enabled, the working attention mask is only a view of `full_attention_mask` + """ + self.batch_size_offset = offset + self.attention_mask = self.full_attention_mask[offset : offset + bsz] + + def set_attention_mask(self, mask): + """useful when generate using Engine/trainer rather than directly using model""" + self.full_attention_mask = mask + + +def process_parallel_output(model_output): + # 1. concat + if gpc.is_last_rank(ParallelMode.PIPELINE): + if not isinstance(model_output, torch.Tensor): + model_output = torch.cat(model_output, dim=0) + else: + return None + + # gather tp parallel output + if gpc.config.model.parallel_output and gpc.is_initialized(ParallelMode.TENSOR): + return gather(model_output, ParallelMode.TENSOR, -1) + else: + return model_output diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 5be788af..122d4cf6 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -17,7 +17,7 @@ ckpt_info_sanity_check, ) from internlm.monitor import send_alert_message -from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff +from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -210,8 +210,8 @@ def try_load_internlm_ckpt_func(ckpt_mm, load_info, *args, func=None, **kwargs): load_content_str += f"{CheckpointLoadContent.MODEL}, " internlm_accelerator.synchronize() - if isinstance(ckpt_mm.optimizer, HybridZeroOptimizer): - reload_zero_fp32_buff(ckpt_mm.optimizer) + if isinstance(ckpt_mm.optimizer, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): + ckpt_mm.optimizer.reload_zero_fp32_buff() class CheckpointManager: @@ -552,9 +552,10 @@ def try_resume_training(self, train_state: TrainState, current_time=""): # If we only load model weight, we need rewrite zero optim's fp32 buffer. if ( - load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer) + "optimizer" not in load_content.load_set + and isinstance(self.optimizer, (HybridZeroOptimizer, HybridZeroOptimizer_v2)) ) or gpc.config.get("only_load_lr", False): - reload_zero_fp32_buff(self.optimizer) + self.optimizer.reload_zero_fp32_buff() if gpc.is_rank_for_log(): logger.info(f"load_ckpt_info : {self.load_ckpt_info}") diff --git a/internlm/checkpoint/components.py b/internlm/checkpoint/components.py index 2158a5bb..25435d3c 100644 --- a/internlm/checkpoint/components.py +++ b/internlm/checkpoint/components.py @@ -11,7 +11,7 @@ from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState from internlm.model.moe.moe import MoE -from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.solver.optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.parallel import is_using_isp @@ -345,7 +345,7 @@ def load_optimizer_checkpoint(folder, optim): states = llm_load(os.path.join(folder, fp), map_location=get_current_device()) - if isinstance(optim, HybridZeroOptimizer): + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): fp_meta = os.path.join(folder, optim.rank_unique_id) try: zero_devide_optim_plan = llm_load(fp_meta) @@ -393,7 +393,7 @@ def save_optimizer_checkpoint(optim, state_path): dp_size = gpc.get_world_size(ParallelMode.DATA) states = optim.state_dict() - if isinstance(optim, HybridZeroOptimizer): + if isinstance(optim, (HybridZeroOptimizer, HybridZeroOptimizer_v2)): if is_using_isp(): fp = f"optimizer_tp{tp_rank}_wp{wp_rank}_pp{pp_rank}_dp{dp_rank}.pt" llm_save(os.path.join(state_path, fp), states) diff --git a/internlm/checkpoint/load_funcs.py b/internlm/checkpoint/load_funcs.py index ee4ed472..423695ad 100644 --- a/internlm/checkpoint/load_funcs.py +++ b/internlm/checkpoint/load_funcs.py @@ -6,9 +6,10 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.solver.pipeline_utils import partition_uniform +from internlm.core.parallel.shard import partition_uniform from internlm.utils.logger import get_logger from internlm.utils.storage_manager import get_fns, llm_load +from transformers import AutoModelForCausalLM logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -147,12 +148,6 @@ def load_hf_llama_pretrained_weights(folder, model): if f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq" in states: states.pop(f"model.layers.{layer_ids}.self_attn.rotary_emb.inv_freq") - if gpc.config.model_type in ("LLAMA2",): - w2 = states.pop(f"layers.{i}.feed_forward.w2.weight") - w3 = states.pop(f"layers.{i}.feed_forward.w3.weight") - states[f"layers.{i}.feed_forward.w2.weight"] = w3 - states[f"layers.{i}.feed_forward.w3.weight"] = w2 - for name in list(states.keys()): if name.startswith(f"layers.{i}"): current_states[name.replace(f".{i}.", f".{idx}.")] = states.pop(name) @@ -304,8 +299,22 @@ def load_internlm_with_dynamic_parallel_size(folder, model): ) +def load_hf_model_pretrained_weights(folder, model): + """NOTE: when loading huggingface's model pretrained weights, you should set `adapt_hf=True` in your config.""" + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + pretrained_model = AutoModelForCausalLM.from_pretrained(folder, trust_remote_code=True) + model.load_state_dict(pretrained_model.state_dict(), strict=False) + + if gpc.is_rank_for_log(): + logger.info("Pretrained weights loaded successfully") + + LOAD_FUNC_DICT = { "llama": load_llama_pretrained_weights, "hf_llama": load_hf_llama_pretrained_weights, "internlm_test": load_internlm_with_dynamic_parallel_size, + "hf_model": load_hf_model_pretrained_weights, } diff --git a/internlm/core/communication/utils.py b/internlm/core/communication/utils.py deleted file mode 100644 index 5d08327a..00000000 --- a/internlm/core/communication/utils.py +++ /dev/null @@ -1,231 +0,0 @@ -# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication - -from collections import OrderedDict -from typing import Dict, List, Tuple, Union - -import torch -import torch.distributed as dist -from torch import nn - -from internlm.core.communication.isp import ISPCommunicator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.modules.embedding import Embedding1D -from internlm.model.ops.linear import BaseScaleColumnParallelLinear -from internlm.utils.common import get_current_device - -TensorShape = Union[torch.Size, List[int], Tuple[int]] - - -def send_meta_helper(obj, next_rank, tensor_kwargs): - send_shape = torch.tensor(obj.size(), **tensor_kwargs) - send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) - dist.send(send_ndims, next_rank) - dist.send(send_shape, next_rank) - - -def send_obj_meta(obj, next_rank=None): - """Sends obj meta information before sending a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be sent before communications. This function - synchronizes with :func:`recv_obj_meta`. - - Args: - obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. - need_meta (bool, optional): If False, meta information won't be sent. - next_rank (int): The rank of the next member in pipeline parallel group. - - Returns: - bool: False - """ - if next_rank is None: - next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} - if isinstance(obj, torch.Tensor): - send_obj_nums = torch.tensor(1, **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - send_meta_helper(obj, next_rank, tensor_kwargs) - else: - send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) - dist.send(send_obj_nums, next_rank) - for tensor_to_send in obj: - send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) - - -def recv_meta_helper(prev_rank, tensor_kwargs): - recv_ndims = torch.empty((), **tensor_kwargs) - dist.recv(recv_ndims, prev_rank) - recv_shape = torch.empty(recv_ndims, **tensor_kwargs) - dist.recv(recv_shape, prev_rank) - return recv_shape - - -def recv_obj_meta(prev_rank=None) -> torch.Size: - """Receives obj meta information before receiving a specific obj. - Since the recipient must know the shape of the obj in p2p communications, - meta information of the obj should be received before communications. This function - synchronizes with :func:`send_obj_meta`. - - Args: - obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. - prev_rank (int): The rank of the source of the obj. - - Returns: - Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. - """ - if prev_rank is None: - prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) - - tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} - recv_obj_nums = torch.empty((), **tensor_kwargs) - dist.recv(recv_obj_nums, prev_rank) - if recv_obj_nums.item() == 1: - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape = torch.Size(recv_shape) - else: - obj_shape = [] - for _ in range(recv_obj_nums.item()): - recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) - obj_shape.append(torch.Size(recv_shape)) - - return obj_shape - - -def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: - """Break a tensor into equal 1D chunks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be split before communication. - new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. - - Returns: - :class:`torch.Tensor`: The split tensor - """ - partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR) - start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR) - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: - """Opposite of above function, gather values from model parallel ranks. - - Args: - tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. - Returns: - :class:`torch.Tensor`: The gathered tensor. - """ - world_size = gpc.get_world_size(ParallelMode.TENSOR) - numel = torch.numel(tensor) - numel_gathered = world_size * numel - gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) - chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] - dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) - return gathered - - -class ParamAsyncBcastHandler: - """ - Model Partition Handler for overlap broadcast with forward - """ - - def __init__( - self, zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None - ) -> None: - self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() - self._param_to_rank: Dict[nn.Parameter, int] = {} - self._block_to_rank: Dict[nn.Module, int] = {} - self._bcast_handles: Dict[int, List[dist.Work]] = {} - - zero1_size = gpc.get_world_size(zero1_mode) - total_param_num = sum(p.numel() for p in model.parameters()) - avg_param_num = total_param_num * 1.0 // zero1_size - - # initialize an empty list for _bcast_handles of each rank - self._bcast_handles = {rank: [] for rank in range(zero1_size)} - - # just want to share same for loop for ModuleList and Module - if not isinstance(model, nn.ModuleList): - model = [model] - - # record the parameters to transformer/embeding/head/norm block - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model - - for _, children in _chunk.named_children(): - # should be the transformer block definaton in modeling_xxx.py - if isinstance(children, nn.ModuleList): - # record the block that a parameter belongs to - for _, block in enumerate(children): - # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) - self._block_to_param[block] = list(block.parameters()) - else: - # record the block that a parameter belongs to - # self._block_to_param[name] = list(children.parameters()) - self._block_to_param[children] = list(children.parameters()) - - alloc_num = 0 - rank_to_go = 0 - - # process the parameters in block_to_param sequencially, - # allocate each parameter to a local rank of ParallelMode.ZERO1, - # NOTE that we do NOT consider following scenarios: - # 1) whether a parameter is trainable; - # 2) paramters maybe in different optimizer group - for block, params in self._block_to_param.items(): - # allocate a model block to a local rank of ParallelMode.ZERO1 - self._block_to_rank[block] = [rank_to_go] - for p in params: - alloc_num = alloc_num + p.numel() - # in this case, allocate the param to next rank if possible - if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: - rank_to_go = rank_to_go + 1 - alloc_num = 0 - self._block_to_rank[block].append(rank_to_go) - # allocate a parameter to a local rank of ParallelMode.ZERO1 - self._param_to_rank[p] = rank_to_go - - # register_forward_pre_hook for transformer/embeding/norm/xxx block - self._register_sync_parameters_hook(isp_communicator) - - def _register_sync_parameters_hook(self, isp_communicator: ISPCommunicator = None) -> None: - def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613 - bcast_handles = [] - # gather all required broadcast hanles into a list - for rank in self._block_to_rank[model]: - bcast_handles.extend(self._bcast_handles[rank]) - # need to clear _bcast_handles since they would be processed later - self._bcast_handles[rank] = [] - # wait all required broadcast handles to be completed - for handle in bcast_handles: - handle.wait() - - # register_forward_pre_hook for transformer/embeding/norm/xxx block - for block, _ in self._block_to_rank.items(): - # TODO: remove special handling for embedding and head layers, - # instead implement support for weight parallelism of embedding and head layers within the ISP. - - # NOTE: Although the layernorm layer does not have explicit processing, - # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, - # so everything is fine. - - embedding_head_cls = (Embedding1D, BaseScaleColumnParallelLinear) - - if isp_communicator is None or isinstance(block, embedding_head_cls): - block.register_forward_pre_hook(_pre_forward_hook) - if isp_communicator: - isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) - - def get_rank_by_param(self, param) -> int: - return self._param_to_rank[param] - - def add_bcast_handle(self, rank, handle) -> None: - self._bcast_handles[rank].append(handle) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index b9739316..5989536d 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -93,6 +93,11 @@ def criterion(self): """Returns the criterion (loss function) attached to the engine.""" return self._criterion + @criterion.setter + def criterion(self, criterion): + """Sets the criterion (loss function).""" + self._criterion = criterion + def _all_reduce_gradients(self): """Handles all-reduce operations of gradients across different parallel groups.""" for handler in self._gradient_handlers: diff --git a/internlm/core/naive_amp.py b/internlm/core/naive_amp.py index 4e1427eb..7cac640d 100644 --- a/internlm/core/naive_amp.py +++ b/internlm/core/naive_amp.py @@ -4,7 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/tree/main/colossalai/amp from functools import partial -from typing import Any, Union +from typing import Any, List, Union import torch import torch.distributed as dist @@ -206,3 +206,10 @@ def _post_forward_hook_for_fp32( torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True sub_module.register_forward_pre_hook(partial(_pre_forward_hook_for_fp32)) + + +def unwrap_naive_amp(model: Union[nn.Module, nn.ModuleList]) -> List[nn.Module]: + if not isinstance(model, nn.ModuleList): + model = [model] + + return [_chunk.model if isinstance(_chunk, NaiveAMPModel) else _chunk for _chunk in model] diff --git a/internlm/model/llava_modules/__init__.py b/internlm/core/parallel/__init__.py similarity index 100% rename from internlm/model/llava_modules/__init__.py rename to internlm/core/parallel/__init__.py diff --git a/internlm/core/parallel/comm/__init__.py b/internlm/core/parallel/comm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/internlm/core/communication/isp.py b/internlm/core/parallel/comm/isp.py similarity index 69% rename from internlm/core/communication/isp.py rename to internlm/core/parallel/comm/isp.py index 2976899a..71dde3a3 100644 --- a/internlm/core/communication/isp.py +++ b/internlm/core/parallel/comm/isp.py @@ -1,18 +1,62 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +""" +communication for isp parallel. +""" +from abc import ABC, abstractmethod from functools import partial -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Tuple, Union import torch from torch import distributed as dist from torch import nn +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.naive_amp import NaiveAMPModel -from internlm.model.ops.linear import ISPLinear -from internlm.model.utils import all_gather_raw, reduce_scatter_raw +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm.utils import ( + DUMMY_HANDLE_CONST, + AsyncCommHandle, + all_gather_raw, + reduce_scatter_raw, +) +from internlm.model.modules.linear import ParallelLinearWithCommExt from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.utils import ( + CuSeqlenType, + QKVPackType, + check_attention_argument, + params_dispatch_with_condition, +) + + +# not really useful, only for code hint. +class WPCommunicator(ABC): + """ + Common communicator interface for weight parallel + """ + + @abstractmethod + def communication_mode(self) -> str: + """ + communication mode of communictor + """ + pass + + @abstractmethod + def weight_hook(self, tensor: torch.Tensor, async_op: bool = False, **kwargs) -> torch.Tensor: + """ + communication for weight when forward/backward. + """ + pass + + @abstractmethod + def grad_hook(self, tensor: torch.Tensor, async_op: bool = False, **kwargs) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad when backward. + """ + pass class ISPCommModelConfig: @@ -148,7 +192,7 @@ def __init__(self) -> None: self.bias_global_output: Dict[str, torch.Tensor] = {} -class ISPCommunicator: +class ISPCommunicator(WPCommunicator): """ ISP Communicator for managing the all-gather and reduce_scatter of Intern Sequence Parallel. """ @@ -195,16 +239,11 @@ def __init__( # init overlap states if necessary. if self.overlap: - # just want to share same for loop for modulelist and module. - model = model if isinstance(model, nn.ModuleList) else [model] # build overlap states for every chunk. - for chunk_id, chunk in enumerate(model): - if isinstance(chunk, NaiveAMPModel): - chunk = chunk.model + for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): self._parse_model_structure(chunk_id, chunk) - # register overlap hooks for every chunk. - for chunk_id in range(len(model)): self.switch_current_model_chunk(chunk_id) + # register overlap hooks for every chunk. self._register_sync_parameters_hook() # switch to chunk 0 at first. self.switch_current_model_chunk(0) @@ -219,8 +258,12 @@ def __init__( def _parse_model_structure(self, cid: int, model: nn.Module) -> None: self._overlap_states[cid] = ISPOverlapState() + def get_model(obj: nn.Module) -> nn.Module: + return get_model(obj.model) if hasattr(obj, "model") else obj + # Important: only works for llama-class models - for _, children in model.named_children(): + children_name = get_model(model).named_children() + for _, children in children_name: if isinstance(children, nn.ModuleList): self._overlap_states[cid].ckpt_block_num = int(self.model_conf.activation_checkpointing * len(children)) @@ -232,7 +275,7 @@ def _parse_model_structure(self, cid: int, model: nn.Module) -> None: if name in ["out_proj", "wo"]: self._overlap_states[cid].isp_outs.append(child) self._overlap_states[cid].module_to_index[child] = idx - if isinstance(child, ISPLinear): + if isinstance(child, ParallelLinearWithCommExt): if name not in self._module_shapes: origin_shape = tuple( [child.weight.shape[0] * gpc.weight_parallel_size] @@ -436,6 +479,9 @@ def _get_constant_zero(self, size: tuple) -> torch.Tensor: device=self.model_conf.device, ).contiguous() + def communication_mode(self) -> str: + return "wp" + def switch_current_model_chunk(self, chunk_id: int) -> None: self._isp_outs = self._overlap_states[chunk_id].isp_outs self._isp_modules = self._overlap_states[chunk_id].isp_modules @@ -478,44 +524,51 @@ def register_prerequisite_for_forward_prefetch_hooks(self, prerequisite_func: Ca # communication operation interfaces - def all_gather(self, tensor: torch.Tensor, module: nn.Module, is_bias: bool = False): + def weight_hook( + self, tensor: torch.Tensor, async_op: bool = False, module: nn.Module = None, is_bias: bool = False + ) -> torch.Tensor: if dist.get_world_size(self.process_group) <= 1: return tensor if not self.overlap: - result, _ = all_gather_raw(tensor, self.process_group, async_op=False) + result, _ = all_gather_raw(tensor, self.process_group, async_op=async_op) elif is_bias: + assert module is not None, "The module parameter must be specified" result = self._bias_global_output[module] else: + assert module is not None, "The module parameter must be specified" result = self._weight_global_output[module] return result - def reduce_scatter( + def grad_hook( self, tensor: torch.Tensor, - model: nn.Module, - op: dist.ReduceOp, + async_op: bool = False, + module: nn.Module = None, + reduce_op: dist.ReduceOp = dist.ReduceOp.AVG, is_bias: bool = False, - ): + ) -> Tuple[torch.Tensor, AsyncCommHandle]: if dist.get_world_size(self.process_group) <= 1: - return tensor, None + return tensor, DUMMY_HANDLE_CONST if not self.overlap: - result, handle = reduce_scatter_raw(tensor, self.process_group, op=op, async_op=True) + result, handle = reduce_scatter_raw(tensor, self.process_group, op=reduce_op, async_op=async_op) else: + assert module is not None, "The module parameter must be specified" + if is_bias: - assert hasattr(model.bias, "isp_reduce_scatter_name") - key = getattr(model.bias, "isp_reduce_scatter_name") + assert hasattr(module.bias, "isp_reduce_scatter_name") + key = getattr(module.bias, "isp_reduce_scatter_name") else: - assert hasattr(model.weight, "isp_reduce_scatter_name") - key = getattr(model.weight, "isp_reduce_scatter_name") + assert hasattr(module.weight, "isp_reduce_scatter_name") + key = getattr(module.weight, "isp_reduce_scatter_name") self.reduce_scatter_handlers[key] = reduce_scatter_raw( tensor, self.process_group, - op=op, - async_op=True, + op=reduce_op, + async_op=async_op, memory_pool_allocator=( self.memory_pool.allocate_reduce_scatter_memory if self.enable_memory_pool else None ), @@ -528,7 +581,7 @@ def reduce_scatter( *tensor.shape[1:], ) ), - None, + DUMMY_HANDLE_CONST, ) return result, handle @@ -543,33 +596,190 @@ def __init__(self, overlap_handler: ISPCommunicator, zero_optim) -> None: self._isp_communicator = overlap_handler self._zero_optim = zero_optim - def before_forward(self, scheduler, inputs) -> None: + def before_forward(self, scheduler, inputs) -> None: # pylint: disable=W0613 self._isp_communicator.is_forward = True # switch model chunk before forward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) - def after_forward(self, scheduler, outputs) -> None: + def after_forward(self, scheduler, outputs) -> None: # pylint: disable=W0613 pass - def before_criterion(self, scheduler, outputs, label) -> None: + def before_criterion(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass - def after_criterion(self, scheduler, loss) -> None: + def after_criterion(self, scheduler, loss) -> None: # pylint: disable=W0613 pass - def before_backward(self, scheduler, outputs, outputs_grad) -> None: + def before_backward(self, scheduler, outputs, outputs_grad) -> None: # pylint: disable=W0613 self._isp_communicator.is_forward = False # switch model chunk before backward chunk_id = 0 if gpc.virtual_pipeline_parallel_rank is None else gpc.virtual_pipeline_parallel_rank self._isp_communicator.switch_current_model_chunk(chunk_id) - def after_backward(self, scheduler, inputs_grad) -> None: + def after_backward(self, scheduler, inputs_grad) -> None: # pylint: disable=W0613 # accumulate left gradients in last bucket after backward. self._zero_optim.accumulate_left_grads_after_backward() # reset lazy memory pools for reduce scatter after every micro step. if self._isp_communicator and self._isp_communicator.enable_memory_pool: self._isp_communicator.memory_pool.reset_lazy_pools() - def post_helper_func(self, scheduler, outputs, label) -> None: + def post_helper_func(self, scheduler, outputs, label) -> None: # pylint: disable=W0613 pass + + +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py +class _SeqAllToAll(torch.autograd.Function): + "sequence alltoall function" + + @staticmethod + def forward(ctx, group: dist.ProcessGroup, input_: torch.Tensor, scatter_idx: int, gather_idx: int) -> torch.Tensor: + ctx.group = group + ctx.scatter_idx = scatter_idx + ctx.gather_idx = gather_idx + + if dist.get_world_size(group) <= 1: + return input_ + + seq_world_size = dist.get_world_size(group) + + input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] + output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] + # TODO: use all_to_all_single instead + dist.all_to_all(output_list, input_list, group=group) + return torch.cat(output_list, dim=gather_idx).contiguous() + + @staticmethod + def backward(ctx, *grad_output: torch.Tensor) -> Tuple[None, torch.Tensor, None, None]: + if dist.get_world_size(ctx.group) <= 1: + return (None, *grad_output, None, None) + + return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) + + +# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py +class DistributedAttention(nn.Module): + """Initialization. + + Arguments: + local_attention (Module): local self-attention module + sequence_process_group (ProcessGroup): sequence parallel process group + """ + + def __init__( + self, + local_attention: nn.Module, + sequence_process_group: dist.ProcessGroup, + ) -> None: + super().__init__() + self.local_attn = local_attention + self.spg = sequence_process_group + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self) -> torch.Tensor: + assert False, "Should never arrive" + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.WithOut))) + def _(self, qkv: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + qkv (Tensor): packed qkv input to the layer + kwargs: other args + + Returns: + * output (Tensor): context output + """ + # qkv shape: [1, packlen, 3, n_head, head_dim] or [batch, seqlen, 3, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + qkv = _SeqAllToAll.apply(self.spg, qkv, 3, 1) + + context = self.local_attn(qkv, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _(self, q: torch.Tensor, kv: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + q (Tensor): q input to the layer + kv (Tensor): packed kv input to the layer + kwargs: other args + + Returns: + output (Tensor): context output + """ + # q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + q = _SeqAllToAll.apply(self.spg, q, 2, 1) + # kv shape: [1, packlen, 2, n_head, head_dim] or [batch, seqlen, 2, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + kv = _SeqAllToAll.apply(self.spg, kv, 3, 1) + + context = self.local_attn(q, kv, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torch.Tensor: + """forward + + Arguments: + q (Tensor): q input to the layer + k (Tensor): k input to the layer + v (Tensor): v input to the layer + kwargs: other args + + Returns: + * output (Tensor): context output + """ + # self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] + # q shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + q = _SeqAllToAll.apply(self.spg, q, 2, 1) + # k shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + k = _SeqAllToAll.apply(self.spg, k, 2, 1) + # v shpae: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in n_head and gather in seqlen(packlen) + v = _SeqAllToAll.apply(self.spg, v, 2, 1) + + context = self.local_attn(q, k, v, **kwargs) + + # context shape: [1, packlen, n_head, head_dim] or [batch, seqlen, n_head, head_dim] + # scatter in seqlen(packlen) and gather in n_head + context = _SeqAllToAll.apply(self.spg, context, 1, 2) + + return context + + +def auto_wrap_distributed_attention(cls: nn.Module) -> Callable[[bool, Any, float], nn.Module]: + """ + Wrap a local attention module to a distributed one, which will be used in the ISP parallelism. + """ + + # should we impl distributed attention as a metaclass? + def _attetion_constructor( + local_attn_cls: type, causal=False, softmax_scale=None, attention_dropout=0.0 + ) -> nn.Module: + if gpc.config.parallel["tensor"].get("mode", "mtp") != "isp": + return local_attn_cls(causal, softmax_scale, attention_dropout) + else: + return DistributedAttention( + local_attention=local_attn_cls(causal, softmax_scale, attention_dropout), + sequence_process_group=gpc.get_group(ParallelMode.TENSOR), + ) + + return partial(_attetion_constructor, local_attn_cls=cls) diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py new file mode 100644 index 00000000..ca8c1900 --- /dev/null +++ b/internlm/core/parallel/comm/tensor.py @@ -0,0 +1,369 @@ +""" +communication for tensor/sequence parallel. +""" + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Tuple + +import torch +from torch import distributed as dist + +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.utils import ( + DUMMY_HANDLE_CONST, + AsyncCommHandle, + _gather, + _split, + all_gather_raw, + all_reduce_raw, + gather_forward_split_backward, + reduce_scatter_raw, + split_forward_gather_backward, +) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.moe.moe import MoE + +# input gather dim +_GATHER_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim] +_REDUCE_DIM = 1 # shape: [batch, seqlen, dim] or [1, packlen, dim] + + +class LinearRole(Enum): + COLUMN = "column" + ROW = "row" + + +# not really useful, only for code hint. +class TPCommunicator(ABC): + """ + Common communicator interafce for tensor/sequence parallel. + """ + + @abstractmethod + def save_total_input(self) -> bool: + """ + Should linear save total input after all gather as activation in sequence parallel. + """ + pass + + @abstractmethod + def communication_mode(self) -> str: + """ + communication mode of communictor + """ + pass + + @abstractmethod + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for input when forward/backward. + """ + pass + + @abstractmethod + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad_output when backward. + """ + pass + + @abstractmethod + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for grad_input when backward. + """ + pass + + @abstractmethod + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + communication for output when forward. + """ + pass + + +class TensorParallelCommunicator(TPCommunicator): + """ + tensor parallel communicator for linear + """ + + def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None: + assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" + + self._process_group = process_group + self._role = role + + self._save_total_input = False + + def save_total_input(self) -> bool: + return self._save_total_input + + def communication_mode(self) -> str: + return "tp" + + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + tensor parallel should do nothing for input. + """ + return _input, DUMMY_HANDLE_CONST + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + tensor parallel should do nothing for grad_output. + """ + return grad_output, DUMMY_HANDLE_CONST + + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all reduce grad_input only for column parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + return grad_input, DUMMY_HANDLE_CONST + + return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) + + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all reduce output only for row parallel linear when forward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return output, DUMMY_HANDLE_CONST + + return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) + + +class SequenceParallelCommunicator(TPCommunicator): + """ + sequence parallel communicator for linear + """ + + def __init__( + self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False + ) -> None: + assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" + + self._process_group = process_group + self._role = role + + self._save_total_input = save_total_input_as_activation + + def save_total_input(self) -> bool: + return self._save_total_input + + def communication_mode(self) -> str: + return "sp" + + def input_hook( + self, _input: torch.Tensor, async_op: bool = False, is_forward: bool = True + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather input only for column parallel linear when forward/backward. + """ + # 1. world_size <= 1 + # 2. row parallel linear should not allgather input. + # 3. column parallel linear should not allgather input if save_total_input_as_activation and backward is True. + if ( + dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.ROW + or (is_forward is False and self._save_total_input) + ): + return _input, DUMMY_HANDLE_CONST + + return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather grad_output only for row parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return grad_output, DUMMY_HANDLE_CONST + + return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) + + def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter grad_input only for column parallel linear when backward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.ROW: + return grad_input, DUMMY_HANDLE_CONST + + return reduce_scatter_raw( + grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM + ) + + def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + reduce scatter output only for row parallel linear when forward. + """ + if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + return output, DUMMY_HANDLE_CONST + + return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) + + +class HeadTensorParallelCommunicator(TensorParallelCommunicator): + """ + tensor parallel communicator for head linear + """ + + def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True) -> None: + super().__init__(process_group=gpc.get_group(parallel_mode), role=LinearRole.COLUMN) + + self._parallel_mode = parallel_mode + self._retain_out_sharded = retain_out_sharded + + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + split grad_output if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return grad_output, DUMMY_HANDLE_CONST + + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST + + def output_hook( + self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather output for head layer if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return output, DUMMY_HANDLE_CONST + + return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST + + +class HeadSequenceParallelCommunicator(SequenceParallelCommunicator): + """ + sequence parallel communicator for head linear + """ + + def __init__( + self, parallel_mode: ParallelMode, retain_out_sharded: bool = True, save_total_input_as_activation: bool = False + ) -> None: + super().__init__( + process_group=gpc.get_group(parallel_mode), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + + self._parallel_mode = parallel_mode + self._retain_out_sharded = retain_out_sharded + + # rewrite grad_output communication hook + def grad_output_hook( + self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + split grad_output if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return grad_output, DUMMY_HANDLE_CONST + + return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST + + # rewrite ouput communication hook + def output_hook( + self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, AsyncCommHandle]: + """ + all gather output for head layer if retain_out_sharded is False. + """ + if self._retain_out_sharded or dist.get_world_size(self._process_group) <= 1: + return output, DUMMY_HANDLE_CONST + + return _gather(output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST + + +class MoESequenceParallelCommunicator: + """ + sequence parallel communicator for moe layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: MoE) -> None: + assert isinstance(module, MoE), "MoE sequence parallel communicator is only support moe module" + + module.register_forward_pre_hook(self.input_hook, with_kwargs=True) + module.register_forward_hook(self.output_hook) + + def input_hook(self, module: MoE, args, kwargs) -> torch.Tensor: # pylint: disable=W0613 + """ + allgather input before forward and split grad_input after backward. + """ + _input = args[0] if len(args) > 0 else kwargs.pop("hidden_states") + _input = gather_forward_split_backward(_input, self._parallel_mode, dim=_GATHER_DIM) + + return (_input, *args), kwargs + + def output_hook(self, module: MoE, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _output, *_others = output + _output = split_forward_gather_backward(_output, self._parallel_mode, dim=_REDUCE_DIM) + + return (_output, *_others) + + +class EmbbedingTensorParallelCommunicator: + """ + tensor parallel communicator for embbeding layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: Embedding1D) -> None: + assert isinstance(module, Embedding1D), "Embbeding tensor parallel communicator is only support Embedding1D" + + module.register_forward_hook(self.output_hook) + + def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _emb_dim = 2 # [bsz, seqlen, emb_dim] + + return gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + + +class EmbbedingSequenceParallelCommunicator: + """ + sequence parallel communictor for embbeding layer + """ + + def __init__(self, parallel_mode: ParallelMode) -> None: + self._parallel_mode = parallel_mode + + def register_module_hook(self, module: Embedding1D) -> None: + assert isinstance(module, Embedding1D), "Embbeding sequence parallel communicator is only support Embedding1D" + + module.register_forward_hook(self.output_hook) + + def output_hook(self, module: Embedding1D, args: Any, output: Tuple[Any]) -> Tuple[Any]: # pylint: disable=W0613 + """ + split output after forward and allgather grad_output before backward. + """ + _emb_dim, _seq_dim = 2, 1 # [bsz, seqlen, emb_dim] + + output = gather_forward_split_backward(output, self._parallel_mode, dim=_emb_dim) + output = split_forward_gather_backward(output, self._parallel_mode, dim=_seq_dim) + + return output diff --git a/internlm/core/parallel/comm/utils.py b/internlm/core/parallel/comm/utils.py new file mode 100644 index 00000000..dbfeb3fd --- /dev/null +++ b/internlm/core/parallel/comm/utils.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from abc import ABC, abstractmethod +from typing import Callable + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed import ProcessGroup + +from internlm.core.context import global_context as gpc + + +class AsyncCommHandle(ABC): + """A interface for asynchronous communication handles.""" + + @abstractmethod + def wait(self) -> None: + """wait asynchronous communication to complete.""" + + +class DummyAsyncCommHandle(AsyncCommHandle): + """A fake communication handle used to maintain consistency in code writing""" + + def wait(self) -> None: + pass + + +DUMMY_HANDLE_CONST = DummyAsyncCommHandle() + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup, reduce_dim: int = 0) -> Tensor: + ctx.process_group = process_group + ctx.reduce_dim = reduce_dim + output, _ = reduce_scatter_raw(input_, process_group, reduce_dim=reduce_dim) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + gather_dim = ctx.reduce_dim + grad_input, _ = all_gather_raw(grad_output, ctx.process_group, gather_dim=gather_dim) + return grad_input, None, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + _ = ctx # avoid lint warning W0613 + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def _split(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # Split along last dimension. + dim_size = input_.size(dim) + assert dim_size % world_size == 0, ( + f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " + f"cannot split tensor evenly" + ) + + tensor_list = torch.split(input_, dim_size // world_size, dim=dim) + rank = gpc.get_local_rank(parallel_mode) + output = tensor_list[rank].contiguous() + output = output.detach().clone() + + return output + + +def _gather(input_, parallel_mode, dim=-1): + # skip if only one rank involved + world_size = gpc.get_world_size(parallel_mode) + if world_size == 1: + return input_ + + # all gather + rank = gpc.get_local_rank(parallel_mode) + tensor_list = [torch.empty_like(input_) for _ in range(world_size)] + tensor_list[rank] = input_ + group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) + dist.all_gather(tensor_list, input_, group=group) + + # concat + output = torch.cat(tensor_list, dim=dim).contiguous() + + return output + + +class _GatherForwardSplitBackward(torch.autograd.Function): + """Gather the input from model parallel region and concatenate. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(input_): + return _gather(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _gather(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _split(grad_output, ctx.mode, ctx.dim), None, None + + +def gather_forward_split_backward(input_, parallel_mode, dim): + return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) + + +class _SplitForwardGatherBackward(torch.autograd.Function): + """ + Split the input and keep only the corresponding chuck to the rank. + + Args: + input_: input matrix. + parallel_mode: parallel mode. + dim: dimension + """ + + @staticmethod + def symbolic(input_): + return _split(input_, parallel_mode=None) + + @staticmethod + def forward(ctx, input_, parallel_mode, dim): + ctx.mode = parallel_mode + ctx.dim = dim + return _split(input_, parallel_mode, dim) + + @staticmethod + def backward(ctx, grad_output): + return _gather(grad_output, ctx.mode, ctx.dim), None, None + + +def split_forward_gather_backward(input_, parallel_mode, dim): + return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) + + +def all_gather_raw( + input_: Tensor, + process_group: ProcessGroup, + async_op: bool = False, + gather_dim: int = 0, + memory_pool_allocator: Callable = None, +): + world_size = dist.get_world_size(process_group) + if world_size <= 1: + return input_, None + + if memory_pool_allocator is not None: + output = memory_pool_allocator() + else: + shape = list(input_.shape) + shape[gather_dim] = shape[gather_dim] * world_size + output = torch.empty(shape, dtype=input_.dtype, device=input_.device) + + handle = dist.all_gather_into_tensor(output, input_.contiguous(), group=process_group, async_op=async_op) + return output, handle + + +def reduce_scatter_raw( + input_: Tensor, + process_group: ProcessGroup, + op=dist.ReduceOp.SUM, + async_op: bool = False, + reduce_dim: int = 0, + memory_pool_allocator: Callable = None, +): + world_size = dist.get_world_size(process_group) + assert input_.shape[reduce_dim] % world_size == 0 + + if world_size <= 1: + return input_, None + + shape_list = list(input_.shape) + shape_list[reduce_dim] = shape_list[reduce_dim] // world_size + + if memory_pool_allocator is not None: + output = memory_pool_allocator(tuple(shape_list)) + else: + output = torch.empty( + shape_list, + dtype=input_.dtype, + device=input_.device, + ).contiguous() + + handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) + return output, handle diff --git a/internlm/core/parallel/comm/zero.py b/internlm/core/parallel/comm/zero.py new file mode 100644 index 00000000..0e5d18eb --- /dev/null +++ b/internlm/core/parallel/comm/zero.py @@ -0,0 +1,187 @@ +""" +communication for zero parallel +""" + +from collections import OrderedDict +from typing import Dict, List, Union + +from torch import distributed as dist +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.core.parallel.comm.isp import ISPCommunicator +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import ScaleColumnParallelLinear +from internlm.solver.optimizer.utils import flatten + + +class ParamAsyncBcastHandler: + """ + Model Partition Handler for overlap broadcast with forward + """ + + def __init__( + self, zero1_mode: ParallelMode, model: Union[nn.Module, nn.ModuleList], isp_communicator: ISPCommunicator = None + ) -> None: + self._block_to_param: Dict[nn.Module, List[nn.Parameter]] = OrderedDict() + self._param_to_rank: Dict[nn.Parameter, int] = {} + self._block_to_rank: Dict[nn.Module, int] = {} + self._bcast_handles: Dict[int, List[dist.Work]] = {} + self._block_to_name: Dict[nn.Module, str] = {} + + zero1_size = gpc.get_world_size(zero1_mode) + total_param_num = sum(p.numel() for p in model.parameters()) + avg_param_num = total_param_num * 1.0 // zero1_size + + # initialize an empty list for _bcast_handles of each rank + self._bcast_handles = {rank: [] for rank in range(zero1_size)} + # initialize an empty list for _allgather_handles + self._block_allgather_handles = {} + self._block_master_params = {} + self._block_working_params = {} + self._block_gathered_params = {} + self._block_allgather_order = {} + + # record the parameters to transformer/embeding/head/norm block + for _chunk in unwrap_naive_amp(model): + for name, children in _chunk.named_children(): + # should be the transformer block definaton in modeling_xxx.py + if isinstance(children, nn.ModuleList): + # record the block that a parameter belongs to + for idx, block in enumerate(children): + block_name = name + f"_{idx}" + # self._block_to_param[f"{name}.{idx}"] = list(block.parameters()) + self._block_to_param[block] = list(block.parameters()) + self._block_to_name[block] = block_name + else: + # record the block that a parameter belongs to + # self._block_to_param[name] = list(children.parameters()) + self._block_to_param[children] = list(children.parameters()) + self._block_to_name[children] = name + + alloc_num = 0 + rank_to_go = 0 + + # process the parameters in block_to_param sequencially, + # allocate each parameter to a local rank of ParallelMode.ZERO1, + # NOTE that we do NOT consider following scenarios: + # 1) whether a parameter is trainable; + # 2) paramters maybe in different optimizer group + for block, params in self._block_to_param.items(): + # allocate a model block to a local rank of ParallelMode.ZERO1 + self._block_to_rank[block] = [rank_to_go] + for p in params: + alloc_num = alloc_num + p.numel() + # in this case, allocate the param to next rank if possible + if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1: + rank_to_go = rank_to_go + 1 + alloc_num = 0 + self._block_to_rank[block].append(rank_to_go) + # allocate a parameter to a local rank of ParallelMode.ZERO1 + self._param_to_rank[p] = rank_to_go + + for block_name in self._block_to_name.values(): + self._block_allgather_handles[block_name] = None + self._block_master_params[block_name] = [] + self._block_working_params[block_name] = [] + self._block_gathered_params[block_name] = [] + self._block_allgather_order[block_name] = -1 + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + if ( + "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer + or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim + ): + self._register_sync_parameters_hook(isp_communicator) + else: + self._register_sync_parameters_hook_v2(isp_communicator) + + def _register_sync_parameters_hook(self, isp_communicator: ISPCommunicator = None) -> None: + def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613 + bcast_handles = [] + # gather all required broadcast hanles into a list + for rank in self._block_to_rank[model]: + bcast_handles.extend(self._bcast_handles[rank]) + # need to clear _bcast_handles since they would be processed later + self._bcast_handles[rank] = [] + # wait all required broadcast handles to be completed + for handle in bcast_handles: + handle.wait() + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for block, _ in self._block_to_rank.items(): + # TODO: remove special handling for embedding and head layers, + # instead implement support for weight parallelism of embedding and head layers within the ISP. + + # NOTE: Although the layernorm layer does not have explicit processing, + # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, + # so everything is fine. + if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)): + block.register_forward_pre_hook(_pre_forward_hook) + if isp_communicator: + isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) + + def _register_sync_parameters_hook_v2(self, isp_communicator: ISPCommunicator = None) -> None: + def _pre_forward_hook(model: nn.Module, *args, **kwargs): # pylint: disable=W0613 + # For each block, wait corresponding all_gather handle to be completed + # For each all_gather handle, several consecutive blocks may be involved + # In this case only the first block of the handle needs to deal with it + block_name = self._block_to_name[model] + if self._block_allgather_order[block_name] == 1: + if self._block_allgather_handles[block_name] is None: + return + self._block_allgather_handles[block_name].wait() + + # reorganize gatherd params to update working param + # [[A1, B1], [A2, B2]] -> [[A1.reshape, A2.reshape], [B1.reshape, B2.reshape]] + block_master_params = self._block_master_params[block_name] + gathered_params = self._block_gathered_params[block_name] + all_splited_param_list = [] + offset = 0 + for p in block_master_params: + param_size = p.numel() + all_splited_param = [] + for all_params in gathered_params: + split_params = all_params[offset : offset + param_size].reshape(p.shape) + all_splited_param.append(split_params) + offset += param_size + all_splited_param_list.append(all_splited_param) + assert len(all_splited_param_list) == len(self._block_working_params[block_name]) + # Update working parameters + for working_param, all_splited_param in zip( + self._block_working_params[block_name], all_splited_param_list + ): + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].view_as(working_param)) + + self._block_allgather_handles[block_name] = None + self._block_gathered_params[block_name] = [] + self._block_working_params[block_name] = [] + + # register_forward_pre_hook for transformer/embeding/norm/xxx block + for block, _ in self._block_to_rank.items(): + # TODO: remove special handling for embedding and head layers, + # instead implement support for weight parallelism of embedding and head layers within the ISP. + + # NOTE: Although the layernorm layer does not have explicit processing, + # both ISPCommunicator and ParamAsyncBcastHandler handle transformer blocks as granularity, + # so everything is fine. + if isp_communicator is None or isinstance(block, (Embedding1D, ScaleColumnParallelLinear)): + block.register_forward_pre_hook(_pre_forward_hook) + if isp_communicator: + isp_communicator.register_prerequisite_for_forward_prefetch_hooks(_pre_forward_hook) + + def get_rank_by_param(self, param) -> int: + return self._param_to_rank[param] + + def add_bcast_handle(self, rank, handle) -> None: + self._bcast_handles[rank].append(handle) + + def add_allgather_handle(self, handle, master_param, working_param, gatherd_param, block_name) -> None: + assert self._block_allgather_handles[block_name] is None + self._block_allgather_handles[block_name] = handle + self._block_master_params[block_name] = master_param + self._block_working_params[block_name] = working_param + self._block_gathered_params[block_name] = gatherd_param + self._block_allgather_order[block_name] = 1 diff --git a/internlm/core/parallel/shard.py b/internlm/core/parallel/shard.py new file mode 100644 index 00000000..33c187ec --- /dev/null +++ b/internlm/core/parallel/shard.py @@ -0,0 +1,119 @@ +""" +shard strategies for parallel +""" + +from typing import Callable + +import torch +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +# The head layer in ISP mode is actually a special case, +# and we would prefer a unified segmentation and communication logic. +def get_tensor_split_parallel_mode(is_head: bool = False) -> ParallelMode: + tp_mode = gpc.config.parallel.tensor.mode + + if tp_mode == "isp" and is_head is False: + return ParallelMode.WEIGHT + else: + return ParallelMode.TENSOR + + +def get_head_parallel_mode() -> ParallelMode: + return ParallelMode.TENSOR + + +def get_parallel_strategies_split_mode(linear_name: str) -> str: + tp_mode = gpc.config.parallel.tensor.mode + + if linear_name in ("head", "output"): + return "head" + elif linear_name in ("wqkv", "wq", "wk", "wv", "wkv", "w1", "w3", "w13"): + return "column" + elif linear_name in ("wo", "out_proj", "w2") and tp_mode == "isp": + return "column" + elif linear_name in ("wo", "out_proj", "w2"): + return "row" + else: + return "unknown" + + +def partition_uniform(num_items: int, pipeline_parallel_size: int, num_chunks: int): + assert ( + num_items % num_chunks == 0 + ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" + + parts = [[] for _ in range(pipeline_parallel_size)] + partition_items = num_items // num_chunks + for idx in range(num_chunks): + base_idx = idx * partition_items + chunk_size = partition_items // pipeline_parallel_size + left = pipeline_parallel_size - partition_items % pipeline_parallel_size + if chunk_size == 0: + raise ValueError("Some nodes in Pipeline have no requests") + + for p in range(pipeline_parallel_size): + st = base_idx + base_idx += chunk_size + (p >= left) + parts[p].append((st, base_idx)) + + indexes = [] + for _parts in parts: + for s, e in _parts: + indexes.extend(list(range(s, e))) + assert len(indexes) == len(set(indexes)), indexes # should have no duplicates + assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected + return parts + + +def pipeline_parallel_sharding_wrapper( + num_layers: int, num_chunks: int, model_builder: Callable, device: torch.device, **kwargs +): + """ + build generic model 1d + + Args: + num_layers (int): The number of layer. + num_chunks (int): The number of partitions in pipeline parallel. + device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. + + """ + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + + if gpc.is_rank_for_log(): + logger.info("The layer sharding is %r.", all_parts) + + models = [] + + for start, end in parts: + kwargs["num_layers"] = end - start + kwargs["first"] = start == 0 + # If there is no content in the final layer, assign the last layer. + kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 + kwargs["device"] = device + kwargs["start_layer_idx"] = start + + chunk = model_builder(**kwargs).to(device) + setattr(chunk, "first_layer", start) + setattr(chunk, "last_layer", end) + + models.append(chunk) + + torch.distributed.barrier() + + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + return model diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 1800ccc1..da060ade 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -8,6 +8,7 @@ import torch +from internlm.apis import InferenceParams from internlm.core.engine import Engine @@ -44,10 +45,26 @@ def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_st so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'. In all other cases 'bsz_stride' should be equal to 1. """ - assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()} - micro_batch_label = label[offset : offset + bsz_stride] - + assert isinstance(data, dict) + + micro_batch_data = {} + for k, v in data.items(): + if isinstance(v, torch.Tensor): + micro_batch_data[k] = v[offset : offset + bsz_stride] + elif isinstance(v, InferenceParams): + v.set_batch_offset(offset, bsz_stride) + micro_batch_data[k] = v + elif isinstance(v, (list, tuple)): + micro_batch_data[k] = v[offset : offset + bsz_stride] + else: + raise NotImplementedError(f"value of type {type(v)} is not supported") + + if isinstance(label, torch.Tensor): + micro_batch_label = label[offset : offset + bsz_stride] + elif isinstance(label, Dict): + micro_batch_label = {k: v[offset : offset + bsz_stride] if v.dim() > 0 else v for k, v in label.items()} + else: + micro_batch_label = label return micro_batch_data, micro_batch_label @abstractmethod diff --git a/internlm/core/communication/__init__.py b/internlm/core/scheduler/comm/__init__.py similarity index 100% rename from internlm/core/communication/__init__.py rename to internlm/core/scheduler/comm/__init__.py diff --git a/internlm/core/communication/p2p.py b/internlm/core/scheduler/comm/p2p.py similarity index 100% rename from internlm/core/communication/p2p.py rename to internlm/core/scheduler/comm/p2p.py diff --git a/internlm/core/scheduler/comm/utils.py b/internlm/core/scheduler/comm/utils.py new file mode 100644 index 00000000..d9e6f7e8 --- /dev/null +++ b/internlm/core/scheduler/comm/utils.py @@ -0,0 +1,125 @@ +# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/communication + +from typing import List, Tuple, Union + +import torch +import torch.distributed as dist + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.common import get_current_device + +TensorShape = Union[torch.Size, List[int], Tuple[int]] + + +def send_meta_helper(obj, next_rank, tensor_kwargs): + send_shape = torch.tensor(obj.size(), **tensor_kwargs) + send_ndims = torch.tensor(len(obj.size()), **tensor_kwargs) + dist.send(send_ndims, next_rank) + dist.send(send_shape, next_rank) + + +def send_obj_meta(obj, next_rank=None): + """Sends obj meta information before sending a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be sent before communications. This function + synchronizes with :func:`recv_obj_meta`. + + Args: + obj (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): obj to be sent. + need_meta (bool, optional): If False, meta information won't be sent. + next_rank (int): The rank of the next member in pipeline parallel group. + + Returns: + bool: False + """ + if next_rank is None: + next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + if isinstance(obj, torch.Tensor): + send_obj_nums = torch.tensor(1, **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + send_meta_helper(obj, next_rank, tensor_kwargs) + else: + send_obj_nums = torch.tensor(len(obj), **tensor_kwargs) + dist.send(send_obj_nums, next_rank) + for tensor_to_send in obj: + send_meta_helper(tensor_to_send, next_rank, tensor_kwargs) + + +def recv_meta_helper(prev_rank, tensor_kwargs): + recv_ndims = torch.empty((), **tensor_kwargs) + dist.recv(recv_ndims, prev_rank) + recv_shape = torch.empty(recv_ndims, **tensor_kwargs) + dist.recv(recv_shape, prev_rank) + return recv_shape + + +def recv_obj_meta(prev_rank=None) -> torch.Size: + """Receives obj meta information before receiving a specific obj. + Since the recipient must know the shape of the obj in p2p communications, + meta information of the obj should be received before communications. This function + synchronizes with :func:`send_obj_meta`. + + Args: + obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. + prev_rank (int): The rank of the source of the obj. + + Returns: + Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. + """ + if prev_rank is None: + prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) + + tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} + recv_obj_nums = torch.empty((), **tensor_kwargs) + dist.recv(recv_obj_nums, prev_rank) + if recv_obj_nums.item() == 1: + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape = torch.Size(recv_shape) + else: + obj_shape = [] + for _ in range(recv_obj_nums.item()): + recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) + obj_shape.append(torch.Size(recv_shape)) + + return obj_shape + + +def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: + """Break a tensor into equal 1D chunks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be split before communication. + new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor. + + Returns: + :class:`torch.Tensor`: The split tensor + """ + partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.TENSOR) + start_index = partition_size * gpc.get_local_rank(ParallelMode.TENSOR) + end_index = start_index + partition_size + if new_buffer: + data = torch.empty(partition_size, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) + data.copy_(tensor.view(-1)[start_index:end_index]) + else: + data = tensor.view(-1)[start_index:end_index] + return data + + +def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: + """Opposite of above function, gather values from model parallel ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. + Returns: + :class:`torch.Tensor`: The gathered tensor. + """ + world_size = gpc.get_world_size(ParallelMode.TENSOR) + numel = torch.numel(tensor) + numel_gathered = world_size * numel + gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=get_current_device(), requires_grad=False) + chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] + dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) + return gathered diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 3aacf77a..4040e8e1 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -85,10 +85,7 @@ def _load_accum_batch(self, data: Any, label: Any): self._grad_accum_offset += self._bsz_stride if self.data_process_func: - _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"]) - _label = self.data_process_func(_label, _data["cu_seqlens"], padding_v=-100) - _data.pop("cu_seqlens") - _data.pop("indexes") + _data, _label = self.data_process_func(_data, _label) return _data, _label @@ -178,7 +175,6 @@ def forward_backward_step( If True, the model is run for the forward pass, else back propagation will be executed. return_loss (bool, optional): Loss will be returned if True. return_output_label (bool, optional): Output and label will be returned if True. - Returns: Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None. """ diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 66d1cca2..269ddb96 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -9,11 +9,11 @@ import torch import torch.distributed as dist -import internlm.core.communication as comm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel +from internlm.core.scheduler import comm from internlm.utils.common import ( SchedulerHook, check_data_is_packed, @@ -79,7 +79,7 @@ def pack_return_tensors(return_tensors): raise TypeError("Output of model must be tensor or list/tuple of tensors") if isinstance(label[0], torch.Tensor): label = torch.cat(label, dim=0) - else: + elif isinstance(label[0], dict): merged_label = {k: [] for k in label[0].keys()} for d in label: for k, v in d.items(): @@ -220,16 +220,9 @@ def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride ) - if self.data_process_func: - micro_batch_data["input_ids"] = self.data_process_func( - micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"] - ) - micro_batch_label = self.data_process_func( - micro_batch_label, micro_batch_data["cu_seqlens"], padding_v=-100 - ) - micro_batch_data.pop("cu_seqlens") - micro_batch_data.pop("indexes") + if self.data_process_func: + micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) micro_batch_data["label"] = micro_batch_label self.microbatch_offset += self.bsz_stride @@ -575,7 +568,7 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T if num_1f1b_micropairs > 0: if not gpc.is_first_rank(ParallelMode.PIPELINE): if forward_recv_shapes is None: - forward_recv_shapes = comm.recv_obj_meta(forward_recv_shapes) + forward_recv_shapes = comm.recv_obj_meta() input_obj = comm.recv_forward( forward_recv_shapes, dtype=self.dtype, @@ -821,15 +814,7 @@ def load_micro_batch(self, model_chunk_id): bsz_stride=self.bsz_stride, ) if self.data_process_func: - micro_batch_data["input_ids"] = self.data_process_func( - micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"] - ) - micro_batch_label = self.data_process_func( - micro_batch_label, micro_batch_data["cu_seqlens"], padding_v=-100 - ) - - micro_batch_data.pop("cu_seqlens") - micro_batch_data.pop("indexes") + micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label) micro_batch_data["label"] = micro_batch_label self.microbatch_offset[model_chunk_id] += self.bsz_stride @@ -977,7 +962,7 @@ def _run_warmup_loop( """ if not gpc.is_pipeline_first_stage(): if self._input_obj_shapes[0] is None: - self._input_obj_shapes[0] = comm.recv_obj_meta(self._input_obj_shapes[0]) + self._input_obj_shapes[0] = comm.recv_obj_meta() self._input_objs[0].append( comm.recv_forward( self._input_obj_shapes[0], diff --git a/internlm/core/trainer.py b/internlm/core/trainer.py index b1890318..12150157 100644 --- a/internlm/core/trainer.py +++ b/internlm/core/trainer.py @@ -64,7 +64,7 @@ def __init__(self, config, batch_sampler) -> None: self.tgs_statistic = { "sum_step": 0, "sum_tg": 0, - "sum_time": 0, + "total_time": 0, "sum_last_tg_10": 0, "sum_last_time_10": 0, "sum_last_tg_50": 0, diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py new file mode 100644 index 00000000..8933a5df --- /dev/null +++ b/internlm/core/trainer_builder.py @@ -0,0 +1,333 @@ +import gc +import logging +import time +from functools import partial +from typing import Dict, Optional + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader + +from internlm.checkpoint.checkpoint_manager import CheckpointManager +from internlm.core.context import global_context as gpc +from internlm.core.context.process_group_initializer import ParallelMode +from internlm.core.trainer import Trainer +from internlm.data.streaming.utils import hf_simple_resume +from internlm.data.train_state import get_train_state +from internlm.eval.evaluation import evaluate_on_val_dls +from internlm.initialize.initialize_trainer import initialize_trainer +from internlm.model.losses.ce_loss import FlashGPTLMLoss +from internlm.model.metrics import AccPerplex +from internlm.monitor.monitor import send_alert_message +from internlm.train.pipeline import ( + get_scheduler_hooks, + initialize_llm_profile, + initialize_optimizer, + initialize_parallel_communicator, + load_new_batch, + record_current_batch_training_metrics, +) +from internlm.utils.common import ( + BatchSkipper, + enable_pytorch_expandable_segments, + get_current_device, + get_megatron_flops, + launch_time, +) +from internlm.utils.gputest import empty_cache_and_diag +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.parallel import get_parallel_log_file_name +from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler +from internlm.utils.writer import Writer + +# global llm logger +logger = logging.getLogger(__file__) + + +class TrainerBuilder(Trainer): + """ + Manage InternEvo training process. + + Args: + model (torch.nn.Module): The model to be trained. + train_dl (torch.utils.data.DataLoader): The training data loader. + val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. + kwargs: Additional keyward arguments. + """ + + def __init__( + self, + model: torch.nn.Module, + train_dl: DataLoader, + val_dls: Optional[Dict[str, DataLoader]] = None, + **kwargs, + ): + """ + Initialize InternEvo TrainerBuilder class. + + Args: + model (torch.nn.Module): The model to be trained. + train_dl (torch.utils.data.DataLoader): The training data loader. + val_dls (Optional[Dict[str, torch.utils.data.DataLoader]]): The validation data loaders. + kwargs: Additional keyward arguments. + """ + + # record very_begining_time + very_begining_time = time.time() + + # set torch expandable_segments + enable_pytorch_expandable_segments() + + # get and broadcast current time + current_time = launch_time() + objs = [current_time] + dist.broadcast_object_list(objs, src=0) + current_time = objs[0].replace(":", ".") + global logger + logger = get_logger( + __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() + ) + + # initialize isp communicator + isp_communicator = initialize_parallel_communicator(model) + + with open(kwargs["config"], "r") as f: + config_lines = f.readlines() + + # initialize loss function + criterion = FlashGPTLMLoss( + parallel_output=gpc.config.model.parallel_output, label_smoothing=gpc.config.loss.label_smoothing + ) + + # initialize and resume train state + train_state = get_train_state(train_dl) + + # initialize optimizer + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + # initialize checkpoint manager + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + model_config_file="".join(config_lines), + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + + # load other persistent training states + ckpt_manager.try_resume_training(train_state, current_time) + + # initialize customed llm writer + writer = Writer( + job_name=gpc.config.JOB_NAME, + launch_time=current_time, + file_name=get_parallel_log_file_name(), + tensorboard_folder=gpc.config.tensorboard_folder, + resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. + step_count=train_state.step_count, # resume from ckpt. + config=config_lines, + logger=logger, + enable_tb=gpc.config.enable_tb, + queue_max_length=gpc.config.tensorboard.queue_max_length, + total_steps=gpc.config.data.total_steps, + ) + + # initialize metric for calculating accuracy and perplexity + metric = AccPerplex( + device=get_current_device(), + tp_pg=gpc.get_group(ParallelMode.TENSOR), + dp_pg=gpc.get_group(ParallelMode.DATA), + dataset_types=kwargs["dataset_types"], + ) + + # initialize simple memory profiler + if kwargs["profiling"]: + self.memory_profiler = SimpleMemoryProfiler( + model, + optimizer.optim, + log_folder=f"RUN/{gpc.config.JOB_NAME}/{current_time}/memory_trace/rank{gpc.get_global_rank()}_" + + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" + + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", + ) + else: + self.memory_profiler = None + + # initialize batch skipper + skip_batches = gpc.config.data.skip_batches + if gpc.config.data.type == "hf" and gpc.config.ckpt.auto_resume: + skip_batches = hf_simple_resume(train_state) + self.batch_skipper = BatchSkipper(skip_batches) + + # set TrainerBuilder attributes + self.very_begining_time = very_begining_time + self.profiling = kwargs["profiling"] + self.current_time = current_time + self.train_dl = train_dl + self.val_dls = val_dls + self.train_state = train_state + self.optimizer = optimizer + self.beta2_scheduler = beta2_scheduler + self.isp_communicator = isp_communicator + self.writer = writer + self.ckpt_manager = ckpt_manager + self.metric = metric + + # initialize trainer + engine, scheduler = initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), + ) + + super().__init__(engine, scheduler) + + def fit(self): + """ + Launch InternEvo TrainerBuilder training process. + """ + + self.train() + train_iter = iter(self.train_dl) + + with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: + # close automatic garbage collection + gc.disable() + # start iterating the train data and begin training + for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): + empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) + # internlm_accelerator.memory._record_memory_history() + start_time = time.time() + timer("one-batch").start() + + # load batch data + batch, train_iter = load_new_batch( + train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state + ) + + # record the consumed samples in training + self.train_state.batch_count = batch_count + self.train_state.num_consumed_samples_in_epoch += len(batch[1]) + if self.batch_skipper(batch_count): # skip this batch + if gpc.is_rank_for_log(): + logger.info(f"Skip batch count:`{batch_count}`...") + timer("one-batch").stop() + continue + + # zero the grads of parameters + self.zero_grad() + # process data + if batch[0].get("type_ids", None) is not None: + self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + # if batch[0].get("cu_seqlens", None) is not None: + # metric.set_cu_seqlens(cu_seqlens=batch[0].pop("cu_seqlens", None)) + + # do forward and backward + timer("fwd-bwd").start() + + moe_loss = None + if hasattr(gpc.config.model, "num_experts"): + _, _, loss, moe_loss = self.execute_schedule( + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + else: + _, _, loss = self.execute_schedule( # pylint: disable=W0632 + batch, + forward_only=False, + return_loss=True, + return_output_label=False, + ) + timer("fwd-bwd").stop() + + if self.isp_communicator and self.isp_communicator.enable_memory_pool: + self.isp_communicator.memory_pool.reset_lazy_pools() + + # update parameters, and returns (success_update, grad_norm) + trainer_result = self.step() + assert trainer_result is not None + + success_update, grad_norm_groups = trainer_result + if success_update: # update parameters successfully + self.train_state.step_count += 1 + else: + self.train_state.inf_nan_skip_batches += ( + 1 # record the amount of updating parameters unsuccessfully. + ) + if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case + logger.warning(f"Warning: skip parameter update at step {batch_count}.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message=f"Warning: skip parameter update at step {batch_count}.", + ) + + get_tflops_func = partial( + get_megatron_flops, + checkpoint=gpc.config.model.checkpoint, + seq_len=gpc.config.data["seq_len"], + hidden_size=gpc.config.model.hidden_size, + num_layers=gpc.config.model.num_layers, + vocab_size=gpc.config.model.vocab_size, + global_batch_size=gpc.config.data.micro_bsz + * gpc.config.data.micro_num + * gpc.get_world_size(ParallelMode.DATA), + global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), + mlp_ratio=gpc.config.model["mlp_ratio"], + ) + + # calculate and record the training metrics, eg. loss, accuracy and so on. + record_current_batch_training_metrics( + get_tflops_func=get_tflops_func, + logger=logger, + writer=self.writer, + success_update=success_update, + batch_count=batch_count, + batch=batch, + train_state=self.train_state, + optimizer=self.optimizer, + beta2_scheduler=self.beta2_scheduler, + trainer=self, + start_time=start_time, + very_begining_time=self.very_begining_time, + loss=loss, + moe_loss=moe_loss, + grad_norm=grad_norm_groups, + metric=self.metric, + ) + + timer("one-batch").stop() + + # evaluate on validation data loaders + if gpc.config.data.valid_every > 0 and self.train_state.step_count % gpc.config.data.valid_every == 0: + evaluate_on_val_dls( + self, + val_dls=self.val_dls, + writer=self.writer, + logger=logger, + step_count=self.train_state.step_count, + ) + + # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" + # # save batch sampler that tracks the true consumed samples + now_break = self.ckpt_manager.try_save_checkpoint(self.train_state) + if now_break: + break + + if self.memory_profiler is not None: + self.memory_profiler.step() + + if batch_count % 2 == 0: + prof.step() + + # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + + self.ckpt_manager.wait_async_upload_finish() diff --git a/internlm/data/__init__.py b/internlm/data/__init__.py index 08ad5d88..35f6ade4 100644 --- a/internlm/data/__init__.py +++ b/internlm/data/__init__.py @@ -1,4 +1,5 @@ from .build_dataloader import ( + build_generation_loader_with_data_type, build_train_loader_with_data_type, build_valid_loader_with_data_type, ) @@ -6,4 +7,5 @@ __all__ = [ "build_train_loader_with_data_type", "build_valid_loader_with_data_type", + "build_generation_loader_with_data_type", ] diff --git a/internlm/data/build_dataloader.py b/internlm/data/build_dataloader.py index c2c0ea69..aa09a960 100644 --- a/internlm/data/build_dataloader.py +++ b/internlm/data/build_dataloader.py @@ -6,11 +6,21 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.data.streaming.batch_sampler import StreamingStaticBatchSampler +from internlm.data.streaming.collaters import nopack_collate_fn, pack_collate_fn +from internlm.data.streaming.dataset import ( + HuggingFacePackedDataset, + HuggingFaceStreamingDataset, +) from internlm.data.tokenized.batch_sampler import ( StaticBatchSampler, get_dpsampler_dataloader, ) -from internlm.data.tokenized.collaters import jsonl_ds_collate_fn, packed_collate_fn +from internlm.data.tokenized.collaters import ( + generation_collate_fn, + jsonl_ds_collate_fn, + packed_collate_fn, +) from internlm.data.tokenized.dataset import get_dataset_dict from internlm.data.tokenized.dummy_dataset import RandomDataset from internlm.data.tokenized.dummy_dataset_multimodal import RandomDatasetMultimodal @@ -108,6 +118,31 @@ def get_tokenized_valid_loader_items(data_cfg): return valid_ds, valid_collate_fn +def get_hf_train_loader_items(data_cfg): + train_ds = HuggingFaceStreamingDataset( + dataset_name=data_cfg.train_folder, + tokenizer_name=data_cfg.tokenizer_path, + model_max_length=data_cfg.seq_len, + subset_name=data_cfg.get("subset_name", None), + ) + if gpc.config.model_type == "hf" and not data_cfg.use_packed_dataset: + train_sampler = StreamingStaticBatchSampler( + batch_size=data_cfg.micro_num * data_cfg.micro_bsz, rampup_batch_size=data_cfg.rampup_batch_size + ) + train_collate_fn = partial( + nopack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + ) + else: + train_ds = HuggingFacePackedDataset(dataset=train_ds, seq_len=data_cfg.seq_len, micro_bsz=data_cfg.micro_bsz) + train_sampler = StreamingStaticBatchSampler( + batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size + ) + train_collate_fn = partial( + pack_collate_fn, micro_num=data_cfg.micro_num, micro_bsz=data_cfg.micro_bsz, seq_len=data_cfg.seq_len + ) + return train_ds, train_sampler, train_collate_fn + + def build_train_loader_with_data_type(): """ Build and return the training data loader based on data type. @@ -115,11 +150,15 @@ def build_train_loader_with_data_type(): Returns: A tuple of (train_dl, dataset_types). """ data_cfg = gpc.config.data + train_folder = data_cfg.get("train_folder", None) - dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] if data_cfg.type == "tokenized": train_ds, train_sampler, train_collate_fn = get_tokenized_train_loader_items(data_cfg) + dataset_types = list(get_dataset_type_ids_map(train_folder).keys()) if train_folder else ["en", "cn", "code"] + elif data_cfg.type == "hf": + train_ds, train_sampler, train_collate_fn = get_hf_train_loader_items(data_cfg) + dataset_types = ["en"] else: raise ValueError(f"dataset type {data_cfg.type} is not supported") @@ -141,7 +180,7 @@ def build_valid_loader_with_data_type(): data_cfg = gpc.config.data - if data_cfg.type == "tokenized": + if data_cfg.type in ["tokenized", "hf"]: valid_ds, valid_collate_fn = get_tokenized_valid_loader_items(data_cfg) else: raise ValueError(f"dataset type {data_cfg.type} is not supported") @@ -178,3 +217,46 @@ def build_valid_loader_with_data_type(): ) return val_dls + + +def build_generation_loader_with_data_type(data_cfg, generation_cfg): + """Generate and return the validation data loader based on data type.""" + + if data_cfg.type == "tokenized": + gene_ds, _ = get_tokenized_valid_loader_items(data_cfg) + else: + raise ValueError(f"dataset type {data_cfg.type} is not supported") + + if gene_ds is None: + return None + + gene_dls = {} + for gene_name, ds in gene_ds.items(): + # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, + # otherwise too much data may be dropped + batch_size = min( + data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) + ) + batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz + if generation_cfg.batch_size: + batch_size = generation_cfg.batch_size + + if batch_size == 0 and gpc.is_rank_for_log(): + logger.info(f"skip validate {gene_name}.") + continue + + gene_dls[gene_name] = get_dpsampler_dataloader( + ds, + shuffle=False, + num_workers=data_cfg.get("num_worker", 0), + batch_size=batch_size, + collate_fn=partial(generation_collate_fn, pad_id=generation_cfg.pad_id), + ) + + if gpc.is_rank_for_log(): + logger.info( + f"load validation dataset {gene_name} with valid batch size {str(batch_size)} and " + f"samples {str(len(gene_dls[gene_name]))}." + ) + + return gene_dls diff --git a/internlm/data/streaming/__init__.py b/internlm/data/streaming/__init__.py new file mode 100644 index 00000000..513e3243 --- /dev/null +++ b/internlm/data/streaming/__init__.py @@ -0,0 +1,13 @@ +from .batch_sampler import StreamingStaticBatchSampler +from .collaters import nopack_collate_fn, pack_collate_fn +from .dataset import HuggingFacePackedDataset, HuggingFaceStreamingDataset +from .utils import hf_simple_resume + +__all__ = [ + "StreamingStaticBatchSampler", + "nopack_collate_fn", + "pack_collate_fn", + "HuggingFaceStreamingDataset", + "HuggingFacePackedDataset", + "hf_simple_resume", +] diff --git a/internlm/data/streaming/batch_sampler.py b/internlm/data/streaming/batch_sampler.py new file mode 100644 index 00000000..11f9bb8b --- /dev/null +++ b/internlm/data/streaming/batch_sampler.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import sys +from typing import Optional + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +class StreamingStaticBatchSampler: + """ + StreamingStaticBatchSampler is used for the training process. + """ + + def __init__(self, batch_size: int = 1, rampup_batch_size: Optional[str] = None, micro_bsz: int = 1): + if rampup_batch_size: + start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) + else: + start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 + + self.raw_rampup_batch_size = rampup_batch_size + self.start_bsz = start_bsz + self.bsz_incre = bsz_incre + self.incre_every = incre_every + + if gpc.is_initialized(ParallelMode.PIPELINE): + assert ( + batch_size - self.start_bsz + ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" + assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.start_bsz % micro_bsz == 0 + ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.bsz_incre % micro_bsz == 0 + ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})" + + self.batch_size = batch_size + self.num_consumed_samples_in_epoch = 0 + self.batch_count = 0 + + def __len__(self): + return sys.maxsize + + def __iter__(self): + while True: + batch_rampup_idx = self.batch_count // self.incre_every + cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz + cur_batch_size = min(cur_batch_size, self.batch_size) + + self.num_consumed_samples_in_epoch += cur_batch_size + self.batch_count += 1 + yield [0] * cur_batch_size + + def state_dict(self): + states = { + "batch_size": self.batch_size, + "raw_rampup_batch_size": self.raw_rampup_batch_size, + "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, + "batch_count": self.batch_count, + } + return states + + def load_state_dict(self, states): + for name in ("raw_rampup_batch_size",): # 'batch_size' + assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change + self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] + self.batch_count = states["batch_count"] + + def copy(self): + copy_sampler = StreamingStaticBatchSampler(self.batch_size, self.raw_rampup_batch_size) + + copy_sampler.load_state_dict(self.state_dict()) + return copy_sampler diff --git a/internlm/data/streaming/collaters.py b/internlm/data/streaming/collaters.py new file mode 100644 index 00000000..4391fd23 --- /dev/null +++ b/internlm/data/streaming/collaters.py @@ -0,0 +1,58 @@ +import torch + + +def nopack_collate_fn(batch, micro_num, micro_bsz, seq_len): + input_ids_list = [] + attention_mask_list = [] + labels_list = [] + for b in batch: + attention_mask = torch.tensor(b["attention_mask"]) + input_ids = torch.LongTensor(b["input_ids"]) + input_ids = torch.abs(input_ids * attention_mask) + input_ids = torch.nn.functional.pad(input_ids, (0, seq_len - len(input_ids)), mode="constant", value=0) + attention_mask = torch.nn.functional.pad( + attention_mask, (0, seq_len - len(attention_mask)), mode="constant", value=0 + ) + label = torch.LongTensor([w if w > 0 else -100 for w in input_ids.tolist()][1:] + [-100]) + input_ids_list.append(input_ids) + attention_mask_list.append(attention_mask) + labels_list.append(label) + input_ids = torch.stack(input_ids_list) + attention_mask = torch.stack(attention_mask_list) + labels = torch.stack(labels_list) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "type_ids": torch.zeros(micro_num, micro_bsz, seq_len, dtype=torch.int64), + }, labels + + +def pack_collate_fn(batch, micro_num, micro_bsz, seq_len): + packed_length = micro_bsz * seq_len + + input_ids_list = [] + cu_seqlens_list = [] + indexes_list = [] + labels_list = [] + + for b in batch: + assert len(b["input_ids"]) == packed_length + assert b["cu_seqlens"][0] == 0 and b["cu_seqlens"][-1] == packed_length + assert len(b["indexes"]) == packed_length + assert len(b["labels"]) == packed_length + + input_ids_list.append(torch.LongTensor(b["input_ids"])) + cu_seqlens_list.append(torch.IntTensor(b["cu_seqlens"])) + indexes_list.append(torch.IntTensor(b["indexes"])) + labels_list.append(torch.LongTensor(b["labels"])) + + input_ids = torch.stack(input_ids_list) + indexes = torch.stack(indexes_list) + labels = torch.stack(labels_list) + + return { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens_list, + "indexes": indexes, + "type_ids": torch.zeros(micro_num, micro_bsz * seq_len, dtype=torch.int64), + }, labels diff --git a/internlm/data/streaming/dataset.py b/internlm/data/streaming/dataset.py new file mode 100644 index 00000000..a3844d70 --- /dev/null +++ b/internlm/data/streaming/dataset.py @@ -0,0 +1,119 @@ +import itertools +import sys + +import datasets +import numpy as np +from datasets.distributed import split_dataset_by_node +from torch.utils.data import Dataset + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from transformers import AutoTokenizer + + +class HuggingFaceStreamingDataset(Dataset): + """ + Streaming and on-the-fly tokenized dataset for huggingface + """ + + def __init__( + self, dataset_name, tokenizer_name, model_max_length, split="train", buffer_size=1000, subset_name=None + ): + self.dataset = datasets.load_dataset(dataset_name, data_dir=subset_name, split=split, streaming=True) + self.dataset = split_dataset_by_node( + self.dataset, rank=gpc.get_local_rank(ParallelMode.DATA), world_size=gpc.get_world_size(ParallelMode.DATA) + ) + self.buffer_size = buffer_size + self.senior_iterator = iter(self) + + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + self.tokenizer.model_max_length = model_max_length + + def __iter__(self): + buffer = [] + for sample in self.dataset: + buffer.append(sample) + if len(buffer) >= self.buffer_size: + yield from self._tokenize(buffer) + buffer = [] + + if buffer: + yield from self._tokenize(buffer) + + def __len__(self): + return sys.maxsize + + def _tokenize(self, samples): + texts = [sample["text"] for sample in samples] + tokenized_outputs = self.tokenizer(texts, truncation=True) + for i in range(len(samples)): + yield {key: tokenized_outputs[key][i] for key in tokenized_outputs} + + def __getitem__(self, _): + return next(self.senior_iterator) + + +class HuggingFacePackedDataset(Dataset): + """ + Simple packed dataset for huggingface. + """ + + def __init__(self, dataset, seq_len, micro_bsz): + self.dataset = dataset + self.seq_len = seq_len + self.micro_bsz = micro_bsz + + self.senior_iterator = iter(self) + + def __iter__(self): + input_ids = [] + cu_seqlens = [0] + labels = [] + for sample in self.dataset: + if len(input_ids + sample["input_ids"]) > self.micro_bsz * self.seq_len: + assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len + input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + cu_seqlens = ( + cu_seqlens + [self.micro_bsz * self.seq_len] + if cu_seqlens[-1] < self.micro_bsz * self.seq_len + else cu_seqlens + ) + labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + } + input_ids = sample["input_ids"] + cu_seqlens = [0, len(sample["input_ids"])] + labels = sample["input_ids"][1:] + [-100] + else: + input_ids = input_ids + sample["input_ids"] + cu_seqlens.append(len(sample["input_ids"]) + cu_seqlens[-1]) + labels = labels + sample["input_ids"][1:] + [-100] + if input_ids: + assert cu_seqlens[-1] <= self.micro_bsz * self.seq_len + input_ids = input_ids + [0] * (self.micro_bsz * self.seq_len - len(input_ids)) + cu_seqlens = ( + cu_seqlens + [self.micro_bsz * self.seq_len] + if cu_seqlens[-1] < self.micro_bsz * self.seq_len + else cu_seqlens + ) + labels = labels + [-100] * (self.micro_bsz * self.seq_len - len(labels)) + yield { + "input_ids": input_ids, + "cu_seqlens": cu_seqlens, + "indexes": list( + itertools.chain(*[np.arange(l2 - l1) for l1, l2 in zip(cu_seqlens[:-1], cu_seqlens[1:])]) + ), + "labels": labels, + } + + def __len__(self): + return sys.maxsize + + def __getitem__(self, _): + return next(self.senior_iterator) diff --git a/internlm/data/streaming/utils.py b/internlm/data/streaming/utils.py new file mode 100644 index 00000000..ee331ba2 --- /dev/null +++ b/internlm/data/streaming/utils.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from internlm.core.context import global_context as gpc + + +# simple auto_resume for huggingface streaming dataloader +def hf_simple_resume(train_state): + skip_batches = gpc.config.data.get("skip_batches", "") + if train_state.batch_count > 0: + assert skip_batches == "", "skip_batches should be empty when huggingface dataloader resume from ckpts" + skip_batches = f"0-{train_state.batch_count - 1}" + train_state.batch_count = 0 + train_state.num_consumed_samples_in_epoch = 0 + if hasattr(train_state, "batch_sampler"): + train_state.batch_sampler.batch_count = 0 + train_state.batch_sampler.num_consumed_samples_in_epoch = 0 + train_state.batch_sampler_iter = iter(train_state.batch_sampler) + return skip_batches diff --git a/internlm/data/tokenized/collaters.py b/internlm/data/tokenized/collaters.py index 785ecc60..fab7c5ac 100644 --- a/internlm/data/tokenized/collaters.py +++ b/internlm/data/tokenized/collaters.py @@ -100,3 +100,29 @@ def jsonl_ds_collate_fn(batch, max_length_per_sample): return {"input_ids": xs, "images": images}, ys else: return {"input_ids": xs}, ys + + +def generation_collate_fn(batch, pad_id=0): + """ + Collate function for generation dataset. + + Args: + batch (List[Dict]): List of dictionaries representing each sample in batch. + Each dictionary contains "tokens". + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", + and the tensor of padded "labels". + + """ + xs, ys = [], [] + for x in batch: + tokens = [abs(w) for w in x["tokens"]] + labels = [w if w > 0 else -100 for w in x["tokens"]] + labels = labels[1:] + [-100] + xs.append(torch.as_tensor(tokens[::-1])) + ys.append(torch.as_tensor(labels[::-1])) # y has been shifted + xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True, padding_value=pad_id).flip(dims=[1]) + ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100).flip(dims=[1]) + + return {"input_ids": xs}, ys diff --git a/internlm/data/tokenized/dataset.py b/internlm/data/tokenized/dataset.py index e39a39b7..8991272b 100644 --- a/internlm/data/tokenized/dataset.py +++ b/internlm/data/tokenized/dataset.py @@ -51,6 +51,6 @@ def get_dataset_dict(folder, split="valid") -> Dict: datasets.append(ds) if datasets: ds = ConcatDataset(datasets=datasets) - data_dict[os.path.basename(root)] = ds + data_dict[os.path.basename(root.rstrip(os.path.sep))] = ds return data_dict diff --git a/internlm/data/tokenized/packed_dataset.py b/internlm/data/tokenized/packed_dataset.py index 2d3bfa49..b2a8b109 100644 --- a/internlm/data/tokenized/packed_dataset.py +++ b/internlm/data/tokenized/packed_dataset.py @@ -4,22 +4,27 @@ import itertools as it import operator import os +import shutil from copy import deepcopy +from pathlib import Path from typing import Dict import numpy as np -import torch import torch.distributed as dist from torch.utils.data import ConcatDataset, Dataset from tqdm import tqdm +from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.data.tokenized.single_dataset import JsonlDataset from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map from internlm.utils.logger import get_logger +from .single_dataset import gen_shm_meta_name_without_scalar + DEFAULT_SEED = 1024 logger = get_logger(__file__) +internlm_accelerator = get_accelerator() class PackedDataset(Dataset): @@ -67,7 +72,12 @@ def __getitem__(self, item: int) -> Dict: return self.build_unpack(item) -class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset): +def gen_shm_meta_name_scalar(path: str, num_tokens: int, seed: int): + shm_path_without_num_tokens = gen_shm_meta_name_without_scalar(path) + return "%".join([str(shm_path_without_num_tokens), str(num_tokens), str(seed)]) + + +class PackedDatasetWithoutCuSeqlen(Dataset): """ A dataset wrapper that aggregates samples with different lengths based on packed_length. If a sample is shorter than max_length_per_sample, it will be merged with other samples. @@ -98,13 +108,50 @@ def __init__( ), "The dataset must have lengths attribute and have the same length as the dataset" self.dataset = dataset self.max_length_per_sample = max_length_per_sample - self.lengths = getattr(self.dataset, "lengths") self.bsz = packed_length // max_length_per_sample self.packed_length = packed_length self.debug = debug # Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting - self.seed = DEFAULT_SEED + self.path = self.get_dataset_name() + + if not gpc.config.data.use_shm: + self._process_init() + else: + if self.dataset.found_cache: + assert ( + hasattr(dataset, "lengths") + and hasattr(dataset, "indices") + and hasattr(dataset, "cum_lens") + and hasattr(dataset, "num_tokens") + ) + self.lengths = self.dataset.lengths + self.indices = self.dataset.indices + self.cum_lens = self.dataset.cum_lens + self.num_tokens = self.dataset.num_tokens + assert self.seed == getattr(self.dataset, "seed") + assert packed_length % max_length_per_sample == 0 + assert len(getattr(dataset, "lengths")) == len( + dataset + ), "The dataset must have lengths attribute and have the same length as the dataset" + else: + self._process_init() + # If shm-packed datast found no cache, local rank 0 try save. + if self.dataset.local_rank == 0: + shm_path = Path(gen_shm_meta_name_scalar(self.path, self.num_tokens, self.seed)) + shm_path_str = str(shm_path) + assert not os.path.exists(shm_path_str) + if not os.path.exists(str(shm_path.parent)): + os.makedirs(str(shm_path.parent), exist_ok=True) + + data = np.asarray([self.dataset.offsets, self.lengths, self.indices, self.cum_lens]) + np.save(shm_path_str, data) + # Prevent the risk of competition between jsondataset and packeddataset in + # different processes on the same node. + shutil.move(shm_path_str + ".npy", shm_path_str + ".final") # np.save will auto add .npy + + def _process_init(self): + self.lengths = getattr(self.dataset, "lengths") indices = np.arange(len(self.lengths)) rng = np.random.RandomState(self.seed) rng.shuffle(indices) @@ -234,8 +281,54 @@ def __init__( packed_length: int = 4096, ): super().__init__(dataset, max_length_per_sample, packed_length) - self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed) - self.num_tokens = sum(self.lengths) + self.path = self.get_dataset_name() + if not gpc.config.data.use_shm: + self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed) + self.num_tokens = sum(self.lengths) + else: + if self.dataset.found_cache: + assert ( + hasattr(dataset, "sample_indices") + and hasattr(dataset, "len_samples_shuffled") + and hasattr(dataset, "acm_len_samples") + and hasattr(dataset, "num_tokens") + ) + self.sample_indices = self.dataset.sample_indices + self.len_samples_shuffled = self.dataset.len_samples_shuffled + self.acm_len_samples = self.dataset.acm_len_samples + self.num_tokens = self.dataset.num_tokens + assert self.seed == getattr(self.dataset, "seed") + assert packed_length % max_length_per_sample == 0 + assert hasattr(dataset, "lengths") + assert len(getattr(dataset, "lengths")) == len( + dataset + ), "The dataset must have lengths attribute and have the same length as the dataset" + else: + self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len( + seed=self.seed + ) + self.num_tokens = sum(self.lengths) + # If shm-packed datast found no cache, local rank 0 try save. + if self.dataset.local_rank == 0: + shm_path = Path(gen_shm_meta_name_scalar(self.path, self.num_tokens, self.seed)) + shm_path_str = str(shm_path) + assert not os.path.exists(shm_path_str) + if not os.path.exists(str(shm_path.parent)): + os.makedirs(str(shm_path.parent), exist_ok=True) + + data = np.asarray( + [ + self.dataset.offsets, + self.lengths, + self.sample_indices, + self.len_samples_shuffled, + self.acm_len_samples, + ] + ) + np.save(shm_path_str, data) + # Prevent the risk of competition between jsondataset and packeddataset in + # different processes on the same node. + shutil.move(shm_path_str + ".npy", shm_path_str + ".final") # np.save will auto add .npy def get_dataset_name(self): return self.dataset.get_dataset_name() @@ -452,7 +545,12 @@ def get_packed_dataset_without_short_length( ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}" ds_type_id = get_dataset_type_id(DATASET_TYPE_IDS_MAP, path=fp) - ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num) + ds = JsonlDataset( + fp, + ds_type_id, + min_length=min_length_num, + pack_sample_into_one=pack_sample_into_one, + ) if hasattr(ds, "old_length"): delete_samples += ds.old_length - len(ds) @@ -501,6 +599,7 @@ class PackedDatasetWithPadForMultimodal(PackedDataset): Args: dataset: The original dataset to pack. max_length_per_sample: The maximum length of each original sample. Default is 2048. + padding_side: The padding side. Default is "right". packed_length: The length of each packed sample. Default is 4096. padding_idx: The token id of padding. Default is 0. """ @@ -511,13 +610,17 @@ def __init__( max_length_per_sample: int = 2048, packed_length: int = 4096, padding_idx: int = 0, + padding_side: str = "right", image_token_id: int = 200000, + has_image: bool = True, ): super().__init__(dataset, max_length_per_sample, packed_length) self.padding_idx = padding_idx + self.padding_side = padding_side self.sample_indices, self.belongs = self.accu_sample_len(self.seed) self.num_tokens = sum(self.lengths) self.image_token_id = image_token_id + self.has_image = has_image def get_dataset_name(self): return self.dataset.get_dataset_name() @@ -555,7 +658,10 @@ def __len__(self): def build_pack(self, index): - pack, cu_seqlens, indexes, labels, type_ids, images = [], [0], [], [], [], [] + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + if self.has_image: + images = [] start_pos = np.searchsorted(self.belongs, index, "left") end_pos = np.searchsorted(self.belongs, index, "right") @@ -567,8 +673,9 @@ def build_pack(self, index): for sample_idx in cur_samples: sample = self.dataset[sample_idx] length = min(len(sample["tokens"]), self.max_length_per_sample) - cur_images = sample["images"] - images.extend(cur_images) + if self.has_image: + cur_images = sample["images"] + images.extend(cur_images) chunk = sample["tokens"][:length] pack.extend(chunk) cu_seqlens.append(cu_seqlens[-1] + len(chunk)) @@ -582,10 +689,16 @@ def build_pack(self, index): indexes.extend(list(range(length))) if cu_seqlens[-1] != self.packed_length: - pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) - labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) - type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) - indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + if self.padding_side == "right": + pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + else: + pack = [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + pack + labels = [-100] * (self.packed_length - cu_seqlens[-1]) + labels + type_ids = [0] * (self.packed_length - cu_seqlens[-1]) + type_ids + indexes = [0] * (self.packed_length - cu_seqlens[-1]) + indexes cu_seqlens.append(self.packed_length) out = { diff --git a/internlm/data/tokenized/single_dataset.py b/internlm/data/tokenized/single_dataset.py index 5477d34c..2527dc0a 100644 --- a/internlm/data/tokenized/single_dataset.py +++ b/internlm/data/tokenized/single_dataset.py @@ -9,11 +9,37 @@ import mmap import os import threading +import time from pathlib import Path import numpy as np import torch +from internlm.accelerator import get_accelerator +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + + +def gen_shm_meta_name_without_scalar(path: str): + """gen_shm_meta_name_without_scalar + + Args: + path (str): dataset path, like: + /llm_data/tokenized/train/cn/train-00000.bin + """ + bin_path = Path(path) + shm_prefix_path = Path(gpc.config.data.shm_path) + + # Use the entire path as the relative path part + dataset_base_path = bin_path.relative_to(bin_path.anchor) # Removes the root part (e.g., '/' or 'C:\') + + # /dev/shm/metacache/llm_data/tokenized/train/cn/train-00000.bin + shm_path_without_num_tokens = Path(shm_prefix_path, dataset_base_path) + return shm_path_without_num_tokens + class JsonlDataset(torch.utils.data.Dataset): """ @@ -29,7 +55,63 @@ class JsonlDataset(torch.utils.data.Dataset): Note that only the "tokens" key is used. """ - def __init__(self, path: str, dataset_type_id: int = 0, min_length=50): + def __init__(self, path: str, dataset_type_id: int = 0, min_length=50, pack_sample_into_one=False): + if not gpc.config.data.use_shm: + self._process_init(path, dataset_type_id, min_length) + else: + devices_per_node = internlm_accelerator.device_count() + self.local_rank = gpc.get_global_rank() % devices_per_node + shm_path_without_num_tokens = gen_shm_meta_name_without_scalar(path) + + found_cache, shm_path, num_tokens, seed = False, None, None, None + while not found_cache: + if shm_path_without_num_tokens.parent.exists(): + for file in shm_path_without_num_tokens.parent.iterdir(): + fp_str = str(file.resolve()) + if fp_str.startswith(str(shm_path_without_num_tokens.resolve())) and fp_str.endswith(".final"): + # Found cache + scalers = fp_str.split("%") + num_tokens = int(scalers[1]) + seed = int(scalers[2].split(".")[0]) + found_cache = True + shm_path = fp_str + + # for local_rank 0, no need to wait + # go forward to do computing and saving + if self.local_rank == 0: + break + + if not found_cache: + logger.warning(f"GPU {self.local_rank} loading meta: cache not found, waiting...") + time.sleep(1) + + if found_cache: + assert shm_path and num_tokens is not None and seed is not None + self.shm_handler = np.load(shm_path, mmap_mode="r+") + self.offsets = self.shm_handler[0] + self.lengths = self.shm_handler[1] + if pack_sample_into_one: + self.indices = self.shm_handler[2] + self.cum_lens = self.shm_handler[3] + else: + self.sample_indices = self.shm_handler[2] + self.len_samples_shuffled = self.shm_handler[3] + self.acm_len_samples = self.shm_handler[4] + self.num_tokens = num_tokens + self.seed = seed + self.threadlocal = threading.local() + self.path = path + self.resolved_path = Path(path).resolve() + self.type_id = dataset_type_id + self.old_length = len(self.offsets) + elif self.local_rank == 0: + self._process_init(path, dataset_type_id, min_length) + else: + assert False, "should not arrive here" + + self.found_cache = found_cache + + def _process_init(self, path: str, dataset_type_id: int = 0, min_length=50): self.path = path self.threadlocal = threading.local() resolved_path = Path(path).resolve() diff --git a/internlm/data/train_state.py b/internlm/data/train_state.py index 6564a4df..cd1cc8a1 100644 --- a/internlm/data/train_state.py +++ b/internlm/data/train_state.py @@ -5,7 +5,7 @@ def get_train_state(dataloader): # initialize and resume train state - if gpc.config.data.type == "tokenized": + if gpc.config.data.type in ["tokenized", "hf"]: train_state = TrainState(gpc.config, dataloader.batch_sampler) else: raise ValueError(f"dataset type {gpc.config.data.type} is not supported") diff --git a/internlm/data/utils.py b/internlm/data/utils.py index 4461c001..19e74ae2 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -5,7 +5,9 @@ import torch +from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.utils import _split def get_dataset_type_ids_map(path): @@ -24,34 +26,63 @@ def get_dataset_type_id(dataset_type_ids_map, path): return match_idxes[0] -def unpack_data(input_ids, cu_seqlens, is_type_ids: bool = False, padding_v: int = 0): - """ - input_ids: if input_ids is not type_ids, the shape is (1, packed_length) - else the shape is (micro_num, packed_length) - is_type_ids: whether the input_ids is type_ids - - Return: - output: if input_ids is not type ids, the shape is (micro_bsz, max_length) - else the shape is (micro_num, micro_bsz, max_length) - """ - bsz = input_ids.shape[0] +def _unpack_data(data, cu_seqlens, padding_v: int = 0): + bsz = data.shape[0] num_seq = gpc.config.data["micro_bsz"] seq_len_ = gpc.config.data.seq_len - dtype_ = input_ids.dtype + dtype_ = data.dtype - outputs = torch.empty(bsz, num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v) + outputs = torch.empty(bsz, num_seq, seq_len_, device=data.device, dtype=dtype_).fill_(padding_v) for i in range(bsz): - output = torch.empty(num_seq, seq_len_, device=input_ids.device, dtype=dtype_).fill_(padding_v) + output = torch.empty(num_seq, seq_len_, device=data.device, dtype=dtype_).fill_(padding_v) cu_seqlens_slice = cu_seqlens[i] for j in range(num_seq): length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] - output[j, 0:length] = input_ids[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] + output[j, 0:length] = data[i, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] outputs[i] = output - # if the input_ids is not type_ids, we need squeeze the first dimension if it is 1. - if bsz == 1 and not is_type_ids: - outputs = outputs.squeeze(0) - return outputs + + +def unpack_type_ids(type_ids, cu_seqlens): + return _unpack_data(type_ids, cu_seqlens) + + +def unpack_data(data, label): + + if gpc.config.model_type == "hf": + return data, label + + data["input_ids"] = _unpack_data(data["input_ids"], data["cu_seqlens"], padding_v=0).squeeze(0) + label = _unpack_data(label, data["cu_seqlens"], padding_v=-100).squeeze(0) + + data.pop("cu_seqlens") + data.pop("indexes") + + return data, label + + +def packed_data_normalizer(data, label): + # Should we normalize packed data in this form of this data processor + # or let the dataset handle it? Currently inclined towards the latter. + assert data["input_ids"].shape[0] == 1, "data should be packed with batch size 1" + + data["indexes"] = data["indexes"][0] + data["cu_seqlens"] = data["cu_seqlens"][0].squeeze(0) + data["max_seqlen"] = (data["cu_seqlens"][1:] - data["cu_seqlens"][:-1]).max().item() + + # Move to parallel package for standardization + if gpc.config.parallel.sequence_parallel and gpc.config.parallel["tensor"].get("mode", "mtp") == "isp": + data["indexes"] = _split(data["indexes"], ParallelMode.TENSOR, dim=0) + + if gpc.config.model_type == "hf": + data.pop("cu_seqlens") + data.pop("max_seqlen") + data["position_ids"] = data.pop("indexes") + data["attention_mask"] = torch.ones( + (data["input_ids"].shape), dtype=torch.bool, device=data["input_ids"].device + ) + + return data, label diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index be42897e..7e440528 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -3,7 +3,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/initialize -from typing import Callable, Iterable, List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from torch import nn from torch.nn.modules.loss import _Loss @@ -22,7 +22,7 @@ ) from internlm.core.scheduler.pipeline_scheduler import get_tensor_shape from internlm.core.trainer import Trainer -from internlm.data.utils import unpack_data +from internlm.data.utils import packed_data_normalizer, unpack_data from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler from internlm.utils.common import SchedulerHook, get_current_device @@ -32,8 +32,6 @@ def initialize_trainer( model: nn.Module, optimizer: Optimizer, criterion: Optional[_Loss] = None, - train_dataloader: Optional[Iterable] = None, - test_dataloader: Optional[Iterable] = None, lr_scheduler: Optional[_LRScheduler] = None, beta2_scheduler: Optional[Beta2Scheduler] = None, scheduler_hooks: Optional[List[SchedulerHook]] = None, @@ -45,14 +43,10 @@ def initialize_trainer( model (:class:`torch.nn.Module` or `Callable`): Your model instance or a function to build the model. optimizer (:class:`BaseOptimizer`): Your optimizer for training. criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance. - train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training. - test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing. lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional. Returns: - Tuple (trainer, train_dataloader, test_dataloader, lr_scheduler): - A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)`` - where only ``trainer`` could not be None. + Tuple (engine, scheduler) """ if isinstance(model, nn.Module): @@ -79,10 +73,9 @@ def initialize_trainer( # initialize scheduler for trainer scheduler = None - if gpc.config.data.use_packed_dataset: - data_fn = None - else: - data_fn = unpack_data + + data_fn = packed_data_normalizer if gpc.config.data.use_packed_dataset else unpack_data + if gpc.is_using_parallel_mode(ParallelMode.PIPELINE): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() @@ -132,6 +125,4 @@ def initialize_trainer( clip_grad_norm=clip_grad_norm, ) - trainer = Trainer(engine, scheduler) - - return trainer, train_dataloader, test_dataloader, lr_scheduler + return engine, scheduler diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 0ea000e9..a07c3e76 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -2,7 +2,6 @@ # -*- encoding: utf-8 -*- import argparse -import gc import os from pathlib import Path from typing import Dict, Union @@ -13,10 +12,6 @@ from internlm.core.context import Config from internlm.core.context import global_context as gpc from internlm.core.context.process_group_initializer import ParallelMode -from internlm.model.moe.megablock.utils import ( - check_megablock_installed, - check_stk_installed, -) from internlm.utils.common import get_master_node from internlm.utils.gputest import warmup_process_group from internlm.utils.logger import get_logger @@ -89,7 +84,7 @@ def args_sanity_check(): gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False)) if "tensor" not in gpc.config.parallel: - gpc.config.parallel._add_item("tensor", 1) + gpc.config.parallel._add_item("tensor", dict(size=1, mode="mtp")) if "weight" not in gpc.config.parallel: gpc.config.parallel._add_item("weight", dict(size=1, overlap=False, memory_pool=False)) @@ -165,6 +160,14 @@ def args_sanity_check(): data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) + if "use_shm" not in data: + data._add_item("use_shm", False) + elif data.use_shm and "shm_path" not in data: + data._add_item("shm_path", "/dev/shm/metacache") + + if data.train_folder is None: + data.use_shm = False + if "use_packed_dataset" not in data: data._add_item("use_packed_dataset", True) @@ -339,16 +342,21 @@ def args_sanity_check(): model._add_item("moe_use_residual", False) if "moe_type" not in model: model._add_item("moe_type", "GShard") - # check dependency - if gpc.config.model.moe_type == "MegaBlock": - check_megablock_installed() - if gpc.config.model.moe_type == "MegaBlock-D": - check_megablock_installed() - check_stk_installed() if "mlp_layer_fusion" not in model: model._add_item("mlp_layer_fusion", False) + # qk_interleaved config + if "qk_interleaved" not in gpc.config.model: + if "adapt_hf" in gpc.config.model: + model._add_item("qk_interleaved", not gpc.config.model.adapt_hf) + else: + model._add_item("qk_interleaved", False) + elif "adapt_hf" in gpc.config.model: + assert gpc.config.model.adapt_hf == ( + not gpc.config.model.qk_interleaved + ), "adapt_hf and qk_interleaved must be opposite" + # process the parallel config if "sequence_parallel" not in gpc.config.parallel: gpc.config.parallel._add_item("sequence_parallel", False) @@ -436,6 +444,11 @@ def args_sanity_check(): optim_ckpt._add_item("overlap_sync_grad", False) if "overlap_sync_param" not in optim_ckpt: optim_ckpt._add_item("overlap_sync_param", False) + if "use_split_tensor_optim" not in optim_ckpt: + optim_ckpt._add_item("use_split_tensor_optim", False) + elif optim_ckpt.use_split_tensor_optim and "all_gather_size" not in optim_ckpt: + optim_ckpt._add_item("all_gather_size", 512 * 1024 * 1024) + if gpc.is_rank_for_log(): logger.info( f"overlap_sync_grad:{optim_ckpt.overlap_sync_grad}, overlap_sync_param:{optim_ckpt.overlap_sync_param}" @@ -619,9 +632,6 @@ def initialize_distributed_env( """ backend = internlm_accelerator._communication_backend_name - # close automatic garbage collection - gc.disable() - if launcher == "torch": launch_from_torch(config=config, seed=seed, backend=backend) elif launcher == "slurm": diff --git a/internlm/model/__init__.py b/internlm/model/__init__.py index 26ac3e7c..e69de29b 100644 --- a/internlm/model/__init__.py +++ b/internlm/model/__init__.py @@ -1,33 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from .metrics import AccPerplex -from .modeling_internlm import build_model_with_cfg -from .modeling_internlm2 import build_model_with_cfg as build_model_with_cfg2 -from .modeling_llama import build_model_with_cfg as build_model_with_llama_cfg -from .modeling_llava import build_model_with_cfg as build_model_with_llava_cfg -from .modeling_moe import build_model_with_moe_cfg -from .modules.embedding import Embedding1D, RotaryEmbedding -from .modules.mlp import FeedForward -from .modules.multi_head_attention import MHA, DistributedAttention -from .moe.moe import MoE -from .ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from .utils import gather_forward_split_backward - -__all__ = [ - "Embedding1D", - "FeedForward", - "MoE", - "RotaryEmbedding", - "RewardModelLinear", - "ScaleColumnParallelLinear", - "AccPerplex", - "MHA", - "DistributedAttention", - "gather_forward_split_backward", - "build_model_with_cfg", - "build_model_with_cfg2", - "build_model_with_moe_cfg", - "build_model_with_llama_cfg", - "build_model_with_llava_cfg", -] diff --git a/internlm/model/builder.py b/internlm/model/builder.py new file mode 100644 index 00000000..c8adcd41 --- /dev/null +++ b/internlm/model/builder.py @@ -0,0 +1,41 @@ +from typing import List, Union + +from torch import nn + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import pipeline_parallel_sharding_wrapper +from internlm.model.registry import hf_config_initializer, model_initializer +from internlm.utils.common import get_current_device + + +def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module]]: + num_layers = kwargs.pop("num_layers") + num_chunks = kwargs.pop("num_chunks", 1) + + # TODO: fix use_flash_attn parameter config + kwargs.pop("use_flash_attn", False) + kwargs.pop("apply_post_layer_norm") + kwargs.pop("embed_split_hidden", True) + + kwargs["checkpoint"] = float(kwargs.get("checkpoint", False)) + kwargs["device"] = get_current_device() + + model_buidler = model_initializer.get_module(module_name=model_type) + + if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): + if model_type == "hf": + hf_config_builder = hf_config_initializer.get_module(module_name=model_type) + config = hf_config_builder(return_dict=False) + model = model_buidler(*args, config).to(kwargs["device"]) + else: + kwargs["first"] = kwargs["last"] = True + kwargs["start_layer_idx"] = 0 + kwargs["num_layers"] = num_layers + model = model_buidler(*args, **kwargs).to(kwargs["device"]) + setattr(model, "first_layer", 0) + setattr(model, "last_layer", num_layers) + else: + model = pipeline_parallel_sharding_wrapper(num_layers, num_chunks, model_buidler, *args, **kwargs) + + return model diff --git a/internlm/model/llava/__init__.py b/internlm/model/llava/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/internlm/model/llava_modules/clip_builder.py b/internlm/model/llava/clip_builder.py similarity index 100% rename from internlm/model/llava_modules/clip_builder.py rename to internlm/model/llava/clip_builder.py diff --git a/internlm/model/llava_modules/clip_encoder.py b/internlm/model/llava/clip_encoder.py similarity index 100% rename from internlm/model/llava_modules/clip_encoder.py rename to internlm/model/llava/clip_encoder.py diff --git a/internlm/model/llava_modules/projector_builder.py b/internlm/model/llava/projector_builder.py similarity index 100% rename from internlm/model/llava_modules/projector_builder.py rename to internlm/model/llava/projector_builder.py diff --git a/internlm/model/losses/ce_loss.py b/internlm/model/losses/ce_loss.py index 3fe4858b..69e09d2f 100644 --- a/internlm/model/losses/ce_loss.py +++ b/internlm/model/losses/ce_loss.py @@ -3,9 +3,11 @@ from torch import nn -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import internlm_init_CrossEntropyLoss +from internlm.model.ops.cross_entropy import new_cross_entropy +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class FlashGPTLMLoss(nn.Module): @@ -24,12 +26,11 @@ def __init__(self, parallel_output=True, label_smoothing=0): label_smoothing = 0 self.label_smoothing = label_smoothing - self.loss_fn = internlm_init_CrossEntropyLoss( - parallel_output=parallel_output, + self.loss_fn = new_cross_entropy( reduction="mean", - inplace_backward=True, - process_group=gpc.get_group(ParallelMode.TENSOR), label_smoothing=self.label_smoothing, + parallel_output=parallel_output, + inplace_backward=True, ) def forward(self, *args): diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 6db8044a..54cc41ba 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -3,17 +3,21 @@ import torch from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import ( - internlm_init_CrossEntropyLoss, - try_import_scatter_sum, -) +from internlm.model.ops.cross_entropy import new_cross_entropy from internlm.utils.common import SchedulerHook, get_current_device +from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +try: + from torch_scatter import scatter as cuda_scatter + + cuda_scatter_impl = True +except (ModuleNotFoundError, ImportError): + cuda_scatter_impl = False + +logger = get_logger(__file__) internlm_accelerator = get_accelerator() -scatter_sum = try_import_scatter_sum() def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): @@ -51,6 +55,24 @@ def vanilla_scatter( return out.scatter_add_(dim, index, src) +# move to ops when there are more than one files use it. +def _get_scatter_sum_impl(): + if cuda_scatter_impl and internlm_accelerator.get_accelerator_backend() in ( + AcceleratorType.GPU, + AcceleratorType.DIPU, + ): + if gpc.is_rank_for_log(): + logger.warning("Use cuda_scatter. Please note this!") + return cuda_scatter + else: + if gpc.is_rank_for_log(): + logger.warning("Use vanilla_scatter rather than cuda_scatter. Please note this!") + return vanilla_scatter + + +scatter_sum_impl = _get_scatter_sum_impl() + + class AccPerplex: """ AccPerplex module for calculating model's accuracy and perplexity metrics. @@ -88,7 +110,7 @@ def __init__(self, device, tp_pg, dp_pg, tokenizer=None, dataset_types: List[str self.ds_tokens = torch.zeros(self.total_type_count, dtype=torch.long, device=device) self.loss_with_type_id = LossWithTypeId(device, dp_pg, dataset_types) - self.scatter_sum = scatter_sum if scatter_sum else vanilla_scatter + self.scatter_sum = scatter_sum_impl def set_current_type_ids(self, type_ids: torch.Tensor): self.batch_shift = 0 @@ -257,13 +279,12 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None: self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device) self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device) - self.loss_fn = internlm_init_CrossEntropyLoss( - parallel_output=gpc.config.model.parallel_output, + self.loss_fn = new_cross_entropy( reduction="none", + parallel_output=gpc.config.model.parallel_output, inplace_backward=True, - process_group=gpc.get_group(ParallelMode.TENSOR), ) - self.scatter_sum = scatter_sum if scatter_sum else vanilla_scatter + self.scatter_sum = scatter_sum_impl def update(self, logits, labels, type_ids=None): with torch.no_grad(): diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index ef5f7e9f..5994e15d 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -12,27 +12,23 @@ from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import MHA -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import MHA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, + internlm1_mha_pre_load_convert, + internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashBaseLayer1D(nn.Module): +class InternLM1Decoder(nn.Module): """ 1D Packed Flash Base Layer. @@ -42,15 +38,22 @@ class PackedFlashBaseLayer1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): Type of data. torch.float by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. layer_idx (int): The index of current layer. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -69,11 +72,10 @@ def __init__( residual_in_fp32: bool = False, device: Optional[torch.device] = None, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, - tp_mode: str = "mtp", rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, @@ -83,18 +85,14 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, + bias=True, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), causal=True, @@ -102,37 +100,35 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, rope_base=rope_base, device=device, dtype=dtype, - tp_mode=self.tp_mode, + qk_interleaved=qk_interleaved, + enable_qkv_fusion=True, ) - self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - - if use_swiglu or not use_flash_attn: - mlp_cls = get_mlp_cls(self.tp_mode) - self.mlp = mlp_cls( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.sequence_parallel, - multiple_of=multiple_of, - ) + # Compatible with the name of internlm1 Wqkv linear layer + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) + + self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + + self.mlp = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", + ) + self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -144,7 +140,7 @@ def reset_parameters(self): for name, param in self.mixer.named_parameters(): if param.ndim == 1: param.data.zero_() - elif "Wqkv" in name: + elif "wqkv" in name: normal_(std=0.006)(param.data) elif self.use_scaled_init: scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) @@ -166,15 +162,15 @@ def reset_parameters(self): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def forward(self, hidden_states, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) + # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, *args) else: - return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, **kwargs) - def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def _forward(self, hidden_states, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -183,12 +179,6 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -204,6 +194,7 @@ def _dropout_and_norm_attn(_hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) hidden_states = self.mixer(hidden_states, **mixer_kwargs) def _dropout_and_norm_ffn(_residual, _hidden_states): @@ -225,7 +216,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual -class PackedFlashInternLm1D(nn.Module): +class InternLM1(nn.Module): """ 1D Packed Flash InternLm. @@ -237,23 +228,27 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number of layers. 0.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -271,7 +266,6 @@ def __init__( layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -279,11 +273,11 @@ def __init__( device: Optional[torch.device] = None, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, rope_base: int = 10000, mlp_layer_fusion: bool = False, multiple_of: int = 256, @@ -291,25 +285,17 @@ def __init__( super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear if first: - self.embedding = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) for _, param in self.embedding.named_parameters(): normal_(std=0.0052)(param) - self.embed_grad_scale = embed_grad_scale self.blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + InternLM1Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, mlp_ratio=mlp_ratio, @@ -327,36 +313,32 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - tp_mode=self.tp_mode, rope_base=rope_base, + qk_interleaved=qk_interleaved, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) for lid in range(num_layers) ] ) + if last: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = head_cls( + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.head) for _, param in self.head.named_parameters(): normal_(std=0.0052)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "embedding") and input_ids is not None: hidden_states = self.embedding(input_ids) @@ -365,172 +347,12 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None - for _, block in enumerate(self.blocks): - hidden_states = block( - hidden_states, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) + hidden_states = self.head(hidden_states) - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=0.0, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - max_position_embeddings=2048, - mlp_ratio=4.0, - residual_in_fp32=False, - use_dynamic_ntk_rope=False, - norm_type="rmsnorm", - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Build model with config. - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - - """ - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - residual_in_fp32=residual_in_fp32, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - norm_type=norm_type, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 08065ddc..c3b89412 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -3,11 +3,8 @@ from typing import Optional import torch -import torch.nn.functional as F -from einops import rearrange from torch import nn -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import ( @@ -16,467 +13,29 @@ scaled_init_method_uniform, uniform_, ) -from internlm.model.modules.embedding import ( - DynamicNTKScalingRotaryEmbedding, - Embedding1D, - RotaryEmbedding, -) -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import ( - _update_kv_cache, - get_gqa_attn_cls, -) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import ( - RewardModelLinear, - ScaleColumnParallelLinearWithNormHead, - get_linear_cls, -) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.model.utils import ( - gather_forward_split_backward, - pack_output_after_attn, - split_forward_gather_backward, - unpack_qkv_before_attn, + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM2_PUBLIC" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -internlm_accelerator = get_accelerator() - - -class MHA(nn.Module): - """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - num_kv_heads (int): The number of attention heads for key and value. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation. - bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and - output projection. False by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (bool): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - rot_embed_HF_impl (Optional[bool]): Whether to use the rotary embedding implementation from HuggingFace. - True by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - max_position_embeddings: int = 2048, - bias: bool = False, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - use_dynamic_ntk_rope: bool = False, - use_flash_attn: bool = True, - rope_base: int = 10000, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - rot_embed_HF_impl: Optional[bool] = True, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" - - self.head_dim = self.embed_dim // num_heads - self.num_kv_heads = num_kv_heads - self.kv_dim = self.head_dim * num_kv_heads - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.dtype = dtype - - self.q_per_kv = num_heads // num_kv_heads - - self.rot_embed_HF_impl = rot_embed_HF_impl - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - - self.max_position_embeddings = max_position_embeddings - self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.tp_mode = tp_mode - - if self.rotary_emb_dim > 0: - if self.use_dynamic_ntk_rope: - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( - self.rotary_emb_dim, - base=rope_base, - scale_base=rotary_emb_scale_base, - device=device, - max_position_embeddings=max_position_embeddings, - scaling_factor=1.0, # Currently do not support dynamic scaling. - ) - else: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - self.wqkv = Wqkv_cls( - embed_dim, - embed_dim + 2 * self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - self.inner_attn, self.inner_cross_attn = get_gqa_attn_cls( - use_flash_attn, self.tp_mode, causal, softmax_scale, dropout, sequence_process_group - ) - self.inner_cross_attn_causal = causal - self.inner_cross_attn_softmax_scale = softmax_scale - self.inner_cross_attn_dropout = dropout - - wo_cls = get_linear_cls(self.tp_mode, "row") - self.wo = wo_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - qkv = self.wqkv(x) - - if seqlen is None: - qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) - else: - qkv = rearrange(qkv, "(b s) (h gs d) -> b s h gs d", s=seqlen, gs=self.q_per_kv + 2, d=self.head_dim) - - q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) - - q = rearrange(q, "b s h gs d -> b s (h gs) d") - - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - - if inference_params is None: - if self.rotary_emb_dim > 0: - q = self.rotary_emb._single_eval_forward(q) - k = self.rotary_emb._single_eval_forward(k) - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv) - - else: - assert self.rotary_emb_dim > 0 - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() - if inference_params.sequence_len_offset == 0: - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward( - moved_q, seqlen_offset=inference_params.sequence_len_offset - ) - moved_k = self.rotary_emb._single_eval_forward( - moved_k, seqlen_offset=inference_params.sequence_len_offset - ) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - raise NotImplementedError( - "You should make sure you are aware that you are changing the method of generating." - "According to your generation function instead of inference/seq_generator_module.py, " - "You may implement here for normal running." - ) - - kv = torch.stack([k, v], dim=2) - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - assert kv.size(1) == 1, "update kv lenth more than 1" - inference_params.key_value_memory_dict[self.layer_idx][ - :, inference_params.keep_first : inference_params.window_size - 1, ... - ] = inference_params.key_value_memory_dict[self.layer_idx][ - :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... - ].clone() - inference_params.real_sequence_len_offset = inference_params.sequence_len_offset - inference_params.sequence_len_offset = inference_params.window_size - 1 - - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - inference_params.sequence_len_offset = inference_params.real_sequence_len_offset - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - # When using FP16, there is a high probability of NAN in the KV. - # Since NAN cannot be removed by multiplying with and 0, it needs - # to be removed manually here. - kv = torch.where(torch.isnan(kv), 0, kv) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - from flash_attn import flash_attn_varlen_kvpacked_func - - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - if self.dtype is torch.float32: - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - output = flash_attn_varlen_kvpacked_func( - q=total_q, - kv=total_kv, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - causal=True, - ).to(self.dtype) - else: - output = flash_attn_varlen_kvpacked_func( - q=total_q, - kv=total_kv, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - dropout_p=0.0, - causal=True, - ) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - attn_mask = torch.concat( - [ - attn_mask[..., : inference_params.keep_first], - attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], - ], - dim=-1, - ) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - expansion = q.size(2) // k.size(2) - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv, causal=True).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv, causal=True) - - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - - out = self.wo(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - assert self.use_flash_attn is True - - qkv = self.wqkv(x) - - qkv = rearrange(qkv, "b t (h gs d) -> b t h gs d", gs=self.q_per_kv + 2, d=self.head_dim) - - q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) - - q = rearrange(q, "b t h gs d -> b t (h gs) d") - - # qkv shift - # the rotary embedding in flash attention module in performed by separating the front and back parts, while - # most of others are done by odd-even methods. - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - indexes = kwargs.pop("indexes") - q = self.rotary_emb._single_forward(q, indexes=indexes) - k = self.rotary_emb._single_forward(k, indexes=indexes) - - if inference_params is None: - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - q = q.squeeze(0) - kv = kv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - q = unpack_qkv_before_attn(q, kwargs["cu_seqlens"]) - kv = unpack_qkv_before_attn(kv, kwargs["cu_seqlens"]) - - if self.dtype is torch.float32: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ).to(self.dtype) - else: - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ) - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, kwargs["cu_seqlens"]) - - out = self.wo(context) - return out - - -class PackedFlashLlamaLayer1D(nn.Module): +class InternLM2Decoder(nn.Module): """ - InternLM2 layer. + InternLM2 Decoder layer. Args: hidden_size (int): The hidden size of model. 768 by default. num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. @@ -488,10 +47,16 @@ class PackedFlashLlamaLayer1D(nn.Module): use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + fused_dropout_add_ln (bool): Whether to fuse dropout, residual addition, and layer normalization. + Defaults to True. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu @@ -499,6 +64,8 @@ class PackedFlashLlamaLayer1D(nn.Module): ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -521,12 +88,10 @@ def __init__( fused_dropout_add_ln: bool = True, no_bias: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = True, + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, - tp_mode: str = "mtp", attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, @@ -541,7 +106,6 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln @@ -552,16 +116,12 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR head_dim = hidden_size // num_attention_heads - self.attention = MHA( + self.attention = GQA( embed_dim=hidden_size, num_heads=num_attention_heads, num_kv_heads=num_kv_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), @@ -570,39 +130,32 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - rot_embed_HF_impl=adapt_hf, + qk_interleaved=qk_interleaved, bias=not no_bias, rope_base=rope_base, - tp_mode=self.tp_mode, + enable_qkv_fusion=True, ) self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - self.feed_forward = get_mlp_cls(self.tp_mode)( + self.feed_forward = new_feed_forward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), bias=False, device=device, dtype=dtype, mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=sequence_parallel, multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", ) - assert use_swiglu is True, "InternLM2 only support swiglu." - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -646,19 +199,15 @@ def reset_parameters(self): param.data ) - def forward( - self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def forward(self, hidden_states, residual=None, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen - ) + # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) else: - return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, residual, **kwargs) - def _forward( - self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def _forward(self, hidden_states, residual, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -683,13 +232,9 @@ def _dropout_and_norm_attn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - hidden_states = self.attention(hidden_states, **mixer_kwargs) + + attn_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **attn_kwargs) if not isinstance(self.feed_forward, nn.Identity): if not self.fused_dropout_add_ln: @@ -715,13 +260,8 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual else: assert residual is None - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - mixer_out = self.attention(hidden_states, **mixer_kwargs) + + mixer_out = self.attention(hidden_states, **kwargs) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( @@ -737,28 +277,26 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class PackedFlashLlama1D(nn.Module): +class InternLM2(nn.Module): """ - 1D Packed Flash InternLM2. + InternLM2 Model. Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + num_attention_heads (int): The number of attention head. 32 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. @@ -766,7 +304,10 @@ class PackedFlashLlama1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -777,28 +318,26 @@ class PackedFlashLlama1D(nn.Module): init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. norm_head (bool): Whether to use norm head. False by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, max_position_embeddings: int = 2048, dtype: torch.dtype = torch.float, - checkpoint: bool = False, - checkpoint_fraction: float = 1.0, + checkpoint: float = 0.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -808,12 +347,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = True, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -823,44 +361,27 @@ def __init__( init_type: str = "normal", rope_base: int = 10000, norm_head: bool = False, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): super().__init__() - self.use_flash_attn = use_flash_attn - - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - - self.tp_mode = tp_mode - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinearWithNormHead + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale - self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + InternLM2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -882,14 +403,12 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, - tp_mode=self.tp_mode, rope_base=rope_base, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, @@ -900,23 +419,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - if norm_head and not issubclass(head_cls, ScaleColumnParallelLinearWithNormHead): - raise TypeError( - "Parameter ``norm_head`` should only be True when head_cls is " - f"``ScaleColumnParallelLinearWithNormHead``, instead of {head_cls}." - ) - self.output = head_cls( # pylint: disable=E1123 + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, norm_head=norm_head, ) @@ -926,9 +438,7 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "tok_embeddings") and input_ids is not None: hidden_states = self.tok_embeddings(input_ids) @@ -936,210 +446,13 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=True, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - norm_head: bool = False, - max_position_embeddings=2048, - use_dynamic_ntk_rope=False, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - max_position_embeddings (int): The maximum position embeddings. 2048 by default. - use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - norm_head=norm_head, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_llama.py b/internlm/model/modeling_llama.py index adbb9a9a..e5768ff3 100644 --- a/internlm/model/modeling_llama.py +++ b/internlm/model/modeling_llama.py @@ -2,11 +2,8 @@ from typing import Optional import torch -import torch.nn.functional as F -from einops import rearrange from torch import nn -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module @@ -16,462 +13,29 @@ scaled_init_method_uniform, uniform_, ) -from internlm.model.modules.embedding import Embedding1D, RotaryEmbedding -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import ( - _update_kv_cache, - get_gqa_attn_cls, -) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import ( - RewardModelLinear, - ScaleColumnParallelLinear, - get_linear_cls, -) +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm from internlm.model.utils import ( - gather_forward_split_backward, - pack_output_after_attn, - split_forward_gather_backward, - unpack_qkv_before_attn, + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "LLAMA2" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -internlm_accelerator = get_accelerator() - - -class MHA(nn.Module): - """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - sequence_process_group (torch.distributed.ProcessGroup): The process group for attention calculation. - bias (boolean): Whether the bias is needed for linears. Will be used when initializing QKV matrix and - output projection. True by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. - - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - num_kv_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - rope_base: int = 10000, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - use_flash_attn: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - rot_embed_HF_impl: Optional[bool] = False, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" - - self.head_dim = self.embed_dim // num_heads - self.num_kv_heads = num_kv_heads - self.kv_dim = self.head_dim * num_kv_heads - self.causal = causal - self.layer_idx = layer_idx - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.dtype = dtype - self.tp_mode = tp_mode - - self.rot_embed_HF_impl = rot_embed_HF_impl - sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) - - if self.rotary_emb_dim > 0: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - # notice here should change bias=True - self.wq = Wqkv_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.wk = Wqkv_cls( - embed_dim, - self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - self.wv = Wqkv_cls( - embed_dim, - self.kv_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - self.inner_attn, self.inner_cross_attn = get_gqa_attn_cls( - use_flash_attn, self.tp_mode, causal, softmax_scale, dropout, sequence_process_group - ) - self.inner_cross_attn_causal = causal - self.inner_cross_attn_softmax_scale = softmax_scale - self.inner_cross_attn_dropout = dropout - - # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(self.tp_mode, "row") - self.wo = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=bias, - sequence_parallel=sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - q, k, v = self.wq(x), self.wk(x), self.wv(x) - if seqlen is None: - q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) - k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) - v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) - else: - q = rearrange(q, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - k = rearrange(k, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - v = rearrange(v, "(b s) (h d) -> b s h d", s=seqlen, d=self.head_dim) - - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - if inference_params is None: - if self.rotary_emb_dim > 0: - q = self.rotary_emb._single_eval_forward(q) - k = self.rotary_emb._single_eval_forward(k) - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv) - - else: - assert self.rotary_emb_dim > 0 - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() - if inference_params.sequence_len_offset == 0: - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward( - moved_q, seqlen_offset=inference_params.sequence_len_offset - ) - moved_k = self.rotary_emb._single_eval_forward( - moved_k, seqlen_offset=inference_params.sequence_len_offset - ) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - raise NotImplementedError( - "You should make sure you are aware that you are changing the method of generating." - "According to your generation function instead of inference/seq_generator_module.py, " - "You may implement here for normal running." - ) - - kv = torch.stack([k, v], dim=2) - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - assert kv.size(1) == 1, "update kv lenth more than 1" - inference_params.key_value_memory_dict[self.layer_idx][ - :, inference_params.keep_first : inference_params.window_size - 1, ... - ] = inference_params.key_value_memory_dict[self.layer_idx][ - :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... - ].clone() - inference_params.real_sequence_len_offset = inference_params.sequence_len_offset - inference_params.sequence_len_offset = inference_params.window_size - 1 - - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - inference_params.sequence_len_offset = inference_params.real_sequence_len_offset - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - else: - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - # When using FP16, there is a high probability of NAN in the KV. - # Since NAN cannot be removed by multiplying with and 0, it needs - # to be removed manually here. - kv = torch.where(torch.isnan(kv), 0, kv) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - from flash_attn.flash_attn_interface import FlashAttnVarlenKVPackedFunc - - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - - if self.dtype is torch.float32: - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - output = FlashAttnVarlenKVPackedFunc.apply( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ).to(self.dtype) - else: - output = FlashAttnVarlenKVPackedFunc.apply( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - if hasattr(inference_params, "window_size") and inference_params.window_size is not None: - if inference_params.window_size <= inference_params.sequence_len_offset: - attn_mask = torch.concat( - [ - attn_mask[..., : inference_params.keep_first], - attn_mask[..., -(inference_params.window_size - inference_params.keep_first) :], - ], - dim=-1, - ) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - expansion = q.size(2) // k.size(2) - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - if self.dtype is torch.float32 and self.use_flash_attn: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_cross_attn(q=q, kv=kv, causal=True).to(self.dtype) - else: - context = self.inner_cross_attn(q=q, kv=kv, causal=True) - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - out = self.wo(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - we delete seqlen=None for lint check, cause this arg is not used. - - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - assert self.use_flash_attn is True - q, k, v = self.wq(x), self.wk(x), self.wv(x) - q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim) - k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim) - v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim) - - # qkv shift - # the rotary embedding in flash attention module in performed by separating the front and back parts, while - # most of others are done by odd-even methods. - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - - indexes = kwargs.pop("indexes") - - q = self.rotary_emb._single_forward(q, indexes=indexes) - k = self.rotary_emb._single_forward(k, indexes=indexes) - if inference_params is None: - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - q = q.squeeze(0) - kv = kv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - q = unpack_qkv_before_attn(q, kwargs["cu_seqlens"]) - kv = unpack_qkv_before_attn(kv, kwargs["cu_seqlens"]) - if self.dtype is torch.float32: - if q.dtype not in [torch.float16, torch.bfloat16]: - q = q.to(torch.bfloat16) - if kv.dtype not in [torch.float16, torch.bfloat16]: - kv = kv.to(torch.bfloat16) - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ).to(self.dtype) - else: - context = self.inner_attn( - q=q, - kv=kv, - cu_seqlens_q=kwargs["cu_seqlens"], - cu_seqlens_k=kwargs["cu_seqlens"], - max_seqlen_q=kwargs["max_seqlen"], - max_seqlen_k=kwargs["max_seqlen"], - dropout_p=self.inner_cross_attn_dropout, - softmax_scale=self.inner_cross_attn_softmax_scale, - causal=self.inner_cross_attn_causal, - ) - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, kwargs["cu_seqlens"]) - - out = self.wo(context) - return out - - -class PackedFlashLlamaLayer1D(nn.Module): +class Llama2Decoder(nn.Module): """ - 1D Packed Flash Llama Layer. + Llama2 Decoder Layer. Args: hidden_size (int): The hidden size of model. 768 by default. num_attention_heads (int): The number of attention heads. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. @@ -481,8 +45,16 @@ class PackedFlashLlamaLayer1D(nn.Module): layer_idx (int): The index of current layer. 0 by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + fused_dropout_add_ln (bool): Whether to fuse dropout, residual addition, and layer normalization. + Defaults to True. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu @@ -490,8 +62,8 @@ class PackedFlashLlamaLayer1D(nn.Module): ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], - "mtp" by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -512,18 +84,16 @@ def __init__( fused_dropout_add_ln: bool = True, no_bias: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, ffn_other_init_std: float = 0.02, init_type: str = "normal", rope_base: int = 10000, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): @@ -532,7 +102,6 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln @@ -542,52 +111,43 @@ def __init__( self.ffn_other_init_std = ffn_other_init_std head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR - self.attention = MHA( + self.attention = GQA( embed_dim=hidden_size, num_heads=num_attention_heads, num_kv_heads=num_kv_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, softmax_scale=1 / math.sqrt(head_dim), causal=True, layer_idx=layer_idx, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - rot_embed_HF_impl=adapt_hf, + qk_interleaved=qk_interleaved, bias=not no_bias, rope_base=rope_base, - tp_mode=self.tp_mode, + enable_qkv_fusion=False, ) self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.attention_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.feed_forward = get_mlp_cls(self.tp_mode)( + self.feed_forward = new_feed_forward( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), bias=False, device=device, dtype=dtype, mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.get("sequence_parallel", False), multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", ) - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -631,19 +191,15 @@ def reset_parameters(self): param.data ) - def forward( - self, hidden_states, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def forward(self, hidden_states, residual=None, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen - ) + # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) else: - return self._forward(hidden_states, residual, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, residual, **kwargs) - def _forward( - self, hidden_states=None, residual=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None - ): + def _forward(self, hidden_states, residual, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -668,13 +224,9 @@ def _dropout_and_norm_attn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - hidden_states = self.attention(hidden_states, **mixer_kwargs) + + attn_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **attn_kwargs) if not isinstance(self.feed_forward, nn.Identity): if not self.fused_dropout_add_ln: @@ -700,13 +252,8 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states + residual else: assert residual is None - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } - mixer_out = self.attention(hidden_states, **mixer_kwargs) + + mixer_out = self.attention(hidden_states, **kwargs) if self.return_residual: # mixer out is actually a pair here mixer_out, hidden_states = mixer_out hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( @@ -722,34 +269,38 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class PackedFlashLlama1D(nn.Module): +class Llama2(nn.Module): """ - 1D Packed Flash Llama. + Llama2 Model. Args: num_layers (int): The number of layer. 12 by default. hidden_size (int): The size of hidden state. 768 by default. num_attention_heads (int): The number of attention head. 12 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 12. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. dtype (torch.dtype): The type of data. torch.float by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. + checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 0.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. + apply_post_layer_norm (bool): Whether to apply layer normalization after the attention and mlp. + Defaults to False. + no_bias (bool): Whether to exclude bias in attention and feed-forward networks. Defaults to False. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -759,25 +310,25 @@ class PackedFlashLlama1D(nn.Module): out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, init_type (str): Initialization type. Use uniform or normal. "normal" by default, rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, - checkpoint: bool = False, - checkpoint_fraction: float = 1.0, + checkpoint: float = 0.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -786,12 +337,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -805,35 +355,22 @@ def __init__( ): super().__init__() - self.use_flash_attn = use_flash_attn - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear + checkpoint_layer_num = int(num_layers * checkpoint) + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + Llama2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -853,15 +390,13 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, rope_base=rope_base, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -871,18 +406,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.output = head_cls( + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.output) @@ -892,9 +425,7 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 if hasattr(self, "tok_embeddings") and input_ids is not None: hidden_states = self.tok_embeddings(input_ids) @@ -902,203 +433,14 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlama1D(**filter_kwargs(PackedFlashLlama1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=False, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_llava.py b/internlm/model/modeling_llava.py index 37246e14..57b68859 100644 --- a/internlm/model/modeling_llava.py +++ b/internlm/model/modeling_llava.py @@ -7,53 +7,42 @@ from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module from internlm.initialize.initialize_tensor import normal_, uniform_ -from internlm.model.modeling_llama import PackedFlashLlamaLayer1D +from internlm.model.llava.clip_builder import build_vision_tower +from internlm.model.llava.projector_builder import build_vision_projector +from internlm.model.modeling_llama import Llama2Decoder from internlm.model.modules.embedding import Embedding1D -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, -) -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs +from internlm.model.modules.linear import new_linear +from internlm.model.modules.norm import new_layer_norm from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "LLAVA" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashLlava1D(nn.Module): +class Llava(nn.Module): """ 1D Packed Flash Llava. Args: - num_layers (int): The number of layer. 12 by default. - hidden_size (int): The size of hidden state. 768 by default. - num_attention_heads (int): The number of attention head. 12 by default. + num_layers (int): The number of layer. 48 by default. + hidden_size (int): The size of hidden state. 2048 by default. + num_attention_heads (int): The number of attention head. 32 by default. + num_kv_attention_heads (int): The number of key/value attention heads. Defaults to 32. vocab_size (int): The size of vocabulary. 50304 by default. mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. - checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number - of layers. 1.0 by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -70,21 +59,19 @@ class PackedFlashLlava1D(nn.Module): def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, - num_kv_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, + num_kv_attention_heads: int = 32, vocab_size: int = 50304, mlp_ratio: int = 4, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, dtype: torch.dtype = torch.float, checkpoint: bool = False, - checkpoint_fraction: float = 1.0, layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -93,12 +80,11 @@ def __init__( no_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -115,39 +101,25 @@ def __init__( ): super().__init__() - self.use_flash_attn = use_flash_attn - if checkpoint_fraction <= 0: - checkpoint = False - if not checkpoint: - checkpoint_fraction = 0 - checkpoint_layer_num = num_layers * checkpoint_fraction - self.tp_mode = "mtp" + checkpoint_layer_num = num_layers * checkpoint + self.dtype = dtype self.image_token_id = image_token_id - - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output if first: - self.tok_embeddings = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) for _, param in self.tok_embeddings.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: uniform_(std=embedding_init_std)(param) - self.embed_grad_scale = embed_grad_scale self.layers = nn.ModuleList( [ - PackedFlashLlamaLayer1D( + Llama2Decoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, num_kv_attention_heads=num_kv_attention_heads, @@ -167,15 +139,13 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - adapt_hf=adapt_hf, + qk_interleaved=qk_interleaved, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, ffn_other_init_std=ffn_other_init_std, init_type=init_type, rope_base=rope_base, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -185,18 +155,16 @@ def __init__( if last: if not apply_post_layer_norm: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) - self.output = head_cls( + self.output = new_linear( + name="output", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) set_output_attr_to_module(self.output) @@ -206,57 +174,42 @@ def __init__( else: uniform_(std=out_head_init_std)(param) - self.parallel_output = parallel_output - assert vit_cfg is not None if first: - from internlm.model.llava_modules.clip_builder import build_vision_tower - + assert vit_cfg is not None self.vit = build_vision_tower(vit_cfg) self.vit.requires_grad_(False) - assert vision_proj_cfg is not None - if first: - from internlm.model.llava_modules.projector_builder import ( - build_vision_projector, - ) - + assert vision_proj_cfg is not None self.vision_proj = build_vision_projector(vision_proj_cfg) # self.vision_proj.requires_grad_(False) - def forward( # pylint: disable=W0102 - self, - hidden_states=None, - images=[], - cu_seqlens=None, - input_ids=None, - indexes=None, - inference_params=None, - ): + def forward(self, hidden_states=None, images=None, input_ids=None, **kwargs): xs = [] pure_text = False - input_ids = input_ids.clone() - assert hasattr(self, "vit") - assert hasattr(self, "vision_proj") - if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update - images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] - pure_text = True - - for image in images: - assert len(image) > 0 - if len(image) == 0: - x = [] - else: - assert not isinstance(image, list), image - x = image.to(torch.cuda.current_device()).to(self.dtype) - x = self.vit(x) - x = self.vision_proj(x) - xs.append(x) + images = [] if images is None else images + + if hasattr(self, "vit") and hasattr(self, "vision_proj") and hasattr(self, "tok_embeddings"): + # vit + if len(images) == 1 and len(images[0]) == 0: # make sure grad in Qformer for update + images = [torch.rand(1, 3, self.vit.image_size, self.vit.image_size).cuda().to(self.dtype)] + pure_text = True + + for image in images: + assert len(image) > 0 + if len(image) == 0: + x = [] + else: + assert not isinstance(image, list), image + x = image.to(torch.cuda.current_device()).to(self.dtype) + x = self.vit(x) + x = self.vision_proj(x) + xs.append(x) - # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "tok_embeddings") and input_ids is not None: + # tok embeddings org_ids = input_ids.clone() input_ids[input_ids == self.image_token_id] = 0 hidden_states = self.tok_embeddings(input_ids).clone() + if pure_text and len(xs) > 0: hidden_states = hidden_states + 0 * xs[0].sum() else: @@ -269,208 +222,14 @@ def forward( # pylint: disable=W0102 hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None for _, block in enumerate(self.layers): - hidden_states = block( - hidden_states, - residual=None, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states = block(hidden_states, residual=None, **kwargs) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "output"): - hidden_states = self.output(hidden_states, gather_dim=1, tp_mode=self.tp_mode) - - if not self.parallel_output: - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) + hidden_states = self.output(hidden_states) return hidden_states - - -def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. torch.device("cuda") by default. - - """ - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - kwargs["checkpoint_fraction"] = float(kwargs.get("checkpoint", False)) - start_idx, end_idx = 0, 0 - for start, end in parts: - start_idx, end_idx = start, end - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashLlava1D(**filter_kwargs(PackedFlashLlava1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - setattr(model, "first_layer", start_idx) - setattr(model, "last_layer", end_idx) - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_cfg( - num_chunks=1, - checkpoint=False, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - num_kv_attention_heads=None, - mlp_ratio=4.0, - residual_in_fp32=False, - norm_type="rmsnorm", - adapt_hf=False, - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - no_bias=False, - deepnorm=False, - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - embedding_init_std: float = 0.02, - attn_wqkv_init_std: float = 0.02, - attn_other_init_std: float = 0.02, - ffn_uplayer_init_std: float = 0.02, - ffn_other_init_std: float = 0.02, - out_head_init_std: float = 0.02, - init_type: str = "normal", - rope_base: int = 10000, - mlp_layer_fusion: bool = False, - multiple_of: int = 256, - image_token_id: int = 200000, - vit_cfg=None, - vision_proj_cfg=None, -): - """ - Builde model with config - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - embedding_init_std (float): std used to init embedding weight. 0.02 by default, - attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, - attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, - ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu - otherwise init fc1 weight in ffn. 0.02 by default, - ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, - out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, - init_type (str): Initialization type. Use uniform or normal. "normal" by default, - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - """ - if deepnorm: - raise AssertionError("deepnorm will not be supported in future versions." "Use early versions if necessary.") - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - num_kv_attention_heads=num_kv_attention_heads if num_kv_attention_heads else num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - apply_post_layer_norm=apply_post_layer_norm, - no_bias=no_bias, - residual_in_fp32=residual_in_fp32, - norm_type=norm_type, - adapt_hf=adapt_hf, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - embedding_init_std=embedding_init_std, - attn_wqkv_init_std=attn_wqkv_init_std, - attn_other_init_std=attn_other_init_std, - ffn_uplayer_init_std=ffn_uplayer_init_std, - ffn_other_init_std=ffn_other_init_std, - out_head_init_std=out_head_init_std, - init_type=init_type, - rope_base=rope_base, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - image_token_id=image_token_id, - vit_cfg=vit_cfg, - vision_proj_cfg=vision_proj_cfg, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index b1e3ed8b..36307453 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -12,30 +12,26 @@ from internlm.core.naive_amp import set_fp32_attr_to_module from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.mlp import get_mlp_cls -from internlm.model.modules.multi_head_attention import MHA -from internlm.model.moe import MoE -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import MHA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.moe.moe import MoE from internlm.model.utils import ( - gather_forward_split_backward, - split_forward_gather_backward, + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, + internlm1_mha_pre_load_convert, + internlm1_mha_save_convert, ) from internlm.solver.activation_checkpoint import activation_checkpoint -from internlm.solver.pipeline_utils import partition_uniform -from internlm.utils.common import filter_kwargs, get_current_device from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER - -MODEL_TYPE = "INTERNLM_MoE" logger = get_logger(__file__) -RMSNorm = try_import_RMSNorm() -class PackedFlashBaseLayer1D(nn.Module): +class Internlm1MoEDecoder(nn.Module): """ - 1D Packed Flash Base Layer. + InternLM1 MoE Decoder Layer. Args: hidden_size (int): The hidden size of model. 768 by default. @@ -43,18 +39,22 @@ class PackedFlashBaseLayer1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0 by default. drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): Type of data. torch.float by default. layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. layer_idx (int): The index of current layer. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. - moe_type (str): determine which moe impl will be used, default is GShardMoE + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout layers only. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( @@ -73,12 +73,11 @@ def __init__( residual_in_fp32: bool = False, device: Optional[torch.device] = None, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, num_experts: int = 1, - tp_mode: str = "mtp", mlp_layer_fusion: bool = False, multiple_of: int = 256, ): @@ -87,16 +86,12 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads - self.tp_mode = tp_mode - parallel_mode = ParallelMode.WEIGHT if self.tp_mode == "isp" else ParallelMode.TENSOR + self.mixer = MHA( embed_dim=hidden_size, num_heads=num_attention_heads, - process_group=gpc.get_group(parallel_mode), - sequence_process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), @@ -105,54 +100,51 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, - tp_mode=self.tp_mode, + qk_interleaved=qk_interleaved, ) + # Compatible with the name of internlm1 Wqkv linear layer + self.mixer.register_checkpoint_compatibility_hooks(internlm1_mha_pre_load_convert, internlm1_mha_save_convert) + self.dropout1 = nn.Dropout(drop_rate) - if norm_type == "rmsnorm": - self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) + self.dropout2 = nn.Dropout(drop_rate) + + self.norm1 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.norm2 = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) self.num_experts = num_experts ep_size = gpc.get_world_size(ParallelMode.EXPERT) if num_experts <= 1: # dense, not MoE - if use_swiglu: - mlp_cls = get_mlp_cls(self.tp_mode) - self.mlp = mlp_cls( - hidden_size, - int(hidden_size * mlp_ratio), - out_features=hidden_size, - process_group=gpc.get_group(parallel_mode), - bias=False, - device=device, - dtype=dtype, - mlp_layer_fusion=mlp_layer_fusion, - sequence_parallel=gpc.config.parallel.sequence_parallel, - multiple_of=multiple_of, - ) + self.mlp = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + # TODO: to support more activation functions + activation_type="swiglu" if use_swiglu else "swiglu", + ) else: # replace mlp by MoE module. The expert in MoE is a FeedForward module. - mlp_cls = get_mlp_cls(self.tp_mode) + # mlp_cls = get_mlp_cls(self.tp_mode) self.mlp = MoE( hidden_size, int(hidden_size * mlp_ratio), out_features=hidden_size, num_experts=num_experts, - ep_cls=mlp_cls, ep_group=gpc.get_group(ParallelMode.EXPERT), ep_size=ep_size, device=device, dtype=dtype, ) + # TODO: remove from model package. set_fp32_attr_to_module(self.mlp.moe_layer.gate) - self.dropout2 = nn.Dropout(drop_rate) self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm @@ -164,7 +156,7 @@ def reset_parameters(self): for name, param in self.mixer.named_parameters(): if param.ndim == 1: param.data.zero_() - elif "Wqkv" in name: + elif "wqkv" in name: normal_(std=0.006)(param.data) elif self.use_scaled_init: scaled_init_method_normal(sigma=0.006, num_layers=self.layer_idx + 1)(param.data) @@ -186,15 +178,16 @@ def reset_parameters(self): else: normal_(std=0.006 if "fc1" in name else 0.0015)(param.data) - def forward(self, hidden_states, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def forward(self, hidden_states, **kwargs): if self.checkpoint and self.training: - return activation_checkpoint( - self._forward, False, hidden_states, cu_seqlens, indexes, inference_params, max_seqlen - ) # TODO: check whether this will be affected by moe + # TODO: check whether this will be affected by moe + # NOTICE: activation_checkpiont do not support kwargs when use_reentrant = True. + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, *args) else: - return self._forward(hidden_states, cu_seqlens, indexes, inference_params, max_seqlen) + return self._forward(hidden_states, **kwargs) - def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_params=None, max_seqlen=None): + def _forward(self, hidden_states, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -203,12 +196,6 @@ def _forward(self, hidden_states=None, cu_seqlens=None, indexes=None, inference_ cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - mixer_kwargs = { - "cu_seqlens": cu_seqlens, - "max_seqlen": max_seqlen, - "indexes": indexes, - "inference_params": inference_params, - } def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -224,6 +211,7 @@ def _dropout_and_norm_attn(_hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) hidden_states = self.mixer(hidden_states, **mixer_kwargs) def _dropout_and_norm_ffn(_residual, _hidden_states): @@ -241,18 +229,18 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): residual = residual.to(torch.float32) # MLP. - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) if self.num_experts <= 1: # dense mlp output hidden_states = self.mlp(hidden_states) + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) else: # MoE output hidden_states, moe_loss, _ = self.mlp(hidden_states) return hidden_states + residual, moe_loss -class PackedFlashInternLm1D(nn.Module): +class Internlm1MoE(nn.Module): """ - 1D Packed Flash InternLm. + InternLM1 MoE. Args: num_layers (int): The number of layer. 12 by default. @@ -262,34 +250,39 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio (int): The ratio of MLP layers. 4 by default. attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + max_position_embeddings (int): The maximum position embeddings. 2048 by default. dtype (torch.dtype): The type of data. torch.float by default. checkpoint (float): The proportion of layers that need to be checkpointed compared to the total number of layers. 0.0 by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + use_dynamic_ntk_rope (bool): Whether to use dynamic ntk rope. False by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. + qk_interleaved (bool): Whether the odd and even columns of the wq and wk are normally interleaved. + dropout_selective_checkpoint (bool): Whether to selectively checkpoint dropout and norm layers. + use_scaled_init (bool): Whether to use scaled initialization for weights. + use_swiglu (bool): Whether to use SwiGLU activation in the mlp module. num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE (https://arxiv.org/abs/2201.05596) layer. moe_type (str): determine which moe impl will be used, default is GShardMoE + mlp_layer_fusion (bool): Whether to fuse layers in the mlp module for optimization. + multiple_of (int): Ensures mlp dimensions are multiples of this value for efficient hardware utilization. """ def __init__( self, - num_layers: int = 12, - hidden_size: int = 768, - num_attention_heads: int = 12, + num_layers: int = 48, + hidden_size: int = 2048, + num_attention_heads: int = 32, vocab_size: int = 50304, - mlp_ratio: int = 4.0, + mlp_ratio: float = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, max_position_embeddings: int = 2048, @@ -298,7 +291,6 @@ def __init__( layer_norm_epsilon: float = 1e-5, first: bool = False, last: bool = False, - embed_split_hidden: bool = False, embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, @@ -306,37 +298,30 @@ def __init__( device: Optional[torch.device] = None, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", + qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, num_experts: bool = 1, + moe_use_residual: bool = False, # pylint: disable=W0613 + moe_type: str = None, # pylint: disable=W0613 mlp_layer_fusion: bool = False, multiple_of: int = 256, ): super().__init__() checkpoint_layer_num = int(num_layers * checkpoint) - self.tp_mode = "mtp" - if isinstance(gpc.config.parallel["tensor"], dict): - self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") - - if is_reward: - head_cls = RewardModelLinear - else: - head_cls = ScaleColumnParallelLinear if first: - self.embedding = Embedding1D( - num_embeddings=vocab_size, embedding_dim=hidden_size, embed_split_hidden=embed_split_hidden - ) + self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embedding.named_parameters(): normal_(std=0.0052)(param) self.embed_grad_scale = embed_grad_scale self.blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + Internlm1MoEDecoder( hidden_size=hidden_size, num_attention_heads=num_attention_heads, mlp_ratio=mlp_ratio, @@ -354,9 +339,8 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, + qk_interleaved=qk_interleaved, num_experts=num_experts, - tp_mode=self.tp_mode, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, ) @@ -364,17 +348,15 @@ def __init__( ] ) if last: - if norm_type == "rmsnorm": - self.norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) - else: - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) - self.head = head_cls( + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon) + self.head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, + is_reward=is_reward, weight_scale=embed_grad_scale, ) for _, param in self.head.named_parameters(): @@ -382,7 +364,7 @@ def __init__( self.parallel_output = parallel_output - def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): + def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 # old condition may fail when use shared embedding if gpc.is_pipeline_first_stage() and input_ids is not None: @@ -391,176 +373,15 @@ def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=N hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() ) - if isinstance(cu_seqlens, list): - assert len(cu_seqlens) == 1 - cu_seqlens = cu_seqlens[0].to(hidden_states.device) - - if cu_seqlens is not None: - cu_seqlens = cu_seqlens.squeeze(0) - - if indexes is not None: - assert len(indexes) == 1 - # The indexes are used to indicate the actual position IDs of each token in the packed input. - indexes = indexes[0] - # if the sequence parallel mode is 'isp', the indexes should also be split in sequence dimension. - if gpc.config.parallel.sequence_parallel and self.tp_mode == "isp": - indexes = split_forward_gather_backward(indexes, ParallelMode.TENSOR, dim=0) - - max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None moe_losses = [] for _, block in enumerate(self.blocks): - hidden_states, mos_loss = block( - hidden_states, - cu_seqlens=cu_seqlens, - indexes=indexes, - inference_params=inference_params, - max_seqlen=max_seqlen, - ) + hidden_states, mos_loss = block(hidden_states, **kwargs) moe_losses.append(mos_loss) if hasattr(self, "norm"): hidden_states = self.norm(hidden_states.float()) if hasattr(self, "head"): - hidden_states = self.head(hidden_states, gather_dim=1, tp_mode=self.tp_mode) + hidden_states = self.head(hidden_states) - if not self.parallel_output and gpc.is_pipeline_last_stage(): - hidden_states = gather_forward_split_backward(hidden_states, ParallelMode.TENSOR, dim=-1) return hidden_states, moe_losses - - -def _build_generic_model_1d(num_layers, num_chunks, **kwargs): - """ - build generic model 1d - - Args: - num_layers (int): The number of layer. - num_chunks (int): The number of partitions in pipeline parallel. - device (Optional[Union[str, torch.device]]): The device will be used. internlm_accelerator.device() by default. - - """ - device = get_current_device() - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - logger.info(f"The layer sharding is {all_parts}.") - - models = [] - - for start, end in parts: - kwargs["num_layers"] = end - start - kwargs["first"] = start == 0 - # If there is no content in the final layer, assign the last layer. - kwargs["last"] = end == num_layers and len(all_parts[-1]) != 0 - kwargs["device"] = device - kwargs["start_layer_idx"] = start - chunk = PackedFlashInternLm1D(**filter_kwargs(PackedFlashInternLm1D.__init__, kwargs)).to(device) - - models.append(chunk) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - return model - - -@MODEL_INITIALIZER.register_module(module_name=MODEL_TYPE) -def build_model_with_moe_cfg( - num_chunks=1, - checkpoint=0.0, - dtype=torch.float, - embed_split_hidden=False, - num_layers=48, - hidden_size=2048, - vocab_size=50304, - embed_grad_scale=1, - parallel_output=True, - num_attention_heads=32, - max_position_embeddings=2048, - mlp_ratio=4.0, - residual_in_fp32=False, - use_dynamic_ntk_rope=False, - norm_type="rmsnorm", - drop_rate=0, - attn_drop_rate=0, - apply_post_layer_norm=False, # pylint: disable=W0613 - layer_norm_epsilon=1e-5, - is_reward=False, - dropout_selective_checkpoint=True, - use_scaled_init: bool = True, - use_swiglu: bool = True, - use_flash_attn: bool = True, - num_experts: int = 1, - moe_use_residual: bool = False, # pylint: disable=W0613 - moe_type: str = None, # pylint: disable=W0613 - mlp_layer_fusion: bool = False, - multiple_of: int = 256, -): - """ - Build model with config. - - Args: - num_chunks (int): The number of partitions in pipeline parallel. 1 by default. - checkpoint (bool): Whether to use checkpointing to save VRAM. False by default. - dtype (torch.dtype): The type of data. torch.float by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - False by default. - num_layers (int): The number of layer. 48 by default. - hidden_size (int): The size of hidden state. 2048 by default. - vocab_size (int): The size of vocabulary. 50304 by default. - embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. - parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. - num_attention_heads (int): The number of attention head. 32 by default. - mlp_ratio (int): The ratio of MLP layers. 4.0 by default. - residual_in_fp32 (bool): Whether to use residual in fp32. False by default. It cannot be used temporarily - because this parameter requires inconsistent data types to be passed between pipelines, - which requires significant modifications to internlm. - norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - drop_rate (float): The dropout rate of input hidden state. 0 by default. - attn_drop_rate (float): The dropout rate of attention module. 0 by default. - apply_post_layer_norm (bool): Whether to apply post layer norm. False by default. - layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. - is_reward (bool): Whether to use reward model. False by default. - dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. - use_scaled_init (bool): Whether to use scaled init. True by default. - use_swiglu (bool): Whether to use swiglu. True by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - num_experts (int): The number of experts. <=1 means dense, >1 means MoE. 1 by default. - moe_use_residual (bool, optional): default=False, make this MoE layer a Residual MoE - (https://arxiv.org/abs/2201.05596) layer. - moe_type (str): determine which moe impl will be used, default is GShardMoE - """ - - cfg = dict( - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - checkpoint=checkpoint, - dtype=dtype, - embed_split_hidden=embed_split_hidden, - vocab_size=vocab_size, - embed_grad_scale=embed_grad_scale, - parallel_output=parallel_output, - mlp_ratio=mlp_ratio, - residual_in_fp32=residual_in_fp32, - max_position_embeddings=max_position_embeddings, - use_dynamic_ntk_rope=use_dynamic_ntk_rope, - norm_type=norm_type, - drop_rate=drop_rate, - attn_drop_rate=attn_drop_rate, - layer_norm_epsilon=layer_norm_epsilon, - is_reward=is_reward, - dropout_selective_checkpoint=dropout_selective_checkpoint, - use_scaled_init=use_scaled_init, - use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, - num_experts=num_experts, - mlp_layer_fusion=mlp_layer_fusion, - multiple_of=multiple_of, - ) - - return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 2dfea2e8..fa922daa 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -1,23 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Tuple +from typing import Optional, Union import torch import torch.nn.functional as F from einops import rearrange from torch import Tensor, nn -from internlm.accelerator import get_accelerator -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_fused_rotary - -from ..utils import gather_forward_split_backward, split_forward_gather_backward - -internlm_accelerator = get_accelerator() - -apply_rotary_emb, apply_rotary_emb_qkv_, apply_rotary_func = None, None, None +from internlm.model.ops.rotary_emb import apply_rotary_emb class Embedding1D(nn.Module): @@ -31,8 +23,6 @@ class Embedding1D(nn.Module): therefore, the embedding vector at :attr:`padding_idx` is not updated during training, i.e. it remains as a fixed "pad". None by default. dtype (Optional[torch.dtype]): Data type None by default. - embed_split_hidden (Optional[Bool]): Whether to split the embed_dim in tensor parallel style. - """ def __init__( @@ -42,220 +32,21 @@ def __init__( *args, padding_idx: int = None, dtype: torch.dtype = None, - embed_split_hidden: bool = True, **kwargs, ): super().__init__() self.num_embeddings = num_embeddings self.embed_dim = embedding_dim - self.embed_split_hidden = embed_split_hidden - if self.embed_split_hidden: - self.embed_split_hidden = gpc.tensor_parallel_size > 1 - - split_nums = 1 if not self.embed_split_hidden else gpc.tensor_parallel_size - embed_dim_per_partition = embedding_dim // split_nums - self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs + embed_dim_per_partition = embedding_dim // gpc.tensor_parallel_size self.weight = nn.Parameter(torch.empty((num_embeddings, embed_dim_per_partition), dtype=dtype)) def forward(self, input_: Tensor) -> Tensor: - output = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - - if self.embed_split_hidden: - output = gather_forward_split_backward(output, ParallelMode.TENSOR, dim=-1) - - if gpc.config.parallel.sequence_parallel: - output = split_forward_gather_backward(output, ParallelMode.TENSOR, dim=1) - - return output - - -def _torch_apply_rotary_func( - x1: torch.Tensor, - x2: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - out1: torch.Tensor, - out2: torch.Tensor, - conj: bool = False, -): - assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" - assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" - assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" - assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" - - x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() - - if conj: - out1.copy_(x1 * cos + x2 * sin) - out2.copy_(-x1 * sin + x2 * cos) - else: - out1.copy_(x1 * cos - x2 * sin) - out2.copy_(x1 * sin + x2 * cos) - - return out1, out2 - - -class ApplyRotaryEmb(torch.autograd.Function): - """ - ApplyRotaryEmb - """ - - @staticmethod - def forward(ctx, x, cos, sin, interleaved=False): - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead - of 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding to the first rotary_dim of x. - """ - _, seqlen, _, headdim = x.shape - rotary_seqlen, rotary_dim = cos.shape - rotary_dim *= 2 - assert rotary_dim <= headdim - assert seqlen <= rotary_seqlen - assert sin.shape == (rotary_seqlen, rotary_dim // 2) - x_ro = x[..., :rotary_dim] - x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2]) - out = torch.empty_like(x) - out_ro = out[..., :rotary_dim] - o1, o2 = out_ro.chunk(2, dim=-1) if not interleaved else (out_ro[..., ::2], out_ro[..., 1::2]) - - apply_rotary_func( - x1, - x2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - o1, - o2, - False, - ) - - if rotary_dim < headdim: - out[..., rotary_dim:].copy_(x[..., rotary_dim:]) - ctx.save_for_backward(cos, sin) - ctx.interleaved = interleaved - return out - - @staticmethod - def backward(ctx, do): - cos, sin = ctx.saved_tensors - _, seqlen, _, headdim = do.shape - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - do_ro = do[..., :rotary_dim] - do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2]) - dx = torch.empty_like(do) - dx_ro = dx[..., :rotary_dim] - dx1, dx2 = dx_ro.chunk(2, dim=-1) if not ctx.interleaved else (dx_ro[..., ::2], dx_ro[..., 1::2]) - - apply_rotary_func( - do1, - do2, - rearrange(cos[:seqlen], "s d -> s 1 d"), - rearrange(sin[:seqlen], "s d -> s 1 d"), - dx1, - dx2, - True, - ) - if rotary_dim < headdim: - dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) - return dx, None, None, None, None - - -class ApplyRotaryEmbQKV_(torch.autograd.Function): - """ - ApplyRotaryEmbQKV_ - """ - - @staticmethod - def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False): - """ - qkv: (total, 3, nheads, headdim) / (batch_size, seqlen, 3, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) - cos_k, sin_k: (seqlen, rotary_dim / 2), optional - interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of - 1st half and 2nd half (GPT-NeoX style). - rotary_dim must be <= headdim - Apply rotary embedding *inplace* to the first rotary_dim of q and k. - """ - # len(qkv.shape) == 4 means the format of qkv is (total, 3, nheads, headdim) which is packed, - # otherwise the format of qkv is (batch_size, seqlen, 3, nheads, headdim) which is unpacked. - # We handle both packed qkv and unpacked qkv scenario in this class. - three = qkv.shape[1] if len(qkv.shape) == 4 else qkv.shape[2] - assert three == 3 - seqlen = None if len(qkv.shape) == 4 else qkv.shape[1] - rotary_seqlen, rotary_dim = cos.shape - if len(qkv.shape) != 4: - assert seqlen <= rotary_seqlen - headdim = qkv.shape[-1] - rotary_dim *= 2 - assert rotary_dim <= headdim - cos_k = cos if cos_k is None else cos_k - sin_k = sin if sin_k is None else sin_k - assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2) - q_ro = qkv[:, 0, :, :rotary_dim] if len(qkv.shape) == 4 else qkv[:, :, 0, :, :rotary_dim] - q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2]) - re_cos = rearrange(cos, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") - - apply_rotary_func(q1, q2, re_cos, re_sin, q1, q2, False) - - k_ro = qkv[:, 1, :, :rotary_dim] if len(qkv.shape) == 4 else qkv[:, :, 1, :, :rotary_dim] - k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2]) - re_cos_k = ( - rearrange(cos_k, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(cos_k[:seqlen], "s d -> s 1 d") - ) - re_sin_k = ( - rearrange(sin_k, "s d -> s 1 d") if len(qkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") - ) - - apply_rotary_func(k1, k2, re_cos_k, re_sin_k, k1, k2, False) - - ctx.save_for_backward(cos, sin, cos_k, sin_k) - ctx.interleaved = interleaved - return qkv - - @staticmethod - def backward(ctx, dqkv): - cos, sin, cos_k, sin_k = ctx.saved_tensors - seqlen = None if len(dqkv.shape) == 4 else dqkv.shape[1] - rotary_dim = cos.shape[-1] - rotary_dim *= 2 - dq_ro = dqkv[:, 0, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 0, :, :rotary_dim] - dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2]) - re_cos = rearrange(cos, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(cos[:seqlen], "s d -> s 1 d") - re_sin = rearrange(sin, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(sin[:seqlen], "s d -> s 1 d") - - apply_rotary_func(dq1, dq2, re_cos, re_sin, dq1, dq2, True) - - dk_ro = dqkv[:, 1, :, :rotary_dim] if len(dqkv.shape) == 4 else dqkv[:, :, 1, :, :rotary_dim] - dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2]) - re_cos_k = ( - rearrange(cos_k, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(cos_k[:seqlen], "s d -> s 1 d") - ) - re_sin_k = ( - rearrange(sin_k, "s d -> s 1 d") if len(dqkv.shape) == 4 else rearrange(sin_k[:seqlen], "s d -> s 1 d") - ) - - apply_rotary_func(dk1, dk2, re_cos_k, re_sin_k, dk1, dk2, True) - - return dqkv, None, None, None, None, None - - -apply_rotary_emb, apply_rotary_emb_qkv_, apply_rotary_func = try_import_fused_rotary() -if apply_rotary_emb is None: - apply_rotary_emb = ApplyRotaryEmb.apply -if apply_rotary_emb_qkv_ is None: - apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply -if apply_rotary_func is None: - apply_rotary_func = _torch_apply_rotary_func + return F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) class RotaryEmbedding(torch.nn.Module): @@ -296,12 +87,19 @@ def __init__(self, dim: int, base=10000, scale_base=0, device=None): self._cos_k_cached = None self._sin_k_cached = None - def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" - if not isinstance(indexes, int): - seqlen = indexes.max().item() + 1 + def _update_cos_sin_cache( + self, x: torch.Tensor, indexes: Union[int, torch.Tensor] = 0, max_seqlen: Optional[int] = None + ): + """x: (batch, seqlen, nheads, headdim)""" + if max_seqlen is not None: + seqlen = max_seqlen + elif isinstance(indexes, int): + seqlen = indexes + x.shape[1] + 1 else: - seqlen = indexes + 1 # eval_forward + # Note that this statement may cause synchronization between CPU and GPU, + # so it's best to precompute and pass in max_seqlen ahead of time + seqlen = indexes.max().item() + 1 + # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) if seqlen > self._seq_len_cached or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: @@ -324,54 +122,78 @@ def _update_cos_sin_cache(self, x, indexes): self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) - def forward(self, qkv: torch.Tensor, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._forward(qkv, kwargs.pop("indexes")) - if kwargs.get("inference_params", None) is not None: - return self._eval_forward(qkv, seqlen_offset=kwargs.get("inference_params", None).sequence_len_offset) + def _get_slice(self, tensor: torch.Tensor, offsets: Union[int, torch.Tensor] = 0): + if isinstance(offsets, int): + return tensor[offsets:] else: - return self._eval_forward(qkv) + return tensor[offsets] - def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Tensor]: - self._update_cos_sin_cache(qkv, indexes) - if self.scale is None: - return apply_rotary_emb_qkv_(qkv, self._cos_cached[indexes], self._sin_cached[indexes]) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached[indexes], - self._sin_cached[indexes], - self._cos_k_cached[indexes], - self._sin_k_cached[indexes], - ) - - def _eval_forward(self, qkv, seqlen_offset=0): + def _convert_padding( + self, x: torch.Tensor, empties: torch.Tensor, convert_type: str = "left2right", in_place: bool = False + ): + # TODO: impl in_place = True. + assert not in_place, "in_place = True is NYI." + assert convert_type in ("left2right", "right2left"), f"Unknown convert type {convert_type}" + + ret = x.clone() + + for i in range(len(empties)): + if empties[i] == 0: + continue + + if convert_type == "left2right": + ret[i][: -empties[i]] = x[i][empties[i] :] + ret[i][empties[i] :] = x[i][: -empties[i]] + else: # right2left + ret[i][empties[i] :] = x[i][: -empties[i]] + ret[i][: -empties[i]] = x[i][empties[i] :] + + return ret + + def forward( + self, + x: torch.Tensor, + offsets: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + cache_type: str = "query", + interleaved: bool = False, + in_place: bool = False, + left_padding_mask: Optional[torch.Tensor] = None, + ): """ - seqlen_offset: can be used in generation where the qkv being passed in is only the last - token in the batch. + Applies rotary position embeddings to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + offsets (Union[int, torch.Tensor], optional): The sequence offsets for the input. Defaults to 0. + max_seqlen (Optional[int], optional): The maximum sequence length for caching. Defaults to None. + cache_type (str, optional): Specifies whether the cache is for 'query' or 'key'. Defaults to "query". + interleaved (bool, optional): Whether the input tensor is interleaved. Defaults to False. + in_place (bool, optional): Whether the operation should be done in-place. Defaults to False. + left_padding_mask (Optional[torch.Tensor], optional): A mask for left padding. Defaults to None. + + Returns: + torch.Tensor: The tensor with applied rotary position embeddings. """ - self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1]) - if self.scale is None: - return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) - else: - return apply_rotary_emb_qkv_( - qkv, - self._cos_cached[seqlen_offset:], - self._sin_cached[seqlen_offset:], - self._cos_k_cached[seqlen_offset:], - self._sin_k_cached[seqlen_offset:], - ) - - def _single_forward(self, x, indexes=0): - assert self.scale is None - self._update_cos_sin_cache(x, indexes) - ret = apply_rotary_emb(x, self._cos_cached[indexes], self._sin_cached[indexes]) - return ret + assert cache_type in ("query", "key"), f"Unknown cache type {cache_type}" + assert isinstance(offsets, (int, torch.Tensor)), f"Invalid offsets type {type(offsets)}" - def _single_eval_forward(self, x, seqlen_offset=0): - assert self.scale is None - self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) - return apply_rotary_emb(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) + if left_padding_mask is not None: + empties = left_padding_mask[..., -1].sum(dim=-1) + x = self._convert_padding(x, empties, convert_type="left2right", in_place=in_place) + + self._update_cos_sin_cache(x, offsets, max_seqlen) + + cos_cached = self._cos_k_cached if cache_type == "key" and self.scale is not None else self._cos_cached + sin_cached = self._sin_k_cached if cache_type == "key" and self.scale is not None else self._sin_cached + ret = apply_rotary_emb( + x, self._get_slice(cos_cached, offsets), self._get_slice(sin_cached, offsets), interleaved, in_place + ) + + if left_padding_mask is not None: + ret = self._convert_padding(ret, empties, convert_type="right2left", in_place=in_place) + + return ret class LinearRotaryEmbedding(RotaryEmbedding): @@ -390,11 +212,11 @@ def __init__( self.scaling_factor = scaling_factor def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" + """x: (batch, seqlen, nheads, headdim)""" if not isinstance(indexes, int): seqlen = indexes.max().item() + 1 else: - seqlen = indexes + 1 + seqlen = indexes + x.shape[1] + 1 t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) t = t / self.scaling_factor @@ -457,11 +279,11 @@ def _update(self, seqlen, x): self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) def _update_cos_sin_cache(self, x, indexes): - """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" + """x: (batch, seqlen, nheads, headdim)""" if not isinstance(indexes, int): seqlen = indexes.max().item() + 1 else: - seqlen = indexes + 1 # eval_forward + seqlen = indexes + x.shape[1] + 1 # eval_forward if seqlen <= self.max_position_embeddings: # Reset the tables if the sequence length has changed, # or if we're on a new device (possibly due to tracing for instance) @@ -474,3 +296,22 @@ def _update_cos_sin_cache(self, x, indexes): self._update(seqlen, x) else: self._update(seqlen, x) + + +def new_rotary_embedding( + dim: int, + base=10000, + scale_base=0, + device=None, + max_position_embeddings=2048, + scaling_factor=1.0, + rotary_type: str = "native", +) -> RotaryEmbedding: + assert rotary_type in ("native", "linear_scale", "dynamic_ntk"), f"Unknown rotary type {rotary_type}" + + if rotary_type == "linear_scale": + return LinearRotaryEmbedding(dim, base, scale_base, device, max_position_embeddings, scaling_factor) + elif rotary_type == "dynamic_ntk": + return DynamicNTKScalingRotaryEmbedding(dim, base, scale_base, device, max_position_embeddings, scaling_factor) + else: # native + return RotaryEmbedding(dim, base, scale_base, device) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py new file mode 100644 index 00000000..0d8c4bf8 --- /dev/null +++ b/internlm/model/modules/linear.py @@ -0,0 +1,606 @@ +""" +Linear Modules +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional, Union + +import torch +import torch.distributed as dist +from torch import nn + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.parallel.shard import ( + get_head_parallel_mode, + get_parallel_strategies_split_mode, + get_tensor_split_parallel_mode, +) +from internlm.model.ops.linear import linear_backward_op, linear_forward_op +from internlm.utils.logger import get_logger + +if TYPE_CHECKING: + from internlm.core.parallel.comm.isp import WPCommunicator + from internlm.core.parallel.comm.tensor import TPCommunicator + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +custom_bwd = internlm_accelerator.return_custom_bwd() +custom_fwd = internlm_accelerator.return_custom_fwd() + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py +class SPFusedDenseFunc(torch.autograd.Function): + "FusedDenseFunc for tensor parallel in flash-attn implementation." + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + communicator: TPCommunicator, + return_residual=False, + ): + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.communicator = communicator + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + # parallel strategy-specific communication callback 1-1. + # see more details in the communicator for different parallel strategies. + # we want to kick off the all_gather early, before weight dtype conversion. + total_x, handle_x = communicator.input_hook(x, async_op=True) + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + + # wait for x has been gathered. + handle_x.wait() + + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + + output = linear_forward_op(total_x, weight, bias) + + # parallel strategy-specific communication callback 2. + # see more details in the communicator for different parallel strategies. + output, _ = communicator.output_hook(output, async_op=False) + + saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x + ctx.save_for_backward(saved_x, weight) + + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + communicator: TPCommunicator = ctx.communicator + + # parallel strategy-specific communication callback 3. + # see more details in the communicator for different parallel strategies. + grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False) + grad_output = grad_output.contiguous() + + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + + x, weight = ctx.saved_tensors + + # parallel strategy-specific communication callback 1-2. + # see more details in the communicator for different parallel strategies. + if ctx.needs_input_grad[1]: + x, handle_x = communicator.input_hook(x, async_op=True, is_forward=False) + + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = linear_forward_op(grad_output, weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, + weight, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + # parallel strategy-specific communication callback 4. + # see more details in the communicator for different parallel strategies. + grad_input, handle_grad_input = communicator.grad_input_hook(grad_input, async_op=True) + else: + grad_input = None + + # computes gradinets for weight and bias if necessary + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + + # wait for x has been gathered + handle_x.wait() + + x = x.reshape(batch_dim, x.shape[-1]) + grad_weight, grad_bias = linear_backward_op(x, grad_output, ctx.needs_input_grad[2]) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + + # wait for grad_input has been gathered + handle_grad_input.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None, None + + +# Q: Should we unify WPFusedDenseFunc and SPFusedDenseFunc, as well as the related communicator interface? +# A: Currently, WPFusedDenseFunc and SPFusedDenseFunc have significant differences in their computation logic +# and communication interfaces, so they should not be unified. +class WPFusedDenseFunc(torch.autograd.Function): + "FusedDenseFunc for weigth parallel, which is optimized based on flash implementation." + + @staticmethod + @custom_fwd + def forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + module: nn.Module, + communicator: WPCommunicator, + return_residual=False, + ): + ctx.compute_weight_gradient = weight.requires_grad + ctx.return_residual = return_residual + ctx.module = module + ctx.communicator = communicator + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + + total_weight = communicator.weight_hook(weight, module=module) + total_bias = bias if bias is None else communicator.weight_hook(bias, module=module, is_bias=True) + + if torch.is_autocast_enabled(): + total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) + if total_bias: + total_bias.to(dtype=torch.get_autocast_gpu_dtype()) + + total_weight = total_weight.contiguous() + batch_shape, n = x.shape[:-1], x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + if min(batch_dim, n, *total_weight.shape) > 65535 * 32: + raise RuntimeError("fused_dense only supports matrix dims <= 2M") + + output = linear_forward_op(x, total_weight, total_bias) + + # release memory + del total_weight + del total_bias + + saved_x = None if ctx.compute_weight_gradient is False else x + ctx.save_for_backward(saved_x, weight) + + return output if not return_residual else (output, x) + + @staticmethod + @custom_bwd + def backward(ctx, grad_output, *args): + module: nn.Module = ctx.module + communicator: WPCommunicator = ctx.communicator + x, weight = ctx.saved_tensors + + grad_output = grad_output.contiguous() + if ctx.return_residual: + (grad_input,) = args + grad_input = grad_input.contiguous() + + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + + total_weight = communicator.weight_hook(weight, module=module) + + # compute weight grad + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + grad_weight, grad_bias = linear_backward_op( + x.reshape(batch_dim, x.shape[-1]), + grad_output, + ctx.needs_input_grad[2], + ) + + grad_weight, grad_weight_sync = communicator.grad_hook( + grad_weight, async_op=True, module=module, is_bias=False + ) + if grad_bias is not None: + grad_bias, grad_bias_sync = communicator.grad_hook( + grad_bias, async_op=True, module=module, is_bias=True + ) + else: + grad_weight = None + grad_bias = grad_output if ctx.needs_input_grad[2] else None + + if ctx.needs_input_grad[0]: + if not ctx.return_residual: + grad_input = linear_forward_op(grad_output, total_weight.t()) + else: + grad_input = torch.addmm( + grad_input.reshape(batch_dim, grad_input.shape[-1]), + grad_output, + total_weight, + ) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + else: + grad_input = None + + del total_weight + + if ctx.needs_input_grad[1]: + grad_weight_sync.wait() + if grad_bias is not None: + grad_bias_sync.wait() + + return grad_input, grad_weight, grad_bias, None, None, None, None + + +def fused_dense_func( + x: torch.Tensor, + weight: torch.Tensor, + communicator: Union[TPCommunicator, WPCommunicator], + module: Optional[nn.Module] = None, + bias: Optional[torch.Tensor] = None, + return_residual: bool = False, +): + if communicator.communication_mode() == "wp": + return WPFusedDenseFunc.apply( + x, + weight, + bias, + module, + communicator, + return_residual, + ) + else: # mtp, msp, and fsp + return SPFusedDenseFunc.apply( + x, + weight, + bias, + communicator, + return_residual, + ) + + +class ParallelLinearWithCommExt(nn.Linear): + """ + Parallel linear with commuication extention. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + split_mode (str): The split mode. It can be "none", "column", or "row". + """ + + # class level communicator variable. + _communicator = None + + @classmethod + def register_cls_communicator(cls, communicator): + cls._communicator = communicator + + def register_communicator(self, communicator): + """ + override the class default communicator for a parallel linear instance + """ + self._communicator = communicator + + def __init__( + self, + in_features: int, + out_features: int, + parallel_mode: ParallelMode, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + split_mode: str = "none", + ) -> None: + assert split_mode in ("none", "column", "row"), f"unknown split_mode {split_mode}" + + world_size = gpc.get_world_size(parallel_mode) + rank = gpc.get_local_rank(parallel_mode) + + if split_mode != "none": + split_features = out_features if split_mode == "column" else in_features + multiple = split_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(rank < mod) + + if split_mode == "column": + super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) + elif split_mode == "row": + super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype) + else: + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + + def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 + _class_name = self.__class__.__name__ + assert self._communicator is not None, f"{_class_name} should register with a communicator first." + + return fused_dense_func( + input, + self.weight, + communicator=self._communicator, + module=self, + bias=self.bias, + ) + + +class ColumnParallelLinear(ParallelLinearWithCommExt): + """ + ColumnParallelLinear + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + ) -> None: + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + + parallel_mode = get_tensor_split_parallel_mode() + super().__init__( + in_features, out_features, parallel_mode, bias=bias, device=device, dtype=dtype, split_mode="column" + ) + + +class RowParallelLinear(ParallelLinearWithCommExt): + """ + RowParallelLinear + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of: int = 1, + device: torch.device = None, + dtype: torch.dtype = None, + ) -> None: + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + + parallel_mode = get_tensor_split_parallel_mode() + rank = gpc.get_local_rank(parallel_mode) + super().__init__( + in_features, + out_features, + parallel_mode, + bias=bias and rank == 0, + device=device, + dtype=dtype, + split_mode="row", + ) + + +class ScaleColumnParallelLinear(ParallelLinearWithCommExt): + """ + ScaleColumnParallelLinear. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by + the norm of embedding. The implementation is referred to baichuan2, + see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_scale: int = 1, + norm_head: bool = False, + ) -> None: + if norm_head: + logger.info("Notice that norm head is enabled to normalize head weight.") + + parallel_mode = get_tensor_split_parallel_mode(is_head=True) + super().__init__( + in_features, out_features, parallel_mode, bias=bias, device=device, dtype=dtype, split_mode="column" + ) + + self.weight_scale = weight_scale + self.norm_head = norm_head + self.first_eval_flag = True + self.tmp_weight = None + + def forward(self, input): # pylint: disable=W0622 + _class_name = self.__class__.__name__ + assert self._communicator is not None, f"{_class_name} should register with a communicator first." + + if self.weight_scale == 1: + weight = self.weight + else: + weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() + + if self.norm_head: + if self.training: + if not self.first_eval_flag: + self.first_eval_flag = True + self.tmp_weight = None + # We normalized the output Embedding so that the dot product + # is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf + weight = nn.functional.normalize(weight) + else: + if self.first_eval_flag: + # cache l2 norm of head to accelerate infer. + self.first_eval_flag = False + self.tmp_weight = nn.functional.normalize(weight) + + weight = self.tmp_weight + + return fused_dense_func( + input, + weight, + communicator=self._communicator, + module=self, + bias=self.bias, + ) + + +class RewardModelLinear(ScaleColumnParallelLinear): + """ + RewardModelLinear. + + Args: + in_features (int): size of each input sample + out_features (int): size of each output sample + bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False + in the config. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + weight_scale (int): For training stability. 1 by default. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + weight_scale: int = 1, + ) -> None: + super().__init__(in_features, out_features, bias, device, dtype, weight_scale) + + # broadcast parameters for reward model head layer. + parallel_mode = get_head_parallel_mode() + process_group = gpc.get_group(parallel_mode) + dist.broadcast(self.weight, gpc.get_ranks_in_group(parallel_mode)[0], process_group) + if bias: + dist.broadcast(self.bias, gpc.get_ranks_in_group(parallel_mode)[0], process_group) + + +def new_linear( + name: str, + in_features: int, + out_features: int, + bias: bool = True, + multiple_of=1, + device=None, + dtype=None, + is_reward: bool = False, + weight_scale: int = 1, + norm_head: bool = False, + **kwargs, +) -> nn.Linear: + + name = str.lower(name) + manual_select_class: Optional[str] = kwargs.get("manual_select_class", None) + + if manual_select_class is not None: + assert manual_select_class in ( + "head", + "column", + "row", + ), f"unknown manual selection {manual_select_class} for creating a linear." + + # use caller manual selection if it is provided. + split_mode = manual_select_class if manual_select_class is not None else get_parallel_strategies_split_mode(name) + + if split_mode == "head": + if is_reward: + return RewardModelLinear( + in_features, + out_features, + bias, + device, + dtype, + weight_scale, + ) + else: + return ScaleColumnParallelLinear( + in_features, + out_features, + bias, + device, + dtype, + weight_scale=weight_scale, + norm_head=norm_head, + ) + elif split_mode == "column": + return ColumnParallelLinear( + in_features, + out_features, + bias, + multiple_of, + device, + dtype, + ) + elif split_mode == "row": + return RowParallelLinear( + in_features, + out_features, + bias, + multiple_of, + device, + dtype, + ) + else: + err_msg = ( + f"Parallel strategies for linear is unsupported, which is named as {name}.\n" + + "Consider use manual_select_class parameter to select a linear class manually." + ) + + raise ValueError(err_msg) diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py new file mode 100644 index 00000000..3c08fb0a --- /dev/null +++ b/internlm/model/modules/mha.py @@ -0,0 +1,604 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import math +from typing import Callable, Dict, Optional + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F + +from internlm.model.modules.embedding import new_rotary_embedding +from internlm.model.modules.linear import new_linear +from internlm.model.modules.utils import update_kv_cache +from internlm.model.ops.attention import CrossAttention, SelfAttention +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def _convert_cu_seqlens_for_qksplited(kwargs: Dict): + cu_seqlens = kwargs.pop("cu_seqlens", None) + max_seqlen = kwargs.pop("max_seqlen", None) + + if cu_seqlens is not None: + kwargs["cu_seqlens_q"] = cu_seqlens + kwargs["cu_seqlens_k"] = cu_seqlens + kwargs["max_seqlen_q"] = max_seqlen + kwargs["max_seqlen_k"] = max_seqlen + + return kwargs + + +class MHA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + max_position_embeddings (int): max position embeddings, 2048 by default. + bias (bool): Whether the bias is needed for linears. True by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. + enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + max_position_embeddings: int = 2048, + bias: bool = True, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + rope_base: int = 10000, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + qk_interleaved: Optional[bool] = True, + enable_qkv_fusion: bool = True, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.causal = causal + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // num_heads + self.enable_qkv_fusion = enable_qkv_fusion + + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.max_position_embeddings = max_position_embeddings + self.interleaved = qk_interleaved + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + + if self.rotary_emb_dim > 0: + self.rotary_emb = new_rotary_embedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, + rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", + ) + + if self.enable_qkv_fusion: + # bias=True is according to https://spaces.ac.cn/archives/9577 + self.wqkv = new_linear("wqkv", embed_dim, 3 * embed_dim, bias, **factory_kwargs) + else: + self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) + self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + # output projection always have the bias (for now) + self.out_proj = new_linear("out_proj", embed_dim, embed_dim, bias=True, **factory_kwargs) + + def register_checkpoint_compatibility_hooks( + self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None + ): + # Here we explicitly expose the checkpoint compatibility interface of the module, + # hoping that model developers will make good use of it when adapting. + # Is this interface already meeting all reasonable requirements? + self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) + self._register_state_dict_hook(pre_save_hook) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + """ + # wqkv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + q = qkv[:, :, 0].squeeze(2) + k = qkv[:, :, 1].squeeze(2) + v = qkv[:, :, 2].squeeze(2) + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + # rotary embedding + indexes = kwargs.pop("indexes", 0) + max_seqlen = kwargs.get("max_seqlen", None) + q = self.rotary_emb( + q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True + ) + k = self.rotary_emb( + k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen, in_place=True + ) + + # self attention + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + context = self.inner_attn(q, k, v, **kwargs) + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset + batch_size = x.shape[0] + + # wqkv, output: q, kv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + + q = qkv[:, :, 0].squeeze(2) + kv = qkv[:, :, 1:] + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + kv = torch.stack([k, v], dim=2) + + # rotary embedding, output: q, kv + # q shape: [bsz, nheads, head_dim] + # kv shape: [bsz, seqlen, 2, nheads, head_dim] + if self.use_dynamic_ntk_rope: + # update kv cache fisrt when enable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + if sequence_len_offset != 0: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + if self.rotary_emb_dim > 0: + q = self.rotary_emb( + q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved + ) + k = kv[:, :, 0].squeeze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + if self.rotary_emb_dim > 0: + q = self.rotary_emb(q, offsets=0, cache_type="query", interleaved=self.interleaved) + k = kv[:, :, 0].squeeze(2) + self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, in_place=True + ) # in-place is important + else: + assert self.rotary_emb_dim > 0, "You should use rotary_emb." + + k, v = kv[:, :, 0].squeeze(2), kv[:, :, 1].squeeze(2) + + if attention_mask is None: + q = self.rotary_emb(q, offsets=sequence_len_offset, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=sequence_len_offset, cache_type="key", interleaved=self.interleaved) + else: + if sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + if sequence_len_offset > self.max_position_embeddings: + logger.warning( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + + kv = torch.stack([k, v], dim=2) + # update kv cache after rotary embedding when disable dynamic ntk rope. + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + # self-attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1) + + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output) + else: + attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")) + + +class GQA(nn.Module): + """ + Multi-head self-attention and cross-attention. + + Args: + embed_dim (int): The dimention of hidden state. + num_heads (int): The number of attention heads. + num_kv_heads (int): The number of attention heads for key and value. + max_position_embeddings (int): max position embeddings, 2048 by default. + bias (bool): Whether the bias is needed for linears. Will be used when initializing QKV matrix and + output projection. False by default. + dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. + softmax_scale (float): The temperature to use for the softmax attention. + causal (boolean): Whether to apply causal attention mask. False by default. + layer_idx (int): The index of current layer. None by default. + use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. + rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements + XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. + dtype (Optional[torch.dtype]): The type of data. + qk_interleaved (Optional[bool]): whether the odd and even columns of wq and wk is interleaved. True by default. + enable_qkv_fusion (bool): whether wq, wk and wv lienar is fused. True by default. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + num_kv_heads: int, + max_position_embeddings: int = 2048, + bias: bool = False, + dropout: float = 0.0, + softmax_scale: float = None, + causal: bool = False, + layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, + rope_base: int = 10000, + rotary_emb_dim: int = 0, + rotary_emb_scale_base: int = 0, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + qk_interleaved: Optional[bool] = True, + enable_qkv_fusion: bool = True, + ) -> None: + super().__init__() + self.layer_idx = layer_idx + self.causal = causal + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.q_per_kv = num_heads // num_kv_heads + self.head_dim = self.embed_dim // num_heads + self.kv_dim = self.head_dim * num_kv_heads + self.enable_qkv_fusion = enable_qkv_fusion + + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope + self.rotary_emb_dim = rotary_emb_dim + self.max_position_embeddings = max_position_embeddings + self.interleaved = qk_interleaved + + factory_kwargs = {"device": device, "dtype": dtype} + + assert self.use_dynamic_ntk_rope is False, "Not support dynamic ntk rope yet." + assert self.embed_dim % num_heads == 0, "embedding dim must be divisible by num_heads" + + if self.rotary_emb_dim > 0: + self.rotary_emb = new_rotary_embedding( + self.rotary_emb_dim, + base=rope_base, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, + rotary_type="dynamic_ntk" if self.use_dynamic_ntk_rope else "native", + ) + + if enable_qkv_fusion: + self.wqkv = new_linear("wqkv", embed_dim, embed_dim + 2 * self.kv_dim, bias, **factory_kwargs) + else: + self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) + self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + + self.inner_attn = SelfAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + self.inner_cross_attn = CrossAttention(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) + + self.wo = new_linear("wo", embed_dim, embed_dim, bias, **factory_kwargs) + + def register_checkpoint_compatibility_hooks( + self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None + ): + # Here we explicitly expose the checkpoint compatibility interface of the module, + # hoping that model developers will make good use of it when adapting. + # Is this interface already meeting all reasonable requirements? + self._register_load_state_dict_pre_hook(pre_load_hook, with_module=True) + self._register_state_dict_hook(pre_save_hook) + + def forward(self, x, inference_params=None, **kwargs): + if inference_params is None: + return self._training(x=x, **kwargs) + else: + return self._inference(x=x, inference_params=inference_params, **kwargs) + + def _training(self, x, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) + """ + # wqkv + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) + q = rearrange(q, "b s h gs d -> b s (h gs) d") + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + kwargs = _convert_cu_seqlens_for_qksplited(kwargs) + + # rotary embedding + if self.rotary_emb_dim > 0: + indexes = kwargs.pop("indexes", 0) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + max_seqlen_k = kwargs.get("max_seqlen_k", None) + + q = self.rotary_emb( + q, offsets=indexes, max_seqlen=max_seqlen_q, cache_type="query", interleaved=self.interleaved + ) + k = self.rotary_emb( + k, offsets=indexes, max_seqlen=max_seqlen_k, cache_type="key", interleaved=self.interleaved + ) + + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + + # self attention + context = self.inner_attn(q, kv, **kwargs) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) + + def _convert_unpacked_qkv_to_packed( + self, q: torch.Tensor, kv: torch.Tensor, batch_size: int, attention_mask: torch.Tensor + ): + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attention_mask.device), + attention_mask.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ).cumsum(dim=0, dtype=torch.int32) + + cu_seqlens_q = cu_seqlens + cu_seqlens_k = cu_seqlens + + max_seqlen_q = attention_mask.shape[-1] + max_seqlen_k = attention_mask.shape[-1] + + q_packed = ( + q.masked_select(attention_mask.view(batch_size, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]).unsqueeze(0) + ) + kv_packed = ( + kv.masked_select(attention_mask.view(batch_size, -1, 1, 1, 1)) + .view(-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]) + .unsqueeze(0) + ) + + return q_packed, kv_packed, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + + def _inference(self, x, inference_params, **kwargs): # pylint: disable=W0613 + assert inference_params is not None, "inference_params is required for inference" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + attention_mask = inference_params.attention_mask + sequence_len_offset = inference_params.sequence_len_offset + window_size = inference_params.window_size + + batch_size = x.shape[0] + + # wqkv, output: q, k, v + if self.enable_qkv_fusion: + qkv = self.wqkv(x) + qkv = rearrange(qkv, "b s (h gs d) -> b s h gs d", gs=self.q_per_kv + 2, d=self.head_dim) + q, k, v = (qkv[..., : self.q_per_kv, :], qkv[..., -2, :], qkv[..., -1, :]) + q = rearrange(q, "b s h gs d -> b s (h gs) d") + else: + q, k, v = self.wq(x), self.wk(x), self.wv(x) + q = rearrange(q, "b s (h d) -> b s h d", d=self.head_dim) + k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) + v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) + + # rotary embedding, output: q, kv + assert self.rotary_emb_dim > 0 + if attention_mask is None: + raise NotImplementedError( + "You should make sure you are aware that you are changing the method of generating." + "According to your generation function instead of inference/seq_generator_module.py, " + "You may implement here for normal running." + ) + else: + if inference_params.sequence_len_offset == 0: + q = self.rotary_emb( + q, offsets=0, cache_type="query", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + k = self.rotary_emb( + k, offsets=0, cache_type="key", interleaved=self.interleaved, left_padding_mask=attention_mask + ) + else: + empties = attention_mask[..., -1].sum(dim=-1) + indexes4q = sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - empties + indexes4k = sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties + q = self.rotary_emb(q, offsets=indexes4q, cache_type="query", interleaved=self.interleaved) + k = self.rotary_emb(k, offsets=indexes4k, cache_type="key", interleaved=self.interleaved) + + kv = torch.stack([k, v], dim=2) + + if window_size is None or window_size > sequence_len_offset: + kv = update_kv_cache(kv, inference_params, self.layer_idx) + else: # window_size <= sequence_len_offset + assert kv.size(1) == 1, "update kv length more than 1" + + inference_params.key_value_memory_dict[self.layer_idx][ + :, inference_params.keep_first : inference_params.window_size - 1, ... + ] = inference_params.key_value_memory_dict[self.layer_idx][ + :, -(inference_params.window_size - 1 - inference_params.keep_first) :, ... + ].clone() + inference_params.real_sequence_len_offset = inference_params.sequence_len_offset + inference_params.sequence_len_offset = inference_params.window_size - 1 + + kv = update_kv_cache(kv, inference_params, self.layer_idx) + + inference_params.sequence_len_offset = inference_params.real_sequence_len_offset + + # When using FP16, there is a high probability of NAN in the KV. + # Since NAN cannot be removed by multiplying with and 0, it needs + # to be removed manually here. + kv = torch.where(torch.isnan(kv), 0, kv) + + # attention + if attention_mask is None: + context = self.inner_cross_attn(q, kv) + else: + if sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = attention_mask[:, None, ...] + attn_mask = torch.logical_or(torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(batch_size, -1) + + output = self.inner_attn(*self._convert_unpacked_qkv_to_packed(q, kv, batch_size, attn_mask4flsh)) + output = output.to(x.dtype) + + context = torch.zeros_like(q).masked_scatter_(attn_mask4flsh.view(batch_size, -1, 1, 1), output) + + else: + attn_mask = attention_mask[:, -1, :].view(batch_size, 1, 1, -1) + if window_size is not None and window_size <= sequence_len_offset: + attn_mask = torch.concat( + [ + attn_mask[..., : inference_params.keep_first], + attn_mask[..., -(window_size - inference_params.keep_first) :], + ], + dim=-1, + ) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + + # wo + return self.wo(rearrange(context, "b s h d -> b s (h d)")) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index fddc4194..897e1363 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -1,142 +1,60 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from typing import Callable, Dict, Optional +from typing import Dict, Optional import torch from torch import nn -from internlm.model.ops.linear import ( - ColumnParallelLinearTorch, - ISPLinear, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - RowParallelLinearTorch, -) -from internlm.model.utils import Silu +from internlm.model.modules.linear import new_linear +from internlm.model.modules.utils import Silu +from internlm.utils.logger import get_logger +logger = get_logger(__file__) -class BaseFeedForward(nn.Module): - """ - Base FeedForward in flash implementation. - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - column_cls (Optional[Callable]): The column parallel class for w1 and w3. None by default. - row_cls (Optional[Callable]): The row parallel class for w2. None by default. - mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP. - """ +def split_fused_mlp_weight(w1_w3): + w1, w3 = torch.split(w1_w3, w1_w3.shape[0] // 2, dim=0) + return w1, w3 - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - column_cls: Optional[Callable] = None, - row_cls: Optional[Callable] = None, - ): - super().__init__() - self.mlp_layer_fusion = mlp_layer_fusion - hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) - mlp_args = { - "process_group": process_group, - "bias": bias, - "sequence_parallel": sequence_parallel, - "device": device, - "dtype": dtype, - "multiple_of": 1, # TODO: check Column/RowParallelLinearTorch. - } - if not self.mlp_layer_fusion: - # gate_proj - self.w1 = column_cls(in_features, hidden_features, **mlp_args) - # down_proj - self.w2 = row_cls(hidden_features, out_features, **mlp_args) - # up_proj - self.w3 = column_cls(in_features, hidden_features, **mlp_args) - else: - assert bias is False, "Fuesd FeedForward only support bias is False." - # fused gate/up projection - self.fused_w1_w3 = column_cls(in_features, hidden_features * 2, **mlp_args) - # down_proj - self.w2 = row_cls(hidden_features, out_features, **mlp_args) - # TODO: Internal methods could change without a deprecation warning. - self._register_load_state_dict_pre_hook(BaseFeedForward._mlp_pre_load_convert, with_module=True) - self._register_state_dict_hook(BaseFeedForward._mlp_save_convert) +def _mlp_pre_load_convert( + module: "FeedForward", state_dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight" - def forward(self, x): - if not self.mlp_layer_fusion: - w1_o = self.w1(x) - w3_o = self.w3(x) - else: - fussed_out = self.fused_w1_w3(x) - w1_o, w3_o = BaseFeedForward.split_fused_mlp_output(fussed_out) - out = self.w2(Silu(w1_o, w3_o)) - return out + if module.mlp_layer_fusion and fused_name not in state_dict: + w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name) + state_dict[fused_name] = torch.cat([w1, w3], dim=0) - @staticmethod - def split_fused_mlp_weight(w1_w3): - w1, w3 = torch.split(w1_w3, w1_w3.shape[0] // 2, dim=0) - return w1, w3 + if not module.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict): + state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name)) - @staticmethod - def split_fused_mlp_output(w1_w3_out): - w1_o, w3_o = torch.split(w1_w3_out, w1_w3_out.shape[-1] // 2, dim=-1) - return w1_o, w3_o - def _mlp_pre_load_convert(self, state_dict, prefix, *args, **kwargs) -> None: # pylint: disable=W0613 - w1_name = f"{prefix}w1.weight" - w3_name = f"{prefix}w3.weight" - fused_w1_w3_name = f"{prefix}fused_w1_w3.weight" +def _mlp_save_convert(module: "FeedForward", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613 + w1_name, w3_name, fused_name = f"{prefix}w1.weight", f"{prefix}w3.weight", f"{prefix}fused_w1_w3.weight" - if self.mlp_layer_fusion and fused_w1_w3_name not in state_dict: - w1, w3 = state_dict.pop(w1_name), state_dict.pop(w3_name) - state_dict[fused_w1_w3_name] = torch.cat([w1, w3], dim=0) - if not self.mlp_layer_fusion and (w1_name not in state_dict or w3_name not in state_dict): - state_dict[w1_name], state_dict[w3_name] = self.split_fused_mlp_weight(state_dict.pop(fused_w1_w3_name)) + if module.mlp_layer_fusion: + state_dict[w1_name], state_dict[w3_name] = split_fused_mlp_weight(state_dict.pop(fused_name)) - def _mlp_save_convert(self, state_dict, prefix, *args, **kwargs) -> Dict: # pylint: disable=W0613 - w1_name = f"{prefix}w1.weight" - w3_name = f"{prefix}w3.weight" - fused_w1_w3_name = f"{prefix}fused_w1_w3.weight" + return state_dict - if self.mlp_layer_fusion: - state_dict[w1_name], state_dict[w3_name] = self.split_fused_mlp_weight( - w1_w3=state_dict.pop(fused_w1_w3_name) - ) - return state_dict - - -class FeedForward(BaseFeedForward): +class FeedForward(nn.Module): """ - FeedForward in flash implementation. + Base FeedForward in flash implementation. Args: in_features (int): size of each input sample hidden_features (int): size of hidden state of FFN out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False in the config. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. + mlp_layer_fusion (Optional[Bool]): Some linears without bias in FFN can be fused to reduce the comm cost of SP. + activation_type (str): the activation function used for feed forward, "swiglu" by default. """ def __init__( @@ -144,125 +62,57 @@ def __init__( in_features: int, hidden_features: int, out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, bias: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, multiple_of: int = 256, mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, + activation_type: str = "swiglu", ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - ColumnParallelLinearTorch, - RowParallelLinearTorch, - ) - + super().__init__() -class MegatronFeedForward(BaseFeedForward): - """ - FeedForward in megatron implementation. + # TODO: support gelu... + assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}" - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - """ + self.mlp_layer_fusion = mlp_layer_fusion - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - MegatronColumnParallelLinearTorch, - MegatronRowParallelLinearTorch, - ) + hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) + if self.mlp_layer_fusion: + assert bias is False, "Fuesd FeedForward only support bias is False." -class ISPFeedForward(BaseFeedForward): - """ - FeedForward in ISP. + self.fused_w1_w3 = new_linear("w13", in_features, hidden_features * 2, bias, device=device, dtype=dtype) + self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) - Args: - in_features (int): size of each input sample - hidden_features (int): size of hidden state of FFN - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. - """ + self._register_load_state_dict_pre_hook(_mlp_pre_load_convert, with_module=True) + self._register_state_dict_hook(_mlp_save_convert) + else: + self.w1 = new_linear("w1", in_features, hidden_features, bias, device=device, dtype=dtype) + self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) + self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype) - def __init__( - self, - in_features: int, - hidden_features: int, - out_features: int = None, - process_group: Optional[torch.distributed.ProcessGroup] = None, - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - multiple_of: int = 256, - mlp_layer_fusion: Optional[bool] = False, - sequence_parallel: Optional[bool] = False, - ): - super().__init__( - in_features, - hidden_features, - out_features, - process_group, - bias, - device, - dtype, - multiple_of, - mlp_layer_fusion, - sequence_parallel, - ISPLinear, - ISPLinear, - ) + def forward(self, x): + if not self.mlp_layer_fusion: + w1_o = self.w1(x) + w3_o = self.w3(x) + else: + fussed_out = self.fused_w1_w3(x) + w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) + out = self.w2(Silu(w1_o, w3_o)) + return out -def get_mlp_cls(tp_mode: str): - if tp_mode in ["mtp", "fsp"]: - mlp_cls = FeedForward - elif tp_mode == "msp": - mlp_cls = MegatronFeedForward - else: - mlp_cls = ISPFeedForward - return mlp_cls +def new_feed_forward( + in_features: int, + hidden_features: int, + out_features: int = None, + bias: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + multiple_of: int = 256, + mlp_layer_fusion: Optional[bool] = False, + activation_type: str = "swiglu", +) -> FeedForward: + return FeedForward( + in_features, hidden_features, out_features, bias, device, dtype, multiple_of, mlp_layer_fusion, activation_type + ) diff --git a/internlm/model/modules/multi_head_attention.py b/internlm/model/modules/multi_head_attention.py deleted file mode 100644 index 9630f077..00000000 --- a/internlm/model/modules/multi_head_attention.py +++ /dev/null @@ -1,866 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import math -import warnings -from typing import Any, Optional, Tuple - -import torch -import torch.distributed as dist -import torch.nn.functional as F -from einops import rearrange, repeat -from torch import Tensor, nn -from torch.nn import Module - -from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import global_context as gpc -from internlm.model.modules.embedding import ( - DynamicNTKScalingRotaryEmbedding, - RotaryEmbedding, -) -from internlm.model.ops.linear import get_linear_cls -from internlm.model.utils import pack_output_after_attn, unpack_qkv_before_attn -from internlm.utils.common import get_current_device - -internlm_accelerator = get_accelerator() - -try: - import torch_npu -except (ImportError, ModuleNotFoundError): - pass - - -def get_gqa_attn_cls(use_flash_attn, tp_mode, causal, softmax_scale, dropout, sequence_process_group): - if use_flash_attn: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend == AcceleratorType.GPU: - from flash_attn import flash_attn_varlen_kvpacked_func - from flash_attn.modules.mha import FlashCrossAttention - - inner_attn, inner_cross_attn_cls = flash_attn_varlen_kvpacked_func, FlashCrossAttention - elif device_backend == AcceleratorType.NPU: - from internlm.model.modules.multi_head_attention import ( - AscendFlashSelfAttention, - ) - - inner_attn_cls, inner_cross_attn_cls = AscendFlashSelfAttention, AscendFlashSelfAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - elif device_backend == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - FlashCrossAttention, - FlashSelfAttention, - ) - - inner_attn_cls, inner_cross_attn_cls = FlashSelfAttention, FlashCrossAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - else: - raise NotImplementedError(f"Unsupport device type: {device_backend} for flash attention") - else: - inner_attn_cls, inner_cross_attn_cls = SelfAttention, CrossAttention - inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - - inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - - if tp_mode == "isp": - inner_attn = DistributedAttention(inner_attn, sequence_process_group=sequence_process_group) - inner_cross_attn = DistributedAttention(inner_cross_attn, sequence_process_group=sequence_process_group) - - return inner_attn, inner_cross_attn - - -class AscendFlashSelfAttention(torch.nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__( - self, - causal: bool = True, - softmax_scale: float = None, - attention_dropout: float = 0.0, - ): - super().__init__() - assert rearrange is not None, "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.shape_order = "BSND" - self.dropout_p = attention_dropout - - if self.causal: - self.sparse_mode = 0 - self.next_tockens = 0 - else: - assert False, "Ascend flash attention unsupport causal=False now!" - - def forward( - self, - qkv=None, - q=None, - k=None, - v=None, - kv=None, - cu_seqlens_q=None, # pylint: disable=W0613 - cu_seqlens_k=None, # pylint: disable=W0613 - max_seqlen_q=None, # pylint: disable=W0613 - max_seqlen_k=None, # pylint: disable=W0613 - dropout_p=0.0, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # pylint: disable=W0613 - alibi_slopes=None, # pylint: disable=W0613 - deterministic=False, - return_attn_probs=False, # pylint: disable=W0613 - attention_mask=None, - ): - if qkv is not None: - assert (q, k, v, kv) == (None, None, None, None) - q = qkv[:, :, 0] - k = qkv[:, :, 1] - v = qkv[:, :, 2] - else: - assert q is not None - if kv is not None: - assert (k, v) == (None, None) - k = kv[:, :, 0] - v = kv[:, :, 1] - else: - assert k is not None and v is not None - - if causal: - assert causal == self.causal - if dropout_p: - assert dropout_p == self.dropout_p - if softmax_scale: - assert softmax_scale == self.softmax_scale - - return self._forward(q, k, v, deterministic=deterministic, attention_mask=attention_mask) - - def _forward( - self, - q, - k, - v, - deterministic: bool = False, - attention_mask: Tensor = None, - actual_seq_qlen: Tensor = None, # pylint: disable=W0613 - actual_seq_kvlen: Tensor = None, # pylint: disable=W0613 - ): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the query, key, and value. (B, S, H, D) - """ - assert q.dtype in (torch.bfloat16, torch.float16) - - if len(q.shape) == 5: - q = q.squeeze(dim=2) - k = k.squeeze(dim=2) - v = v.squeeze(dim=2) - - B, S, N, D = q.shape[0], q.shape[1], q.shape[2], q.shape[3] # noqa: F841 # pylint: disable=W0612 - - if self.shape_order == "BSH": - q, k, v = [rearrange(x, "b s h d -> b s (h d)") for x in [q, k, v]] - elif self.shape_order == "SBH": - q, k, v = [rearrange(x, "b s h d -> s b (h d)") for x in [q, k, v]] - elif self.shape_order != "BSND": - raise ValueError("Invalid shape-order: {}, shape-order must be SBH or BSH or BSND".format(self.shape_order)) - - if attention_mask is None: - attention_mask = torch.triu(torch.ones(S, S, device=get_current_device()), 1).bool() - - output = torch_npu.npu_fusion_attention( - query=q, - key=k, - value=v, - head_num=N, - input_layout="BSND", - pse=None, - atten_mask=attention_mask, - scale=self.softmax_scale, - sparse_mode=self.sparse_mode, - pre_tockens=k.shape[1], # Used for sparse calculations, representing the left boundary of the slides window - next_tockens=self.next_tockens, - keep_prob=1 - self.dropout_p, - inner_precise=0 if not deterministic else 2, - )[0] - - if self.shape_order == "BSH": - output = rearrange(output, "b s (h d) -> b s h d", h=N) - elif self.shape_order == "SBH": - output = rearrange(output, "s b (h d) -> b s h d", h=N) - elif self.shape_order != "BSND": - raise ValueError("Invalid shape-order: {}, shape-order must be SBH or BSH or BSND".format(self.shape_order)) - - return output - - -# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py -class _SeqAllToAll(torch.autograd.Function): - "sequence alltoall" - - @staticmethod - def forward(ctx: Any, group: dist.ProcessGroup, input_: Tensor, scatter_idx: int, gather_idx: int) -> Tensor: - ctx.group = group - ctx.scatter_idx = scatter_idx - ctx.gather_idx = gather_idx - - if dist.get_world_size(group) <= 1: - return input_ - - seq_world_size = dist.get_world_size(group) - - input_list = [t.contiguous() for t in torch.tensor_split(input_, seq_world_size, scatter_idx)] - output_list = [torch.empty_like(input_list[0]) for _ in range(seq_world_size)] - # TODO Use all_to_all_single instead - dist.all_to_all(output_list, input_list, group=group) - return torch.cat(output_list, dim=gather_idx).contiguous() - - @staticmethod - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: - if dist.get_world_size(ctx.group) <= 1: - return (None, *grad_output, None, None) - - return (None, _SeqAllToAll.apply(ctx.group, *grad_output, ctx.gather_idx, ctx.scatter_idx), None, None) - - -# adpated from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/sequence/layer.py -class DistributedAttention(torch.nn.Module): - """Initialization. - - Arguments: - local_attention (Module): local attention with q,k,v - sequence_process_group (ProcessGroup): sequence parallel process group - first_scatter_idx (int): scatter_idx for the first all2all comm - first_gather_idx (int): gather_idx for the first all2all comm - second_scatter_idx (int): scatter_idx for the second all2all comm - second_gather_idx (int): gather_idx for the second all2all comm - """ - - def __init__( - self, - local_attention: Module, - sequence_process_group: dist.ProcessGroup, - ) -> None: - super().__init__() - self.local_attn = local_attention - self.spg = sequence_process_group - self._scatter_gather_idx = {} - - # scatter_gather_idx contains the scatter and gather index for different data packed mode - # key is the data packed mode, which should be in ['qkv', 'kv', 'q', 'output'] - # value is the scatter and gather index in all2all - self._scatter_gather_idx["qkv"] = [2, 0] # qkv shape:[sequence, 3, head, head_dim] - self._scatter_gather_idx["kv"] = [2, 0] # kv shape: [sequence, 2, head, head_dim] - self._scatter_gather_idx["q"] = [1, 0] # q/k/v shape: [sequence, head, head_dim] - self._scatter_gather_idx["output"] = [0, 1] # output shape: [sequence, head, head_dim] - - def forward( - self, qkv: Tensor = None, kv: Tensor = None, q: Tensor = None, k: Tensor = None, v: Tensor = None, **kwargs: Any - ) -> Tensor: - if gpc.is_evaluating is True or gpc.config.data.use_packed_dataset is False: - # when conducting evaluation, the scatter and gather index should add 1. - eval_scatter_gather_idx = {key: [x + 1 for x in value] for key, value in self._scatter_gather_idx.items()} - return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=eval_scatter_gather_idx, **kwargs) - else: - return self._forward(qkv=qkv, kv=kv, q=q, k=k, v=v, scatter_gather=self._scatter_gather_idx, **kwargs) - - def _forward( - self, - qkv: Tensor = None, - kv: Tensor = None, - q: Tensor = None, - k: Tensor = None, - v: Tensor = None, - scatter_gather: dict = None, - **kwargs: Any, - ) -> Tensor: - """forward - - Arguments: - qkv (Tensor): packed qkv input to the layer - kv (Tensor): packed kv input to the layer - q (Tensor): q input to the layer - k (Tensor): k input to the layer - v (Tensor): v input to the layer - args: other args - - Returns: - * output (Tensor): context output - """ - - if qkv is not None: - qkv = _SeqAllToAll.apply(self.spg, qkv, scatter_gather["qkv"][0], scatter_gather["qkv"][1]) - context_layer = self.local_attn(qkv=qkv, **kwargs) - elif kv is not None: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) - kv = _SeqAllToAll.apply(self.spg, kv, scatter_gather["kv"][0], scatter_gather["kv"][1]) - context_layer = self.local_attn(q=q, kv=kv, **kwargs) - else: - q = _SeqAllToAll.apply(self.spg, q, scatter_gather["q"][0], scatter_gather["q"][1]) - k = _SeqAllToAll.apply(self.spg, k, scatter_gather["q"][0], scatter_gather["q"][1]) - v = _SeqAllToAll.apply(self.spg, v, scatter_gather["q"][0], scatter_gather["q"][1]) - context_layer = self.local_attn(q=q, k=k, v=v, **kwargs) - output = _SeqAllToAll.apply(self.spg, context_layer, scatter_gather["output"][0], scatter_gather["output"][1]) - - # out e.g., [s/p::h] - return output - - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, qkv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - causal = self.causal if causal is None else causal - q, k, v = qkv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -class CrossAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.drop = nn.Dropout(attention_dropout) - - def forward(self, q, kv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - q: The tensor containing the query. (B, Sq, H, D) - kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, Sk) - """ - batch_size, seqlen_q = q.shape[0], q.shape[1] - causal = self.causal if causal is None else causal - seqlen_k = kv.shape[1] - assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] - if kv.shape[3] != q.shape[2]: # MQA/GQA - kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) - k, v = kv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") - if causal: - # causal mask needs to take into account the difference between seqlen_q and seqlen_k - row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) - sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - causal_mask = col_idx > row_idx + sk - seqlen_q - scores = scores.masked_fill(causal_mask, -10000.0) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = self.drop(attention) - output = torch.einsum("bhts,bshd->bthd", attention_drop, v) - return output - - -def _update_kv_cache(kv, inference_params, layer_idx): - """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" - # Pre-allocate memory for key-values for inference. - num_heads, head_dim = kv.shape[-2:] - if layer_idx not in inference_params.key_value_memory_dict: - kv_cache = torch.empty( - inference_params.max_batch_size, - inference_params.max_sequence_len, - 2, - num_heads, - head_dim, - dtype=kv.dtype, - device=kv.device, - ) - inference_params.key_value_memory_dict[layer_idx] = kv_cache - else: - if not inference_params.fused_ft_kernel: - kv_cache = inference_params.key_value_memory_dict[layer_idx] - else: - # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) - # where packsize = 4 if fp32, 8 if fp16 or bf16. - # v_cache has shape (b, h, s, headdim) - k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] - kv_cache = None - # Adjust key and value for inference - batch_start = inference_params.batch_size_offset - batch_end = batch_start + kv.shape[0] - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + kv.shape[1] - assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) - assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) - # Copy key and values. - if not inference_params.fused_ft_kernel: - assert kv_cache is not None - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - kv = kv_cache[batch_start:batch_end, :sequence_end, ...] - return kv - else: - assert inference_params.sequence_len_offset == 0 - # FT kernel requires different layouts for the k_cache and v_cache. - assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] - packsize = 4 if kv.dtype == torch.float32 else 8 - if kv_cache is not None: - kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv - k_cache = rearrange( - kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize - ).contiguous() - v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() - inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) - else: - k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( - kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize - ) - v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") - return kv - - -class MHA(nn.Module): - """ - Multi-head self-attention and cross-attention. - - Args: - embed_dim (int): The dimention of hidden state. - num_heads (int): The number of attention heads. - process_group (torch.distributed.ProcessGroup): The group of the current device for `parallel_mode`. - max_position_embeddings (int): max position embeddings, 2048 by default. - dropout (float): The dropout rate for cross attention and self attention. 0.0 by default. - softmax_scale (float): The temperature to use for the softmax attention. - causal (boolean): Whether to apply causal attention mask. False by default. - layer_idx (int): The index of current layer. None by default. - use_dynamic_ntk_rope (bool): whether use dynamic ntk rope, false by default. - rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. - rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements - XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. - rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - process_group: Optional[torch.distributed.ProcessGroup], - sequence_process_group: Optional[torch.distributed.ProcessGroup], - max_position_embeddings: int = 2048, - dropout: float = 0.0, - softmax_scale: float = None, - causal: bool = False, - layer_idx: int = None, - use_dynamic_ntk_rope: bool = False, - rotary_emb_dim: int = 0, - rotary_emb_scale_base: int = 0, - use_flash_attn: bool = True, - rope_base: int = 10000, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - tp_mode: str = "mtp", - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - self.layer_idx = layer_idx - self.max_position_embeddings = max_position_embeddings - self.use_dynamic_ntk_rope = use_dynamic_ntk_rope - self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - self.tp_mode = tp_mode - - if self.rotary_emb_dim > 0: - if self.use_dynamic_ntk_rope: - self.rotary_emb = DynamicNTKScalingRotaryEmbedding( - self.rotary_emb_dim, - base=rope_base, - scale_base=rotary_emb_scale_base, - device=device, - max_position_embeddings=max_position_embeddings, - scaling_factor=1.0, # Currently do not support dynamic scaling. - ) - else: - self.rotary_emb = RotaryEmbedding( - self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device - ) - - # notice here should change bias=True - Wqkv_cls = get_linear_cls(self.tp_mode, "column") - self.Wqkv = Wqkv_cls( - embed_dim, - 3 * embed_dim, - process_group, - bias=True, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) # according to https://spaces.ac.cn/archives/9577 - - if gpc.config.model.use_flash_attn: - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - from flash_attn.modules.mha import ( - FlashCrossAttention, - FlashSelfAttention, - ) - elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: - FlashCrossAttention, FlashSelfAttention = AscendFlashSelfAttention, AscendFlashSelfAttention - elif internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - FlashCrossAttention, - FlashSelfAttention, - ) - - inner_attn_cls = FlashSelfAttention - inner_cross_attn_cls = FlashCrossAttention - else: - inner_attn_cls = SelfAttention - inner_cross_attn_cls = CrossAttention - - self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout) - self.inner_cross_attn = inner_cross_attn_cls( - causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout - ) - if self.tp_mode == "isp": - self.inner_attn = DistributedAttention(self.inner_attn, sequence_process_group=sequence_process_group) - self.inner_cross_attn = DistributedAttention( - self.inner_cross_attn, sequence_process_group=sequence_process_group - ) - - # output projection always have the bias (for now) - out_proj_cls = get_linear_cls(self.tp_mode, "row") - self.out_proj = out_proj_cls( - embed_dim, - embed_dim, - process_group, - bias=True, - sequence_parallel=gpc.config.parallel.sequence_parallel, - **factory_kwargs, - ) - - def forward(self, x, seqlen=None, inference_params=None, **kwargs): - if kwargs.get("indexes", None) is not None: - return self._packed_forward(x=x, inference_params=inference_params, **kwargs) - else: - return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - bsz, _, _ = x.shape - qkv = self.Wqkv(x) - if seqlen is None: - qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) - else: - qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) - - if inference_params is None: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if qkv.dtype not in [torch.float16, torch.bfloat16]: - qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv=qkv).to(x.dtype) - else: - context = self.inner_attn(qkv=qkv) - - else: - if self.use_dynamic_ntk_rope: - q = qkv[:, :, 0] - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) - if inference_params.sequence_len_offset != 0: - # q shape: [bsz, 1, nheads, head_dim] - # kv shape: [bsz, seqlen, 2, nheads, head_dim] - bsz, seq_len, _, nheads, head_dim = kv.shape - q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2) - qkv = torch.cat([q, kv], dim=2) - if self.rotary_emb_dim > 0: - qkv = self.rotary_emb(qkv) - q = qkv[:, [-1], 0] - kv = qkv[:, :, 1:] - else: - if inference_params.sequence_len_offset > self.max_position_embeddings: - warnings.warn( - "Notice your prompt's length is longer than model's max_position_embeddings: " - f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." - ) - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - q = qkv[:, :, 0] - kv = qkv[:, :, 1:] - else: - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) - kv = torch.stack([k, v], dim=2) - assert self.rotary_emb_dim > 0, "You should use rotary_emb." - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - if inference_params.sequence_len_offset == 0: - moved_q = q.clone() - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - elif not self.use_dynamic_ntk_rope: - if inference_params.sequence_len_offset > self.max_position_embeddings: - warnings.warn( - "Notice your prompt's length is longer than model's max_position_embeddings: " - f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations." - ) - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset - * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset - * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ) - else: - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset - * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ) - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - k[i] = moved_k[i] - else: - q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) - k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) - - kv = torch.stack([k, v], dim=2) - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) - attn_mask = inference_params.attention_mask[:, None, ...] - attn_mask = torch.logical_or( - torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask - ) - attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) - cu_seqlens = torch.concat( - [ - torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), - attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), - ], - dim=0, - ) - cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) - max_seqlen_q = attn_mask4flsh.shape[-1] - max_seqlen_k = attn_mask4flsh.shape[-1] - total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) - total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( - -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] - ) - - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if total_q.dtype not in [torch.float16, torch.bfloat16]: - total_q = total_q.to(torch.bfloat16) - if total_kv.dtype not in [torch.float16, torch.bfloat16]: - total_kv = total_kv.to(torch.bfloat16) - - try: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_func, - ) - except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - try: - from flash_attn.flash_attn_interface import ( - flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, - ) - except ImportError: - raise ImportError("Please check your flash_attn version >= 1.0.5.") - - output = flash_attn_unpadded_func( - total_q, - total_kv, - cu_seqlens, - cu_seqlens, - max_seqlen_q, - max_seqlen_k, - 0.0, - None, - True, - False, - ).to(x.dtype) - else: - attn_scores = torch.matmul(total_q, total_kv.transpose(-2, -1)) / (cu_seqlens**0.5) - attn_weights = F.softmax(attn_scores, dim=-1) - output = torch.matmul(attn_weights, total_kv) - - context = torch.zeros_like(q) - context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) - - else: - attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) - - k, v = torch.chunk(kv, 2, dim=2) - k = k.squeeze(2) - v = v.squeeze(2) - sp = k.shape - scores = torch.einsum( - "blhd,bnhd->bhln", - q, - k.reshape(sp[0], sp[1], q.size(2), sp[3]), - ) / math.sqrt(q.size(-1)) - scores = scores.masked_fill(attn_mask, -65000.0) - scores = F.softmax(scores, dim=-1) # bsz x h x L x L - context = torch.einsum( - "bhmn,bnhd->bmhd", - scores, - v.reshape(sp[0], sp[1], q.size(2), sp[3]), - ) - else: - context = self.inner_cross_attn(q, kv, causal=True) - - if seqlen is None: - context = rearrange(context, "b s h d -> b s (h d)") - else: - context = rearrange(context, "b s h d -> (b s) (h d)") - - out = self.out_proj(context) - return out - - def _packed_forward(self, x, inference_params=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. - If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we - split x during sequence parallel, we split the batch * seqlen dimension - (in case batch is small). - """ - qkv = self.Wqkv(x) # bsz x total x hsz - qkv = rearrange( - qkv, "b t (three h d) -> b t three h d", three=3, d=self.head_dim - ) # bsz x total x 3 x n_head x d - qkv = self.rotary_emb(qkv, **kwargs) - - kwargs.pop("indexes") - - # for packed data, batch dimension with a size of 1 should be directly squeezed off. - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - qkv = qkv.squeeze(0) - # since torch_npu only supports fa with no packed data currently, qkv should be unpacked - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - qkv = unpack_qkv_before_attn(qkv, kwargs["cu_seqlens"]) - kwargs.pop("cu_seqlens") - kwargs.pop("max_seqlen") - - if inference_params is None: - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): - if qkv.dtype not in [torch.float16, torch.bfloat16]: - qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv=qkv, **kwargs).to(x.dtype) - else: - context = self.inner_attn(qkv=qkv, **kwargs) - - else: - raise RuntimeError("Not support this right now") - - if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: - context = rearrange(context, "s h d -> s (h d)") # recover the shape - context = context.unsqueeze(0) # restore bsz dimension - elif internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]: - context = rearrange(context, "b s h d -> b s (h d)") # recover the shape - context = pack_output_after_attn(context, kwargs["cu_seqlens"]) - - out = self.out_proj(context) - - return out diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py new file mode 100644 index 00000000..b94cdd43 --- /dev/null +++ b/internlm/model/modules/norm.py @@ -0,0 +1,19 @@ +""" +layer norm modules +""" + +from typing import List, Union + +import torch +from torch import nn + +from internlm.model.ops.norm import RMSNorm + +Shape = Union[int, List[int], torch.Size] + + +def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5): + if norm_type == "rmsnorm": + return RMSNorm(normalized_shape, eps) + else: # default: layernorm + return nn.LayerNorm(normalized_shape, eps) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py new file mode 100644 index 00000000..dd86cb1c --- /dev/null +++ b/internlm/model/modules/utils.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch +import torch.nn.functional as F +from einops import rearrange + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +def is_moe_param(param: torch.Tensor) -> bool: + if hasattr(param, "is_expert") and param.is_expert: + return True + return False + + +def Silu(w1_o, w2_o): + return F.silu(w1_o) * w2_o + + +Silu = torch.jit.script(Silu) + + +def update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_sequence_len, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + if not inference_params.fused_ft_kernel: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + else: + # For FT, k_cache has shape (b, h, headdim / packsize, s, packsize) + # where packsize = 4 if fp32, 8 if fp16 or bf16. + # v_cache has shape (b, h, s, headdim) + k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx] + kv_cache = None + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.sequence_len_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= (kv_cache.shape[0] if kv_cache is not None else v_cache.shape[0]) + assert sequence_end <= (kv_cache.shape[1] if kv_cache is not None else v_cache.shape[2]) + # Copy key and values. + if not inference_params.fused_ft_kernel: + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + kv = kv_cache[batch_start:batch_end, :sequence_end, ...] + return kv + else: + assert inference_params.sequence_len_offset == 0 + # FT kernel requires different layouts for the k_cache and v_cache. + assert kv.dtype in [torch.float16, torch.bfloat16, torch.float32] + packsize = 4 if kv.dtype == torch.float32 else 8 + if kv_cache is not None: + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + k_cache = rearrange( + kv_cache[:, :, 0], + "b s h (d packsize) -> b h d s packsize", + packsize=packsize, + ).contiguous() + v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous() + inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache) + else: + k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange( + kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize + ) + v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(kv[:, :, 1], "b s h d -> b h s d") + return kv diff --git a/internlm/model/moe/__init__.py b/internlm/model/moe/__init__.py index 9ebcea66..e69de29b 100644 --- a/internlm/model/moe/__init__.py +++ b/internlm/model/moe/__init__.py @@ -1,28 +0,0 @@ -from .gshard_layer import GShardMOELayer -from .moe import MoE - -__all__ = ["MoE", "GShardMOELayer"] - - -try: - from megablocks import ops # noqa # pylint: disable=W0611 -except ModuleNotFoundError: - pass -else: - from internlm.model.moe.megablock.megablock_moe import ( # noqa # pylint: disable=W0611 - MegaBlockMoE, - ) - - __all__ += "MegaBlockMoE" - -try: - import stk # noqa # pylint: disable=W0611 - from megablocks import ops # noqa # pylint: disable=W0611 -except ModuleNotFoundError: - pass -else: - from internlm.model.moe.megablock.megablock_dmoe import ( # noqa # pylint: disable=W0611 - MegaBlockdMoE, - ) - - __all__ += "MegaBlockdMoE" diff --git a/internlm/model/moe/gshard_layer.py b/internlm/model/moe/gshard_layer.py index ee03d781..e84abc88 100644 --- a/internlm/model/moe/gshard_layer.py +++ b/internlm/model/moe/gshard_layer.py @@ -15,9 +15,9 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.mlp import new_feed_forward from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.registry import MODEL_INITIALIZER from .base_layer import BaseMoELayer from .utils import all_to_all @@ -436,7 +436,6 @@ def forward( return gate_output -@MODEL_INITIALIZER.register_module(module_name="GShard") class GShardMOELayer(BaseMoELayer): """MOELayer module which implements MixtureOfExperts as described in Gshard_. :: @@ -461,7 +460,6 @@ def __init__( hidden_features: int, out_features: int, num_experts: int, - ep_cls: Optional[Callable], ep_group: Optional[torch.distributed.ProcessGroup], ep_size: int, top_k: int = 1, @@ -496,11 +494,10 @@ def __init__( ), torch.nn.ModuleList( [ - ep_cls( + new_feed_forward( in_features, hidden_features, out_features, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, diff --git a/internlm/model/moe/megablock/megablock_dmoe.py b/internlm/model/moe/megablock/megablock_dmoe.py index 88e7e806..8c44baf9 100644 --- a/internlm/model/moe/megablock/megablock_dmoe.py +++ b/internlm/model/moe/megablock/megablock_dmoe.py @@ -1,9 +1,7 @@ from typing import Optional, Tuple import numpy as np -import stk import torch -from megablocks import ops from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -11,10 +9,15 @@ from internlm.model.moe.megablock.megablock_moe import MegaBlockMoE from internlm.model.moe.megablock.mlp import MegaBlockGroupedFeedForward from internlm.model.moe.megablock.utils import promote_scalar -from internlm.utils.registry import MODEL_INITIALIZER + +try: + import stk + from megablocks import ops +except (ModuleNotFoundError, ImportError): + stk = None + ops = None -@MODEL_INITIALIZER.register_module(module_name="MegaBlock-D") class MegaBlockdMoE(MegaBlockMoE): """ Built on the paper and library Megablocks as described in @@ -30,10 +33,12 @@ class MegaBlockdMoE(MegaBlockMoE): def __init__( # pylint: disable=W0231 self, - hidden_size: int, + in_features: int, + hidden_features: int, + out_features: int, + num_experts: int, ep_group: Optional[torch.distributed.ProcessGroup], ep_size: int, - num_experts: int, top_k: int = 1, parallel_mode: str = "tensor", device: Optional[torch.device] = None, @@ -41,11 +46,12 @@ def __init__( # pylint: disable=W0231 multiple_of: int = 256, ) -> None: assert gpc.expert_parallel_size == 1, "do not support expert parallel" + assert ops is not None and stk is not None, "MegaBlocks not found, please run " '"pip install megablocks".' self.top_k = top_k self.num_experts = num_experts tp_size = gpc.get_world_size(ParallelMode.TENSOR) - self.ffn_dim = multiple_of * ((int(hidden_size * gpc.config.model.mlp_ratio) + multiple_of - 1) // multiple_of) + self.ffn_dim = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) assert self.ffn_dim % tp_size == 0 if parallel_mode == "tensor": self.ffn_dim_per_row = self.ffn_dim // tp_size // ep_size @@ -53,10 +59,11 @@ def __init__( # pylint: disable=W0231 self.ffn_dim_per_row = self.ffn_dim // ep_size BaseMoELayer.__init__( # pylint: disable=W0233 self, - torch.nn.Linear(hidden_size, num_experts, bias=False), + torch.nn.Linear(in_features, num_experts, bias=False), MegaBlockGroupedFeedForward( - hidden_size, + in_features, (self.ffn_dim // tp_size) * (num_experts // ep_size), + out_features, parallel_mode, device, dtype, @@ -111,7 +118,7 @@ def sparse_transpose( offsets_t = torch.cat([zero, nnz_per_column]) return column_indices_t, offsets_t, block_offsets_t - def topology(self, x: torch.Tensor, padded_bins: torch.Tensor) -> stk.Matrix: + def topology(self, x: torch.Tensor, padded_bins: torch.Tensor): padded_tokens, _ = x.size() assert padded_tokens % self.blocking == 0 assert self.ffn_dim_per_row % self.blocking == 0 diff --git a/internlm/model/moe/megablock/megablock_moe.py b/internlm/model/moe/megablock/megablock_moe.py index 202a5088..2afbac26 100644 --- a/internlm/model/moe/megablock/megablock_moe.py +++ b/internlm/model/moe/megablock/megablock_moe.py @@ -3,17 +3,19 @@ import numpy as np import torch import torch.nn.functional as F -from megablocks import ops from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.model.moe.base_layer import BaseMoELayer from internlm.model.moe.megablock.mlp import MegaBlockFeedForward from internlm.model.moe.utils import all_to_all -from internlm.utils.registry import MODEL_INITIALIZER + +try: + from megablocks import ops +except (ModuleNotFoundError, ImportError): + ops = None -@MODEL_INITIALIZER.register_module(module_name="MegaBlock") class MegaBlockMoE(BaseMoELayer): """ Built on the paper and library Megablocks as described in @@ -29,10 +31,12 @@ class MegaBlockMoE(BaseMoELayer): def __init__( self, - hidden_size: int, + in_features: int, + hidden_features: int, + out_features: int, + num_experts: int, ep_group: Optional[torch.distributed.ProcessGroup], ep_size: int, - num_experts: int, top_k: int = 1, capacity_factor: float = 1.0, drop_tokens: bool = True, @@ -41,19 +45,21 @@ def __init__( multiple_of: int = 256, ) -> None: assert not gpc.config.parallel.sequence_parallel, "do not support sequence parallel" + assert ops is not None, 'MegaBlocks not found, please run "pip install megablocks".' self.top_k = top_k self.num_experts = num_experts tp_size = gpc.get_world_size(ParallelMode.TENSOR) - self.ffn_dim = multiple_of * ((int(hidden_size * gpc.config.model.mlp_ratio) + multiple_of - 1) // multiple_of) + self.ffn_dim = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) self.capacity_factor = capacity_factor self.drop_tokens = drop_tokens assert self.ffn_dim % tp_size == 0 super().__init__( - torch.nn.Linear(hidden_size, num_experts, bias=False), + torch.nn.Linear(in_features, num_experts, bias=False), MegaBlockFeedForward( - hidden_size, + in_features, self.ffn_dim // tp_size, + out_features, num_experts // ep_size, device, dtype, diff --git a/internlm/model/moe/megablock/mlp.py b/internlm/model/moe/megablock/mlp.py index 3ac8913b..4c68e5cc 100644 --- a/internlm/model/moe/megablock/mlp.py +++ b/internlm/model/moe/megablock/mlp.py @@ -3,13 +3,13 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.utils import Silu from internlm.model.moe.megablock.utils import ( act_fn, dsd_nn, sdd_nt, tensor_parallel_bmm, ) -from internlm.model.utils import Silu class MegaBlockFeedForward(nn.Module): @@ -21,6 +21,7 @@ def __init__( self, in_features: int, hidden_features: int, + out_features: int, num_local_experts: int, device=None, dtype=None, @@ -29,8 +30,10 @@ def __init__( # merged expert weights, all of size (ffn_dim * n_experts, model_dim) self.w1 = nn.Parameter(torch.empty(num_local_experts, in_features, hidden_features, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.empty(num_local_experts, in_features, hidden_features, device=device, dtype=dtype)) - self.w3 = nn.Parameter(torch.empty(num_local_experts, hidden_features, in_features, device=device, dtype=dtype)) + self.w3 = nn.Parameter(torch.empty(num_local_experts, in_features, hidden_features, device=device, dtype=dtype)) + self.w2 = nn.Parameter( + torch.empty(num_local_experts, hidden_features, out_features, device=device, dtype=dtype) + ) def forward(self, x): # TODO w2 and w3 should swap @@ -51,6 +54,7 @@ def __init__( self, in_features: int, hidden_features: int, + out_features: int, parallel_mode="tensor", device=None, dtype=None, @@ -59,7 +63,7 @@ def __init__( # merged expert weights, all of size (ffn_dim * n_experts, model_dim) self.w1 = nn.Parameter(torch.empty(hidden_features, in_features, device=device, dtype=dtype)) - self.w2 = nn.Parameter(torch.empty(hidden_features, in_features, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.empty(hidden_features, out_features, device=device, dtype=dtype)) self.w3 = nn.Parameter(torch.empty(hidden_features, in_features, device=device, dtype=dtype)) self.parallel_mode = parallel_mode diff --git a/internlm/model/moe/megablock/utils.py b/internlm/model/moe/megablock/utils.py index 2c890e01..857dd8b7 100644 --- a/internlm/model/moe/megablock/utils.py +++ b/internlm/model/moe/megablock/utils.py @@ -1,9 +1,7 @@ -import sys - import torch from internlm.accelerator import get_accelerator -from internlm.model.utils import Silu +from internlm.model.modules.utils import Silu try: import stk @@ -366,26 +364,3 @@ def act_fn(x1, x2, topo): ) return y - - -# check dependency -def check_megablock_installed(): - try: - from megablocks import ops # noqa # pylint: disable=W0611 - except ModuleNotFoundError: - print( - "MegaBlocks not found, please see " - "https://github.com/stanford-futuredata/megablocks/. " - "Note that MegaBlocks depends on mosaicml-turbo, which only " - "supports python 3.10.", - flush=True, - ) - sys.exit() - - -def check_stk_installed(): - try: - import stk # noqa # pylint: disable=W0611 - except ModuleNotFoundError: - print("STK not found: please see https://github.com/stanford-futuredata/stk", flush=True) - sys.exit() diff --git a/internlm/model/moe/moe.py b/internlm/model/moe/moe.py index 392cca89..304d8d0a 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/moe/moe.py @@ -1,16 +1,29 @@ -from typing import Callable, Optional +from typing import Optional import torch -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.moe.gshard_layer import GShardMOELayer +from internlm.model.moe.megablock.megablock_dmoe import MegaBlockdMoE +from internlm.model.moe.megablock.megablock_moe import MegaBlockMoE from internlm.utils.logger import get_logger -from internlm.utils.registry import MODEL_INITIALIZER # global llm logger logger = get_logger(__file__) +def new_moe_layer(moe_type: str, **kwargs): + if moe_type == "GShard": + return GShardMOELayer(**kwargs) + elif moe_type == "MegaBlock": + return MegaBlockMoE(**kwargs) + elif moe_type == "MegaBlock-D": + return MegaBlockdMoE(**kwargs) + else: + raise ValueError(f"Unsupported model type: {moe_type}") + + class MoE(torch.nn.Module): """Initialize an MoE layer. @@ -38,7 +51,6 @@ def __init__( in_features: int, hidden_features: int, out_features: int, - ep_cls: Optional[Callable], ep_group: Optional[torch.distributed.ProcessGroup], num_experts: int = 1, ep_size=1, @@ -52,27 +64,26 @@ def __init__( if not hasattr(gpc.config, "moe"): gpc.config.moe = dict() - self.moe_layer = MODEL_INITIALIZER.get_module(module_name=gpc.config.model.moe_type)( + self.moe_layer = new_moe_layer( + moe_type=gpc.config.model.moe_type, in_features=in_features, hidden_features=hidden_features, out_features=out_features, num_experts=num_experts, - ep_cls=ep_cls, ep_group=ep_group, ep_size=ep_size, device=device, dtype=dtype, - **(gpc.config.moe) + **(gpc.config.moe), ) # residual network, see https://arxiv.org/pdf/2201.05596.pdf, seems useful for convergence self.use_residual = use_residual if self.use_residual: - self.residual_mlp = ep_cls( + self.residual_mlp = new_feed_forward( in_features=in_features, hidden_features=hidden_features, out_features=out_features, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=dtype, diff --git a/internlm/model/ops/attention.py b/internlm/model/ops/attention.py new file mode 100644 index 00000000..9205652a --- /dev/null +++ b/internlm/model/ops/attention.py @@ -0,0 +1,847 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the attention operators. +""" + +import math +from typing import Callable, Tuple + +import torch +from einops import rearrange, repeat +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc +from internlm.core.parallel.comm.isp import auto_wrap_distributed_attention +from internlm.model.ops.utils import pack_output_after_attn, unpack_qkv_before_attn +from internlm.utils.common import get_current_device +from internlm.utils.utils import ( + CuSeqlenType, + QKVPackType, + check_attention_argument, + params_dispatch_with_condition, +) + +try: + from torch_npu import npu_fusion_attention as _origin_npu_fixedlen_qkvsplited_func + + is_torch_npu = True +except (ModuleNotFoundError, ImportError): + is_torch_npu = False + +try: + # TODO: add support of deeplink + from deeplink_ext.internevo_ops import FlashCrossAttention, FlashSelfAttention + + del FlashCrossAttention, FlashSelfAttention + + deeplink_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_flash_attn_impl = False + +try: + from flash_attn.flash_attn_interface import ( + flash_attn_func as _flash_fixedlen_qkvsplited_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_kvpacked_func as _flash_fixedlen_kvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_qkvpacked_func as _flash_fixedlen_qkvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_func as _flash_varlen_qkvsplited_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_kvpacked_func as _flash_varlen_kvpacked_func, + ) + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_qkvpacked_func as _flash_varlen_qkvpacked_func, + ) + + gpu_flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + gpu_flash_attn_impl = False + +internlm_accelerator = get_accelerator() +device_backend = internlm_accelerator.get_accelerator_backend() + + +def _nyi_attn(func_name, *args, **kwargs): # pylint: disable=W0613 + assert False, f"{func_name} is not yet implemented" + + +# gpu flash attention operators + + +def _flash_float32_compatibility_wrapper(input_idxs: Tuple, flash_func: Callable, *args, **kwargs): + if gpc.config.model.dtype is torch.float32: + inputs = (args[idx] for idx in input_idxs) + input_dtype = inputs[0].dtype + other_args = [args[idx] for idx in range(len(inputs), len(args))] + + with internlm_accelerator.amp.autocast(dtype=torch.bfloat16): + for idx in input_idxs: + if inputs[idx].dtype is torch.float32: + inputs[idx] = inputs[idx].to(torch.bfloat16) + return flash_func(*inputs, *other_args, **kwargs).to(input_dtype) + + return flash_func(*args, **kwargs) + + +def _flash_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False +): + # compatible data format: [1, packelen, 3, n_head, headim] + qkv = qkv.squeeze(dim=0) + + # input_idxs: 0: qkv + output = _flash_float32_compatibility_wrapper( + (0), _flash_varlen_qkvpacked_func, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: qkv + return _flash_float32_compatibility_wrapper( + (0), _flash_fixedlen_qkvpacked_func, qkv, dropout_p, softmax_scale, causal + ) + + +def _flash_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # compatible data format: [1, packelen, 3, n_head, headim] + q, kv = q.squeeze(dim=0), kv.squeeze(dim=0) + + # input_idxs: 0: q, 1: kv + output = _flash_float32_compatibility_wrapper( + (0, 1), + _flash_varlen_kvpacked_func, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: q, 1: kv + return _flash_float32_compatibility_wrapper( + (0, 1), _flash_fixedlen_kvpacked_func, q, kv, dropout_p, softmax_scale, causal + ) + + +def _flash_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # compatible data format: [1, packelen, 3, n_head, headim] + q, k, v = q.squeeze(dim=0), k.squeeze(dim=0), v.squeeze(dim=0) + + # input_idxs: 0: q, 1: k, 2: v + output = _flash_float32_compatibility_wrapper( + (0, 1, 2), + _flash_varlen_qkvsplited_func, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + ) + + return output.unsqueeze(dim=0) + + +def _flash_fixedlen_qkvsplited_attn(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): + # input_idxs: 0: q, 1: k, 2: v + return _flash_float32_compatibility_wrapper( + (0, 1, 2), _flash_fixedlen_qkvsplited_func, q, k, v, dropout_p, softmax_scale, causal + ) + + +# npu flash attention operators +# TODO: should we add _flash_float32_compatibility_wrapper support for npu. + + +def _npu_varlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # TODO: support npu native varlen flash attention + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + k = unpack_qkv_before_attn(k, cu_seqlens=cu_seqlens_k) + v = unpack_qkv_before_attn(v, cu_seqlens=cu_seqlens_k) + + output = _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _npu_fixedlen_qkvsplited_attn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + dropout_p: float, + softmax_scale=None, + causal=False, +): + assert causal is True + assert q.dtype in (torch.bfloat16, torch.float16) + + if len(q.shape) == 5: # [batch, seqlen, 1, n_head, headdim] + q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2) + + _, seqlen, n_head, _ = q.shape + attention_mask = torch.triu(torch.ones(seqlen, seqlen, device=get_current_device()), 1).bool() + + return _origin_npu_fixedlen_qkvsplited_func( + query=q, + key=k, + value=v, + head_num=n_head, + input_layout="BSND", # If necessary, expose the interface + pse=None, + atten_mask=attention_mask, + scale=softmax_scale, + sparse_mode=0, # If necessary, expose the interface + pre_tockens=seqlen, # Used for sparse calculations, representing the left boundary of the slides window + next_tockens=0, # If necessary, expose the interface + keep_prob=1 - dropout_p, + inner_precise=0, # If necessary, expose the interface + ) + + +def _npu_varlen_qkvpacked_attn( + qkv: torch.Tensor, cu_seqlens, max_seqlen, dropout_p, softmax_scale=None, causal=False # pylint: disable=W0613 +): + # TODO: support npu native varlen flash attention + packed_length = qkv.size(dim=1) + + qkv = unpack_qkv_before_attn(qkv, cu_seqlens=cu_seqlens) + + output = _npu_fixedlen_qkvpacked_attn(qkv, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens, packed_length) + + +def _npu_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): + q, k, v = qkv.unbind(dim=2) + return _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + +def _npu_varlen_kvpacked_attn( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, # pylint: disable=W0613 + max_seqlen_k, # pylint: disable=W0613 + dropout_p=0.0, + softmax_scale=None, + causal=False, +): + # TODO: support npu native varlen flash attention + packed_length = q.size(dim=1) + + q = unpack_qkv_before_attn(q, cu_seqlens=cu_seqlens_q) + kv = unpack_qkv_before_attn(kv, cu_seqlens=cu_seqlens_k) + + output = _npu_fixedlen_kvpacked_attn(q, kv, dropout_p, softmax_scale, causal) + + return pack_output_after_attn(output, cu_seqlens_q, packed_length) + + +def _npu_fixedlen_kvpacked_attn(q: torch.Tensor, kv: torch.Tensor, dropout_p: float, softmax_scale=None, causal=False): + k, v = kv.unbind(dim=2) + k, v = k.squeeze(dim=2), v.squeeze(dim=2) + return _npu_fixedlen_qkvsplited_attn(q, k, v, dropout_p, softmax_scale, causal) + + +# deeplink flash attention operators + + +def _deeplink_varlen_qkvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_qkvpacked_attn", *args, **kwargs) + + +def _deeplink_fixedlne_qkvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlne_qkvpacked_attn", *args, **kwargs) + + +def _deeplink_varlen_kvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_kvpacked_attn", *args, **kwargs) + + +def _deeplink_fixedlen_kvpacked_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlen_kvpacked_attn", *args, **kwargs) + + +def _deeplink_varlen_qkvsplited_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_varlen_qkvsplited_attn", *args, **kwargs) + + +def _deeplink_fixedlen_qkvsplited_attn(*args, **kwargs): + # TODO: support deeplink version flash attention + _nyi_attn("_deeplink_fixedlen_qkvsplited_attn", *args, **kwargs) + + +# torch attention operators + + +def _torch_varlen_qkvpacked_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_qkvpacked_attn", *args, **kwargs) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py +def _torch_fixedlen_qkvpacked_attn(qkv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None): + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + q, k, v = qkv.unbind(dim=2) + + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = dropout(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + + return output + + +def _torch_varlen_kvpacked_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_kvpacked_attn", *args, **kwargs) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/v2.2.1/flash_attn/modules/mha.py +def _torch_fixedlen_kvpacked_attn( + q: torch.Tensor, kv: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None +): + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + if kv.shape[3] != q.shape[2]: # MQA/GQA + kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) + k, v = kv.unbind(dim=2) + softmax_scale = softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + + if causal: + # causal mask needs to take into account the difference between seqlen_q and seqlen_k + row_idx = rearrange(torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + causal_mask = col_idx > row_idx + sk - seqlen_q + scores = scores.masked_fill(causal_mask, -10000.0) + + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = dropout(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + + return output + + +def _torch_varlen_qkvsplited_attn(*args, **kwargs): + _nyi_attn("_torch_varlen_qkvsplited_attn", *args, **kwargs) + + +def _torch_fixedlen_qkvsplited_attn( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, dropout, softmax_scale=None, causal=False, key_padding_mask=None +): + kv = torch.stack([k, v], dim=2) + return _torch_fixedlen_kvpacked_attn(q, kv, dropout, softmax_scale, causal, key_padding_mask) + + +@auto_wrap_distributed_attention +class SelfAttention(nn.Module): + """Implements scaled dot-product attention with optional softmax scaling. + + This class implements the scaled dot-product attention mechanism, which can be optionally scaled + by a softmax scaling factor. It supports configurations for causal attention and applies dropout + to the attention scores. + + Arguments: + causal (bool): If True, applies causal attention to mask future tokens. Defaults to False. + softmax_scale (Optional[float]): Scaling factor for attention scores before applying softmax. + Defaults to 1/sqrt(d_keys) where d_keys is the dimension of the keys, computed at runtime. + attention_dropout (float): Dropout rate for attention scores. Defaults to 0.0. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = nn.Dropout(attention_dropout) + + if device_backend == AcceleratorType.NPU: + assert self.causal, "Ascend flash attention does not spport causal=False yet!" + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self): + """Placeholder for multihead softmax attention implementation. + + This method serves as a placeholder and should not be reached during execution. It is expected + to be overridden by specific implementations for different attention mechanisms. + + Raises: + AssertionError: Always raised to indicate the method should not be called directly. + """ + assert False, "Never arrive here" + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.WithOut))) + def _qkv_without_cu_seqlens(self, qkv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlne_qkvpacked_attn(qkv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvpacked_attn(qkv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_kvpacked_attn(q, kv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvsplited_attn(q, k, v, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVPACKED), str(CuSeqlenType.With))) + def _qkv_with_cu_seqlens( + self, + qkv, + cu_seqlens, + max_seqlen, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvpacked_attn(qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvpacked_attn(qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvpacked_attn( + qkv, cu_seqlens, max_seqlen, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvpacked_attn( + qkv, cu_seqlens, max_seqlen, self.dropout, softmax_scale, causal, key_padding_mask + ) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + def _q_kv_with_cu_seqlens( + self, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_kvpacked_attn( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + def _q_k_v_with_cu_seqlens( + self, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + +@auto_wrap_distributed_attention +class CrossAttention(nn.Module): + """Implements scaled dot product attention with softmax. + + This class provides the functionality for cross attention mechanism using scaled dot product attention + with optional softmax scaling and dropout for attention weights. + + Arguments: + causal (bool): If True, applies causality to prevent tokens from attending to future tokens. Default is False. + softmax_scale (float, optional): The scaling factor to apply to the dot products before softmax. If None, + it defaults to 1/sqrt(d_keys) where d_keys is the dimension of the keys, computed at runtime. + attention_dropout (float): The dropout rate to apply to the attention. + + Raises: + AssertionError: If `device_backend` is NPU and `causal` is False, since Ascend flash attention does not + support non-causal attention yet. + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = nn.Dropout(attention_dropout) + + if device_backend == AcceleratorType.NPU: + assert self.causal, "Ascend flash attention does not support causal=False yet!" + + @params_dispatch_with_condition(condition=check_attention_argument) + def forward(self): + """Placeholder for cross attention implementation. + + This method is a placeholder and should not be reached in execution as it is expected to be + overridden by specific implementations for different attention parameters. + + Raises: + AssertionError: Always raised to indicate the method should not be called directly. + """ + assert False, "Never arrive here" + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.WithOut))) + def _q_kv_without_cu_seqlens(self, q, kv, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_kvpacked_attn(q, kv, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_kvpacked_attn(q, kv, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.WithOut))) + def _q_k_v_without_cu_seqlens(self, q, k, v, softmax_scale=None, causal=None, key_padding_mask=None): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_fixedlen_qkvsplited_attn(q, k, v, self.dropout.p, softmax_scale, causal) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_fixedlen_qkvsplited_attn(q, k, v, self.dropout, softmax_scale, causal, key_padding_mask) + + @forward.register(conditions=(str(QKVPackType.KVPACKED), str(CuSeqlenType.With))) + def _q_kv_with_cu_seqlens( + self, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_kvpacked_attn( + q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, self.dropout.p, softmax_scale, causal + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_kvpacked_attn( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) + + @forward.register(conditions=(str(QKVPackType.QKVSPLITED), str(CuSeqlenType.With))) + def _q_k_v_with_cu_seqlens( + self, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=None, + causal=None, + key_padding_mask=None, + ): + softmax_scale = self.softmax_scale if softmax_scale is None else softmax_scale + causal = self.causal if causal is None else causal + + if gpc.config.model.get("use_flash_attn", False): + if device_backend == AcceleratorType.GPU and gpu_flash_attn_impl: + return _flash_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.NPU and is_torch_npu: + return _npu_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + elif device_backend == AcceleratorType.DIPU and deeplink_flash_attn_impl: + return _deeplink_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout.p, + softmax_scale, + causal, + ) + else: + raise NotImplementedError(f"Unsupported device type: {device_backend} for flash attention") + else: + return _torch_varlen_qkvsplited_attn( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + self.dropout, + softmax_scale, + causal, + key_padding_mask, + ) diff --git a/internlm/model/ops/cross_entropy.py b/internlm/model/ops/cross_entropy.py new file mode 100644 index 00000000..f3fdccf9 --- /dev/null +++ b/internlm/model/ops/cross_entropy.py @@ -0,0 +1,60 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the cross entropy operators. +""" + +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +try: + from flash_attn.losses.cross_entropy import ( + CrossEntropyLoss as FlashCrossEntropyLoss, + ) + + flash_cross_entropy_impl = True +except (ModuleNotFoundError, ImportError): + flash_cross_entropy_impl = False + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + + +# TODO: ops是否需要实现更加统一的形式 +def new_cross_entropy( + ignore_index: int = -100, + reduction: str = "mean", + label_smoothing: float = 0, + parallel_output: bool = False, + **kwargs, +): + if parallel_output: + assert ( + gpc.config.model.get("use_flash_attn", False) and flash_cross_entropy_impl + ), "Only flash cross entropy support parallel_output" + assert ( + internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + ), "flash cross entropy only support gpu backend" + + return FlashCrossEntropyLoss( + ignore_index=ignore_index, + reduction=reduction, + label_smoothing=label_smoothing, + process_group=gpc.get_group(ParallelMode.TENSOR), + ) + else: + if gpc.is_rank_for_log(): + logger.warning( + "Use nn.CrossEntropyLoss rather than flashattn CrossEntropyLoss." + "parallel_output must be set false. Please note this!" + ) + kwargs.pop("inplace_backward", None) + return nn.CrossEntropyLoss( + ignore_index=ignore_index, reduction=reduction, label_smoothing=label_smoothing, **kwargs + ) diff --git a/internlm/model/ops/fusion_ops_import_helper.py b/internlm/model/ops/fusion_ops_import_helper.py deleted file mode 100644 index f75ff889..00000000 --- a/internlm/model/ops/fusion_ops_import_helper.py +++ /dev/null @@ -1,211 +0,0 @@ -from typing import Callable, Tuple, Union - -import torch -from torch import nn - -from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.context import ParallelMode -from internlm.core.context import global_context as gpc -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - -internlm_accelerator = get_accelerator() - - -# RMSNorm -def try_import_RMSNorm(): - """ - Try import MixFusedRMSNorm from apex, if failed, return our RMSNorm - - """ - try: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend == AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import MixedFusedRMSNorm as RMSNorm - - if gpc.is_rank_for_log(): - logger.warning("Use Deeplink MixedFusedRMSNorm, Please note this!") - - return RMSNorm - else: - from apex.normalization.fused_layer_norm import MixedFusedRMSNorm as RMSNorm - - if gpc.is_rank_for_log(): - logger.warning("Use apex MixedFusedRMSNorm, Please note this!") - - return RMSNorm - except (ModuleNotFoundError, ImportError): - if gpc.is_rank_for_log(): - logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") - from internlm.model.ops.norm import RMSNormTorch as RMSNorm - - return RMSNorm - - -# RotaryEmb -def try_import_fused_rotary() -> Tuple[Union[None, Callable], Union[None, Callable], Union[None, Callable]]: - """try_import_fused_rotary - - Returns: - Tuple[Union[None, Callable], Union[None, Callable], Union[None, Callable]]: - Returns if there is a mixing operator available, otherwise returns None. - """ - try: - device_backend = internlm_accelerator.get_accelerator_backend() - if device_backend is AcceleratorType.GPU: - import rotary_emb - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn rotary_emb, Please note this!") - - return None, None, rotary_emb.apply_rotary - elif device_backend is AcceleratorType.DIPU: - from deeplink_ext.internevo_ops import ( - ApplyRotaryEmb as DeeplinkApplyRotaryEmb, - ) - from deeplink_ext.internevo_ops import ( - ApplyRotaryEmbQKV_ as DeeplinkApplyRotaryEmbQKV_, - ) - - if gpc.is_rank_for_log(): - logger.warning("Use Deeplink ApplyRotaryEmb, Please note this!") - - return DeeplinkApplyRotaryEmb.apply, DeeplinkApplyRotaryEmbQKV_.apply, None - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning( - "The torch implementation for apply_rotary is slower" "than flash atten rotary_emb. Please note this!" - ) - return None, None, None - - -# CrossEntropyLoss -def internlm_init_CrossEntropyLoss( - parallel_output: bool, reduction="none", label_smoothing=0, inplace_backward=True, process_group=None, **kwargs -): - """ - Try import FlashCrossEntropyLoss from flash_attn, if failed, return our CrossEntropyLoss - - """ - if parallel_output: - try: - if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: - from flash_attn.losses.cross_entropy import ( - CrossEntropyLoss as FlashCrossEntropyLoss, - ) - - if process_group is None: - gpc.get_group(ParallelMode.TENSOR) - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn FlashCrossEntropyLoss, Please note this!") - - return FlashCrossEntropyLoss( - reduction=reduction, - inplace_backward=inplace_backward, - process_group=process_group, - label_smoothing=label_smoothing, - **kwargs, - ) - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning( - "Use nn.CrossEntropyLoss rather than CrossEntropyLoss." - "parallel_output must be set false. Please note this!" - ) - - if "process_group" in kwargs: - kwargs.pop("process_group") - if "inplace_backward" in kwargs: - kwargs.pop("inplace_backward") - - return nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing, **kwargs) - - -# Adamw -def try_import_FusedAdamW(): - """ - Try import FusedAdamW from torch_npu/torch - - """ - adam_extra_kwargs = {} - backend = internlm_accelerator.get_accelerator_backend() - try: - if backend is AcceleratorType.GPU: - if torch.__version__ >= "2.1.0": - adam_extra_kwargs["fused"] = True - - if gpc.is_rank_for_log(): - logger.warning( - "Use fused AdamaW to avoid nan grad norm when " - "model size is larger and use_fp32_norm=True, Please note this!" - ) - return adam_extra_kwargs, torch.optim.AdamW - elif backend is AcceleratorType.NPU: - - if gpc.is_rank_for_log(): - logger.warning( - "Use normal AdamaW, NPU fused_adamw currently has" - "accuracy issues and is not supported yet. Please note this!" - ) - # return adam_extra_kwargs, torch_npu.optim.NpuFusedAdamW - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") - return adam_extra_kwargs, torch.optim.AdamW - - -# scatter_sum -def try_import_scatter_sum(): - """ - Try import scatter_sum from cuda, if failed, return None - - """ - try: - if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.GPU, AcceleratorType.DIPU]: - from torch_scatter import scatter as cuda_scatter - - if gpc.is_rank_for_log(): - logger.warning("Use cuda_scatter. Please note this!") - - return cuda_scatter - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use vanilla_scatter rather than cuda_scatter. Please note this!") - - return None - - -# FlashAttn -def try_import_linear_bias_wgrad(): - """ - Try import linear_bias_wgrad from flash_attn, if failed, return None - - """ - try: - if internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU: - import fused_dense_lib as fused_dense_cuda - - if gpc.is_rank_for_log(): - logger.warning("Use flash_attn linear_bias_wgrad. Please note this!") - - return fused_dense_cuda.linear_bias_wgrad - - except (ModuleNotFoundError, ImportError): - pass - - if gpc.is_rank_for_log(): - logger.warning("Use linear_bias_wgrad_torch. Please note this!") - - return None diff --git a/internlm/model/ops/linear.py b/internlm/model/ops/linear.py index 6afd1e61..eeffddc0 100644 --- a/internlm/model/ops/linear.py +++ b/internlm/model/ops/linear.py @@ -1,396 +1,63 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. -from typing import Optional +This file implements support for the linear layer operators. +""" + +from typing import Optional, Tuple import torch -from torch import nn -from torch.distributed import ProcessGroup +from torch.nn.functional import linear as _torch_linear_forward_op -from internlm.core.context import ParallelMode +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import global_context as gpc -from internlm.model.utils import ( - all_reduce, - fused_dense_func, - isp_fused_dense_func, - megatron_fused_dense_func, - reduce_scatter, -) -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -class BaseScaleColumnParallelLinear(nn.Linear): - """ - Base class for ScaleColumnParallelLinear. - - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % world_size != 0: - raise ValueError(f"out_features ({out_features}) must be divisible by " f"world_size ({world_size})") - super().__init__(in_features, out_features // world_size, bias=bias, device=device, dtype=dtype) - self.process_group = process_group - self.weight_scale = weight_scale - - -class ScaleColumnParallelLinear(BaseScaleColumnParallelLinear): - """ - ScaleColumnParallelLinear in flash implementation. - """ - - def forward(self, input, gather_dim=1, tp_mode: str = "mtp"): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - - _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func - return _fused_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - gather_dim=gather_dim, - ) - - -class ScaleColumnParallelLinearWithNormHead(BaseScaleColumnParallelLinear): - """ - ScaleColumnParallelLinear for InternLM2. - - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - norm_head (bool): Normalize the output embedding in order to let the calculation of logits not affected by - the norm of embedding. The implementation is referred to baichuan2, - see https://huggingface.co/baichuan-inc/Baichuan2-7B-Base for more information. False by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - norm_head: bool = False, - ) -> None: - super().__init__( - in_features, out_features, process_group, bias=bias, device=device, dtype=dtype, weight_scale=weight_scale - ) - - self.norm_head = norm_head - if self.norm_head: - logger.info("Notice that norm head is enabled to normalize head weight.") - self.first_eval_flag = True - self.tmp_weight = None - - def forward(self, input, gather_dim=1, tp_mode: str = "mtp"): # pylint: disable=W0622 - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - if self.norm_head: - if self.training: - if not self.first_eval_flag: - self.first_eval_flag = True - self.tmp_weight = None - # We normalized the output Embedding so that the dot product - # is not affected by the norm of embedding. Ref: https://arxiv.org/pdf/2309.10305.pdf - weight = nn.functional.normalize(weight) - else: - if self.first_eval_flag: - # cache l2 norm of head to accelerate infer. - self.first_eval_flag = False - self.tmp_weight = nn.functional.normalize(weight) - - weight = self.tmp_weight - - _fused_func = fused_dense_func if tp_mode in ["mtp", "fsp", "isp"] else megatron_fused_dense_func - return _fused_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - gather_dim=gather_dim, - ) - - -class RewardModelLinear(BaseScaleColumnParallelLinear): - """ - RewardModelLinear. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: Optional[torch.distributed.ProcessGroup], - bias: bool = True, - device: Optional[torch.device] = None, - dtype: Optional[torch.dtype] = None, - weight_scale: int = 1, - ) -> None: - super().__init__(in_features, out_features, process_group, bias, device, dtype, weight_scale) - torch.distributed.broadcast(self.weight, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) - if bias: - torch.distributed.broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.TENSOR)[0], process_group) - - def forward(self, input): # pylint: disable=W0622 - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - if self.weight_scale != 1: - weight = self.weight * self.weight_scale + (1 - self.weight_scale) * self.weight.detach() - else: - weight = self.weight - return fused_dense_func( - input, - weight, - self.bias, - process_group=self.process_group, - sequence_parallel=gpc.config.parallel.sequence_parallel, - ) - - -class ColumnParallelLinearTorch(nn.Linear): - """ - ColumnParallelLinearTorch. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ - - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - if out_features % multiple_of: - raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") - multiple = out_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - - def forward(self, x, gather_dim=1): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - gather_dim=gather_dim, - ) - - -class MegatronColumnParallelLinearTorch(ColumnParallelLinearTorch): - """ - MegatronColumnParallelLinearTorch - """ - def forward(self, x, gather_dim=1): - # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - # we do an all_gather of x before doing the matmul. - # If not, then the input is already gathered. - return megatron_fused_dense_func( - x, - self.weight, - self.bias, - process_group=self.process_group, - sequence_parallel=self.sequence_parallel, - gather_dim=gather_dim, - ) +try: + from fused_dense_lib import linear_bias_wgrad as _flash_linear_backward_op + flash_attn_impl = True +except (ModuleNotFoundError, ImportError): + flash_attn_impl = False -class RowParallelLinearTorch(nn.Linear): - """ - RowParallelLinearTorch. - Args: - in_features (int): size of each input sample - out_features (int): size of each output sample - process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. - bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False - in the config. - sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: - we do an all_gather of x before doing the matmul. - If not, then the input is already gathered. - device (Optional[Union[str, torch.device]]): The device will be used. - dtype (Optional[torch.dtype]): The type of data. - weight_scale (int): For training stability. 1 by default. - """ +internlm_accelerator = get_accelerator() - def __init__( - self, - in_features: int, - out_features: int, - process_group: ProcessGroup, - bias: bool = True, - sequence_parallel=True, - multiple_of=1, - device=None, - dtype=None, - ) -> None: - world_size = torch.distributed.get_world_size(process_group) - rank = torch.distributed.get_rank(process_group) - if in_features % multiple_of: - raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") - multiple = in_features // multiple_of - # We want to split @multiple across world_size, but it could be an uneven split - div = multiple // world_size - mod = multiple % world_size - # The first @mod ranks get @div + 1 copies, the rest get @div copies - local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) - # Only rank 0 will have bias - super().__init__( - local_multiple * multiple_of, - out_features, - bias=bias and rank == 0, - device=device, - dtype=dtype, - ) - self.process_group = process_group - self.sequence_parallel = sequence_parallel - def forward(self, x, reduce_dim=1): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = fused_dense_func(x, self.weight, self.bias) - if self.sequence_parallel: - return reduce_scatter(out, self.process_group, reduce_dim) - else: - return all_reduce(out, self.process_group) +def _select_ops_binding(dtype: torch.dtype, is_cuda: bool = True) -> None: + dtype_eligible = dtype in (torch.float16, torch.bfloat16) or ( + dtype == torch.float32 and torch.is_autocast_enabled() + ) + use_flash_attn = gpc.config.model.get("use_flash_attn", False) + is_gpu_backend = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + flash_attn_eligible = flash_attn_impl and dtype_eligible and is_cuda + if use_flash_attn and is_gpu_backend and flash_attn_eligible: + return _torch_linear_forward_op, _flash_linear_backward_op + else: + return _torch_linear_forward_op, _linear_bias_wgrad_torch -class MegatronRowParallelLinearTorch(RowParallelLinearTorch): - """ - MegatronRowParallelLinearTorch. - """ - def forward(self, x, reduce_dim=1): - """ - We're doing Tensor Parallel with sequence parallelism: we do the matmul and then - a reduce_scatter of the result. - """ - out = megatron_fused_dense_func(x, self.weight, self.bias) - if self.sequence_parallel: - return reduce_scatter(out, self.process_group, reduce_dim) - else: - return all_reduce(out, self.process_group) +def _linear_bias_wgrad_torch(_input: torch.Tensor, grad_output: torch.Tensor, has_d_bias: bool): + assert _input.dtype == grad_output.dtype + grad_weight = torch.matmul(grad_output.t(), _input) + grad_bias = grad_output.sum(dim=0) if has_d_bias else None -class ISPLinear(ColumnParallelLinearTorch): - """ - Linear class for isp tensor parallel mode. - """ + return grad_weight, grad_bias - # class level communicator variable. - __communicator = None - @staticmethod - def register_communicator(communicator): - ISPLinear.__communicator = communicator +def linear_forward_op(_input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + _is_cuda = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + _forward_op, _ = _select_ops_binding(_input.dtype, _is_cuda) - def forward(self, x): - assert self.__communicator is not None, "ISPLinear should be register with a communicator first." + return _forward_op(_input, weight, bias) - return isp_fused_dense_func( - x, - self.weight, - module=self, - communicator=self.__communicator, - bias=self.bias, - ) +def linear_backward_op( + _input: torch.Tensor, weight: torch.Tensor, has_d_bias: bool +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + _is_cuda = internlm_accelerator.get_accelerator_backend() is AcceleratorType.GPU + _, _backward_op = _select_ops_binding(_input.dtype, _is_cuda) -def get_linear_cls(tp_mode: str, parallel_mode: str): - if parallel_mode == "column": - if tp_mode in ["mtp", "fsp"]: - cls = ColumnParallelLinearTorch - elif tp_mode == "msp": - cls = MegatronColumnParallelLinearTorch - else: - cls = ISPLinear - elif parallel_mode == "row": - if tp_mode in ["mtp", "fsp"]: - cls = RowParallelLinearTorch - elif tp_mode == "msp": - cls = MegatronRowParallelLinearTorch - else: - cls = ISPLinear - return cls + return _backward_op(_input, weight, has_d_bias) diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 6598e178..3cd43dab 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -6,8 +6,36 @@ from torch.nn import init from torch.nn.parameter import Parameter +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.utils.logger import get_logger -def manual_rms_norm(my_input, normalized_shape, weight, eps): +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +try: + from apex.normalization.fused_layer_norm import mixed_dtype_fused_rms_norm_affine + + apex_rmsnorm_impl = True +except (ModuleNotFoundError, ImportError): + logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") + apex_rmsnorm_impl = False + +try: + from deeplink_ext.internevo_ops import MixedFusedRMSNorm + + deeplink_rmsnorm_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_rmsnorm_impl = False + +try: + from torch_npu import npu_rms_norm + + torchnpu_rmsnorm_impl = True +except (ModuleNotFoundError, ImportError): + torchnpu_rmsnorm_impl = False + + +def manual_rms_norm(my_input, weight, normalized_shape, eps): # layer norm should always be calculated in float32 dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) @@ -23,8 +51,36 @@ def manual_rms_norm(my_input, normalized_shape, weight, eps): return weight * my_input -class RMSNormTorch(torch.nn.Module): - """A custom PyTorch module for RMS normalization.""" +class _RMSNorm(torch.nn.Module): + """A generic module for RMS normalization.""" + + def __init__(self, normalized_shape, eps=1e-5): + super().__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.weight = Parameter(torch.empty(*normalized_shape)) + self.reset_parameters() + + def forward(self, _input: torch.Tensor): + if apex_rmsnorm_impl: + _norm_func = mixed_dtype_fused_rms_norm_affine + else: + _norm_func = manual_rms_norm + + return _norm_func(_input, self.weight, self.normalized_shape, self.eps) + + def reset_parameters(self): + init.ones_(self.weight) + + def extra_repr(self): + return f"{self.normalized_shape}, eps={self.eps}, " + + +class _RMSNormNPU(torch.nn.Module): + """A custom NPU module for RMS normalization.""" def __init__(self, normalized_shape, eps=1e-5): super().__init__() @@ -35,12 +91,26 @@ def __init__(self, normalized_shape, eps=1e-5): self.eps = eps self.weight = Parameter(torch.empty(*normalized_shape)) self.reset_parameters() + self.rmsorm_npu_forward = npu_rms_norm def forward(self, _input: torch.Tensor): - return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps) + weight_fp32 = self.weight.to(torch.float32) + input_fp32 = _input.to(torch.float32) + output = self.rmsorm_npu_forward(input_fp32, gamma=weight_fp32, epsilon=self.eps)[0].to(self.weight.dtype) + return output def reset_parameters(self): init.ones_(self.weight) def extra_repr(self): - return "{normalized_shape}, eps={eps}, ".format(**self.__dict__) + return f"{self.normalized_shape}, eps={self.eps}, ".format(**self.__dict__) + + +# TODO: Support deeplink in a more unified manner +backend = internlm_accelerator.get_accelerator_backend() +if backend == AcceleratorType.DIPU and deeplink_rmsnorm_impl: + RMSNorm = MixedFusedRMSNorm +elif backend == AcceleratorType.NPU and torchnpu_rmsnorm_impl: + RMSNorm = _RMSNormNPU +else: + RMSNorm = _RMSNorm diff --git a/internlm/model/ops/rotary_emb.py b/internlm/model/ops/rotary_emb.py new file mode 100644 index 00000000..86c11570 --- /dev/null +++ b/internlm/model/ops/rotary_emb.py @@ -0,0 +1,305 @@ +""" +A simple operator selector, used for compatibility with different platforms such as CUDA and Ascend, +as well as whether to enable flash-attn operator optimization, may be replaced by a more comprehensive +operator compatibility layer in the future. + +This file implements support for the roatry embedding operators. +""" + +from typing import Callable, Tuple + +import torch +from einops import rearrange +from torch import Tensor + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc + +try: + from rotary_emb import apply_rotary as _flash_apply_rotary_func + + flash_rotary_impl = True +except (ModuleNotFoundError, ImportError): + flash_rotary_impl = False + +try: + from deeplink_ext.internlm_ops import ApplyRotaryEmb as DeeplinkApplyRotaryEmb + + deeplink_rotary_impl = True +except (ModuleNotFoundError, ImportError): + deeplink_rotary_impl = False + + +try: + from torch_npu import npu_rotary_mul + + torchnpu_rotary_impl = True +except (ModuleNotFoundError, ImportError): + torchnpu_rotary_impl = False + +internlm_accelerator = get_accelerator() + + +def _rope_to_float32_wrapper(input_idxs: Tuple, rope_func: Callable, *args, **kwargs): + try: + use_fp32_rope = gpc.config.model.get("use_fp32_rope", True) + except AttributeError: + use_fp32_rope = True + + if use_fp32_rope: + inputs = [args[idx] for idx in input_idxs] + input_dtype = inputs[0].dtype + other_args = [args[idx] for idx in range(len(inputs), len(args))] + + for idx in input_idxs: + inputs[idx] = inputs[idx].to(torch.float32) + + res = rope_func(*inputs, *other_args, **kwargs) + if res is not None: + return res.to(input_dtype) + else: + return rope_func(*args, **kwargs) + + +def _torch_apply_rotary_func( + x1: torch.Tensor, + x2: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, + conj: bool = False, +): + # TODO: improve perfermance. + assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device" + assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype" + assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes" + assert cos.size() == sin.size(), "Input cos and sin must have the same sizes" + + # x1, x2, cos, sin = x1.float(), x2.float(), cos.float(), sin.float() + + if conj: + out1.copy_(x1 * cos + x2 * sin) + out2.copy_(-x1 * sin + x2 * cos) + else: + out1.copy_(x1 * cos - x2 * sin) + out2.copy_(x1 * sin + x2 * cos) + + +def _apply_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor): + """ + Implement RotaryEmbedding rotation position encoding. Support FakeTensor mode. + Ref: https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/80RC1alpha002/ + apiref/fmkadptapi/ptaoplist_000451.html + Args: + x (Tensor): q or k, shape is [B, S, N, D]. + cos (Tensor): cos, shape is [1, S, 1, D]. + sin (Tensor): sin, shape is [1, S, 1, D]. + """ + return npu_rotary_mul(x, cos, sin) + + +def _apply_torch_npu_rotary_mul(x: Tensor, cos: Tensor, sin: Tensor): + """Torch implementation of 'npu_rotary_mul', baseline for unit testing. + + Args: + x (Tensor): q or k, shape is [B, S, N, D]. + cos (Tensor): cos, shape is [1, S, 1, D]. + sin (Tensor): sin, shape is [1, S, 1, D]. + """ + # NOTE: This could probably be moved to Triton. + def rotate_half(_x): + x1, x2 = _x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + # Handle a possible sequence length mismatch in between q and k. + cos = cos[:, : x.shape[1], :, :] + sin = sin[:, : x.shape[1], :, :] + re = (x * cos) + (rotate_half(x) * sin) + + del rotate_half + return re + + +def _select_apply_rotary_func_npu(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, use_fused_rope: bool = False): + if use_fused_rope: + return _rope_to_float32_wrapper((0, 1, 2), _apply_npu_rotary_mul, x, cos, sin) + else: + return _rope_to_float32_wrapper((0, 1, 2), _apply_torch_npu_rotary_mul, x, cos, sin) + + +def rotary_emb_in_rotate_half_style( + x: Tensor, + cos: Tensor, + sin: Tensor, + interleaved=False, + use_fused_rope=False, +): + """The rotary_emb implemented in the rotate_half style is different from the flash_attn's rotary_emb + in that cos and sin require [max_position_embeddings, dim/2] -> [1, max_position_embeddings, 1, dim]. + + Args: + x (Tensor): x, If x is qkv, shape is [B, S, 3, N, D]; If x is q or k, shape is [B, S, N, D]. + cos (Tensor): cos, shape is [S, D//2]. + sin (Tensor): sin, shape is [S, D//2]. + """ + # reformat cos/sin shape. + cos = torch.cat((cos, cos), dim=-1)[None, :, None, :] + sin = torch.cat((sin, sin), dim=-1)[None, :, None, :] + + if len(x.shape) == 5: + q, k, _ = x.unbind(dim=2) + + if interleaved: + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + + q = _select_apply_rotary_func_npu(q, cos, sin, use_fused_rope) + k = _select_apply_rotary_func_npu(k, cos, sin, use_fused_rope) + + if interleaved: + x[:, :, 0, ..., : x.shape[-1] // 2].copy_(q[..., ::2]) + x[:, :, 0, ..., x.shape[-1] // 2 :].copy_(q[..., 1::2]) + + x[:, :, 1, ..., : x.shape[-1] // 2].copy_(k[..., ::2]) + x[:, :, 1, ..., x.shape[-1] // 2 :].copy_(k[..., 1::2]) + else: + x[:, :, 0, ...].copy_(q) + x[:, :, 1, ...].copy_(k) + else: + if interleaved: + x = torch.cat([x[..., ::2], x[..., 1::2]], dim=-1) + x = _select_apply_rotary_func_npu(x, cos, sin, use_fused_rope) + if interleaved: + out = torch.empty_like(x) + out[..., ::2].copy_(x[..., : x.shape[-1] // 2]) + out[..., 1::2].copy_(x[..., x.shape[-1] // 2 :]) + x = out + return x + + +def _select_apply_rotary_func( + x1: torch.Tensor, + x2: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + out1: torch.Tensor, + out2: torch.Tensor, + conj: bool = False, + use_fused_rope: bool = True, +) -> None: + if use_fused_rope and flash_rotary_impl: + _flash_apply_rotary_func(x1, x2, cos, sin, out1, out2, conj) + else: + _rope_to_float32_wrapper((0, 1, 2, 3), _torch_apply_rotary_func, x1, x2, cos, sin, out1, out2, conj) + + +# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py#L35 +class ApplyRotaryEmb(torch.autograd.Function): + """ + ApplyRotaryEmb + """ + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + interleaved: bool = False, + in_place: bool = False, + use_fused_rope: bool = True, + ): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + *_, seqlen, _, head_dim = x.shape + rotary_seqlen, rotary_dim = cos.shape + rotary_dim *= 2 + + assert rotary_dim <= head_dim + assert seqlen <= rotary_seqlen + assert sin.shape == (rotary_seqlen, rotary_dim // 2) + + x_ro = x[..., :rotary_dim] + x1, x2 = (x_ro[..., ::2], x_ro[..., 1::2]) if interleaved else x_ro.chunk(2, dim=-1) + + if in_place: + out, o1, o2 = x, x1, x2 + else: + out = torch.empty_like(x) + out_ro = out[..., :rotary_dim] + o1, o2 = (out_ro[..., ::2], out_ro[..., 1::2]) if interleaved else out_ro.chunk(2, dim=-1) + + _select_apply_rotary_func( + x1, + x2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + o1, + o2, + False, + use_fused_rope, + ) + + if rotary_dim < head_dim and not in_place: + out[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + ctx.save_for_backward(cos, sin) + ctx.interleaved = interleaved + ctx.in_place = in_place + ctx.use_fused_rope = use_fused_rope + + return out + + @staticmethod + def backward(ctx, do): + cos, sin = ctx.saved_tensors + *_, seqlen, _, head_dim = do.shape + rotary_dim = cos.shape[-1] + rotary_dim *= 2 + + do_ro = do[..., :rotary_dim] + do1, do2 = (do_ro[..., ::2], do_ro[..., 1::2]) if ctx.interleaved else do_ro.chunk(2, dim=-1) + + if ctx.in_place: + dx, dx1, dx2 = do, do1, do2 + else: + dx = torch.empty_like(do) + dx_ro = dx[..., :rotary_dim] + dx1, dx2 = (dx_ro[..., ::2], dx_ro[..., 1::2]) if ctx.interleaved else dx_ro.chunk(2, dim=-1) + + _select_apply_rotary_func( + do1, + do2, + rearrange(cos[:seqlen], "s d -> s 1 d"), + rearrange(sin[:seqlen], "s d -> s 1 d"), + dx1, + dx2, + True, + ctx.use_fused_rope, + ) + + if rotary_dim < head_dim and not ctx.in_place: + dx[..., rotary_dim:].copy_(do[..., rotary_dim:]) + + return dx, None, None, None, None + + +def apply_rotary_emb( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False, in_place: bool = False +): + # TODO: Support deeplink in a more unified manner + use_fused_rope = gpc.config.model.get("use_fused_rope", True) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.DIPU: + # TODO: to support in_place argument + return DeeplinkApplyRotaryEmb.apply(x, cos, sin, interleaved, use_fused_rope) + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + return rotary_emb_in_rotate_half_style(x, cos, sin, interleaved, use_fused_rope) + else: + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, in_place) diff --git a/internlm/model/ops/utils.py b/internlm/model/ops/utils.py new file mode 100644 index 00000000..04d068cd --- /dev/null +++ b/internlm/model/ops/utils.py @@ -0,0 +1,48 @@ +""" +Some hepler functions for ops package. +""" + +import torch +from torch.nn.utils.rnn import pad_sequence + + +def unpack_qkv_before_attn(cur_input: torch.Tensor, cu_seqlens: torch.Tensor, padding_v: int = 0): + """ + qkv: the shape is (1, packed_length, three, head_num, head_dim) + kv: the shape is (1, packed_length, two, head_num, head_dim) + q/k/v: the shape is (1, packed_length, head_num, head_dim) + + Return: + output: the shape is (micro_bsz, seq_len, three, head_num, head_dim) for qkv + (micro_bsz, seq_len, two, head_num, head_dim) for kv + (micro_bsz, seq_len, head_num, head_dim) for q/k/v + """ + assert cur_input.shape[0] == 1 + cur_input = cur_input.squeeze(0) + + sequences = [] + for i in range(len(cu_seqlens) - 1): + sequences.append(cur_input[cu_seqlens[i] : cu_seqlens[i + 1]]) + + padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_v) + + return padded_sequences + + +def pack_output_after_attn(cur_input: torch.Tensor, cu_seqlens: torch.Tensor, packed_length: int, padding_v: int = 0): + """ + cur_input: the shape is (micro_bsz, seq_len, head_num, head_dim) + + Return: + output: the shape is (1, packed_length, head_num, head_dim) + """ + output_shape = list(cur_input.shape) + output_shape[0] = 1 + output_shape[1] = packed_length + + output = torch.full(output_shape, fill_value=padding_v, device=cur_input.device, dtype=cur_input.dtype) + for i in range(len(cu_seqlens) - 1): + length = cu_seqlens[i + 1] - cu_seqlens[i] + output[0, cu_seqlens[i] : cu_seqlens[i + 1]] = cur_input[i, 0:length] + + return output diff --git a/internlm/utils/registry.py b/internlm/model/registry.py similarity index 70% rename from internlm/utils/registry.py rename to internlm/model/registry.py index 3ac14452..01f02dc1 100644 --- a/internlm/utils/registry.py +++ b/internlm/model/registry.py @@ -1,6 +1,14 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from typing import Callable + +from internlm.model.modeling_internlm import InternLM1 +from internlm.model.modeling_internlm2 import InternLM2 +from internlm.model.modeling_llama import Llama2 +from internlm.model.modeling_llava import Llava +from internlm.model.modeling_moe import Internlm1MoE + class Registry: """This is a registry class used to register classes and modules so that a universal @@ -12,13 +20,13 @@ class Registry: def __init__(self, name: str): self._name = name - self._registry = dict() + self._registry = {} @property def name(self): return self._name - def register_module(self, module_name: str): + def register_module(self, module_name: str, func: Callable): """Registers a module represented in `module_class`. Args: @@ -31,11 +39,7 @@ def register_module(self, module_name: str): assert module_name not in self._registry, f"{module_name} already registered in {self.name}" - def decorator_wrapper(original_func): - self._registry[module_name] = original_func - return original_func - - return decorator_wrapper + self._registry[module_name] = func def get_module(self, module_name: str): """Retrieves a module with name `module_name` and returns the module if it has @@ -68,4 +72,13 @@ def has(self, module_name: str): return found_flag -MODEL_INITIALIZER = Registry("model_initializer") +model_initializer = Registry("model_initializer") +hf_config_initializer = Registry("hf_config_initializer") + + +def register_model_initializer() -> None: + model_initializer.register_module("INTERNLM", InternLM1) + model_initializer.register_module("INTERNLM2_PUBLIC", InternLM2) + model_initializer.register_module("LLAMA2", Llama2) + model_initializer.register_module("INTERNLM_MoE", Internlm1MoE) + model_initializer.register_module("LLAVA", Llava) diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 7fef0f93..c2311007 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,715 +1,53 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- +from typing import Any, Dict, List -from typing import Callable, Optional +from internlm.model.modules.mha import MHA -import torch -import torch.distributed as dist -import torch.nn.functional as F -from torch import Tensor -from torch.distributed import ProcessGroup -from torch.nn.utils.rnn import pad_sequence -from internlm.accelerator import get_accelerator -from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_linear_bias_wgrad -from internlm.utils.logger import get_logger +def internlm1_mha_pre_load_convert( + model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + if f"{prefix}wqkv.weight" not in state_dict and f"{prefix}Wqkv.weight" in state_dict: + state_dict[f"{prefix}wqkv.weight"] = state_dict.pop(f"{prefix}Wqkv.weight") -internlm_accelerator = get_accelerator() + if f"{prefix}wqkv.bias" not in state_dict and f"{prefix}Wqkv.bias" in state_dict: + state_dict[f"{prefix}wqkv.bias"] = state_dict.pop(f"{prefix}Wqkv.bias") -custom_bwd = internlm_accelerator.return_custom_bwd() -custom_fwd = internlm_accelerator.return_custom_fwd() -logger = get_logger(__file__) +def internlm1_mha_save_convert( + model: MHA, state_dict: Dict, prefix: str, *args, **kwargs # pylint: disable=W0613 +) -> None: + state_dict[f"{prefix}Wqkv.weight"] = state_dict.pop(f"{prefix}wqkv.weight") + if f"{prefix}wqkv.bias" in state_dict: + state_dict[f"{prefix}Wqkv.bias"] = state_dict.pop(f"{prefix}wqkv.bias") -def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias): - assert my_input.dtype == grad_output.dtype - grad_weight = torch.matmul(grad_output.t(), my_input) - grad_bias = grad_output.sum(dim=0) if has_d_bias else None - return grad_weight, grad_bias +def convert_attn_kwargs_to_args(kwargs) -> List[Any]: + inference_params = kwargs.get("inference_params", None) + cu_seqlens = kwargs.get("cu_seqlens", None) + indexes = kwargs.get("indexes", None) + max_seqlen = kwargs.get("max_seqlen", None) -linear_bias_wgrad = try_import_linear_bias_wgrad() -is_using_cuda_linear_bias_wgrad = True -if linear_bias_wgrad is None: - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False + return (inference_params, cu_seqlens, indexes, max_seqlen) -# Raw operation, does not support autograd, but does support async -def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): - input_ = input_.contiguous() - handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) - return input_, handle +def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]: + if len(args) == 0: + return kwargs + assert len(args) == 4, "args must be generate by convert_attn_kwargs_to_args function" -class ReduceScatterFunc(torch.autograd.Function): - """Reduce scatter the input from the sequence parallel region and concatenate.""" + if args[0] is not None: + assert "inference_params" not in kwargs, "repeated 'inference_params' argument exists both in args and kwargs" + kwargs["inference_params"] = args[0] + if args[1] is not None: + assert "cu_seqlens" not in kwargs, "repeated 'cu_seqlens' argument exists both in args and kwargs" + kwargs["cu_seqlens"] = args[1] + if args[2] is not None: + assert "indexes" not in kwargs, "repeated 'indexes' argument exists both in args and kwargs" + kwargs["indexes"] = args[2] + if args[3] is not None: + assert "max_seqlen" not in kwargs, "repeated 'max_seqlen' argument exists both in args and kwargs" + kwargs["max_seqlen"] = args[3] - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup, reduce_dim: int = 0) -> Tensor: - ctx.process_group = process_group - ctx.reduce_dim = reduce_dim - output, _ = reduce_scatter_raw(input_, process_group, reduce_dim=reduce_dim) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - gather_dim = ctx.reduce_dim - grad_input, _ = all_gather_raw(grad_output, ctx.process_group, gather_dim=gather_dim) - return grad_input, None, None - - -# Supports autograd, but does not support async -reduce_scatter = ReduceScatterFunc.apply - - -class AllReduceFunc(torch.autograd.Function): - """Gather the input from sequence parallel region and concatenate.""" - - @staticmethod - def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: - ctx.process_group = process_group - output, _ = all_reduce_raw(input_, process_group) - return output - - @staticmethod - def backward(ctx, grad_output: Tensor): - _ = ctx # avoid lint warning W0613 - return grad_output, None - - -# Supports autograd, but does not support async -all_reduce = AllReduceFunc.apply - - -def _split(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # Split along last dimension. - dim_size = input_.size(dim) - assert dim_size % world_size == 0, ( - f"The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), " - f"cannot split tensor evenly" - ) - - tensor_list = torch.split(input_, dim_size // world_size, dim=dim) - rank = gpc.get_local_rank(parallel_mode) - output = tensor_list[rank].contiguous() - output = output.detach().clone() - - return output - - -def _gather(input_, parallel_mode, dim=-1): - # skip if only one rank involved - world_size = gpc.get_world_size(parallel_mode) - if world_size == 1: - return input_ - - # all gather - rank = gpc.get_local_rank(parallel_mode) - tensor_list = [torch.empty_like(input_) for _ in range(world_size)] - tensor_list[rank] = input_ - group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) - dist.all_gather(tensor_list, input_, group=group) - - # concat - output = torch.cat(tensor_list, dim=dim).contiguous() - - return output - - -class _GatherForwardSplitBackward(torch.autograd.Function): - """Gather the input from model parallel region and concatenate. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(input_): - return _gather(input_, parallel_mode=None) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _gather(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _split(grad_output, ctx.mode, ctx.dim), None, None - - -def gather_forward_split_backward(input_, parallel_mode, dim): - return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) - - -class _SplitForwardGatherBackward(torch.autograd.Function): - """ - Split the input and keep only the corresponding chuck to the rank. - - Args: - input_: input matrix. - parallel_mode: parallel mode. - dim: dimension - """ - - @staticmethod - def symbolic(input_): - return _split(input_, parallel_mode=None) - - @staticmethod - def forward(ctx, input_, parallel_mode, dim): - ctx.mode = parallel_mode - ctx.dim = dim - return _split(input_, parallel_mode, dim) - - @staticmethod - def backward(ctx, grad_output): - return _gather(grad_output, ctx.mode, ctx.dim), None, None - - -def split_forward_gather_backward(input_, parallel_mode, dim): - return _SplitForwardGatherBackward.apply(input_, parallel_mode, dim) - - -def all_gather_raw( - input_: Tensor, - process_group: ProcessGroup, - async_op: bool = False, - gather_dim: int = 0, - memory_pool_allocator: Callable = None, -): - world_size = dist.get_world_size(process_group) - if world_size <= 1: - return input_, None - - if memory_pool_allocator is not None: - output = memory_pool_allocator() - else: - shape = list(input_.shape) - shape[gather_dim] = shape[gather_dim] * world_size - output = torch.empty(shape, dtype=input_.dtype, device=input_.device) - - handle = dist.all_gather_into_tensor(output, input_.contiguous(), group=process_group, async_op=async_op) - return output, handle - - -def reduce_scatter_raw( - input_: Tensor, - process_group: ProcessGroup, - op=dist.ReduceOp.SUM, - async_op: bool = False, - reduce_dim: int = 0, - memory_pool_allocator: Callable = None, -): - world_size = dist.get_world_size(process_group) - assert input_.shape[reduce_dim] % world_size == 0 - - if world_size <= 1: - return input_, None - - shape_list = list(input_.shape) - shape_list[reduce_dim] = shape_list[reduce_dim] // world_size - - if memory_pool_allocator is not None: - output = memory_pool_allocator(tuple(shape_list)) - else: - output = torch.empty( - shape_list, - dtype=input_.dtype, - device=input_.device, - ).contiguous() - - handle = dist.reduce_scatter_tensor(output, input_.contiguous(), op=op, group=process_group, async_op=async_op) - return output, handle - - -# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py -class FusedDenseFunc(torch.autograd.Function): - "FusedDenseFunc for tensor parallel in flash-attn implementation." - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - return_residual=False, - process_group=None, - sequence_parallel=True, - gather_dim=0, - dtype_eligible: bool = True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.gather_dim = gather_dim - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) # pylint: disable=E1102 - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - gather_dim = ctx.gather_dim - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - if process_group is not None and sequence_parallel: - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - if sequence_parallel: - grad_input, handle_grad_input = reduce_scatter_raw( - grad_input, process_group, async_op=True, reduce_dim=1 - ) - else: - grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - grad_weight, grad_bias = linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0] and handle_grad_input is not None: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None, None - - -class MegatronFusedDenseFunc(torch.autograd.Function): - """ - FusedDenseFunc for tensor parallel in megatron implementation. - The diffenrence between the implementation of flash-attn and megatron is that the total_x could be - saved for backward in megatron, so that the all-gather in backward is ommited. - """ - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - return_residual=False, - process_group=None, - sequence_parallel=True, - gather_dim=0, - dtype_eligible: bool = True, - ): - """ - If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel - with sequence parallelism: we do an all_gather_raw of x before doing the matmul. - """ - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.process_group = process_group - ctx.sequence_parallel = sequence_parallel - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - if process_group is not None and sequence_parallel: - # We want to kick off the all_gather early, before weight dtype conversion - total_x, handle_x = all_gather_raw(x, process_group, async_op=True, gather_dim=gather_dim) - else: - total_x = x - - if torch.is_autocast_enabled(): - weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) - bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None - weight = weight.contiguous() - if process_group is not None and sequence_parallel and handle_x is not None: - handle_x.wait() - batch_shape, n = total_x.shape[:-1], total_x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - output = F.linear(total_x, weight, bias) # pylint: disable=E1102 - if ctx.compute_weight_gradient: - ctx.save_for_backward(total_x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - process_group = ctx.process_group - sequence_parallel = ctx.sequence_parallel - if ctx.compute_weight_gradient: - total_x, weight = ctx.saved_tensors - else: - (weight,) = ctx.saved_tensors - total_x = None - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - if process_group is not None: - if sequence_parallel: - grad_input, handle_grad_input = reduce_scatter_raw( - grad_input, process_group, async_op=True, reduce_dim=1 - ) - else: - grad_input, handle_grad_input = all_reduce_raw(grad_input, process_group, async_op=True) - else: - grad_input = None - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - grad_weight, grad_bias = linear_bias_wgrad( - total_x.reshape(batch_dim, total_x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - if process_group is not None and ctx.needs_input_grad[0] and handle_grad_input is not None: - handle_grad_input.wait() - return grad_input, grad_weight, grad_bias, None, None, None, None, None - - -class ISPFusedDenseFunc(torch.autograd.Function): - "FusedDenseFunc for ISP, which is optimized based on flash implementation." - - @staticmethod - @custom_fwd - def forward( - ctx, - x, - weight, - bias, - module, - communicator, - return_residual=False, - dtype_eligible: bool = True, - ): - ctx.compute_weight_gradient = weight.requires_grad - ctx.return_residual = return_residual - ctx.module = module - ctx.communicator = communicator - ctx.dtype_eligible = dtype_eligible - - if ctx.dtype_eligible is False: - global linear_bias_wgrad, is_using_cuda_linear_bias_wgrad - linear_bias_wgrad = linear_bias_wgrad_torch - is_using_cuda_linear_bias_wgrad = False - - if torch.is_autocast_enabled(): - x = x.to(dtype=torch.get_autocast_gpu_dtype()) - x = x.contiguous() - - total_weight = communicator.all_gather(weight, module) - total_bias = bias if bias is None else communicator.all_gather(bias, module, is_bias=True) - - if torch.is_autocast_enabled(): - total_weight = total_weight.to(dtype=torch.get_autocast_gpu_dtype()) - if total_bias: - total_bias.to(dtype=torch.get_autocast_gpu_dtype()) - - total_weight = total_weight.contiguous() - batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = batch_shape.numel() - # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 - if min(batch_dim, n, *total_weight.shape) > 65535 * 32: - raise RuntimeError("fused_dense only supports matrix dims <= 2M") - - output = F.linear(x, total_weight, total_bias) # pylint: disable=E1102 - - # release memory - del total_weight - del total_bias - if ctx.compute_weight_gradient: - ctx.save_for_backward(x, weight) - else: - ctx.save_for_backward(weight) - return output if not return_residual else (output, x) - - @staticmethod - @custom_bwd - def backward(ctx, grad_output, *args): - module = ctx.module - communicator = ctx.communicator - grad_output = grad_output.contiguous() - if ctx.return_residual: - (grad_input,) = args - grad_input = grad_input.contiguous() - - if ctx.compute_weight_gradient: - x, weight = ctx.saved_tensors - else: - x, weight = (None, *ctx.saved_tensors) - - batch_shape = grad_output.shape[:-1] - batch_dim = batch_shape.numel() - grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) - - total_weight = communicator.all_gather(weight, module) - - # compute weight grad - if ctx.needs_input_grad[1]: - assert ctx.compute_weight_gradient - grad_weight, grad_bias = linear_bias_wgrad( - x.reshape(batch_dim, x.shape[-1]), - grad_output, - ctx.needs_input_grad[2], - ) - - grad_weight, grad_weight_sync = communicator.reduce_scatter(grad_weight, module, op=dist.ReduceOp.AVG) - if grad_bias is not None: - grad_bias, grad_bias_sync = communicator.reduce_scatter( - grad_bias, module, op=dist.ReduceOp.AVG, is_bias=True - ) - else: - grad_weight = None - grad_bias = grad_output if ctx.needs_input_grad[2] else None - - if ctx.needs_input_grad[0]: - if not ctx.return_residual: - grad_input = F.linear(grad_output, total_weight.t()) # pylint: disable=E1102 - else: - grad_input = torch.addmm( - grad_input.reshape(batch_dim, grad_input.shape[-1]), - grad_output, - total_weight, - ) - grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) - else: - grad_input = None - - del total_weight - - if ctx.needs_input_grad[1]: - if grad_weight_sync: - grad_weight_sync.wait() - if grad_bias is not None and grad_bias_sync is not None: - grad_bias_sync.wait() - - return grad_input, grad_weight, grad_bias, None, None, None, None - - -def fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, - gather_dim: int = 0, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return FusedDenseFunc.apply( - x, - weight, - bias, - return_residual, - process_group, - sequence_parallel, - gather_dim, - dtype_eligible, - ) - - -def megatron_fused_dense_func( - x: Tensor, - weight: Tensor, - bias: Optional[Tensor] = None, - return_residual: bool = False, - process_group: Optional[ProcessGroup] = None, - sequence_parallel: bool = True, - gather_dim: int = 0, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return MegatronFusedDenseFunc.apply( - x, - weight, - bias, - return_residual, - process_group, - sequence_parallel, - gather_dim, - dtype_eligible, - ) - - -def isp_fused_dense_func( - x: Tensor, - weight: Tensor, - module, - communicator, - bias: Optional[Tensor] = None, - return_residual: bool = False, -): - dtype_eligible = x.dtype in [torch.float16, torch.bfloat16] or ( - x.dtype == torch.float32 and torch.is_autocast_enabled() - ) - return ISPFusedDenseFunc.apply( - x, - weight, - bias, - module, - communicator, - return_residual, - dtype_eligible, - ) - - -def is_moe_param(param: torch.Tensor) -> bool: - if hasattr(param, "is_expert") and param.is_expert: - return True - return False - - -def Silu(w1_o, w2_o): - return F.silu(w1_o) * w2_o - - -Silu = torch.jit.script(Silu) - - -def unpack_qkv_before_attn(cur_input=None, cu_seqlens=None, padding_v: int = 0): - """ - qkv: the shape is (1, packed_length, three, head_num, head_dim) - kv: the shape is (1, packed_length, two, head_num, head_dim) - q/k/v: the shape is (1, packed_length, head_num, head_dim) - - Return: - output: the shape is (micro_bsz, seq_len, three, head_num, head_dim) for qkv - (micro_bsz, seq_len, two, head_num, head_dim) for kv - (micro_bsz, seq_len, head_num, head_dim) for q/k/v - """ - if cu_seqlens is None or cur_input is None: - raise ValueError("cu_seqlens and cur_input must be provided.") - - assert cur_input.shape[0] == 1 - cur_input = cur_input.squeeze(0) - - sequences = [] - for i in range(len(cu_seqlens) - 1): - sequences.append(cur_input[cu_seqlens[i] : cu_seqlens[i + 1]]) - - padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_v) - - return padded_sequences - - -def pack_output_after_attn(cur_input=None, cu_seqlens=None, padding_v: int = 0): - """ - cur_input: the shape is (micro_bsz, seq_len, hidden_size) - - Return: - output: the shape is (1, packed_length, hidden_size) - """ - if cu_seqlens is None or cur_input is None: - raise ValueError("cu_seqlens and cur_input must be provided.") - - packed_len_ = gpc.config.data.micro_bsz * gpc.config.data.seq_len - output_shape = list(cur_input.shape) - output_shape[0] = 1 - output_shape[1] = packed_len_ - - output = torch.full(output_shape, fill_value=padding_v, device=cur_input.device, dtype=cur_input.dtype) - for i in range(len(cu_seqlens) - 1): - length = cu_seqlens[i + 1] - cu_seqlens[i] - output[0, cu_seqlens[i] : cu_seqlens[i + 1]] = cur_input[i, 0:length] - - return output + return kwargs diff --git a/internlm/monitor/__init__.py b/internlm/monitor/__init__.py index 56c8309b..2bcfa2cc 100644 --- a/internlm/monitor/__init__.py +++ b/internlm/monitor/__init__.py @@ -1,8 +1,9 @@ -from .monitor import initialize_monitor_manager, send_alert_message +from .monitor import initialize_monitor_manager, internevo_monitor, send_alert_message from .utils import set_env_var __all__ = [ "send_alert_message", "initialize_monitor_manager", "set_env_var", + "internevo_monitor", ] diff --git a/internlm/monitor/monitor.py b/internlm/monitor/monitor.py index cca9ca44..fc33de62 100644 --- a/internlm/monitor/monitor.py +++ b/internlm/monitor/monitor.py @@ -1,17 +1,59 @@ import fcntl +import logging import os +import shutil import signal import socket import time +import traceback from contextlib import contextmanager +from functools import wraps from threading import Thread +from internlm.accelerator.abstract_accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.monitor.alert import send_feishu_msg_with_webhook from internlm.utils.common import SingletonMeta from .utils import get_job_key, set_env_var +logger = logging.getLogger(__file__) +internlm_accelerator = get_accelerator() + + +def internevo_monitor(feishu_alert=True, clean_run=True): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + if feishu_alert: + with initialize_monitor_manager( + job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address + ): + return execute_with_exception_handling(func, *args, **kwargs) + else: + return execute_with_exception_handling(func, *args, **kwargs) + + def execute_with_exception_handling(func, *args, **kwargs): + if not clean_run: + return func(*args, **kwargs) + try: + return func(*args, **kwargs) + except Exception: + hostname = socket.gethostname() + logger.error( + f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", + ) + finally: + devices_per_node = internlm_accelerator.device_count() + local_rank = gpc.get_global_rank() % devices_per_node + if gpc.config.data.use_shm and local_rank == 0: + if os.path.exists(gpc.config.data.shm_path): + shutil.rmtree(gpc.config.data.shm_path) + + return wrapper + + return decorator + def send_alert_message(address: str = None, title: str = None, message: str = None): """ diff --git a/internlm/solver/__init__.py b/internlm/solver/__init__.py index c9fd1244..86661138 100644 --- a/internlm/solver/__init__.py +++ b/internlm/solver/__init__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from .optimizer import HybridZeroOptimizer +from .optimizer import HybridZeroOptimizer, HybridZeroOptimizer_v2 from .schedulers import Beta2Scheduler, FineTuneCosineAnnealingWarmupLR -__all__ = ["Beta2Scheduler", "FineTuneCosineAnnealingWarmupLR", "HybridZeroOptimizer"] +__all__ = ["Beta2Scheduler", "FineTuneCosineAnnealingWarmupLR", "HybridZeroOptimizer", "HybridZeroOptimizer_v2"] diff --git a/internlm/solver/optimizer/__init__.py b/internlm/solver/optimizer/__init__.py index 7c6a1c64..55070fc3 100644 --- a/internlm/solver/optimizer/__init__.py +++ b/internlm/solver/optimizer/__init__.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- from .fsdp_optimizer import FSDPadaptOptimizer -from .hybrid_zero_optim import HybridZeroOptimizer, reload_zero_fp32_buff +from .hybrid_zero_optim import HybridZeroOptimizer +from .hybrid_zero_optim_v2 import HybridZeroOptimizer_v2 -__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "reload_zero_fp32_buff"] +__all__ = ["FSDPadaptOptimizer", "HybridZeroOptimizer", "HybridZeroOptimizer_v2"] diff --git a/internlm/solver/optimizer/compatible_adamw.py b/internlm/solver/optimizer/compatible_adamw.py new file mode 100644 index 00000000..bca8c274 --- /dev/null +++ b/internlm/solver/optimizer/compatible_adamw.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import torch + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + +try: + from torch_npu.optim import NpuFusedAdamW + + del NpuFusedAdamW + + npu_adamw_impl = True +except (ModuleNotFoundError, ImportError): + npu_adamw_impl = False + + +# TODO: 给上次一个统一的接口,这些接口都能被下层的各种实现支持,哪些参数应该保留,那些参数应该省略? +def new_compatible_adamw(params, lr: float = 0.001, betas: Tuple[float, float] = (0.9, 0.999), eps: float = 1e-8): + """ + return a compatibel adamw instance. + """ + adam_extra_kwargs = {} + backend = internlm_accelerator.get_accelerator_backend() + + if backend is AcceleratorType.GPU and torch.__version__ >= "2.1.0": + if gpc.is_rank_for_log(): + logger.warning( + "Use fused AdamaW to avoid nan grad norm when " + "model size is larger and use_fp32_norm=True, Please note this!" + ) + adam_extra_kwargs["fused"] = True + elif backend is AcceleratorType.NPU: + if gpc.is_rank_for_log(): + logger.warning( + "Use normal AdamaW, NPU fused_adamw currently has" + "accuracy issues and is not supported yet. Please note this!" + ) + # TODO: support npu version adamw + elif backend is AcceleratorType.DIPU: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than deeplink adamw. Please note this!") + # TODO: support deeplink version adamw + else: + if gpc.is_rank_for_log(): + logger.warning("Use torch.optim.AdamW rather than FusedAdamW. Please note this!") + + return torch.optim.AdamW(params, lr=lr, betas=betas, eps=eps, **adam_extra_kwargs) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c4d36be2..5461f922 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -11,7 +11,6 @@ from torch.optim import Optimizer from internlm.accelerator import get_accelerator -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import ( @@ -21,6 +20,7 @@ IS_TENSOR_ZERO_PARALLEL, IS_WEIGHT_ZERO_PARALLEL, ) +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.monitor import send_alert_message from internlm.solver.optimizer.store import ( BucketStore, @@ -940,16 +940,14 @@ def load_state_dict(self, states): if "zero_devide_optim_plan" in states: self.params_per_rank_id_dict = states["zero_devide_optim_plan"] - -def reload_zero_fp32_buff(optimizer): - # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value. - # Or we must ensure that loading model weights must be done before zero is initialized. - if isinstance(optimizer, HybridZeroOptimizer): - for group_id, param_group in enumerate(optimizer.optim.param_groups): - if optimizer.param_group_has_params[group_id]: + def reload_zero_fp32_buff(self): + # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value. + # Or we must ensure that loading model weights must be done before zero is initialized. + for group_id, param_group in enumerate(self.optim.param_groups): + if self.param_group_has_params[group_id]: # flatten fp16 params have already been updated by 'load_model_checkpoint' - fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group( - optimizer._zero_local_rank[group_id], group_id + fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( + self._zero_local_rank[group_id], group_id ) # param_group["params"] is fp32 flatten optimizer states of this zero rank. param_group["params"][0].data.copy_(fp16_flat_current_rank.float()) diff --git a/internlm/solver/optimizer/hybrid_zero_optim_v2.py b/internlm/solver/optimizer/hybrid_zero_optim_v2.py new file mode 100644 index 00000000..d231f407 --- /dev/null +++ b/internlm/solver/optimizer/hybrid_zero_optim_v2.py @@ -0,0 +1,929 @@ +# this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import math +from functools import partial +from typing import Dict, List + +import torch +import torch.distributed as dist +from torch.optim import Optimizer + +from internlm.core.context import Config, ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.context.parallel_context import ( + IS_REPLICA_ZERO_PARALLEL, + IS_TENSOR_DATA_PARALLEL, + IS_TENSOR_EXPERT_DATA_PARALLEL, + IS_TENSOR_ZERO_PARALLEL, + IS_WEIGHT_ZERO_PARALLEL, +) +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler +from internlm.monitor import send_alert_message +from internlm.solver.optimizer.store import ( + BucketStore_v2, + GradientStore_v2, + ParameterStore_v2, +) +from internlm.solver.optimizer.utils import ( + DynamicGradScaler, + flatten, + reduce_tensor, + release_param_grad, + sync_param, +) +from internlm.utils.common import get_current_device +from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel + +from .base_optimizer import BaseOptimizer +from .utils import compute_norm + + +def calculate_global_norm_from_list(global_norm_groups): + """Compute total from a list of norms""" + total_norm = 0.0 + for norm in global_norm_groups.values(): + total_norm += norm**2.0 + return math.sqrt(total_norm) + + +logger = get_logger(__file__) + + +class HybridZeroOptimizer_v2(BaseOptimizer): + """Optimizer used for ZeRO-1 and ZeRO-2.""" + + def __init__( + self, + optimizer: Optimizer, + grad_scal_cfg: Config = None, + zero_cfg: Config = None, + param_bcast_sync_handler: ParamAsyncBcastHandler = None, + isp_communicator=None, + partition_grad: bool = False, # zero 2 + cpu_offload: bool = False, # cpu offload + master_weights: bool = True, # master weights + ): + if gpc.config.model.dtype is torch.float32: + initial_scale = 1 + else: + initial_scale = grad_scal_cfg.fp16.initial_scale + min_scale = grad_scal_cfg.fp16.min_scale + growth_interval = grad_scal_cfg.fp16.growth_interval + growth_factor = grad_scal_cfg.growth_factor + backoff_factor = grad_scal_cfg.backoff_factor + hysteresis = grad_scal_cfg.hysteresis + max_scale = grad_scal_cfg.max_scale + + # Zero related args + self._reduce_bucket_size = zero_cfg.reduce_bucket_size + self._all_gather_size = zero_cfg.all_gather_size + self._clip_grad_norm = zero_cfg.clip_grad_norm + self._overlap_sync_grad = zero_cfg.overlap_sync_grad + self._overlap_sync_param = zero_cfg.overlap_sync_param + self.use_isp = is_using_isp() + + self._param_bcast_sync_handler = param_bcast_sync_handler + + if self._overlap_sync_param: + assert self._param_bcast_sync_handler is not None + + self._isp_communicator = isp_communicator + + super().__init__(optim=optimizer) + + self._dtype = self.optim.param_groups[0]["params"][0].dtype + self._element_size = self.optim.param_groups[0]["params"][0].element_size() + + # stage 2 + self._partition_grads = partition_grad + self._cpu_offload = cpu_offload + + # if process_group is none, will use the default one + self._local_rank = gpc.get_local_rank(ParallelMode.DATA) + self._world_size = gpc.get_world_size(ParallelMode.DATA) + + self._zero_local_rank = [] + self._zero_world_size = [] + self._zero_parallel_mode = [] + + # working and master params for mixed precision training + # master params: params that are splited into the current rank, fp32 params + # working params: the original complete params, fp16 params + self._working_param_groups = dict() + self._master_param_groups_of_current_rank = dict() + + self.grad_scaler = DynamicGradScaler( + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale, + ) + + # master weights copy + self._master_weights = master_weights + + # check argument conflict + self._sanity_checks() + + # ParameterStore_v2 will manage the tensor buffers used for zero + # it will not manage the tensors used by mixed precision training + parallel_mode = ParallelMode.WEIGHT_DATA if self.use_isp else ParallelMode.DATA + self._param_store = ParameterStore_v2(ParallelMode.ZERO1) + self._grad_store = GradientStore_v2(parallel_mode, partition_grad=partition_grad) + self._bucket_store: List[BucketStore_v2] = [] + self._accum_grad_buckets: List[BucketStore_v2] = [] + + self.rank_unique_id = ( + f"gpus-{gpc.get_world_size(ParallelMode.GLOBAL)}_" + + f"wp-{gpc.get_local_rank(ParallelMode.WEIGHT)}_" + + f"tp-{gpc.get_local_rank(ParallelMode.TENSOR)}_" + + f"dp-{gpc.get_local_rank(ParallelMode.DATA)}_" + + f"pp-{gpc.get_local_rank(ParallelMode.PIPELINE)}_" + + f"zo-{gpc.get_local_rank(ParallelMode.ZERO1)}.pt" + ) + + self.zero_1_5 = False + + # iterate over the param group in the optimizer + # partition these param groups for data parallel training + # and add buffers to parameter store for future access + for group_id, param_group in enumerate(self.optim.param_groups): + group_params = [] + for param in param_group["params"]: + if param.requires_grad: + setattr(param, "group_id", group_id) + group_params.append(param) + + param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None + + zero_mode = param_group["optimizer_mode"] + self._zero_local_rank.append(gpc.get_local_rank(zero_mode)) + self._zero_world_size.append(gpc.get_world_size(zero_mode)) + self._zero_parallel_mode.append(zero_mode) + + # add the working params to working_param_groups for bookkeeping + self._working_param_groups[group_id] = group_params + master_param_current_rank = self._create_master_param_current_rank(group_id, group_params) + self._master_param_groups_of_current_rank[group_id] = master_param_current_rank + + # need to replace the params in the `params` field in the optimizer + # so that when the optimizer calls step(), it only updates the tensors + # managed by this data parallel rank + param_group["params"] = master_param_current_rank + + if self._is_moe_group(param_group): + grad_reduce_mode = ParallelMode.EXPERT_DATA + elif param_group["name"] != "embed_head" and self.use_isp: + grad_reduce_mode = ParallelMode.WEIGHT_DATA + else: + grad_reduce_mode = ParallelMode.DATA + self._bucket_store.append(BucketStore_v2(group_id, grad_reduce_mode, zero_mode=zero_mode)) + self._accum_grad_buckets.append(BucketStore_v2(group_id, grad_reduce_mode, zero_mode=zero_mode)) + + if gpc.get_world_size(grad_reduce_mode) != gpc.get_world_size(zero_mode): + self.zero_1_5 = True + + # initialize communication stream for + # communication-computation overlapping + self._comm_stream = torch.cuda.Stream(priority=0) + + self.skip_grad_reduce = False + + self._attach_reduction_hook() + + @property + def dtype(self): + return self._dtype + + @property + def num_param_groups(self): + return len(self._working_param_groups) + + @property + def loss_scale(self) -> float: + return self.grad_scaler.scale.item() + + def _is_moe_group(self, param_group): + return "moe" in param_group.keys() and param_group["moe"] + + def _wait_reduce_scatter_and_accumulate_grads(self, param): + param_size = param.numel() + + group_id = getattr(param, "group_id") + current_bucket = self._accum_grad_buckets[group_id] + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # after reduction, the bucket will be empty + if current_bucket.num_elements_in_bucket() + param_size > self._reduce_bucket_size: + self._accum_grads_store_in_bucket(current_bucket) + + # otherwise, add the parameter into bucket. + current_bucket._num_elements_in_bucket += param.numel() + current_bucket._param_list.append(param) + + def _accum_grads_store_in_bucket(self, bucket: BucketStore_v2) -> None: + for _param in bucket.get_param(): + if not hasattr(_param, "isp_reduce_scatter_name"): + continue + + # wait and accumulate gardient. + _key = getattr(_param, "isp_reduce_scatter_name") + _grad, _comm_handle = self._isp_communicator.reduce_scatter_handlers[_key] + _comm_handle.wait() + _param.grad.add_(_grad) + + # release cuda memory. + if self._isp_communicator.enable_memory_pool: + self._isp_communicator.memory_pool.free_reduce_scatter_memory( + key=tuple(_grad.size()), index=_grad.index + ) + _grad = None + self._isp_communicator.reduce_scatter_handlers[_key] = None + + bucket.reset_all() + + def accumulate_left_grads_after_backward(self): + if self._isp_communicator is None or self._isp_communicator.overlap is False: + return + + for group_id in range(self.num_param_groups): + self._accum_grads_store_in_bucket(self._accum_grad_buckets[group_id]) + + def clip_grad_norm(self, model, max_norm): + # will conduct in the step() + pass + + def _sanity_checks(self): + for param_group in self.optim.param_groups: + group_params = param_group["params"] + for param in group_params: + if not hasattr(param, "skip_zero_check") or param.skip_zero_check is False: + assert ( + param.dtype == self._dtype + ), f"Parameters are expected to have the same dtype `{self._dtype}`, but got `{param.dtype}`" + + def add_attr_for_splited_param(self, origin_param, splited_param_current_rank): + + if hasattr(origin_param, IS_TENSOR_ZERO_PARALLEL): + value = getattr(origin_param, IS_TENSOR_ZERO_PARALLEL) + setattr(splited_param_current_rank, IS_TENSOR_ZERO_PARALLEL, value) + + if hasattr(origin_param, IS_WEIGHT_ZERO_PARALLEL): + value = getattr(origin_param, IS_WEIGHT_ZERO_PARALLEL) + setattr(splited_param_current_rank, IS_WEIGHT_ZERO_PARALLEL, value) + + if hasattr(origin_param, IS_REPLICA_ZERO_PARALLEL): + value = getattr(origin_param, IS_REPLICA_ZERO_PARALLEL) + setattr(splited_param_current_rank, IS_REPLICA_ZERO_PARALLEL, value) + + if hasattr(origin_param, IS_TENSOR_DATA_PARALLEL): + value = getattr(origin_param, IS_TENSOR_DATA_PARALLEL) + setattr(splited_param_current_rank, IS_TENSOR_DATA_PARALLEL, value) + + if hasattr(origin_param, IS_TENSOR_EXPERT_DATA_PARALLEL): + value = getattr(origin_param, IS_TENSOR_EXPERT_DATA_PARALLEL) + setattr(splited_param_current_rank, IS_TENSOR_EXPERT_DATA_PARALLEL, value) + + if hasattr(origin_param, "block_name"): + value = getattr(origin_param, "block_name") + setattr(splited_param_current_rank, "block_name", value) + + def _create_master_param_current_rank(self, group_id, param_list): + # split each param evenly by world size + params_current_rank = [] + device = "cpu" if self._cpu_offload else get_current_device() + zero_world_size = self._zero_world_size[group_id] + + for param in param_list: + padding_size = (zero_world_size - param.numel() % zero_world_size) % zero_world_size + self._param_store.record_param_padding_size(param, padding_size) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size]) + # reset working params' ptr when no master weights + if self._master_weights is False: + param.data = padding_param[: param.numel()].view(param.shape) + else: + padding_param = param.data.view(-1) + + splited_params = padding_param.split(padding_param.numel() // zero_world_size) + splited_params = splited_params[self._zero_local_rank[group_id]] + + # use fp32 when master_weights is True + if self._master_weights is True: + splited_param_current_rank = splited_params.detach().float().to(device) + else: + splited_param_current_rank = splited_params + + self.add_attr_for_splited_param(param, splited_param_current_rank) + + params_current_rank.append(splited_param_current_rank) + self._param_store.link_master_and_working_param(splited_param_current_rank, param) + + return params_current_rank + + ####################### + # Reduction Functions # + ####################### + + def _run_reduction(self): + for group_id in range(self.num_param_groups): + current_bucket = self._bucket_store[group_id] + dp_parallel_mode = current_bucket.get_dp_parallel_mode() + reduce_group = gpc.get_group(dp_parallel_mode) + world_size = gpc.get_world_size(dp_parallel_mode) + local_rank = gpc.get_local_rank(dp_parallel_mode) + if current_bucket.num_elements_in_bucket() > 0: + stream = self._comm_stream + # waiting for ops in the default stream finishing + stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(stream): + current_bucket.build_grad_in_bucket(stream) + + flat_grads = current_bucket.get_flatten_grad() + flat_grads /= world_size + + # ready to add other tensors to bucket + current_bucket.reset_num_elements_in_bucket() + group_id = current_bucket.get_param_group_id() + + grad_dtype = flat_grads.dtype + + if not self._partition_grads: + if not self.zero_1_5: + reduce_grads = torch.zeros( + flat_grads.numel() // self._zero_world_size[group_id], + dtype=grad_dtype, + device=get_current_device(), + ) + dist.reduce_scatter_tensor(reduce_grads, flat_grads, group=reduce_group) + + if reduce_grads.dtype != grad_dtype: + reduce_grads = reduce_grads.to(grad_dtype) + + grad_in_bucket_current_rank = current_bucket.get_grad()[local_rank] + self._update_unpartitoned_grad(grad_in_bucket_current_rank, reduce_grads, group_id) + else: + # zero 1.5 + dist.all_reduce(flat_grads, group=reduce_group) + if flat_grads.dtype != grad_dtype: + flat_grads = flat_grads.to(grad_dtype) + flat_grads_per_rank = flat_grads.split( + flat_grads.numel() // self._zero_world_size[group_id] + ) + grad_in_bucket = current_bucket.get_grad() + self._update_unpartitoned_grad(grad_in_bucket.values(), flat_grads_per_rank, group_id) + else: + flat_grads_list = list(flat_grads.split(len(flat_grads) // world_size)) + recieved_grad = torch.zeros_like(flat_grads_list[0]) + dist.reduce_scatter(recieved_grad, flat_grads_list, group=reduce_group) + + if recieved_grad.dtype != grad_dtype: + recieved_grad = recieved_grad.to(grad_dtype) + + grad_in_bucket_current_rank = current_bucket.get_grad()[self._zero_local_rank[group_id]] + self._update_partitoned_grad(grad_in_bucket_current_rank, recieved_grad, group_id, 1) + + current_bucket.reset() + + def _update_unpartitoned_grad(self, origin_grad_list: List, flat_grad_list: List, group_id: int) -> None: + if not self.zero_1_5: + sync_param(flat_grad_list, origin_grad_list) + for grad in origin_grad_list: + param_id = self._bucket_store[group_id].get_param_id_of_grad(grad) + self._add_grad(grad, self._zero_world_size[group_id], group_id, param_id) + else: + for rank, grad_list in enumerate(origin_grad_list): + sync_param(flat_grad_list[rank], grad_list) + for grad in grad_list: + param_id = self._bucket_store[group_id].get_param_id_of_grad(grad) + self._add_grad(grad, self._zero_world_size[group_id], group_id, param_id, rank) + + def _update_partitoned_grad( + self, + origin_grad_list: List, + flat_grad: torch.Tensor, + group_id: int, + partition_num: int, + ) -> None: + sync_param(flat_grad, origin_grad_list) + for grad in origin_grad_list: + param_id = self._bucket_store[group_id].get_param_id_of_grad(grad) + self._add_grad(grad, partition_num, group_id, param_id) + + def _add_grad( + self, + grad: torch.Tensor, + partition_num: int, + group_id: int, + param_id: int, + rank: int = 0, + ) -> None: + if len(self._grad_store.get_partitioned_gradients_by_param_id(group_id, param_id)) < partition_num: + self._grad_store.append_gradients_by_param_id(grad, group_id, param_id) + else: + self._grad_store.add_gradients_by_param_id(grad, rank, group_id, param_id) + + def _add_to_bucket(self, param, group_id): + param_size = param.numel() + + # check if the bucket is full + # if full, will reduce the grads already in the bucket + # or got a grad of param from another group + # after reduction, the bucket will be empty + if ( + self._bucket_store[group_id].num_elements_in_bucket() + param_size > self._reduce_bucket_size + or group_id != self._bucket_store[group_id].get_param_group_id() + ): + self._run_reduction() + + padding_size = self._param_store.get_param_padding_size(param) + self._bucket_store[group_id].add_param_grad(param, padding_size) + + ################################ + # torch.optim.Optimizer methods + ################################ + + def backward(self, loss, retain_graph=False): + assert not ( + self._partition_grads and self.skip_grad_reduce + ), "ZeRO2(partition_grads) and no_sync are not compatible" + + loss = self.loss_scale * loss + loss.backward(retain_graph=retain_graph) + + def backward_by_grad(self, tensor, grad): + assert not ( + self._partition_grads and self.skip_grad_reduce + ), "ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible" + + torch.autograd.backward(tensor, grad) + + def zero_grad(self, set_to_none=True): + """ + Set parameter gradients to zero. If set_to_none = True, gradient + will be set to None to save memory. + + :param set_to_none: Whether set the gradient to None. Default value is True. + :type set_to_none: bool + """ + for _, param_group in self._working_param_groups.items(): + for param in param_group: + if set_to_none: + param.grad = None + else: + if param.grad is not None: + param.grad.detach() + param.grad.zero_() + for group_id in range(self.num_param_groups): + self._bucket_store[group_id].reset_all() + + #################### + # Update Parameter # + #################### + + def step(self, closure=None): + assert closure is None, "closure is not supported by step()" + + self._reduce_grad(self._partition_grads) + # clear reduced grads + torch.cuda.synchronize() + self.zero_grad() + + # record all grads for unscale and clip + grad_partition_groups = [] + + # sometimes not all params are 'really' working + # for instance, when layer drop, the dropped layer has no grad + # and should not be updated + real_working_params = dict() + real_master_params = dict() + real_master_grads = dict() + total_norms = {} + + for group_id in range(self.num_param_groups): + master_params = self._master_param_groups_of_current_rank[group_id] + real_working_params[group_id] = [] + real_master_params[group_id] = [] + real_master_grads[group_id] = [] + grad_index = 0 if not self.zero_1_5 else self._zero_local_rank[group_id] + + for splited_param in master_params: + working_param = self._param_store.master_to_working_param[id(splited_param)] + # if a working param requires grad and has no grad + # it is not 'really' working, e.g. the droped layer + # else the splited grad should be attached to the splited param + grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param)) + if len(grads) > 0: + real_working_params[group_id].append(working_param) + grad = grads[grad_index] + # no need to copy fp32 grad if master_weights is False + if self._master_weights: + grad = grad.to(splited_param.dtype).to(splited_param.device) + splited_param.grad = grad + grad_partition_groups.append(grad) + real_master_params[group_id].append(splited_param) + real_master_grads[group_id].append(splited_param.grad) + + # compute norm + param_group = real_master_params[group_id] + working_grads = real_master_grads[group_id] + + group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default" + group_name = f"{group_id}_{group_name}" + total_norms[group_name] = self._compute_norm( + group_id=group_id, gradients=working_grads, parameters=param_group + ) + + self._grad_store.reset_grads_by_group_id(group_id) + + # update the params in the optimizer + self.optim.param_groups[group_id]["params"] = real_master_params[group_id] + + # check norm + found_inf = False + found_nan = False + + if -1 in total_norms.values(): + found_inf = True + + if -2 in total_norms.values(): + found_nan = True + + if gpc.config.model.dtype is not torch.float32: + self.grad_scaler.update(found_inf) + + # update loss scale if overflow occurs + if found_inf: + if gpc.is_rank_for_log(): + logger.warning("Overflow occurs, please check it.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message="Overflow occurs, please check it.", + ) + self._grad_store._grads_of_params = dict() + self.zero_grad() + return False, total_norms + + if found_nan: + if gpc.is_rank_for_log(): + logger.warning("Nan grad norm occurs, please check it.") + send_alert_message( + address=gpc.config.monitor.alert.feishu_alert_address, + message="Nan grad norm occurs, please check it.", + ) + self._grad_store._grads_of_params = dict() + self.zero_grad() + return False, total_norms + + global_norm_groups = {} + if self._clip_grad_norm > 0: + for group_name, norm in total_norms.items(): + global_norm_groups[group_name] = norm**0.5 + + # unscale and clip grads + global_norm_l2 = calculate_global_norm_from_list(global_norm_groups) + self._unscale_and_clip_grads(grad_partition_groups, global_norm_l2) + + # update the parameters + self.optim.step() + + # release the grad + grad_partition_groups = [] + for group_id in range(self.num_param_groups): + release_param_grad(self._master_param_groups_of_current_rank[group_id]) + + # update working partition updated by the current rank + device = get_current_device() + handles = [] + gathered_params_list = [] + working_params_list = [] + master_params_list = [] + for group_id in range(self.num_param_groups): + if self._zero_world_size[group_id] > 1: + master_working_param = self.optim.param_groups[group_id]["params"] + + if len(master_working_param) == 0: + continue + + # do all_gather at fused block granularity + # In this way, param_overlap is available + all_gather_master_params = [] + all_gather_working_params = [] + sum_numel_size = 0 + for idx in range(len(master_working_param)): + working_param = real_working_params[group_id][idx] + block_name = master_working_param[idx].block_name + # for the same block, all params are arranged in consecutive order + # when enter next block, check numel_size to determine whether to execute all_gather + if idx > 0 and block_name != master_working_param[idx - 1].block_name: + if sum_numel_size >= self._all_gather_size: + self.all_gather_params( + group_id, + all_gather_master_params, + all_gather_working_params, + gathered_params_list, + working_params_list, + master_params_list, + handles, + device, + ) + all_gather_master_params = [] + all_gather_working_params = [] + sum_numel_size = 0 + + sum_numel_size += master_working_param[idx].numel() * self._element_size + all_gather_master_params.append(master_working_param[idx]) + all_gather_working_params.append(working_param) + + # clear the last fused block + if len(all_gather_master_params) > 0: + self.all_gather_params( + group_id, + all_gather_master_params, + all_gather_working_params, + gathered_params_list, + working_params_list, + master_params_list, + handles, + device, + ) + all_gather_master_params = [] + all_gather_working_params = [] + else: + # if zero_world_size==1, directly update working param with master param + for working_param, master_param in zip(real_working_params[group_id], real_master_params[group_id]): + working_param.data.copy_(master_param.view_as(working_param)) + + if not self._overlap_sync_param: + for gather_idx in range(len(handles)): + handles[gather_idx].wait() + # reorganize gatherd params to update working param + # [[A1, B1], [A2, B2]] -> [[A1.reshape, A2.reshape], [B1.reshape, B2.reshape]] + master_params_all_gather = master_params_list[gather_idx] + gathered_params = gathered_params_list[gather_idx] + all_splited_param_list = [] + offset = 0 + for p in master_params_all_gather: + param_size = p.numel() + all_splited_param = [] + for all_params in gathered_params: + split_params = all_params[offset : offset + param_size].reshape(p.shape) + all_splited_param.append(split_params) + offset += param_size + all_splited_param_list.append(all_splited_param) + + # Update working parameters + for working_param, all_splited_param in zip(working_params_list[gather_idx], all_splited_param_list): + working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].view_as(working_param)) + + for group_id in range(self.num_param_groups): + self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id] + + for group_name, global_norm in global_norm_groups.items(): + global_norm_groups[group_name] = global_norm / float(self.loss_scale) + + return True, global_norm_groups + + def all_gather_params( + self, + group_id, + all_gather_master_params, + all_gather_working_params, + gathered_params_list, + working_params_list, + master_params_list, + handles, + device, + ): + # fuse params to do all_gather + handle, gathered_params = self.gather_fused_params( + all_gather_master_params, + self._zero_world_size[group_id], + gpc.get_group(self._zero_parallel_mode[group_id]), + device, + ) + if self._overlap_sync_param: + self._param_bcast_sync_handler.add_allgather_handle( + handle, + all_gather_master_params, + all_gather_working_params, + gathered_params, + all_gather_working_params[0].block_name, + ) + else: + gathered_params_list.append(gathered_params) + working_params_list.append(all_gather_working_params) + master_params_list.append(all_gather_master_params) + handles.append(handle) + + def gather_fused_params(self, params, world_size, group, device): + # Flatten and concatenate all parameters into a single tensor + flattened_params = torch.cat([p.view(-1) for p in params]).to(device).to(self._dtype) + + # Prepare the buffer for all_gather + gathered_params = [ + torch.empty_like(flattened_params, device=device, dtype=self._dtype) for _ in range(world_size) + ] + # Perform the all_gather operation + handle = dist.all_gather(gathered_params, flattened_params, group=group, async_op=True) + + return handle, gathered_params + + def _compute_norm(self, group_id, gradients, parameters): + + if len(parameters) == 0: + return 0 + + norm = 0 + if self._clip_grad_norm > 0: + # this norm is before scaling, it will be very large + norm = compute_norm( + gradients=gradients, parameters=parameters, zero_mode=self._zero_parallel_mode[group_id] + ) + + return norm + + ############################# + # Mixed Precision Utilities # + ############################# + + def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups): + # compute combined scale factor for this group + div_scale = float(self.loss_scale) + if self._clip_grad_norm > 0.0: + # norm is in fact norm*scale + clip = ((total_norm_groups / div_scale) + 1e-6) / self._clip_grad_norm + if clip > 1: + div_scale = clip * div_scale + + for grad in grad_groups_flat: + grad.data.mul_(1.0 / div_scale) + + ############################ + # Gradient Synchronization # + ############################ + + # this method is used to sync gradient manually + def _sync_grad(self): + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + if param.requires_grad and param.grad is not None: + self._add_to_bucket(param, group_id) + + self._run_reduction() + + def _reduce_grad(self, partition_grad): + # if not overlapping communication (no reduction hook is attached) when zero1 + # we need to manually reduce these gradients + if not partition_grad and not self._overlap_sync_grad: + self._sync_grad() + else: + self._run_reduction() + + ############## + # State Dict # + ############## + + def state_dict(self) -> Dict: + states = {} + + grad_scaler = self.grad_scaler.state_dict() + states["grad_scaler"] = grad_scaler + optim_states = self.optim.state_dict() + states["base_optim_states"] = optim_states + + master_current_weights = {} + for group_id, params in self._master_param_groups_of_current_rank.items(): + master_current_weights[group_id] = params + states["master_current_weights"] = master_current_weights + + return states + + def load_state_dict(self, states: Dict): + """Load state dict, requires the state_dict be the pytorch form + + Args: + state_dict (dict): A pytorch form state_dict + """ + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) + optim_states = states["base_optim_states"] + + if gpc.config.get("only_load_lr", False): + if gpc.is_rank_for_log(): + logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): + pg1["lr"] = pg2["lr"] + return + + self.optim.load_state_dict(optim_states) + + master_current_weights = states["master_current_weights"] + for group_id, params in master_current_weights.items(): + if len(params) > 0: + self_params = self._master_param_groups_of_current_rank[group_id] + assert len(self_params) == len( + params + ), f"The loaded parameter shape is inconsistent, {self_params.shape} != {params.shape}" + for self_param, param in zip(self_params, params): + self_param.data.copy_(param.data) + + def reload_zero_fp32_buff(self): + for group_id, param_group in enumerate(self.optim.param_groups): + if len(param_group["params"]) > 0: + for master_param in param_group["params"]: + working_param = self._param_store.master_to_working_param[id(master_param)] + padding_size = self._param_store.get_param_padding_size(working_param) + + with torch.no_grad(): + if padding_size > 0: + padding_param = torch.nn.functional.pad(working_param.data.view(-1), [0, padding_size]) + else: + padding_param = working_param.data.view(-1) + + splited_params = padding_param.split(padding_param.numel() // self._zero_world_size[group_id]) + splited_params = splited_params[self._zero_local_rank[group_id]] + splited_param_current_rank = splited_params.detach().float() + + master_param.data.copy_(splited_param_current_rank) + + ################ + # Overlap Hook # + ################ + + def _attach_reduction_hook(self): + # we iterate over the fp16 params + # on each param, we register a hook to its AccumulateGrad object + for group_id in range(self.num_param_groups): + param_group = self._working_param_groups[group_id] + for param in param_group: + # we should not reduce the param in moe + if not param.requires_grad: + continue + + reduce_rank = None + + def _define_and_attach(param, reduce_rank=None): + # pylint: disable=W0640 + def grad_handler(group_id, param): + # if run with no_sync context, would not sync grad when backward + if not self.skip_grad_reduce: + self._add_to_bucket(param, group_id) + + reduce_scatter_checker = partial( + self._wait_reduce_scatter_and_accumulate_grads, + param=param, + ) + + def reduction_layernorm_func(): + handle = reduce_tensor( + param.grad, + dtype=None, + dst_rank=reduce_rank, + parallel_mode=ParallelMode.WEIGHT if self.use_isp else ParallelMode.TENSOR, + ) + handle.wait() + + # define hook for real gradient accumulation. + + def accum_grad_hook(*args): # pylint: disable=W0613 + reduce_scatter_checker() + + # define hook for sequence_parallel + def extra_layernorm_reduce_grad_hook(*args): # pylint: disable=W0613 + if self.skip_grad_reduce is False: + reduction_layernorm_func() + + # the grad of layernorm should be all-reduce across the global process group + # here is the first stage all-reduce in tp/wp process group + # the second stage all-reduce will be processed in reduce_grad_hook + if ( + is_using_sequence_parallel() + and hasattr(param, IS_REPLICA_ZERO_PARALLEL) + and getattr(param, IS_REPLICA_ZERO_PARALLEL) is True + ): + param.register_post_accumulate_grad_hook(extra_layernorm_reduce_grad_hook) + + # we should not only register for parameters which have isp_reduce_scatter_name attr. + # we must keep up with reduce_grad_hook. + if ( + self._isp_communicator + and self._isp_communicator.overlap + and gpc.config.parallel.weight.size > 1 + ): + param.register_post_accumulate_grad_hook(accum_grad_hook) + + if self._overlap_sync_grad: + param.register_post_accumulate_grad_hook( + partial(grad_handler, group_id) + ) # pylint: disable=W0640 + + _define_and_attach(param, reduce_rank) diff --git a/internlm/solver/optimizer/npu_fused_adamw.py b/internlm/solver/optimizer/npu_fused_adamw.py new file mode 100644 index 00000000..5ae612da --- /dev/null +++ b/internlm/solver/optimizer/npu_fused_adamw.py @@ -0,0 +1,151 @@ +# adpated from https://gitee.com/ascend/AscendSpeed/blob/master/ascendspeed/optimizer/adamw.py +# commit id: c722d00aed8d883f3e92a9d074bf1a41bd589c56 +# pylint: skip-file +# flake8: noqa + +from typing import List, Optional + +import torch +from torch import Tensor +from torch.optim.optimizer import Optimizer + +try: + import torch_npu +except (ModuleNotFoundError, ImportError): + pass + + +def adamw( + params: List[Tensor], + grads: List[Tensor], + exp_avgs: List[Tensor], + exp_avg_sqs: List[Tensor], + max_exp_avg_sqs: List[Tensor], + step: int, + *, + amsgrad: bool, + beta1: float, + beta2: float, + lr: float, + weight_decay: float, + eps: float, + maximize: bool +): + r"""Functional API that performs AdamW algorithm computation. + See :class:`~torch.optim.AdamW` for details. + """ + for i, param in enumerate(params): + grad = grads[i] + exp_avg = exp_avgs[i] + exp_avg_sq = exp_avg_sqs[i] + + # Perform stepweight decay + ## param.mul_(1 - lr * weight_decay) + bias_correction1 = beta1 ** (step - 1) + bias_correction2 = beta2 ** (step - 1) + + param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w( + bias_correction1, + bias_correction2, + lr, + weight_decay, + beta1, + beta2, + eps, + grad, + None, + amsgrad, + maximize, + out=(param.data, exp_avg, exp_avg_sq), + ) + + +class AdamW(Optimizer): + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, *, maximize: bool = False + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize) + super(AdamW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdamW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault("amsgrad", False) + group.setdefault("maximize", False) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + state_sums = [] + max_exp_avg_sqs = [] + state_steps = [] + amsgrad = group["amsgrad"] + beta1, beta2 = group["betas"] + + if "step" in group: + group["step"] += 1 + else: + group["step"] = 1 + + for p in group["params"]: + if p.grad is None: + continue + params_with_grad.append(p) + if p.grad.is_sparse: + raise RuntimeError("AdamW does not support sparse gradients") + grads.append(p.grad) + + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + group["step"], + amsgrad=amsgrad, + beta1=beta1, + beta2=beta2, + lr=group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + maximize=group["maximize"], + ) + + return loss diff --git a/internlm/solver/optimizer/store.py b/internlm/solver/optimizer/store.py index aa1c2b88..22adcb3d 100644 --- a/internlm/solver/optimizer/store.py +++ b/internlm/solver/optimizer/store.py @@ -3,6 +3,7 @@ from typing import List +import torch from torch import Tensor from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors @@ -320,3 +321,298 @@ def unflatten_and_copy(self): unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket) for old, new in zip(self._bucket, unflattened_tensor_list): old.copy_(new) + + +class BucketStore_v2(BaseStore): + """ + Bucket Store V2 + """ + + def __init__(self, group_id, dp_parallel_mode, zero_mode=ParallelMode.ZERO1): + super().__init__(dp_parallel_mode) + self.zero_world_size = gpc.get_world_size(zero_mode) + self.zero_local_rank = gpc.get_local_rank(zero_mode) + self._dp_parallel_mode = dp_parallel_mode + self._group_id = group_id + self.reset_all() + + def get_param_group_id(self): + return self._group_id + + def get_dp_parallel_mode(self): + return self._dp_parallel_mode + + def reset_all(self) -> None: + # init + self._num_elements_in_bucket = 0 + # mapping gradient slices and parameter + self.grad_to_param_mapping = dict() + + self._grad_in_bucket = dict() + + self._grad_current_rank_for_group = dict() + self._param_list_for_group = dict() + self._padding_size_for_group = dict() + self.grad_to_param_mapping2 = dict() + self.offset_list_for_group = dict() + + self._param_list = [] + self._padding_size = [] + for rank in range(self.zero_world_size): + self._grad_in_bucket[rank] = [] + + # offset_list records number of tensors in the bucket before each reduction + self.offset_list = [0] + + def num_elements_in_bucket(self) -> int: + """Return the total number of elements in bucket + + Returns: + int: the total number of elements in bucket + """ + + return self._num_elements_in_bucket + + def reset_num_elements_in_bucket(self): + """Set the number of elements in bucket to zero.""" + + self._num_elements_in_bucket = 0 + + def add_param_grad(self, param: Tensor, padding_size: int): + """Add a param to bucket and record the padding size of a param for gradient padding + + Args: + group_id (int): The index of a parameter group + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._param_list.append(param) + self._padding_size.append(padding_size) + self._num_elements_in_bucket += param.numel() + padding_size + + # number of tensors in current bucket + self.offset_list[-1] += 1 + + def build_grad_in_bucket(self, comm_stream): + """Organize parameters' gradient(padding and split), follows the parameters' splitting method + + Data structure of self._grad_in_bucket: + { + rank0: [grad0_rank0, grad1_rank0, ...] + rank1: [grad0_rank1, grad1_rank1, ...] + } + """ + + for param, padding_size in zip(self._param_list, self._padding_size): + param.grad.record_stream(comm_stream) + grad = param.grad.clone().detach().flatten() + if padding_size > 0: + with torch.no_grad(): + grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size]) + grad_list = grad.split(grad.numel() // self.zero_world_size) + for rank in range(self.zero_world_size): + grad_current_rank = grad_list[rank].clone().detach() + self.grad_to_param_mapping[id(grad_current_rank)] = id(param) + self._grad_in_bucket[rank].append(grad_current_rank) + param.grad = None + + self.offset_list.append(0) + + def get_grad(self): + """Return the dictionary of gradients slices, of which the keys are ranks + + Returns: + Dict: The dictionary of gradients slices + """ + + return self._grad_in_bucket + + def get_param(self): + return self._param_list + + def get_flatten_grad(self) -> Tensor: + """Return the flattened gradients slices in the bucket, the data organization of the flattened tensor: + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] + + Returns: + Tensor: the flattened gradients slices in the bucket + """ + + flat_grad = [] + for grad_list in self._grad_in_bucket.values(): + flat_grad.append(_flatten_dense_tensors(grad_list)) + flat_grad = _flatten_dense_tensors(flat_grad) + return flat_grad + + def get_param_id_of_grad(self, grad: Tensor) -> int: + """Return the id of a parameter which the gradient slice belongs to + + Args: + grad (Tensor): the gradient slice + + Returns: + int: the id of a parameter which the gradient slice belongs to + """ + + return self.grad_to_param_mapping[id(grad)] + + def reset(self): + """Reset the bucket storage after reduction, only release the tensors have been reduced""" + cur_offset = self.offset_list.pop(0) + self._param_list = self._param_list[cur_offset:] + self._padding_size = self._padding_size[cur_offset:] + for _ in range(cur_offset): + del self.grad_to_param_mapping[next(iter(self.grad_to_param_mapping))] + for rank in range(self.zero_world_size): + self._grad_in_bucket[rank] = self._grad_in_bucket[rank][cur_offset:] + + +class GradientStore_v2(BaseStore): + """ + Gradient Store V2 + """ + + def __init__(self, *args, partition_grad: bool = False, zero_mode=ParallelMode.ZERO1): + super().__init__(*args) + """ + self._grads_of_params mapping the parameter and its gradient slices + data structure: + { + group_id:{ + param_id: [grad_rank0, grad_rank1, ...] + } + } + """ + self.zero_world_size = gpc.get_world_size(zero_mode) + self.zero_local_rank = gpc.get_local_rank(zero_mode) + self._grads_of_params = dict() + # for zero2, it's `param_id: [grad_local_rank]` + self._working_index = 0 if partition_grad else self.zero_local_rank + + self.grad_to_param_mapping = dict() + + def get_partitioned_gradients_by_param_id(self, group_id: int, param_id: int) -> List: + """Return list of gradient slices of a specific parameter + + Args: + group_id (int): The index of a parameter group + param_id (int): The id of a parameter + + Returns: + List: the list of gradient slices of a parameter. + """ + + if group_id in self._grads_of_params: + if param_id in self._grads_of_params[group_id]: + return self._grads_of_params[group_id][param_id] + + # the param has no grad, for instance, in layer drop + return [] + + def append_gradients_by_param_id(self, grad: Tensor, group_id: int, param_id: int): + """Append a gradient slice to the parameter's gradient slice list + + Args: + grad (Tensor): The gradient slice to append to list + group_id (int): The index of a parameter group + param_id (int): The id of a parameter + """ + + if group_id not in self._grads_of_params: + self._grads_of_params[group_id] = dict() + if param_id not in self._grads_of_params[group_id]: + self._grads_of_params[group_id][param_id] = [grad] + else: + self._grads_of_params[group_id][param_id].append(grad) + + self.grad_to_param_mapping[id(grad)] = param_id + + def add_gradients_by_param_id(self, grad: Tensor, grad_idx: int, group_id: int, param_id: int): + """Add a gradient slice on an existing slice of the parameter's gradient + Used when no_sync is not activated. + + Args: + grad (Tensor): The split gradient to append to list + grad_idx (int): The index of the existing slice + group_id (int): The index of a parameter group + param_id (int): The id of a parameter + """ + + self._grads_of_params[group_id][param_id][grad_idx].add_(grad) + + def reset_grads_by_group_id(self, group_id: int): + self._grads_of_params[group_id] = dict() + + +class ParameterStore_v2(BaseStore): + """ + Parameter Store V2 + """ + + def __init__(self, dp_parallel_mode): + super().__init__(dp_parallel_mode) + + # record the padding size of each param + self._padding_map = dict() + + # mapping working param and master param + self.master_to_working_param = dict() + self.working_to_master_param = dict() + + self._bucket_reduced_param = {} + self._bucket_reduced_grad = {} + + def record_param_padding_size(self, param: Tensor, padding_size: int): + """Record the padding size of a param + + Args: + param (Tensor): The parameter + padding_size (int): The padding size of the parameter + """ + + self._padding_map[id(param)] = padding_size + + def get_param_padding_size(self, param: Tensor) -> int: + """Return the padding size of the parameter + + Args: + param (Tensor): The parameter + + Returns: + int: the padding size of the parameter + """ + + return self._padding_map[id(param)] + + def link_master_and_working_param(self, master_param: Tensor, working_param: Tensor): + """Mapping master parameter and working parameter + + Args: + master_param (Tensor): The parameter copy in optimizer + working_param (Tensor): The parameter of the model + """ + + self.master_to_working_param[id(master_param)] = working_param + self.working_to_master_param[id(working_param)] = master_param + + def add_reduced_param_for_compute_norm(self, param): + group_id = getattr(param, "group_id") + if group_id not in self._bucket_reduced_param: + self._bucket_reduced_param[group_id] = [] + self._bucket_reduced_grad[group_id] = [] + + self._bucket_reduced_param[group_id].append(param) + self._bucket_reduced_grad[group_id].append(param.grad) + + def get_reduced_param_for_compute_norm(self, group_id=0): + if group_id not in self._bucket_reduced_param: + return [], [] + return ( + self._bucket_reduced_param[group_id], + self._bucket_reduced_grad[group_id], + ) + + def reset_reduced_data_for_compute_norm(self): + self._bucket_reduced_param = {} + self._bucket_reduced_grad = {} diff --git a/internlm/solver/pipeline_utils.py b/internlm/solver/pipeline_utils.py deleted file mode 100644 index c57765e4..00000000 --- a/internlm/solver/pipeline_utils.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -def partition_uniform(num_items, pipeline_parallel_size, num_chunks): - assert ( - num_items % num_chunks == 0 - ), "Layer length should be divided by the number of chunks, otherwise parameter method is recomended" - - parts = [[] for _ in range(pipeline_parallel_size)] - partition_items = num_items // num_chunks - for idx in range(num_chunks): - base_idx = idx * partition_items - chunk_size = partition_items // pipeline_parallel_size - left = pipeline_parallel_size - partition_items % pipeline_parallel_size - if chunk_size == 0: - raise ValueError("Some nodes in Pipeline have no requests") - - for p in range(pipeline_parallel_size): - st = base_idx - base_idx += chunk_size + (p >= left) - parts[p].append((st, base_idx)) - - indexes = [] - for _parts in parts: - for s, e in _parts: - indexes.extend(list(range(s, e))) - assert len(indexes) == len(set(indexes)), indexes # should have no duplicates - assert set(indexes) == set(list(range(num_items))), (indexes, num_items) # should have the same indexes as expected - return parts diff --git a/internlm/train/__init__.py b/internlm/train/__init__.py index 9fc111ff..2ad60df0 100644 --- a/internlm/train/__init__.py +++ b/internlm/train/__init__.py @@ -1,24 +1,22 @@ from .pipeline import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_llm_profile, initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, record_current_batch_training_metrics, set_fp32_attr_for_model, set_parallel_attr_for_param_groups, - wrap_FSDP_model, ) __all__ = [ "initialize_llm_profile", "initialize_model", - "initialize_isp_communicator", + "initialize_parallel_communicator", "initialize_optimizer", "load_new_batch", "record_current_batch_training_metrics", - "wrap_FSDP_model", "get_scheduler_hooks", "set_parallel_attr_for_param_groups", "set_fp32_attr_for_model", diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 3152af0d..0c515661 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -1,28 +1,15 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import functools import math import time -from typing import Callable, Iterable, List, Optional, Union +from typing import Callable, Iterable, List, Optional, Tuple, TypeVar, Union import torch from torch import nn -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from torch.distributed.fsdp.fully_sharded_data_parallel import ( - BackwardPrefetch, - ShardingStrategy, -) -from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from torch.utils.data import DataLoader from internlm.accelerator import AcceleratorType, get_accelerator -from internlm.core.communication.isp import ( - ISPCommModelConfig, - ISPCommunicator, - ISPCommunicatorSchedulerHook, -) -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context import ( IS_REPLICA_ZERO_PARALLEL, IS_TENSOR_DATA_PARALLEL, @@ -33,37 +20,58 @@ ) from internlm.core.context import global_context as gpc from internlm.core.context.random import set_mode -from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module +from internlm.core.naive_amp import ( + NaiveAMPModel, + set_fp32_attr_to_module, + unwrap_naive_amp, +) +from internlm.core.parallel.comm.isp import ( + ISPCommModelConfig, + ISPCommunicator, + ISPCommunicatorSchedulerHook, +) +from internlm.core.parallel.comm.tensor import ( + EmbbedingSequenceParallelCommunicator, + EmbbedingTensorParallelCommunicator, + HeadSequenceParallelCommunicator, + HeadTensorParallelCommunicator, + LinearRole, + MoESequenceParallelCommunicator, + SequenceParallelCommunicator, + TensorParallelCommunicator, +) +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.core.trainer import TrainState -from internlm.data.utils import unpack_data +from internlm.data.utils import unpack_type_ids +from internlm.model.builder import create_model from internlm.model.metrics import SchedulerMetricHook from internlm.model.modules.embedding import Embedding1D -from internlm.model.modules.mlp import FeedForward -from internlm.model.modules.multi_head_attention import MHA +from internlm.model.modules.linear import ( + ColumnParallelLinear, + ParallelLinearWithCommExt, + RewardModelLinear, + RowParallelLinear, + ScaleColumnParallelLinear, +) +from internlm.model.modules.utils import is_moe_param from internlm.model.moe.megablock.mlp import ( MegaBlockFeedForward, MegaBlockGroupedFeedForward, ) from internlm.model.moe.moe import MoE -from internlm.model.ops.fusion_ops_import_helper import ( - try_import_FusedAdamW, - try_import_RMSNorm, -) -from internlm.model.ops.linear import ( - BaseScaleColumnParallelLinear, - ColumnParallelLinearTorch, - ISPLinear, - RewardModelLinear, - RowParallelLinearTorch, - ScaleColumnParallelLinear, -) -from internlm.model.utils import is_moe_param +from internlm.model.ops.norm import RMSNorm +from internlm.model.registry import register_model_initializer from internlm.monitor import set_env_var from internlm.monitor.monitor import monitor_manager as mm -from internlm.solver.optimizer import FSDPadaptOptimizer, HybridZeroOptimizer +from internlm.solver.optimizer import ( + FSDPadaptOptimizer, + HybridZeroOptimizer, + HybridZeroOptimizer_v2, +) +from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.solver.schedulers.beta2_scheduler import Beta2Scheduler from internlm.solver.schedulers.lr_scheduler import FineTuneCosineAnnealingWarmupLR -from internlm.train.utils import create_param_groups +from internlm.train.utils import create_param_groups, map_param_block from internlm.utils.common import DummyProfile, SchedulerHook, get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer @@ -77,7 +85,6 @@ sync_model_param, sync_model_replica_param_group, ) -from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.timeout import llm_timeout try: @@ -85,7 +92,6 @@ except (ImportError, ModuleNotFoundError): pass -RMSNorm = try_import_RMSNorm() logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -112,9 +118,8 @@ def _check_module(name, module): setattr(param, IS_REPLICA_ZERO_PARALLEL, True) # embedding and head - embedding_head_cls = (Embedding1D, BaseScaleColumnParallelLinear) - if isinstance(module, embedding_head_cls): + if isinstance(module, (Embedding1D, ScaleColumnParallelLinear)): for param in module.parameters(): if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): setattr(param, IS_TENSOR_DATA_PARALLEL, True) @@ -124,7 +129,7 @@ def _check_module(name, module): # for linear module if isinstance( module, - (ColumnParallelLinearTorch, RowParallelLinearTorch, MegaBlockFeedForward, MegaBlockGroupedFeedForward), + (ParallelLinearWithCommExt, MegaBlockFeedForward, MegaBlockGroupedFeedForward), ): for param in module.parameters(): if gpc.is_initialized(ParallelMode.EXPERT_DATA) and is_moe_param(param): @@ -138,21 +143,23 @@ def _check_module(name, module): # for vit and vit project if "vision_tower" in name.lower() or "vision_proj" in name.lower(): for param in module.parameters(): - if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): - setattr(param, IS_TENSOR_DATA_PARALLEL, True) - elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): - setattr(param, IS_TENSOR_ZERO_PARALLEL, True) + setattr(param, IS_REPLICA_ZERO_PARALLEL, True) - if not isinstance(model, nn.ModuleList): - model = [model] - - for _chunk in model: - if isinstance(_chunk, NaiveAMPModel): - _chunk = _chunk.model + def _check_module_hf(_, module): + # TODO: check parallel attribute for hf model + for param in module.parameters(): + if gpc.is_initialized(ParallelMode.TENSOR) and is_using_isp(): + setattr(param, IS_TENSOR_DATA_PARALLEL, True) + elif gpc.is_initialized(ParallelMode.TENSOR) and not is_using_isp(): + setattr(param, IS_TENSOR_ZERO_PARALLEL, True) + for _chunk in unwrap_naive_amp(model): # set param parallel attribute for name, module in _chunk.named_modules(): - _check_module(name, module) + if gpc.config.model_type == "hf": + _check_module_hf(name, module) + else: + _check_module(name, module) for name, param in _chunk.named_parameters(): assert ( @@ -175,7 +182,11 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f """ if pre_process_func: pre_process_output = pre_process_func() - model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(gpc.config.model)) + + register_model_initializer() + + model = create_model(model_type=gpc.config.model_type, **(gpc.config.model)) + if post_process_func: post_process_func(pre_process_output) @@ -218,47 +229,22 @@ def initialize_model(pre_process_func: Optional[Callable] = None, post_process_f random_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA set_mode(random_mode) - # if fsdp enabled, wrap the model - model = wrap_FSDP_model(model) - return model -def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]): - if gpc.config.parallel.zero1.fsdp and gpc.config.model.use_flash_attn: - from flash_attn.modules.embedding import ParallelGPT2Embeddings - - # set wrap_policy for fsdp wrap - transformer_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - transformer_layer_cls={ - Embedding1D, - ParallelGPT2Embeddings, - MHA, - RMSNorm, - FeedForward, - RewardModelLinear, - ScaleColumnParallelLinear, - }, - ) +_T = TypeVar("_T") - # wrap the model - grp = gpc.get_group(ParallelMode.ZERO1) - model = FSDP( # pylint: disable=unexpected-keyword-arg - module=model, - process_group=grp, - sharding_strategy=ShardingStrategy.FULL_SHARD, - auto_wrap_policy=transformer_wrap_policy, - forward_prefetch=True, - backward_prefetch=BackwardPrefetch.BACKWARD_PRE, - limit_all_gathers=True, - use_orig_params=True, - ) - return model +def _submodule_filter(model: Union[nn.Module, nn.ModuleList], target_cls: Union[_T, Tuple[_T]]) -> Iterable[_T]: + for _chunk in unwrap_naive_amp(model): + for _module in _chunk.modules(): + if not isinstance(_module, target_cls): + continue + yield _module -def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): + +def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): """ Initialize communicator for isp tensor parallel mode. @@ -269,6 +255,8 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): An isp communicator for managing comp/comm overlap and memory pool. """ isp_communicator = None + _retain_out_sharded = gpc.config.model.get("parallel_output", True) + if is_using_isp(): isp_communicator = ISPCommunicator( model, @@ -281,8 +269,73 @@ def initialize_isp_communicator(model: Union[nn.Module, nn.ModuleList]): gpc.config.parallel.weight.memory_pool, gpc.get_group(ParallelMode.WEIGHT), ) - # register communicator for isp linear. - ISPLinear.register_communicator(isp_communicator) + # register communicator for isp column parallel linear. + ColumnParallelLinear.register_cls_communicator(isp_communicator) + # row parallel linear will not be used. + RowParallelLinear.register_cls_communicator(None) + _head_communicator = HeadSequenceParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + + # register communictor for mtp/msp/fsp linear. + + # tensor parallel + if gpc.config.parallel.tensor.mode == "mtp": + ColumnParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + RowParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + _embedding_communicator = EmbbedingTensorParallelCommunicator(ParallelMode.TENSOR) + # sequence parallel + if gpc.config.parallel.tensor.mode in ("msp", "fsp"): + save_total_input_as_activation = gpc.config.parallel.tensor.mode == "msp" + + ColumnParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.COLUMN, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + RowParallelLinear.register_cls_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + ) + ) + + _head_communicator = HeadSequenceParallelCommunicator( + ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation + ) + _embedding_communicator = EmbbedingSequenceParallelCommunicator(ParallelMode.TENSOR) + + # MoE sequence parallel + if gpc.config.model.get("num_experts", 1) > 1: + _column_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN + ) + _row_communicator = TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW + ) + for moe in _submodule_filter(model, MoE): + # 1. the linear in MoE degrades the parallel communication pattern from sp to tp + for column_linear in _submodule_filter(moe, ColumnParallelLinear): + column_linear.register_communicator(_column_communicator) + for row_linear in _submodule_filter(moe, RowParallelLinear): + row_linear.register_communicator(_row_communicator) + # 2. register MoESequenceParallelCommunicator for MoE layer + MoESequenceParallelCommunicator(ParallelMode.TENSOR).register_module_hook(moe) + + # register communitorc for embedding layer. + for embedding in _submodule_filter(model, Embedding1D): + _embedding_communicator.register_module_hook(embedding) + + # register communictor for head layer. + ScaleColumnParallelLinear.register_cls_communicator(_head_communicator) + RewardModelLinear.register_cls_communicator(_head_communicator) return isp_communicator @@ -303,17 +356,16 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato zero_cfg = gpc.config.hybrid_zero_optimizer grad_scal_cfg = gpc.config.grad_scaler - params = create_param_groups(model, adam_cfg.weight_decay) + if "use_split_tensor_optim" in zero_cfg and zero_cfg.use_split_tensor_optim: + map_param_block(model) - # TODO(caikun): add DIPU backend adamw - adam_extra_kwargs, internlm_adamw = try_import_FusedAdamW() + params = create_param_groups(model, adam_cfg.weight_decay) - naive_optimizer = internlm_adamw( + naive_optimizer = new_compatible_adamw( params=params, lr=adam_cfg.lr, betas=(adam_cfg.adam_beta1, adam_cfg.adam_beta2), eps=adam_cfg.adam_eps, - **adam_extra_kwargs, ) if ( @@ -336,13 +388,25 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato param_bcast_sync_handler = None if not gpc.config.parallel.zero1.fsdp: - optimizer = HybridZeroOptimizer( - naive_optimizer, - grad_scal_cfg=grad_scal_cfg, - zero_cfg=zero_cfg, - param_bcast_sync_handler=param_bcast_sync_handler, - isp_communicator=isp_communicator, - ) + if ( + "use_split_tensor_optim" not in gpc.config.hybrid_zero_optimizer + or not gpc.config.hybrid_zero_optimizer.use_split_tensor_optim + ): + optimizer = HybridZeroOptimizer( + naive_optimizer, + grad_scal_cfg=grad_scal_cfg, + zero_cfg=zero_cfg, + param_bcast_sync_handler=param_bcast_sync_handler, + isp_communicator=isp_communicator, + ) + else: + optimizer = HybridZeroOptimizer_v2( + naive_optimizer, + grad_scal_cfg=grad_scal_cfg, + zero_cfg=zero_cfg, + param_bcast_sync_handler=param_bcast_sync_handler, + isp_communicator=isp_communicator, + ) else: optimizer = FSDPadaptOptimizer( naive_optimizer, @@ -411,7 +475,8 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai if batch[0].get("type_ids", None) is not None: # if use_packed_dataset is False, we need to unpack type_ids if not gpc.config.data.use_packed_dataset: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) + if gpc.config.data.type != "hf" or gpc.config.model_type != "hf": + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) return batch, train_iter @@ -474,6 +539,7 @@ def record_current_batch_training_metrics( beta2_scheduler, trainer, start_time, + very_begining_time, loss, moe_loss, grad_norm, @@ -500,10 +566,17 @@ def record_current_batch_training_metrics( num_tokens_in_batch = batch[1].nelement() real_num_tokens = math.ceil(acc_perplex.pop("real_token_num") / gpc.get_world_size(ParallelMode.GLOBAL)) - num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) - max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) - min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + # TODO: check logic + if gpc.config.data.type == "hf" and gpc.config.model_type == "hf" and not gpc.config.data.use_packed_dataset: + num_samples_in_batch = gpc.config.data.micro_bsz * gpc.config.data.micro_num + max_length_in_batch = batch[0]["attention_mask"].sum(dim=1).max().item() + max_samples_in_batch = gpc.config.data.micro_bsz + min_samples_in_batch = gpc.config.data.micro_bsz + else: + num_samples_in_batch = sum([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + max_length_in_batch = max([(b[1:] - b[:-1]).max().item() for b in batch[0]["cu_seqlens"]]) + max_samples_in_batch = max([len(b) - 1 for b in batch[0]["cu_seqlens"]]) + min_samples_in_batch = min([len(b) - 1 for b in batch[0]["cu_seqlens"]]) time_cost = time.time() - start_time tk_per_gpu = round( num_tokens_in_batch * gpc.get_world_size(ParallelMode.DATA) / gpc.get_world_size(ParallelMode.GLOBAL), @@ -512,7 +585,7 @@ def record_current_batch_training_metrics( tgs_statistic = train_state.tgs_statistic tgs_statistic["sum_step"] += 1 tgs_statistic["sum_tg"] += tk_per_gpu - tgs_statistic["sum_time"] += time_cost + tgs_statistic["total_time"] = time.time() - very_begining_time tgs_statistic["sum_last_tg_10"] += tk_per_gpu tgs_statistic["sum_last_time_10"] += time_cost tgs_statistic["sum_last_tg_50"] += tk_per_gpu @@ -543,7 +616,7 @@ def record_current_batch_training_metrics( last_tgs_10 = tgs_statistic["last_tgs_10"] last_tgs_50 = tgs_statistic["last_tgs_50"] - tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["sum_time"], 2) + tgs_all = round(tgs_statistic["sum_tg"] / tgs_statistic["total_time"], 2) tgs_avg = round(tgs_statistic["sum_tgs"] / tgs_statistic["sum_step"], 2) tgs_SMA = round(tgs_statistic["SMA_tg_50"] / tgs_statistic["SMA_time_50"], 2) @@ -592,6 +665,8 @@ def record_current_batch_training_metrics( fwd_bwd_time = round(timer("fwd-bwd").elapsed(), 2) infos["fwd_bwd_time"] = fwd_bwd_time + bwd_time = round(timer("bwd").elapsed(), 2) + infos["bwd_time"] = bwd_time for key, value in acc_perplex.items(): infos[key] = value diff --git a/internlm/train/utils.py b/internlm/train/utils.py index 39d088b4..2f303443 100644 --- a/internlm/train/utils.py +++ b/internlm/train/utils.py @@ -1,10 +1,12 @@ from typing import Dict, Tuple import torch +from torch import nn from internlm.core.context.parallel_context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.utils import is_moe_param +from internlm.core.naive_amp import unwrap_naive_amp +from internlm.model.modules.utils import is_moe_param from internlm.utils.parallel import is_tensor_data_parallel_parameter, is_using_isp @@ -61,11 +63,11 @@ def split_params_into_different_groups_for_optimizer( if is_tensor_data_parallel_parameter(param): # should not be here if not isp mode new_groups["embed_head"]["params"].append(param) - elif param.dtype == torch.float32: - new_groups["fp32"]["params"].append(param) # moe param means MoE is enabled elif is_moe_param(param): new_groups[param.group_name]["params"].append(param) + elif param.dtype == torch.float32 and gpc.config.model.dtype != torch.float32: + new_groups["fp32"]["params"].append(param) else: origin_params.append(param) @@ -86,3 +88,16 @@ def create_param_groups(model, weight_decay): "weight_decay": weight_decay, } return split_params_into_different_groups_for_optimizer(parameters) + + +def map_param_block(model): + for _chunk in unwrap_naive_amp(model): + for name, children in _chunk.named_children(): + if isinstance(children, nn.ModuleList): + for idx, block in enumerate(children): + block_name = name + f"_{idx}" + for param in block.parameters(): + setattr(param, "block_name", block_name) + else: + for param in children.parameters(): + setattr(param, "block_name", name) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index c5ba65b0..323613d7 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -5,6 +5,7 @@ import inspect import os import random +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from datetime import datetime @@ -45,45 +46,17 @@ def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Te return norm -def _move_tensor(element): - if not torch.is_tensor(element): - # we expecte the data type if a list of dictionaries - for idx, item in enumerate(element): - if isinstance(item, dict): - for key, value in item.items(): - assert value.device.type == "cpu" - item[key] = value.to(get_current_device()).detach() - elif isinstance(item, list): - for index, value in enumerate(item): - assert value.device.type == "cpu" - item[index] = value.to(get_current_device()).detach() - elif torch.is_tensor(item): - if item.device.type == "cpu": - element[idx] = item.to(get_current_device()).detach() - else: - assert False, f"{type(item)}, {item}" - else: - assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}" - if element.device.type == "cpu": - element = element.to(get_current_device()).detach() - return element - - def move_to_device(data): if isinstance(data, torch.Tensor): - data = data.to(get_current_device()) + if data.device.type == "cpu": + data = data.to(get_current_device()).detach() elif isinstance(data, (list, tuple)): - data_to_return = [] - for element in data: - if isinstance(element, dict): - data_to_return.append({k: _move_tensor(v) for k, v in element.items()}) - else: - data_to_return.append(_move_tensor(element)) - data = data_to_return + data = [move_to_device(x) for x in data] elif isinstance(data, dict): - data = {k: _move_tensor(v) for k, v in data.items()} + data = {k: move_to_device(v) for k, v in data.items()} else: - raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") + # other types like scalar, other params, return the value itself. + return data return data @@ -197,18 +170,27 @@ def __call__(self, batch_count): class SingletonMeta(type): """ - Singleton Meta. + Thread-safe Singleton Meta with double-checked locking. + Reference: https://en.wikipedia.org/wiki/Double-checked_locking """ _instances = {} + _lock = threading.Lock() def __call__(cls, *args, **kwargs): + # First check (without locking) for performance reasons if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) + # Acquire a lock before proceeding to the second check + with cls._lock: + # Second check with lock held to ensure thread safety + if cls not in cls._instances: + instance = super().__call__(*args, **kwargs) + cls._instances[cls] = instance else: assert ( len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." + ), f"{cls.__name__} is a singleton class and an instance has been created." + return cls._instances[cls] @@ -247,10 +229,14 @@ def get_megatron_flops( def enable_pytorch_expandable_segments(): if torch.__version__ >= "2.1.0" and AcceleratorType.GPU == internlm_accelerator.get_accelerator_backend(): - _alloc_setting = "expandable_segments:True" - if os.getenv("PYTORCH_CUDA_ALLOC_CONF", None) is not None: - _alloc_setting = os.getenv("PYTORCH_CUDA_ALLOC_CONF") + "," + _alloc_setting - internlm_accelerator.memory._set_allocator_settings(_alloc_setting) + _expandable_segments_conf = "expandable_segments:True" + _alloc_conf = os.getenv("PYTORCH_CUDA_ALLOC_CONF", None) + if _alloc_conf is None: + _alloc_conf = _expandable_segments_conf + elif "max_split_size_mb" not in _alloc_conf: + _alloc_conf = _alloc_conf + "," + _expandable_segments_conf + + internlm_accelerator.memory._set_allocator_settings(_alloc_conf) else: logger.warning("To support the 'expandable_segments' configuration, please upgrade torch to version 2.1.0.") diff --git a/internlm/utils/parallel.py b/internlm/utils/parallel.py index 4de457cd..1b92974d 100644 --- a/internlm/utils/parallel.py +++ b/internlm/utils/parallel.py @@ -12,9 +12,6 @@ ParallelMode, ) from internlm.core.context import global_context as gpc -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm - -RMSNorm = try_import_RMSNorm() def is_using_sequence_parallel(): @@ -74,7 +71,6 @@ def sync_model_param(model): Args: model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. """ - sync_moe_param = gpc.is_using_parallel_mode(ParallelMode.EXPERT_DATA) sync_parallel_mode = ParallelMode.WEIGHT_DATA if is_using_isp() else ParallelMode.DATA for param in model.parameters(): diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 563ea69c..6aa1ebd1 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -4,6 +4,8 @@ import multiprocessing import os +from internlm.utils.common import SingletonMeta + if "USE_DILL_PICKLE" in os.environ: import dill @@ -964,23 +966,6 @@ def check_tmp_folder_accessibility(tmp_local_folder: str): raise RuntimeError(error_str) -class SingletonMeta(type): - """ - Singleton Meta. - """ - - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - else: - assert ( - len(args) == 0 and len(kwargs) == 0 - ), f"{cls.__name__} is a singleton class and a instance has been created." - return cls._instances[cls] - - class StorageManager(metaclass=SingletonMeta): """ Storage Manager for saving or loading checkpoint. diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 9a30eb26..34766b3b 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -1,5 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import types from contextlib import contextmanager +from enum import Enum, IntEnum +from functools import update_wrapper +from typing import Callable, Tuple + +import torch @contextmanager @@ -16,3 +22,107 @@ def read_base(): .. _tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html#a-pure-python-style-configuration-file-beta # pylint: disable=line-too-long """ # noqa: E501 yield + + +class QKVPackType(IntEnum): + QKVPACKED = 2 + KVPACKED = 3 + QKVSPLITED = 4 + + def __str__(self) -> str: + return str(self.value) + + +class CuSeqlenType(Enum): + With = True + WithOut = False + + def __str__(self) -> str: + return str(self.value) + + +def check_attention_argument(*args, **kwargs) -> str: + # self, qkv, ... + # self, q, kv, .... + # self, q, k, v, ... + # self, qkv, cu_seqlens, max_seqlen, ... + # self, q, kv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ... + # self, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ... + def __qkv_checker(num_args: int): + if num_args < 2: + return "qkv" in kwargs + else: + # qkv: [batch, seqlen, 3, n_head, headdim] + return len(args[1].shape) == 5 + + def __kv_checker(num_args: int): + if num_args < 3: + return "kv" in kwargs + else: + # kv: [batch, seqlen, 3, n_head, headdim] + return len(args[2].shape) == 5 + + def __cu_seqlens_checker(args, check_idx: int): + num_args = len(args) + if num_args < (check_idx + 1): + if check_idx == 2: + return "cu_seqlens" in kwargs and kwargs["cu_seqlens"] is not None + else: + return "cu_seqlens_q" in kwargs and kwargs["cu_seqlens_q"] is not None + else: + return isinstance(args[check_idx], torch.Tensor) + + if __qkv_checker(len(args)): + # qkv packed, and we should check cu_seqlens with index 2 + qkv_pack_type = int(QKVPackType.QKVPACKED) + elif __kv_checker(len(args)): + # kv packed, and we should check cu_seqlens with index 3 + qkv_pack_type = int(QKVPackType.KVPACKED) + else: + # qkv splited, and we should check cu_seqlens with index 4 + qkv_pack_type = int(QKVPackType.QKVSPLITED) + + with_cu_seqlens = __cu_seqlens_checker(args, qkv_pack_type) + + return str(qkv_pack_type), str(with_cu_seqlens) + + +def params_dispatch_with_condition(condition: Callable, func: Callable = None): + + if func is None: + # create a params dispatch wrapper + return lambda f: params_dispatch_with_condition(condition, f) + + registry = {} + funcname = getattr(func, "__name__", "params_dispatch_with_condition function") + + def dispatch(_type: str) -> Callable: + return registry[_type] + + def register(conditions: Tuple[str], func: Callable = None) -> None: + if func is None: + # create a register wrapper + return lambda f: register(conditions, f) + + _type = "-".join(conditions) + + assert _type not in registry, f"Repeatedly register dispatch functions for pattern {_type}" + + registry[_type] = func + + return func + + def wrapper(*args, **kwargs): + if not args: + raise TypeError(f"{funcname} requires at least " "1 positional argument") + + _type = "-".join(condition(*args, **kwargs)) + + return dispatch(_type)(*args, **kwargs) + + registry[""] = func + wrapper.register = register + wrapper.dispatch = dispatch + wrapper.registry = types.MappingProxyType(registry) + update_wrapper(wrapper, func) + return wrapper diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 8416d49f..595a31bd 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,4 +1,4 @@ -transformers<4.30.0 +transformers sentencepiece numpy tqdm diff --git a/setup.py b/setup.py index c46d89fb..c5dd9f20 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ pwd = os.path.dirname(__file__) def readme(): - with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: + with open(os.path.join(pwd, 'README.md')) as f: content = f.read() return content @@ -17,6 +17,10 @@ def get_version(): content = f.read() return content +def fetch_requirements(path): + with open(path, 'r') as fd: + return [r.strip() for r in fd.readlines() if 'torch-scatter' not in r and not r.startswith('-f ')] + setup( name='InternEvo', version=get_version(), @@ -24,12 +28,14 @@ def get_version(): long_description=readme(), long_description_content_type='text/markdown', packages=find_packages(), + install_requires=[ + fetch_requirements('requirements/runtime.txt'), + 'rotary_emb', + 'xentropy', + ], classifiers=[ - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', 'Intended Audience :: Developers', 'Intended Audience :: Education', 'Intended Audience :: Science/Research', diff --git a/tests/common_fixture.py b/tests/common_fixture.py index 71065523..6f3def53 100644 --- a/tests/common_fixture.py +++ b/tests/common_fixture.py @@ -9,7 +9,7 @@ from internlm.accelerator import get_accelerator from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config -from internlm.data.utils import unpack_data +from internlm.data.utils import unpack_type_ids from internlm.initialize.launch import args_sanity_check internlm_accelerator = get_accelerator() @@ -36,6 +36,7 @@ diag_outlier_ratio=1.1, train_folder=None, valid_folder=None, + num_worker=0, ), model=dict( checkpoint=False, @@ -148,6 +149,6 @@ def load_new_batch(train_dl, train_iter): if batch[0].get("type_ids", None) is not None: # if use_flash_attn is False, we need to unpack type_ids if not gpc.config.model.use_flash_attn: - batch[0]["type_ids"] = unpack_data(batch[0]["type_ids"], batch[0]["cu_seqlens"], is_type_ids=True) + batch[0]["type_ids"] = unpack_type_ids(batch[0]["type_ids"], batch[0]["cu_seqlens"]) return batch, train_iter diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index 03968cc4..0c5703e7 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -3,10 +3,10 @@ import pytest import torch -from internlm.accelerator import AcceleratorType, get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.solver.optimizer.compatible_adamw import new_compatible_adamw from internlm.utils.common import get_current_device from tests.test_core.utils import ( MlpModel, @@ -123,9 +123,15 @@ def exam_pipeline_parallel(args): # pp forward and backward output_list = [] for _ in range(10): - output, _, loss = scheduler.forward_backward_step( - engine, input_list, forward_only=False, return_loss=True, return_output_label=True + res = scheduler.forward_backward_step( + engine, + input_list, + forward_only=False, + return_loss=True, + return_output_label=True, ) + output = res[0] + loss = res[2] output_list.append(output) # engine.step() @@ -135,22 +141,12 @@ def exam_pipeline_parallel(args): torch_xs = torch.tensor(x_list).to(device).to(torch.float32) torch_ys = torch.tensor(y_list).to(device).to(torch.float32) torch_model = MlpModel(0, 32, "torch").to(device) - adam_extra_kwargs = {} - if get_accelerator().get_accelerator_backend() == AcceleratorType.NPU: - import torch_npu - internlm_adamw = torch_npu.optim.NpuFusedAdamW - else: - internlm_adamw = torch.optim.AdamW - if torch.__version__ >= "2.1.0": - adam_extra_kwargs["fused"] = True - - torch_optimizer = internlm_adamw( + torch_optimizer = new_compatible_adamw( params=[{"params": torch_model.parameters(), "weight_decay": config.adam.weight_decay}], lr=config.adam.lr, betas=(config.adam.adam_beta1, config.adam.adam_beta2), eps=config.adam.adam_eps, - **adam_extra_kwargs, ) # check only forward logits diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py index 5c767914..5ccaccaf 100644 --- a/tests/test_core/utils.py +++ b/tests/test_core/utils.py @@ -11,13 +11,13 @@ from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler +from internlm.core.parallel.shard import partition_uniform from internlm.core.scheduler import ( InterleavedPipelineScheduler, NonPipelineScheduler, PipelineScheduler, ) from internlm.model.metrics import SchedulerMetricHook -from internlm.solver.pipeline_utils import partition_uniform from internlm.train import initialize_optimizer from internlm.utils.common import get_current_device @@ -41,7 +41,7 @@ def forward( ): # pylint: disable=W0613 if self.model_type != "torch" and self.part[0] != 0: input_ids = hidden_states - + # Simulate Embedding. if self.embedding: if len(input_ids.shape) == 2: diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index c346357c..6b310625 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -187,6 +187,7 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): config.data.gradient_accumulation = config.data.micro_num config.data.rampup_batch_size = group_case[1] config.data.packed_length = micro_bsz_case * config.data.seq_len + config.data.use_shm = False should_sccuess = group_case[2] answer = group_case[3] diff --git a/tests/test_infer/test_generate.py b/tests/test_infer/test_generate.py new file mode 100644 index 00000000..a169c96e --- /dev/null +++ b/tests/test_infer/test_generate.py @@ -0,0 +1,133 @@ +import os + +import pytest +import torch +from sentencepiece import SentencePieceProcessor + +from internlm.apis.inference import SequenceGenerator, batch_tokenize +from internlm.initialize import initialize_distributed_env # noqa: E402 +from internlm.train import initialize_model, initialize_parallel_communicator + + +def set_seed(seed: int = 1024): + import random + + import numpy as np + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def load_and_generate(path, model_type="INTERNLM2_PUBLIC", tokenizer_path=""): + model_cfg = os.path.join(path, "model_config.pt") + model_wt = os.path.join(path, "model_tp0_pp0.pt") + model_config = torch.load(model_cfg) + model_config["apply_post_layer_norm"] = False + if model_config.get("adapt_hf") is not None: + model_config.pop("adapt_hf") + evo_cfg = dict( + model_type=model_type, + model=model_config, + parallel=dict( + zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=True), + tensor=dict(size=1, mode="mtp"), + sequence_parallel=0, + ), + ) + initialize_distributed_env(evo_cfg, master_port=23574, args_check=False) + + tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121 + + def convert_to_str(output_ids): + output_tokens = output_ids.tolist() + all_output_str = [] + for b in range(len(output_tokens)): + for sent_idx in range(len(output_tokens[b])): + cur_output_tokens = output_tokens[b][sent_idx] + cur_sent = tokenizer.decode(cur_output_tokens) + all_output_str.append(cur_sent) + return all_output_str + + model = initialize_model() + _ = initialize_parallel_communicator(model) + # Directly get the origin model without NativeAMP wrapper. + model = model.model + + state_dict = torch.load(model_wt) + load_info = model.load_state_dict(state_dict, strict=False) + print(load_info) + + sequenece_generator = SequenceGenerator( + decoder=model, + eos_token_id=tokenizer.eos_id(), + pad_token_id=tokenizer.bos_id(), + bos_token_id=tokenizer.bos_id(), + additional_eos_token_list=None, + ) + + test_prompt_0 = "Gold is considered to be a precious metal." + test_prompt_1 = "what is love? someone think it is a feeling, someone think it is a chemical reaction." + test_prompt_2 = "kobe bryant is a basketball player." + + prompt_3 = [ + test_prompt_0, + test_prompt_1, + test_prompt_2, + ] + prompt_2 = [ + test_prompt_0, + test_prompt_1, + ] + + prompt_1 = [test_prompt_0] + + def generate(prompt): + input_ids = batch_tokenize(prompt, tokenizer, pad_token_id=tokenizer.bos_id()).cuda() + generate_kwargs = {} + set_seed() + output_ids = sequenece_generator.generate( + input_ids, + num_return_sequences=generate_kwargs.get("num_return_sequences", 1), + max_length=generate_kwargs.get("max_length", input_ids.shape[1] + 80), + num_beams=generate_kwargs.get("num_beams", 1), + do_sample=generate_kwargs.get("do_sample", False), + temperature=generate_kwargs.get("temperature", 1.0), + top_k=generate_kwargs.get("top_k", 50), + top_p=generate_kwargs.get("top_p", 1.0), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1), + length_penalty=generate_kwargs.get("repetition_penalty", 1.0), + ) + + all_output_str = convert_to_str(output_ids) + return all_output_str + + output_3 = generate(prompt_3) + output_2 = generate(prompt_2) + output_1 = generate(prompt_1) + + assert output_3[0] == output_2[0] + assert output_3[1] == output_2[1] + assert ( + output_1[0] + == "Gold is considered to be a precious metal. It is a metal that is highly valued for its \ +rarity and beauty. Gold is often used in jewelry, coins, and other decorative items. It is also used in \ +the production of electronics and other high-tech products. Gold is a highly sought-after metal because \ +of its ability to resist corrosion and tarnish. It is also highly resistant to fire and is a good conductor \ +of heat and electricity.\n" + ) + print("test generate done!") + + +def test_internlm2_1_8B_generate(): + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + model_dir = os.path.join(base_model_dir, "internlm2_1_8B") + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(model_dir) and os.path.exists(tokenizer_path): + load_and_generate(model_dir, tokenizer_path=tokenizer_path) + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "-v", "test_generate.py"]) diff --git a/tests/test_infer/test_trainer_generate.py b/tests/test_infer/test_trainer_generate.py new file mode 100644 index 00000000..3ccbfb54 --- /dev/null +++ b/tests/test_infer/test_trainer_generate.py @@ -0,0 +1,201 @@ +import os + +import pytest +from sentencepiece import SentencePieceProcessor + +import internlm # noqa: E402 +from internlm.apis.inference import SequenceGenerator, batch_tokenize +from internlm.checkpoint import CheckpointManager # noqa: E402 +from internlm.core.context import global_context as gpc # noqa: E402 +from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.data import build_train_loader_with_data_type # noqa: E402 +from internlm.initialize import initialize_distributed_env # noqa: E402 +from internlm.model.losses import FlashGPTLMLoss # noqa: E402 +from internlm.train import ( # noqa: E402 + get_scheduler_hooks, + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) + + +def setup_generator(config, tokenizer): + initialize_distributed_env(config=config) + + model = initialize_model() + isp_communicator = initialize_parallel_communicator(model) + + criterion = FlashGPTLMLoss() + + # initialize the train data loader + train_dl, _ = build_train_loader_with_data_type() + + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) + + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) + + ckpt_manager = CheckpointManager( + ckpt_config=gpc.config.ckpt, + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + train_dl=train_dl, + model_config=gpc.config.model, + feishu_address=gpc.config.monitor.alert.feishu_alert_address, + ) + ckpt_manager.try_resume_training(train_state) + + # initialize trainer + engine, scheduler = internlm.initialize_trainer( + model=model, + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + scheduler_hooks=get_scheduler_hooks(None, optimizer, isp_communicator), + ) + trainer = Trainer(engine, scheduler) + + trainer.schedule.data_process_func = None + + if isinstance(tokenizer, SentencePieceProcessor): + eos_token_id = tokenizer.eos_id() + pad_token_id = tokenizer.eos_id() + bos_token_id = tokenizer.bos_id() + else: + eos_token_id = tokenizer.eos_token_id + pad_token_id = tokenizer.pad_token_id + bos_token_id = tokenizer.bos_token_id + + sequenece_generator = SequenceGenerator( + decoder=trainer, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + additional_eos_token_list=None, + ) + + return sequenece_generator + + +def do_generate(config, tokenizer_path, prompt): + tokenizer = SentencePieceProcessor(tokenizer_path) # pylint: disable=E1121 + + sequenece_generator = setup_generator(config, tokenizer) + input_ids = batch_tokenize(prompt, tokenizer, pad_token_id=tokenizer.bos_id()).cuda() + + generate_kwargs = {} + output_ids = sequenece_generator.generate( + input_ids, + num_return_sequences=generate_kwargs.get("num_return_sequences", 1), + max_length=generate_kwargs.get("max_length", 100), + num_beams=generate_kwargs.get("num_beams", 1), + do_sample=generate_kwargs.get("do_sample", True), + temperature=generate_kwargs.get("temperature", 1.0), + top_k=generate_kwargs.get("top_k", 50), + top_p=generate_kwargs.get("top_p", 1.0), + repetition_penalty=generate_kwargs.get("repetition_penalty", 1), + length_penalty=generate_kwargs.get("repetition_penalty", 1.0), + ) + output_tokens = output_ids.tolist() + all_output_str = [] + for b in range(len(output_tokens)): + for sent_idx in range(len(output_tokens[b])): + cur_output_tokens = output_tokens[b][sent_idx] + cur_sent = tokenizer.decode(cur_output_tokens) + all_output_str.append(cur_sent) + return all_output_str + + +def test_luyou_2B_generate(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/Luyou_1B_merged.py") + + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "There are several things you can do to keep flies away from your house:\n\n\ +1. Keep your home clean: Flies are attracted to food and dirty surfaces. Make sure that your home \ +is well-maintained and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You can also use plastic baggies to keep any food that is dropped on your porch, \ +patio, or windowsill from attracting flies.\n[UNUSED_TOKEN_145]\nNo[UNUSED_TOKEN_145]\nYou could also \ +use scented candles or diffusers" + ) + + +@pytest.mark.skip("requires 2 gpu") +def test_internlm2_pp2_generate(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/Luyou_1B_PP2.py") + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "There are several things you can do to keep flies away \ +from your house:\n\n1. Keep your home clean: Flies are attracted to food and dirty surfaces. Make sure that your \ +home is well-maintained and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You can also use plastic baggies to keep any food that is dropped on your porch, patio, or \ +windowsill from attracting flies.\n[UNUSED_TOKEN_145]\nNo[UNUSED_TOKEN_145]\nYou could also use scented candles \ +or diffusers" + ) + + +@pytest.mark.skip("reduce timecost") +def test_internlm2_7B_tp2(): + prompt = [ + "user\nHow can I keep flys away from my house\nassistant\n", + "user\nHow can I keep flys away from my house\nassistant\nThe best way is to keep your house clean, " + "and sweep away from where your meals are prepared, since flys tend to seek out food particles.\n" + "user\nAny other advice?\nassistant\n", + ] + + base_model_dir = os.environ.get("qa_data") + if base_model_dir is not None: + config = os.path.join(base_model_dir, "model_configs/7B_internlm2.py") + + tokenizer_path = os.path.join(base_model_dir, "InternLM_CI_assets/v13.model") + if os.path.exists(config) and os.path.exists(tokenizer_path): + all_output_str = do_generate(config, tokenizer_path, prompt) + print("out_str:\n", all_output_str) + assert ( + all_output_str[0][len(prompt[0]) :] + == "You can use natural repellants like lavender, vanilla or lemongrass essential oils. \ +Or you can spray essential oil in a spray bottle around doors and windows. Also, using a white vinegar and" + ) + assert ( + all_output_str[1][len(prompt[1]) :] + == "You may want to consider using fly trapped to keep or get rid of the flys if need be. \ +Also wearing indoor protective clothing may be advised as well since they can be dangerous" + ) + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "-v", "test_trainer_generate.py"]) diff --git a/tests/test_model/test_feed_forward.py b/tests/test_model/test_feed_forward.py index e4aab9ec..311f30d7 100644 --- a/tests/test_model/test_feed_forward.py +++ b/tests/test_model/test_feed_forward.py @@ -1,7 +1,7 @@ import pytest import torch -from internlm.model.modules.mlp import BaseFeedForward +from internlm.model.modules.mlp import new_feed_forward, split_fused_mlp_weight from internlm.utils.common import get_current_device SEQ_LEN = 64 @@ -9,20 +9,6 @@ MLP_RATIO = 8 / 3 -class InternLMLinear(torch.nn.Linear): - def __init__( - self, - in_features: int, - out_features: int, - *args, # pylint: disable=W0613 - bias: bool = True, - device=None, - dtype=None, - **kwargs, # pylint: disable=W0613 - ) -> None: - super().__init__(in_features, out_features, bias, device, dtype) - - mlp_args = { "in_features": HIDDEN_SIZE, "hidden_features": int(HIDDEN_SIZE * MLP_RATIO), @@ -30,8 +16,6 @@ def __init__( "bias": False, "device": get_current_device(), "dtype": torch.bfloat16, - "column_cls": InternLMLinear, - "row_cls": InternLMLinear, } @@ -43,13 +27,13 @@ def check_param(a1, a2, b1, b2): def init_mlp(): - mlp_no_fused = BaseFeedForward(**mlp_args) - mlp_fused = BaseFeedForward(mlp_layer_fusion=True, **mlp_args) + mlp_no_fused = new_feed_forward(**mlp_args) + mlp_fused = new_feed_forward(mlp_layer_fusion=True, **mlp_args) for _, param in mlp_fused.named_parameters(): torch.nn.init.normal_(param.data, std=0.02) - w1, w3 = BaseFeedForward.split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight) + w1, w3 = split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight) mlp_no_fused.w1.weight.data = w1.data mlp_no_fused.w3.weight.data = w3.data mlp_no_fused.w2.weight.data = mlp_fused.w2.weight.data @@ -99,7 +83,7 @@ def test_mlp_layer_fusion_loss(): l2.backward() assert torch.allclose(mlp_no_fused.w2.weight.grad, mlp_fused.w2.weight.grad, rtol=1e-4, atol=1e-5) - w1_g, w3_g = BaseFeedForward.split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight.grad) + w1_g, w3_g = split_fused_mlp_weight(mlp_fused.fused_w1_w3.weight.grad) assert torch.allclose(mlp_no_fused.w1.weight.grad, w1_g, rtol=1e-4, atol=1e-5) assert torch.allclose(mlp_no_fused.w3.weight.grad, w3_g, rtol=1e-4, atol=1e-5) diff --git a/tests/test_model/test_fused_precision/test_fused_precision.py b/tests/test_model/test_fused_precision/test_fused_precision.py index 54959ecb..d0b79aae 100644 --- a/tests/test_model/test_fused_precision/test_fused_precision.py +++ b/tests/test_model/test_fused_precision/test_fused_precision.py @@ -6,7 +6,8 @@ from torch import nn from internlm.core.naive_amp import NaiveAMPModel, set_fp32_attr_to_module -from internlm.model.modeling_internlm import PackedFlashBaseLayer1D +from internlm.model.modeling_internlm import InternLM1Decoder +from internlm.train.pipeline import initialize_parallel_communicator from internlm.train.utils import create_param_groups from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port @@ -33,7 +34,7 @@ def check_fused_precision(args): # fix seed seed_all(1024) # define model - model = PackedFlashBaseLayer1D( + model = InternLM1Decoder( hidden_size=16, # 768 num_attention_heads=2, # 12 mlp_ratio=2, @@ -58,6 +59,7 @@ def check_fused_precision(args): dtype=torch.half, sync_buffer=False, ) + _ = initialize_parallel_communicator(model) model.model.norm1.register_forward_pre_hook(partial(_pre_forward_hook_for_check)) model.model.norm1.register_forward_hook(partial(_post_forward_hook_for_check)) diff --git a/tests/test_model/test_model_internlm.py b/tests/test_model/test_model_internlm.py index 4ed8f535..c33f188c 100644 --- a/tests/test_model/test_model_internlm.py +++ b/tests/test_model/test_model_internlm.py @@ -11,9 +11,19 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config from internlm.core.context.parallel_context import global_context as gpc -from internlm.model.modeling_internlm import PackedFlashBaseLayer1D -from internlm.model.ops.linear import RewardModelLinear, ScaleColumnParallelLinear -from internlm.model.utils import gather_forward_split_backward +from internlm.core.parallel.comm.tensor import ( + HeadTensorParallelCommunicator, + LinearRole, + TensorParallelCommunicator, +) +from internlm.core.parallel.comm.utils import gather_forward_split_backward +from internlm.model.modeling_internlm import InternLM1Decoder +from internlm.model.modules.linear import ( + ColumnParallelLinear, + RowParallelLinear, + ScaleColumnParallelLinear, + new_linear, +) from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port @@ -101,10 +111,18 @@ def check_block(args): # fix seed seed_all(1024) + ColumnParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.COLUMN) + ) + + RowParallelLinear.register_cls_communicator( + TensorParallelCommunicator(process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW) + ) + # define block blocks = nn.ModuleList( [ - PackedFlashBaseLayer1D( + InternLM1Decoder( hidden_size=4, # 768 num_attention_heads=2, # 12 mlp_ratio=2, @@ -215,9 +233,12 @@ def check_head(args): # fix seed seed_all(1024) + _retain_out_sharded = gpc.config.model.get("parallel_output", True) + _head_comminucator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) + ScaleColumnParallelLinear.register_cls_communicator(_head_comminucator) + # load standard if is_reward: - head_cls = RewardModelLinear standard_result = torch.tensor([[3.5938], [1.0703], [3.6250], [3.6250]], dtype=torch.bfloat16).to(device) standard_grad = torch.tensor( [ @@ -229,7 +250,6 @@ def check_head(args): dtype=torch.bfloat16, ).to(device) else: - head_cls = ScaleColumnParallelLinear standard_result = torch.tensor( [ [3.5938, -2.2188, 2.0312, 3.5625], @@ -250,13 +270,14 @@ def check_head(args): ).to(device) # define head - head = head_cls( + head = new_linear( + name="head", in_features=hidden_size, out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, - process_group=gpc.get_group(ParallelMode.TENSOR), bias=False, device=device, dtype=torch.bfloat16, + is_reward=is_reward, weight_scale=embed_grad_scale, ) diff --git a/tests/test_model/test_norm.py b/tests/test_model/test_norm.py index 0f5a3a4c..83861b36 100644 --- a/tests/test_model/test_norm.py +++ b/tests/test_model/test_norm.py @@ -3,13 +3,11 @@ import pytest import torch -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm +from internlm.model.modules.norm import new_layer_norm from internlm.utils.common import get_current_device from tests.common_fixture import find_free_port from tests.test_model.test_model_internlm import build_environment, seed_all -RMSNorm = try_import_RMSNorm() - def check_norm(args): # init @@ -24,7 +22,7 @@ def check_norm(args): seed_all(1024) # define norm - norm = RMSNorm(hidden_size, eps=layer_norm_epsilon) + norm = new_layer_norm(norm_type="rmsnorm", normalized_shape=hidden_size, eps=layer_norm_epsilon) norm = norm.to(device) # create input diff --git a/tests/test_model/test_npu_ops.py b/tests/test_model/test_npu_ops/test_flash_attention.py similarity index 98% rename from tests/test_model/test_npu_ops.py rename to tests/test_model/test_npu_ops/test_flash_attention.py index 7d31bc6d..31a8ba61 100644 --- a/tests/test_model/test_npu_ops.py +++ b/tests/test_model/test_npu_ops/test_flash_attention.py @@ -16,9 +16,6 @@ CrossAttention, SelfAttention, ) -from internlm.model.ops.fusion_ops_import_helper import try_import_RMSNorm - -RMSNorm = try_import_RMSNorm() HEAD_NUM = 32 HIDDEN_SZIE = 4096 @@ -88,6 +85,7 @@ def do_cmp_attn( softmax_scale=softmax_scale, attention_dropout=attention_dropout, ).to(dtype) + # TODO: 修复它. npu_flash_attn = AscendFlashSelfAttention( causal=True, softmax_scale=softmax_scale, diff --git a/tests/test_model/test_npu_ops/test_npu_rmsnorm.py b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py new file mode 100644 index 00000000..adeb37c0 --- /dev/null +++ b/tests/test_model/test_npu_ops/test_npu_rmsnorm.py @@ -0,0 +1,44 @@ +import pytest +import torch + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.model.ops.norm import _RMSNorm as RMSNormTorch +from internlm.model.ops.norm import _RMSNormNPU as RMSNormNPU +from internlm.utils.common import get_current_device + +internlm_accelerator = get_accelerator() + + +def check_RMSNormNPU(): + device = get_current_device() + input_data = torch.randn(128).to(torch.float32).to(device) + input_data_2 = input_data.clone().detach() + + rmsnorm_torch = RMSNormTorch(128, eps=1e-5).to(torch.bfloat16).to(device) + output_torch = rmsnorm_torch(input_data) + + rmsnorm_npu = RMSNormNPU(128, eps=1e-5).to(torch.bfloat16).to(device) + output_npu = rmsnorm_npu(input_data_2) + + if torch.equal(output_torch, output_npu): + print("RMSNorm check passed: totaly equal", flush=True) + else: + max_diff, index_max_diff = (output_torch - output_npu).abs().max(dim=0) + max_diff = max_diff.item() + index_max_diff = index_max_diff.item() + rtol = max_diff / abs(output_npu[index_max_diff]) + print( + f"The relative error is {rtol}. Between {output_torch[index_max_diff]} and {output_npu[index_max_diff]}", + flush=True, + ) + assert rtol <= 1e-5, f"RMSNorm check failed: The relative error is {rtol}" + print("RMSNorm check passed: allclose", flush=True) + + +def test_RMSNorm(): + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + check_RMSNormNPU() + + +if __name__ == "__main__": + pytest.main(["-s", "-q", "test_npu_ops.py"]) diff --git a/tests/test_model/test_npu_ops/test_rotary_embed.py b/tests/test_model/test_npu_ops/test_rotary_embed.py new file mode 100644 index 00000000..8fca38ce --- /dev/null +++ b/tests/test_model/test_npu_ops/test_rotary_embed.py @@ -0,0 +1,54 @@ +import pytest +import torch +from torch import nn + +from internlm.accelerator import get_accelerator +from internlm.model.ops.rotary_emb import ( + ApplyRotaryEmb, + rotary_emb_in_rotate_half_style, +) +from internlm.utils.common import get_current_device + +internlm_accelerator = get_accelerator() + + +MICRO_BSZ_LIST = [1, 2] +DTYPE_LIST = [torch.bfloat16, torch.float16] +INTERLEAVED = [True, False] + + +def npu_rope_fwd(B, dtype, interleaved, H=128, N=32, S=4096, rope_base=10000): + device = get_current_device() + # qkv = torch.randn((B, S, 3, N, H), dtype=dtype, device=device) + q = torch.randn((B, S, N, H), dtype=dtype, device=device) + + q = nn.init.normal_(q, mean=0.0, std=1.0) + + inv_freq = 1.0 / (rope_base ** (torch.arange(0, H, 2, device=device, dtype=torch.float32) / H)) + t = torch.arange(S, device=device, dtype=dtype) + freqs = torch.outer(t, inv_freq.to(device=t.device)) + cos, sin = torch.cos(freqs), torch.sin(freqs) + + # Test normal torch. + out1 = ApplyRotaryEmb.apply(q.clone(), cos.clone(), sin.clone(), interleaved, False) + + # Test rotate_half torch. + out2 = rotary_emb_in_rotate_half_style( + x=q.clone(), cos=cos.clone(), sin=sin.clone(), interleaved=interleaved, use_fused_rope=False + ) + + # Test rotate_half torch_npu fused. + out3 = rotary_emb_in_rotate_half_style( + x=q.clone(), cos=cos.clone(), sin=sin.clone(), interleaved=interleaved, use_fused_rope=True + ) + + assert torch.allclose(out1, out2, rtol=1e-4, atol=1e-5) + assert torch.allclose(out2, out3, rtol=1e-4, atol=1e-5) + assert torch.allclose(out1, out3, rtol=1e-4, atol=1e-5) + + +@pytest.mark.parametrize("micro_bsz", MICRO_BSZ_LIST) +@pytest.mark.parametrize("test_dtype", DTYPE_LIST) +@pytest.mark.parametrize("interleaved", INTERLEAVED) +def test_NPU_fa(micro_bsz, test_dtype, interleaved): + npu_rope_fwd(B=micro_bsz, dtype=test_dtype, interleaved=interleaved) diff --git a/tests/test_solver/test_npu_solver.py b/tests/test_solver/test_npu_solver.py new file mode 100644 index 00000000..b3a682b6 --- /dev/null +++ b/tests/test_solver/test_npu_solver.py @@ -0,0 +1,64 @@ +import copy + +import torch +from torch import nn + +from internlm.accelerator import AcceleratorType, get_accelerator +from internlm.solver.optimizer.npu_fused_adamw import AdamW as NPUAdamW +from internlm.utils.common import get_current_device + +internlm_accelerator = get_accelerator() + + +def check_AdamW(): + class MlpModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(128, 256) + self.linear2 = nn.Linear(256, 512) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + device = get_current_device() + dtype = torch.bfloat16 + input_data = torch.rand(16, 128, dtype=dtype).to(device) + torch_model = MlpModel().to(dtype).to(get_current_device()) + npu_model = copy.deepcopy(torch_model) + + adamW_torch = torch.optim.AdamW( + params=torch_model.parameters(), + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + + adamW_npu = NPUAdamW( + params=npu_model.parameters(), + lr=1e-4, + betas=(0.9, 0.95), + eps=1e-8, + ) + + adamW_torch.zero_grad() + adamW_npu.zero_grad() + + output_torch = torch_model(input_data) + output_npu = npu_model(input_data) + + output_torch.mean().backward() + output_npu.mean().backward() + + adamW_torch.step() + adamW_npu.step() + + params_zip = zip(list(torch_model.parameters()), list(npu_model.parameters())) + for torch_param, npu_param in params_zip: + assert torch.allclose(torch_param, npu_param, rtol=1e-5, atol=1e-5) + + +def test_AdamW(): + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU: + check_AdamW() diff --git a/tests/test_solver/test_optimizer.py b/tests/test_solver/test_optimizer.py index 0738ddb3..2c7a93c2 100644 --- a/tests/test_solver/test_optimizer.py +++ b/tests/test_solver/test_optimizer.py @@ -11,8 +11,8 @@ import internlm from internlm.accelerator import get_accelerator -from internlm.core.communication.utils import ParamAsyncBcastHandler from internlm.core.context.parallel_context import Config, ParallelMode +from internlm.core.parallel.comm.zero import ParamAsyncBcastHandler from internlm.solver.optimizer import HybridZeroOptimizer from internlm.utils.common import get_current_device diff --git a/tests/test_training/test_forward_output_no_fa.py b/tests/test_training/test_forward_output_no_fa.py index a1f18201..4c36ad87 100644 --- a/tests/test_training/test_forward_output_no_fa.py +++ b/tests/test_training/test_forward_output_no_fa.py @@ -12,11 +12,16 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize.launch import args_sanity_check from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import initialize_model, initialize_optimizer +from internlm.train import ( + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -165,6 +170,7 @@ def train_check_output(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=False, label_smoothing=gpc.config.loss.label_smoothing) @@ -192,15 +198,15 @@ def train_check_output(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() @@ -228,6 +234,7 @@ def train_check_output(args): logger.info("Outputs are totally equal") else: logger.warning("Outputs are not totally equal") + print(f"tensor1: {tensor1}, tensor2: {tensor2}", flush=True) max_diff, index_max_diff = (tensor1 - tensor2).abs().max(dim=0) max_diff = max_diff.item() index_max_diff = index_max_diff.item() diff --git a/tests/test_training/test_load_ckpt_loss.py b/tests/test_training/test_load_ckpt_loss.py index a09191f9..45cd319c 100644 --- a/tests/test_training/test_load_ckpt_loss.py +++ b/tests/test_training/test_load_ckpt_loss.py @@ -29,6 +29,7 @@ ) from internlm.core.trainer import ( # noqa: E402 #pylint: disable=wrong-import-position TrainState, + Trainer, ) from internlm.data import ( # noqa: E402 #pylint: disable=wrong-import-position build_train_loader_with_data_type, @@ -46,6 +47,7 @@ from internlm.train import ( # noqa: E402 #pylint: disable=wrong-import-position initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import ( # noqa: E402 #pylint: disable=wrong-import-position @@ -67,7 +69,7 @@ zero1=dict(size=-1, fsdp=False), pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, - tensor=1, + tensor=dict(size=1, mode="mtp"), ), data=dict( seq_len=2048, @@ -218,6 +220,7 @@ def train_model(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -263,15 +266,15 @@ def train_model(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) trainer.train() train_iter = iter(train_dl) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index b997734e..fa8147cd 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -6,19 +6,20 @@ import torch.distributed as dist import internlm +from internlm.accelerator import AcceleratorType, get_accelerator from internlm.checkpoint import CheckpointManager from internlm.core.context import Config, ParallelMode from internlm.core.context import global_context as gpc -from internlm.core.trainer import TrainState +from internlm.core.trainer import TrainState, Trainer from internlm.data import build_train_loader_with_data_type from internlm.initialize import initialize_distributed_env from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, load_new_batch, ) from internlm.utils.common import BatchSkipper, get_current_device, launch_time @@ -44,8 +45,8 @@ 4.616517543792725, ] - cur_loss_list = [] +internlm_accelerator = get_accelerator() def train( @@ -57,8 +58,10 @@ def train( interleaved: bool = False, tp_mode: str = "mtp", enable_sp: bool = False, - enable_ckpt: bool = False, + save_ckpt: bool = False, + load_ckpt: bool = False, model_type: str = "INTERNLM", + optimizer_ver: str = "v1", ): # initialize distributed environment config = Config.from_file(CONFIG_FILE_PATH) @@ -68,22 +71,30 @@ def train( config.data.fixed_random_dataset_seqlen = False config.lr_scheduler.total_steps = TOTAL_STEPS config.model_type = model_type + config.ckpt.load_ckpt_folder = None + config.ckpt.load_ckpt_info = None + config.ckpt.auto_resume = False total_steps = config.data.total_steps skip_batches = config.data.skip_batches label_smoothing = config.loss.label_smoothing + if optimizer_ver == "v2": + config.hybrid_zero_optimizer.new_version = True + config.all_gather_size = 512 * 1024 * 1024 + # update ckpt config if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False: config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test") - if enable_ckpt: + if save_ckpt: config.ckpt.enable_save_ckpt = True config.ckpt.checkpoint_every = 10 config.ckpt.save_ckpt_folder = "local:llm_ckpts/" - config.ckpt.load_ckpt_folder = "local:llm_ckpts/" - config.ckpt.load_ckpt_info["content"] = ("all",) config.ckpt.oss_snapshot_freq = 100 + if load_ckpt: + config.ckpt.load_ckpt_info = dict(path="local:llm_ckpts/10", content=("all",), ckpt_type="internevo") + # update parallel config config.parallel.tensor = dict(size=tp_size, mode=tp_mode) config.parallel.pipeline = dict(size=pp_size) @@ -92,7 +103,18 @@ def train( config.parallel.pipeline = dict(size=pp_size, interleaved_overlap=True) config.model.num_chunks = num_chunks - initialize_distributed_env(config=config) + if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [ + AcceleratorType.NPU, + AcceleratorType.DIPU, + ]: + config.data.use_packed_dataset = False + + if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: + launcher = "slurm" + else: + launcher = "torch" + + initialize_distributed_env(config=config, launcher=launcher) assert hasattr(gpc, "config") and gpc.config is not None # check parallel config @@ -133,7 +155,7 @@ def train( model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing) @@ -171,15 +193,15 @@ def train( ) # initialize trainer - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # initialize the batch skipper batch_skipper = BatchSkipper(skip_batches) @@ -235,7 +257,7 @@ def train( ) if gpc.is_rank_for_log(): assert loss is not None and not math.isnan(loss.item()) - global cur_loss_list + global cur_loss_list # pylint: disable=W0602 cur_loss_list.append((loss.item() - moe_loss.item() if moe_loss is not None else loss.item())) timer("fwd-bwd").stop() @@ -292,6 +314,18 @@ def test_training_loss_with_dp4(): check_loss_accuracy() +@pytest.mark.training_4GPU_optimizer_v2 +def test_training_loss_with_dp4_optimizer_v2(): + # model training + train(dp_size=4, optimizer_ver="v2") + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + check_loss_spike() + check_loss_accuracy() + + @pytest.mark.training_8GPU_4DP2TP def test_training_loss_with_dp4_tp2(): # model training @@ -316,6 +350,18 @@ def test_training_loss_with_dp4_tp2_sp(): check_loss_accuracy() +@pytest.mark.training_8GPU_4DP2TPSP_optimizer_v2 +def test_training_loss_with_dp4_tp2_sp_optimizer_v2(): + # model training + train(dp_size=4, tp_size=2, tp_mode="fsp", enable_sp=True, optimizer_ver="v2") + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + check_loss_spike() + check_loss_accuracy() + + @pytest.mark.training_8GPU_4DP2PP def test_training_loss_with_dp4_pp2(): # model training @@ -328,6 +374,18 @@ def test_training_loss_with_dp4_pp2(): check_loss_accuracy() +@pytest.mark.training_8GPU_4DP2PP_optimizer_v2 +def test_training_loss_with_dp4_pp2_optimizer_v2(): + # model training + train(dp_size=4, pp_size=2, optimizer_ver="v2") + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + check_loss_spike() + check_loss_accuracy() + + @pytest.mark.training_8GPU_4DP2PP_InterleavedOverlap def test_training_loss_with_dp4_pp2_interleaved_overlap(): # model training @@ -363,6 +421,18 @@ def test_training_loss_with_dp4_tp2_pp2_msp(): check_loss_accuracy() +@pytest.mark.training_16GPU_4DP2TP2PP_MSP_optimizer_v2 +def test_training_loss_with_dp4_tp2_pp2_msp_optimizer_v2(): + # model training + train(dp_size=4, tp_size=2, pp_size=2, tp_mode="msp", optimizer_ver="v2") + + # print loss value + print(f"cur_loss_list: {cur_loss_list}", flush=True) + + check_loss_spike() + check_loss_accuracy() + + @pytest.mark.training_16GPU_4DP2TP2PP_FSP def test_training_loss_with_dp4_tp2_pp2_fsp(): # model training @@ -409,7 +479,7 @@ def test_training_with_isp_save_ckpt(): CONFIG_FILE_PATH = "./configs/7B_isp_sft.py" # model training save ckpt - train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True) + train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, save_ckpt=True) @pytest.mark.training_8GPU_ISP_LOAD_CKPT @@ -422,7 +492,7 @@ def test_training_with_isp_load_ckpt(): TOTAL_STEPS = 20 # model training load ckpt - train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, enable_ckpt=True) + train(dp_size=4, tp_size=2, wp_size=4, tp_mode="isp", enable_sp=True, load_ckpt=True) @pytest.mark.training_llama2 diff --git a/tests/test_training/test_no_fa_train_temp.py b/tests/test_training/test_no_fa_train_temp.py index 419d08c1..afc1c493 100644 --- a/tests/test_training/test_no_fa_train_temp.py +++ b/tests/test_training/test_no_fa_train_temp.py @@ -6,14 +6,15 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, ) from internlm.utils.logger import get_logger from tests.common_fixture import ( @@ -54,7 +55,7 @@ def train_check(args): model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -70,15 +71,15 @@ def train_check(args): dataset_types=dataset_types, ) - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() diff --git a/tests/test_training/test_norm_weight.py b/tests/test_training/test_norm_weight.py index e9494f03..98b3093d 100644 --- a/tests/test_training/test_norm_weight.py +++ b/tests/test_training/test_norm_weight.py @@ -9,14 +9,15 @@ from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer import Trainer from internlm.data import build_train_loader_with_data_type from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex from internlm.train import ( get_scheduler_hooks, - initialize_isp_communicator, initialize_model, initialize_optimizer, + initialize_parallel_communicator, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -74,7 +75,7 @@ def train_check_norm_weight(args): model = initialize_model() # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) + isp_communicator = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -90,15 +91,15 @@ def train_check_norm_weight(args): dataset_types=dataset_types, ) - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), ) + trainer = Trainer(engine, scheduler) # transfer the train data loader into train data iterator trainer.train() @@ -106,6 +107,8 @@ def train_check_norm_weight(args): train_iter = iter(train_dl) for batch_count in range(total_steps): + if gpc.is_rank_for_log() and batch_count % 100 == 0: + print(f"batch_count: {batch_count}", flush=True) if batch_count % 100 == 0: internlm_accelerator.empty_cache() gc.collect() @@ -180,6 +183,7 @@ def test_check_norm_msp(): pool.join() check_result(result) + print("msp check pass", flush=True) @pytest.mark.check_norm_fsp @@ -195,6 +199,7 @@ def test_check_norm_fsp(): pool.join() check_result(result) + print("fsp check pass", flush=True) @pytest.mark.check_norm_isp @@ -210,3 +215,4 @@ def test_check_norm_isp(): pool.join() check_result(result) + print("isp check pass", flush=True) diff --git a/tests/test_training/test_swap_nb_loss_and_gradnorm.py b/tests/test_training/test_swap_nb_loss_and_gradnorm.py index 48f05f47..92f09ada 100644 --- a/tests/test_training/test_swap_nb_loss_and_gradnorm.py +++ b/tests/test_training/test_swap_nb_loss_and_gradnorm.py @@ -14,6 +14,7 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.trainer import Trainer from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -22,7 +23,11 @@ from internlm.initialize.launch import args_sanity_check from internlm.model.losses import FlashGPTLMLoss from internlm.model.metrics import AccPerplex, SchedulerMetricHook -from internlm.train import initialize_model, initialize_optimizer +from internlm.train import ( + initialize_model, + initialize_optimizer, + initialize_parallel_communicator, +) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -266,6 +271,7 @@ def exam_loss(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) # initialize loss function criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=gpc.config.loss.label_smoothing) @@ -297,15 +303,15 @@ def exam_loss(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) trainer.train() diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index 381a8b7c..4e9ab749 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState # noqa: E402 +from internlm.core.trainer import TrainState, Trainer # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -38,6 +38,7 @@ initialize_llm_profile, initialize_model, initialize_optimizer, + initialize_parallel_communicator, record_current_batch_training_metrics, ) from internlm.utils.common import ( # noqa: E402 @@ -62,7 +63,12 @@ def check_model_weights(model, ckpt_path, total_equal=False): model1_dict = torch.load(ckpt_path, map_location="cuda") model2_dict = model.state_dict() - for key in model2_dict.keys(): + copy_of_ordered_dict = model2_dict.copy() + + for key in copy_of_ordered_dict.keys(): + if "wqkv" in key: + model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key) + key = key.replace("wqkv", "Wqkv") if key not in model1_dict: assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" @@ -83,6 +89,7 @@ def check_model_weights(model, ckpt_path, total_equal=False): def main(args): + very_begining_time = time.time() # init setting skip_batches = gpc.config.data.skip_batches total_steps = gpc.config.data.total_steps @@ -109,6 +116,7 @@ def main(args): # initialize model model = initialize_model() + _ = initialize_parallel_communicator(model) with open(args.config, "r") as f: config_lines = f.readlines() @@ -173,15 +181,15 @@ def main(args): ), ] - trainer, train_dl, _, _ = internlm.initialize_trainer( + engine, scheduler = internlm.initialize_trainer( model=model, optimizer=optimizer, criterion=criterion, - train_dataloader=train_dl, lr_scheduler=lr_scheduler, beta2_scheduler=beta2_scheduler, scheduler_hooks=scheduler_hooks, ) + trainer = Trainer(engine, scheduler) # initialize simple memory profiler if args.profiling: @@ -298,6 +306,7 @@ def main(args): optimizer=optimizer, beta2_scheduler=beta2_scheduler, trainer=trainer, + very_begining_time=very_begining_time, start_time=start_time, loss=loss, moe_loss=moe_loss, diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index 499692e9..023b085c 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -7,9 +7,12 @@ from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config +from internlm.core.naive_amp import NaiveAMPModel +from internlm.model.builder import create_model +from internlm.model.registry import register_model_initializer from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer from internlm.train.utils import create_param_groups -from internlm.utils.storage_manager import SingletonMeta +from internlm.utils.common import SingletonMeta OSS_NAME = os.environ.get("OSS_BUCKET_NAME", None) OSS_IP = os.environ.get("OSS_IP", None) @@ -87,13 +90,8 @@ def init_naive_model(): - # let MODEL_INITIALIZER to work - import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import - import internlm.model.modeling_moe # noqa # pylint: disable=unused-import - from internlm.core.naive_amp import NaiveAMPModel - from internlm.utils.registry import MODEL_INITIALIZER - - model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model)) + register_model_initializer() + model = create_model(model_type=gpc.config.model_type, **(init_config.model)) model = NaiveAMPModel( model=model, output_to_fp32=False, diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 65325d32..5fe8b3c4 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -13,7 +13,8 @@ from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer -from internlm.utils.storage_manager import SingletonMeta, wait_async_upload_finish +from internlm.utils.common import SingletonMeta +from internlm.utils.storage_manager import wait_async_upload_finish from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import ASYNC_TMP_FOLDER, BOTO_SAVE_PATH, diff --git a/tools/README.md b/tools/README.md index 2b47b1f4..a24040ca 100644 --- a/tools/README.md +++ b/tools/README.md @@ -5,7 +5,7 @@ ├── interface.py # 生成用的接口 ├── internlm_sft_on_moss.py # 在 moss 数据集上进行 SFT 训练的样例 ├── intern_moss_example.py # 在 moss 数据集上进行训练的样例 -├── load_internlm_model.py # 加载 InternLM 原生格式并进行推理的工具 +├── load_internlm2_model.py # 加载 InternLM 原生格式并进行推理的工具 ├── openai_api.py # 使用 OpenAI 接口实现的流式部署 ├── pal_inference.py # PAL 范式推理的工具 ├── README_EN.md @@ -141,3 +141,58 @@ if __name__ == "__main__": if hasattr(chunk.choices[0].delta, "content"): print(chunk.choices[0].delta.content, end="", flush=True) ``` + +# load_internlm2_model.py + +加载`InternEvo`框架训练的模型权重并进行推理 + +```bash +torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 --ckpt_dir=[where the internlm2 model weights are stored] --tokenizer_path=tools/tokenizer_internlm2.model tools/load_internlm2_model.py +``` + +LLaMA 7B推理的例子: + +```python + model = initialize_internlm_model( + model_type="LLAMA2", + ckpt_dir=args.ckpt_dir, + model_config=dict( + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", + embed_split_hidden=True, + num_layers=32, + hidden_size=4096, + vocab_size=32000, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=32, + mlp_ratio=2.675, + use_flash_attn=True, + norm_type="rmsnorm", + apply_post_layer_norm=False, + no_bias=True, + layer_norm_epsilon=1e-5, + ), + del_model_prefix=True, + ) + + from sentencepiece import SentencePieceProcessor + + prompt = """<|User|>:{query}\n<|Bot|>:""" + prompt = prompt.replace("{query}", "hello") + # LLaMA tokenizer转换成SentencePieceProcessor 或 此处加载Huggingface Tokenizer,则需额外将generate中调用的decode等方法修改成HF风格 + tokenizer = SentencePieceProcessor(args.tokenizer_path) + generation_config = GenerationConfig() + output_generator = internlm_interactive_generation( + model=model, + tokenizer=tokenizer, + prompt=prompt, + generation_config=generation_config, + additional_eos_token_list=[tokenizer.eos_id()], + ) + + for text in output_generator: + print(text) +``` diff --git a/tools/README_EN.md b/tools/README_EN.md index 63aba410..fe93560d 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -6,7 +6,7 @@ This directory provide some tools for model training with the following file str ├── interface.py # interface for generation ├── internlm_sft_on_moss.py # example for SFT training on moss dataset ├── intern_moss_example.py # example for training on moss dataset -├── load_internlm_model.py # tools for loading InternLM checkpoints and generating +├── load_internlm2_model.py # tools for loading InternLM checkpoints and generating ├── openai_api.py # stream deployment with OpenAI APIs ├── pal_inference.py # tools for PAL reasoning ├── README_EN.md diff --git a/tools/load_internlm_model.py b/tools/load_internlm2_model.py similarity index 85% rename from tools/load_internlm_model.py rename to tools/load_internlm2_model.py index 98e6ad53..6f1561b0 100644 --- a/tools/load_internlm_model.py +++ b/tools/load_internlm2_model.py @@ -1,3 +1,4 @@ +import argparse import inspect import logging import os @@ -9,9 +10,8 @@ from internlm.apis.inference import SequenceGenerator from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc -from internlm.initialize.launch import launch_from_torch -from internlm.train import initialize_model -from internlm.utils.registry import MODEL_INITIALIZER +from internlm.initialize.launch import initialize_distributed_env +from internlm.train import initialize_model, initialize_parallel_communicator from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load from tools.interface import GenerationConfig @@ -102,6 +102,10 @@ def match_fn_signature(func: Callable, args_dict: Dict) -> None: logger.warning(f"These args:{args_set} are popped for func:{func.__name__}.") +def use_torchrun_starter(): + return os.getenv("RANK") is not None + + def get_tp_rank() -> int: """Get the tensor parallel rank. This script uses torchrun to initialize the environment, so RANK in the environment variable is the tensor @@ -119,7 +123,7 @@ def get_tp_world_size() -> int: Returns: int: The tensor parallel world size to which the current process belongs. """ - return int(os.environ.get("WORLD_SIZE", 0)) + return int(os.environ.get("WORLD_SIZE", 1)) def initialize_internlm_model( @@ -172,27 +176,33 @@ def initialize_internlm_model( model_config["dtype"] = param_dtype model_config["parallel_output"] = False - match_fn_signature(MODEL_INITIALIZER.get_module(model_type), model_config) + # FIXME: fix it. if gpc.is_rank_for_log(): logger.info(f"model_config: {model_config}.") - launch_from_torch( + + initialize_distributed_env( config=dict( model_type=model_type, model=model_config, parallel=dict( zero1=dict(size=1, fsdp=False), pipeline=dict(size=1, interleaved_overlap=True), - tensor=get_tp_world_size(), + tensor=dict(size=get_tp_world_size(), mode="mtp"), sequence_parallel=0, ), ), + launcher="torch" if use_torchrun_starter() else "slurm", seed=seed, + master_port=23574, + args_check=False, ) - model = initialize_model() # Directly get the origin model without NativeAMP wrapper. + model = initialize_model() + _ = initialize_parallel_communicator(model) model = model.model state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix) + load_info = model.load_state_dict(state_dict, strict=False) logger.info(f"Rank:{gpc.get_local_rank(ParallelMode.TENSOR)}. Load info: {load_info}.") @@ -223,11 +233,11 @@ def internlm_interactive_generation( sequenece_generator = SequenceGenerator( decoder=model, eos_token_id=tokenizer.eos_id(), - pad_token_id=tokenizer.eos_id(), + pad_token_id=tokenizer.bos_id(), bos_token_id=tokenizer.bos_id(), additional_eos_token_list=additional_eos_token_list, ) - additional_eos_token_list = torch.LongTensor(additional_eos_token_list) + additional_eos_token_list = torch.LongTensor(additional_eos_token_list) if additional_eos_token_list else None input_ids = [tokenizer.bos_id()] + tokenizer.encode(prompt) input_ids = torch.LongTensor([input_ids]).to(get_model_device(model)) output_generator = sequenece_generator.streaming_generate( @@ -249,32 +259,48 @@ def internlm_interactive_generation( yield cur_output +def get_default_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt_dir", type=str, help="path to the ckpt file", required=True) + parser.add_argument( + "--tokenizer_path", type=str, default="tools/tokenizer_internlm2.model", help="path to the tokenizer file" + ) + + return parser + + if __name__ == "__main__": + parser = get_default_parser() + args = parser.parse_args() + """ Here is a simple example to generate with origin internlm model architecture. Use the following command to run: - >>> torchrun --master_port 12331 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm_model.py + >>> torchrun --master_port 12321 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm2_model.py """ model = initialize_internlm_model( - model_type="INTERNLM", - ckpt_dir="[Please replace this with the directory where the internlm model weights are stored]", + model_type="INTERNLM2_PUBLIC", + ckpt_dir=args.ckpt_dir, model_config=dict( - checkpoint=False, - num_attention_heads=32, + num_chunks=1, + checkpoint=0.2, + dtype="torch.bfloat16", embed_split_hidden=True, - vocab_size=103168, - embed_grad_scale=1, - parallel_output=False, - hidden_size=4096, num_layers=32, - mlp_ratio=8 / 3, - apply_post_layer_norm=False, - dtype="torch.bfloat16", + hidden_size=4096, + vocab_size=92544, + embed_grad_scale=1, + parallel_output=True, + num_attention_heads=32, + num_kv_attention_heads=8, + mlp_ratio=3.5, + use_flash_attn=True, norm_type="rmsnorm", + qk_interleaved=True, + apply_post_layer_norm=False, + no_bias=True, layer_norm_epsilon=1e-5, - use_flash_attn=True, - num_chunks=1, - use_dynamic_ntk_rope=True, + rope_base=1000000, ), del_model_prefix=True, ) @@ -283,15 +309,14 @@ def internlm_interactive_generation( prompt = """<|User|>:{query}\n<|Bot|>:""" prompt = prompt.replace("{query}", "hello") - tokenizer = SentencePieceProcessor("tools/tokenizer_internlm.model") # pylint: disable=E1121 - + tokenizer = SentencePieceProcessor(args.tokenizer_path) # pylint: disable=E1121 generation_config = GenerationConfig() output_generator = internlm_interactive_generation( model=model, tokenizer=tokenizer, prompt=prompt, generation_config=generation_config, - additional_eos_token_list=[103028], + additional_eos_token_list=[tokenizer.eos_id()], ) for text in output_generator: diff --git a/train.py b/train.py index 985ac5de..08534420 100644 --- a/train.py +++ b/train.py @@ -1,313 +1,45 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -import logging -import socket -import time -import traceback -from functools import partial - -import torch.distributed as dist - -import internlm -from internlm.checkpoint import CheckpointManager -from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc +from internlm.core.trainer_builder import TrainerBuilder from internlm.data import ( build_train_loader_with_data_type, build_valid_loader_with_data_type, ) -from internlm.data.train_state import get_train_state -from internlm.eval.evaluation import evaluate_on_val_dls from internlm.initialize import initialize_distributed_env -from internlm.model.losses import FlashGPTLMLoss -from internlm.model.metrics import AccPerplex -from internlm.monitor import initialize_monitor_manager, send_alert_message -from internlm.monitor.monitor import monitor_manager as mm -from internlm.train import ( - get_scheduler_hooks, - initialize_isp_communicator, - initialize_llm_profile, - initialize_model, - initialize_optimizer, - load_new_batch, - record_current_batch_training_metrics, -) -from internlm.utils.common import ( - BatchSkipper, - enable_pytorch_expandable_segments, - get_current_device, - get_megatron_flops, - launch_time, - parse_args, -) -from internlm.utils.gputest import empty_cache_and_diag -from internlm.utils.logger import get_logger -from internlm.utils.megatron_timers import megatron_timer as timer -from internlm.utils.parallel import get_parallel_log_file_name -from internlm.utils.simple_memory_profiler import SimpleMemoryProfiler -from internlm.utils.writer import Writer - -# global llm logger -logger = logging.getLogger(__file__) +from internlm.monitor import internevo_monitor +from internlm.train import initialize_model +from internlm.utils.common import parse_args +@internevo_monitor(feishu_alert=True, clean_run=True) def main(args): - enable_pytorch_expandable_segments() - - # init setting - skip_batches = gpc.config.data.skip_batches - total_steps = gpc.config.data.total_steps - valid_every = gpc.config.data.valid_every - label_smoothing = gpc.config.loss.label_smoothing - - get_tflops_func = partial( - get_megatron_flops, - checkpoint=gpc.config.model.checkpoint, - seq_len=gpc.config.data["seq_len"], - hidden_size=gpc.config.model.hidden_size, - num_layers=gpc.config.model.num_layers, - vocab_size=gpc.config.model.vocab_size, - global_batch_size=gpc.config.data.micro_bsz * gpc.config.data.micro_num * gpc.get_world_size(ParallelMode.DATA), - global_world_size=gpc.get_world_size(ParallelMode.GLOBAL), - mlp_ratio=gpc.config.model["mlp_ratio"], - ) - - # get and broadcast current time - current_time = launch_time() - objs = [current_time] - dist.broadcast_object_list(objs, src=0) - current_time = objs[0].replace(":", ".") - global logger - logger = get_logger( - __file__, launch_time=current_time, job_name=gpc.config.JOB_NAME, file_name=get_parallel_log_file_name() - ) - # initialize model model = initialize_model() - # initialize isp communicator - isp_communicator = initialize_isp_communicator(model) - - with open(args.config, "r") as f: - config_lines = f.readlines() - - # initialize loss function - criterion = FlashGPTLMLoss(parallel_output=gpc.config.model.parallel_output, label_smoothing=label_smoothing) - - # initialize the train and validation data loader + # initialize train dataloader train_dl, dataset_types = build_train_loader_with_data_type() - val_dls = build_valid_loader_with_data_type() - - # initialize and resume train state - train_state = get_train_state(train_dl) - - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model, isp_communicator) - - ckpt_manager = CheckpointManager( - ckpt_config=gpc.config.ckpt, - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - train_dl=train_dl, - model_config=gpc.config.model, - model_config_file="".join(config_lines), - feishu_address=gpc.config.monitor.alert.feishu_alert_address, - ) - - # Loading other persistent training states. - ckpt_manager.try_resume_training(train_state, current_time) - - # initialize customed llm writer - writer = Writer( - job_name=gpc.config.JOB_NAME, - launch_time=current_time, - file_name=get_parallel_log_file_name(), - tensorboard_folder=gpc.config.tensorboard_folder, - resume_tb_folder=train_state.resume_tb_folder, # resume from ckpt. - step_count=train_state.step_count, # resume from ckpt. - config=config_lines, - logger=logger, - enable_tb=gpc.config.enable_tb, - queue_max_length=gpc.config.tensorboard.queue_max_length, - total_steps=total_steps, - ) - - # initialize metric for calculating accuracy and perplexity - metric = AccPerplex( - device=get_current_device(), - tp_pg=gpc.get_group(ParallelMode.TENSOR), - dp_pg=gpc.get_group(ParallelMode.DATA), - dataset_types=dataset_types, - ) - - # initialize trainer - trainer, train_dl, _, _ = internlm.initialize_trainer( - model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dl, - lr_scheduler=lr_scheduler, - beta2_scheduler=beta2_scheduler, - scheduler_hooks=get_scheduler_hooks(metric, optimizer, isp_communicator), - ) - # initialize simple memory profiler - if args.profiling: - memory_profiler = SimpleMemoryProfiler( - model, - optimizer.optim, - log_folder=f"RUN/{gpc.config.JOB_NAME}/{current_time}/memory_trace/rank{gpc.get_global_rank()}_" - + f"dp{gpc.get_local_rank(ParallelMode.DATA)}_" - + f"wp{gpc.get_local_rank(ParallelMode.WEIGHT)}_" - + f"tp{gpc.get_local_rank(ParallelMode.TENSOR)}", - ) - else: - memory_profiler = None - - # initialize the batch skipper - batch_skipper = BatchSkipper(skip_batches) - - trainer.train() - - # transfer the train data loader into train data iterator - train_iter = iter(train_dl) - - with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: - # start iterating the train data and begin training - for batch_count in range(train_state.batch_count, total_steps): - empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) - # internlm_accelerator.memory._record_memory_history() - start_time = time.time() - timer("one-batch").start() - - # load batch data - batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) - - # record the consumed samples in training - train_state.batch_count = batch_count - train_state.num_consumed_samples_in_epoch += len(batch[1]) - if batch_skipper(batch_count): # skip this batch - if gpc.is_rank_for_log(): - logger.info(f"Skip batch count:`{batch_count}`...") - timer("one-batch").stop() - continue - - # zero the grads of parameters - trainer.zero_grad() - # process data - if batch[0].get("type_ids", None) is not None: - metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) - # if batch[0].get("cu_seqlens", None) is not None: - # metric.set_cu_seqlens(cu_seqlens=batch[0].pop("cu_seqlens", None)) - - # do forward and backward - timer("fwd-bwd").start() - - moe_loss = None - if hasattr(gpc.config.model, "num_experts"): - _, _, loss, moe_loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - else: - _, _, loss = trainer.execute_schedule( - batch, - forward_only=False, - return_loss=True, - return_output_label=False, - ) - timer("fwd-bwd").stop() - - if isp_communicator and isp_communicator.enable_memory_pool: - isp_communicator.memory_pool.reset_lazy_pools() - - # update parameters, and returns (success_update, grad_norm) - trainer_result = trainer.step() - assert trainer_result is not None - - success_update, grad_norm_groups = trainer_result - if success_update: # update parameters successfully - train_state.step_count += 1 - else: - train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. - if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case - logger.warning(f"Warning: skip parameter update at step {batch_count}.") - send_alert_message( - address=gpc.config.monitor.alert.feishu_alert_address, - message=f"Warning: skip parameter update at step {batch_count}.", - ) - - # calculate and record the training metrics, eg. loss, accuracy and so on. - record_current_batch_training_metrics( - get_tflops_func=get_tflops_func, - logger=logger, - writer=writer, - success_update=success_update, - batch_count=batch_count, - batch=batch, - train_state=train_state, - optimizer=optimizer, - beta2_scheduler=beta2_scheduler, - trainer=trainer, - start_time=start_time, - loss=loss, - moe_loss=moe_loss, - grad_norm=grad_norm_groups, - metric=metric, - ) - - timer("one-batch").stop() - - # evaluate on validation data loaders - if valid_every > 0 and train_state.step_count % valid_every == 0: - evaluate_on_val_dls( - trainer=trainer, - val_dls=val_dls, - writer=writer, - logger=logger, - step_count=train_state.step_count, - ) - - # checkpoint the training states in specific steps, which is determined by the args "checkpoint_every" - # # save batch sampler that tracks the true consumed samples - now_break = ckpt_manager.try_save_checkpoint(train_state) - if now_break: - break - - if memory_profiler is not None: - memory_profiler.step() + # initialize validation dataloader + val_dls = build_valid_loader_with_data_type() - if batch_count % 2 == 0: - prof.step() + # initialize kwargs + kwargs = vars(args) | {"dataset_types": dataset_types} - # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # build trainer + trainer = TrainerBuilder(model, train_dl, val_dls, **kwargs) - ckpt_manager.wait_async_upload_finish() + # training + trainer.fit() if __name__ == "__main__": args = parse_args() - hostname = socket.gethostname() - # initialize distributed environment + # Initialize distributed environment initialize_distributed_env(config=args.config, launcher=args.launcher, master_port=args.port, seed=args.seed) assert hasattr(gpc, "config") and gpc.config is not None - # initialize monitor manager context - with initialize_monitor_manager( - job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address - ): - try: - main(args) - except Exception: - logger.error( - f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}", - ) - mm.monitor_exception( - alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc() - ) - - # internlm_accelerator.memory._dump_snapshot(f"my_snapshot_{gpc.get_global_rank()}.pickle") + # Run the main function with parsed arguments + main(args) diff --git a/transformers/convert2hf_internlm_moe.py b/transformers/convert2hf_internlm_moe.py new file mode 100644 index 00000000..1800e625 --- /dev/null +++ b/transformers/convert2hf_internlm_moe.py @@ -0,0 +1,410 @@ +# Copyright (c) InternLM. All rights reserved. +""" +python transformers/convert2hf_internlm.py --src /path/to/src --tgt /path/to/tgt \ + --max_shard 2G --maxx_pos 8192 \ + --tokenizer /path/to/tokenizer.model \ +""" +import argparse +import gc +import json +import os +import re +import time + +import torch +from datasets import Dataset +from internlm_model import InternLMTokenizer +from internlm_moe_model import InternLMMoEConfig, InternLMMoEForCausalLM +from tqdm import tqdm + +from transformers import Trainer, TrainingArguments +from transformers.modeling_utils import no_init_weights + +embedding_key_list = ["embedding.word_embeddings.weight", "embedding.weight", "tok_embeddings.weight", None] + + +def _find_max_tp_pp(names): + ckpt_names = [] + for name in names: + if name.startswith("model_t") and not name.endswith("md5"): + # _t: avoid conflictint with model_config.pt + ckpt_names.append(name) + + max_tp, max_pp = -1, -1 + for ckpt in ckpt_names: + _, tp, pp = os.path.splitext(ckpt)[0].split("_") + max_tp = max(max_tp, int(tp[2:]) + 1) + max_pp = max(max_pp, int(pp[2:]) + 1) + + return max_tp, max_pp + + +def load_source(src): + """ + load model_config.pt and model_tp{x}_pp{x}.pt from ``src`` + + :return: + - model_config: dict + - states: 2-d array. states[i][j] stands for state_dict of tp_i pp_j + """ + + # config + print("Config loading", flush=True) + config_file = os.path.join(src, "model_config.pt") + assert os.path.isfile(config_file), f"model_config.pt is not found in :{os.listdir(src)}" + model_config = torch.load(config_file) + print(model_config) + print("Config loaded.", flush=True) + + # checkpoint + # find tp pp + assert os.path.isdir(src), "not a folder." + ckpt_names = os.listdir(src) + max_tp, max_pp = _find_max_tp_pp(ckpt_names) + num_moe_layer = model_config["num_layers"] + num_experts = model_config["num_experts"] + + # 2-d array tp_rank, pp_rank + print("Source Checkpoint Loading", flush=True) + states = [[None for _ in range(max_pp)] for __ in range(max_tp)] + moe_states = [[{} for _ in range(max_pp)] for __ in range(max_tp)] + for tp in tqdm(range(max_tp)): + for pp in tqdm(range(max_pp)): + ckpt_name = os.path.join(src, f"model_tp{tp}_pp{pp}.pt") + states[tp][pp] = torch.load(ckpt_name, map_location="cpu") + for lay_id in tqdm(range(num_moe_layer)): + for expert_id in range(num_experts): + moe_ckpt_name = os.path.join(src, f"model_moe_layer{lay_id}_expert{expert_id}_tp{tp}.pt") + moe_states[tp][pp].update(torch.load(moe_ckpt_name, map_location="cpu")) + print("Source Checkpoint Loaded", flush=True) + return model_config, states, moe_states + + +def merge(states): + """ + Merge state dicts of pipeline format and shift some layers. + + :return: + - config: InternLMMoEConfig + - states: merged state dict + """ + # merge pp + merged_states = [] + print("Pipeline Merging", flush=True) + for tp_state in tqdm(states): + layer_shift = 0 + shifted_state = {} + # shift key + for tp_pp_state in tp_state: + _layer_shift = 0 + keys = list(tp_pp_state.keys()) + for key in keys: + if key.endswith(".inv_freq"): + continue + match = re.search(r"\.\d+\.", key) + name = key + if match is not None: + # layers + s, e = match.span() + layer_idx = int(key[s + 1 : e - 1]) + layer_shift + _layer_shift = max(_layer_shift, int(key[s + 1 : e - 1])) + name = key[:s] + f".{layer_idx}." + key[e:] + if name.startswith("model."): + name = name[6:] + shifted_state[name] = tp_pp_state[key] + layer_shift += _layer_shift + 1 + + merged_states.append(shifted_state) + + print("Pipeline Merged", flush=True) + + return merged_states + + +def convert(src, tgt, tokenizer, dtype, max_shard_size, max_pos, topk, rope_scaling): + """ + Convert state_dict to hf format. + + 1. Load and merge state dict + 2. Convert to huggingface + 3. Load tokneizer and save it with ``tokenizer.save_pretrained`` + 4. Load state dict to the model + 5. Call ``model.save_pretrained`` to save checkpoints. + """ + # load states + model_config, src_states, src_moe_states = load_source(src) + states = merge(src_states) + moe_states = merge(src_moe_states) + del src_states + del src_moe_states + + num_shards = len(states) + print("Converting to huggingface format...", flush=True) + + n_heads = model_config["num_attention_heads"] + dim = model_config["hidden_size"] + # n_heads_per_shard = n_heads // num_shards + # dims_per_head = dim // n_heads + intermediate_size = None + + print("Start converting...", flush=True) + state_dict = {} + for layer_i in tqdm(range(model_config["num_layers"])): + wqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.weight").reshape(3, n_heads // num_shards, -1, dim) + for tp in range(num_shards) + ] + bqkvs = [ + states[tp].pop(f"blocks.{layer_i}.mixer.Wqkv.bias").reshape(3, n_heads // num_shards, -1) + for tp in range(num_shards) + ] + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][f"blocks.{layer_i}.norm1.weight"].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"blocks.{layer_i}.norm2.weight" + ].clone(), + } + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [wqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( + [bqkvs[i][0] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [wqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( + [bqkvs[i][1] for i in range(num_shards)], + dim=0, + ).reshape(-1) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [wqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(dim, dim) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( + [bqkvs[i][2] for i in range(num_shards)], + dim=0, + ).reshape(-1) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"blocks.{layer_i}.mixer.out_proj.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.bias"] = states[0][f"blocks.{layer_i}.mixer.out_proj.bias"] + + state_dict[f"model.layers.{layer_i}.mlp.gate.weight"] = states[0][ + f"blocks.{layer_i}.mlp.moe_layer.gate.wg.weight" + ].clone() + + if model_config["moe_use_residual"]: + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.gate_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w1.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.down_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w3.weight"] for i in range(num_shards)], + dim=1, + ) + state_dict[f"model.layers.{layer_i}.mlp.shared_experts.up_proj.weight"] = torch.cat( + [moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.residual_mlp.w2.weight"] for i in range(num_shards)], + dim=0, + ) + + for expert_id in range(model_config["num_experts"]): + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.gate_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w1.weight"] + for i in range(num_shards) + ], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.down_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w3.weight"] + for i in range(num_shards) + ], + dim=1, + ) + state_dict[f"model.layers.{layer_i}.mlp.experts.{expert_id}.up_proj.weight"] = torch.cat( + [ + moe_states[i][f"blocks.{layer_i}.mlp.moe_layer.experts.wrapped_experts.{expert_id}.w2.weight"] + for i in range(num_shards) + ], + dim=0, + ) + + intermediate_size, _ = state_dict[f"model.layers.{0}.mlp.experts.{0}.gate_proj.weight"].shape + + # embedding + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["head.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # initialize model + # tokenizer + tokenizer = InternLMTokenizer(tokenizer) + # config + config = InternLMMoEConfig( + vocab_size=model_config["vocab_size"], + hidden_size=model_config["hidden_size"], + intermediate_size=intermediate_size, + num_attention_heads=model_config["num_attention_heads"], + num_hidden_layers=model_config["num_layers"], + rms_norm_eps=model_config["layer_norm_epsilon"], + bias=True, + rope_theta=model_config.get("rope_base", 10000), + rope_scaling=rope_scaling, + num_experts=model_config.get("num_experts", 1), + num_experts_per_tok=topk, + num_shared_experts=1 if model_config["moe_use_residual"] else 0, + ) + # tokenizer + config.max_position_embeddings = max_pos + # set bos eos pad to avoid improper generation + # since model.generate will create attention_mask + # according to pad_token_id and bos_token_id + config.bos_token_id = tokenizer.bos_token_id + config.eos_token_id = tokenizer.eos_token_id + config.pad_token_id = tokenizer.pad_token_id + + # model + print("Initializing model...", flush=True) + start = time.time() + with no_init_weights(): + model = InternLMMoEForCausalLM._from_config(config, torch_dtype=dtype) + print(f"Initializing model takes {time.time() - start}s", flush=True) + model.load_state_dict(state_dict) + + # 驱动选择 + device = "cuda" if torch.cuda.is_available() else "cpu" + + X = torch.zeros((32, 32), dtype=torch.int64).to(device=device) + labels = [] + for i in range(32): + labels.append((i + 1) % 32) + X[i] = 1 + labels = torch.tensor(labels) + dataset = Dataset.from_dict({"input_ids": X, "labels": X}) + + training_args = TrainingArguments( + output_dir="./results", # output directory 结果输出地址 + num_train_epochs=10, # total # of training epochs 训练总批次 + per_device_train_batch_size=1, # batch size per device during training 训练批大小 + per_device_eval_batch_size=1, # batch size for evaluation 评估批大小 + learning_rate=1e-3, # 学习率 + save_steps=False, # 不保存检查点 + ) + + trainer = Trainer( + model=model, # the instantiated 🤗 Transformers model to be trained 需要训练的模型 + args=training_args, # training arguments, defined above 训练参数 + train_dataset=dataset, # training dataset 训练集 + eval_dataset=dataset, # evaluation dataset 测试集 + ) + + trainer.train() + trainer.evaluate() + + del states + gc.collect() + print(f"Saving model to {tgt}...", flush=True) + tokenizer.save_pretrained(tgt) + model.save_pretrained(tgt, max_shard_size=max_shard_size) + + # fix auto_map in config + with open(os.path.join(tgt, "config.json")) as fp: + config_dict = json.load(fp) + config_dict["auto_map"]["AutoModel"] = "modeling_internlm.InternLMMoEForCausalLM" + with open(os.path.join(tgt, "config.json"), "w") as fp: + json.dump(config_dict, fp, indent=2) + + +def convert_tokenizer(src, tgt): + assert os.path.isfile(src) + tokenizer = InternLMTokenizer(src) + tokenizer.save_pretrained(tgt) + + +def get_rope_scaling(args): + if args.rotary_type == "origin": + return None + elif args.rotary_type == "dynamic": + return {"type": args.rotary_type, "factor": args.scaling_factor} + else: + raise NotImplementedError(f"Unknown rope type {args.rotary_type}") + + +def print_args(args): + print("-------------- Arguments --------------") + print(f"Source Path: {args.src}") + print(f"Target Path: {args.tgt}") + print(f"Dtype: {args.dtype}") + print(f"Max Shard Size: {args.max_shard}") + print(f"Max Position Embedding: {args.max_pos}") + print(f"Tokenizer Path: {args.tokenizer}") + print(f"Rotary Type: {args.rotary_type}") + print(f"Scaling Factor: {args.scaling_factor}") + print("---------------------------------------") + + +def parse_args(): + parser = argparse.ArgumentParser() + # model + parser.add_argument("--src", type=str, default=None, help="Input folder") + parser.add_argument("--tgt", type=str, help="Output folder") + parser.add_argument("--dtype", default="bfloat16", type=str, help="Data type after converting") + parser.add_argument("--max_shard", type=str, default="10GB", help="Max size of every sharded checkpoint.") + parser.add_argument("--max_pos", type=int, default=4096, help="Max position embedding of model.") + # tokenizer + parser.add_argument("--tokenizer", type=str, default=None, help="Tokenizer model.") + # rope + parser.add_argument("--rotary_type", type=str, default="origin", help="Rope type", choices=["origin", "dynamic"]) + parser.add_argument("--scaling_factor", type=float, default=1.0, help="Scaling factor of dynamic rope.") + parser.add_argument("--topk", type=int, default=1, help="top-k experts in MoE.") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + print_args(args) + dtype = getattr(torch, args.dtype) + rope_scaling = get_rope_scaling(args) + + assert args.src is not None, "--src is needed!" + assert args.tokenizer is not None, "--tokenizer is needed!" + assert args.topk is not None, "--topk is needed!" + start = time.time() + convert(args.src, args.tgt, args.tokenizer, dtype, args.max_shard, args.max_pos, args.topk, rope_scaling) + print(f"Converting model takes {time.time() - start}s totally", flush=True) diff --git a/transformers/internlm_moe_model/__init__.py b/transformers/internlm_moe_model/__init__.py new file mode 100644 index 00000000..d1c242ee --- /dev/null +++ b/transformers/internlm_moe_model/__init__.py @@ -0,0 +1,9 @@ +from .configuration_internlm_moe import InternLMMoEConfig +from .modeling_internlm_moe import InternLMMoEForCausalLM +from .tokenization_internlm import InternLMTokenizer + +__all__ = [ + "InternLMMoEConfig", + "InternLMMoEForCausalLM", + "InternLMTokenizer", +] diff --git a/transformers/internlm_moe_model/configuration_internlm_moe.py b/transformers/internlm_moe_model/configuration_internlm_moe.py new file mode 100644 index 00000000..e6a0cc62 --- /dev/null +++ b/transformers/internlm_moe_model/configuration_internlm_moe.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/configuration_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" InternLM model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +# Modified from transformers.model.llama.configuration_llama.LlamaConfig +class InternLMMoEConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate + an InternLM model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the InternLM-7B. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the InternLM model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`InternLMModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + Example: + ```python + >>> from transformers import InternLMModel, InternLMConfig + >>> # Initializing a InternLM internlm-7b style configuration + >>> configuration = InternLMConfig() + >>> # Initializing a model from the internlm-7b style configuration + >>> model = InternLMModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "internlm" + _auto_class = "AutoConfig" + + def __init__( # pylint: disable=W0102 + self, + vocab_size=103168, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + bias=True, + rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102 + attn_implementation="eager", + num_experts=1, + num_experts_per_tok=1, + num_shared_experts=0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.bias = bias + self.rotary = rotary + self.attn_implementation = attn_implementation + self.num_routed_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + if self.attn_implementation is None: + self.attn_implementation = "eager" + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/transformers/internlm_moe_model/modeling_internlm_moe.py b/transformers/internlm_moe_model/modeling_internlm_moe.py new file mode 100644 index 00000000..249a6ca3 --- /dev/null +++ b/transformers/internlm_moe_model/modeling_internlm_moe.py @@ -0,0 +1,1421 @@ +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/modeling_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch InternLM model.""" +import math +import queue +import threading +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) + +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + +from .configuration_internlm_moe import InternLMMoEConfig + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "InternLMMoEConfig" + +flash_attn_func, flash_attn_varlen_func = None, None +pad_input, index_first_axis, unpad_input = None, None, None + + +def _import_flash_attn(): + global flash_attn_func, flash_attn_varlen_func + global pad_input, index_first_axis, unpad_input + try: + from flash_attn import flash_attn_func as _flash_attn_func + from flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis as _index_first_axis + from flash_attn.bert_padding import pad_input as _pad_input + from flash_attn.bert_padding import unpad_input as _unpad_input + + flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func + pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input + except ImportError: + raise ImportError("flash_attn is not installed.") + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.llama.modeling_llama._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _compute_load_balancing_loss(gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2) -> float: + """Calculate the load balancing loss contribution.""" + if gate_logits is None or not isinstance(gate_logits, tuple) or gate_logits[0] is None: + return 0 + moe_losses = [] + for logit in gate_logits: + gates = F.softmax(logit, dim=1) + weight, indices = torch.topk(gates, top_k, dim=1) + num_tokens_per_expert = torch.histc(indices, bins=num_experts, min=0, max=num_experts) + moe_losses.append(torch.dot(num_tokens_per_expert.to(weight.dtype), weight.mean(dim=0))) + + return sum(moe_losses) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->InternLM +class InternLMRMSNorm(nn.Module): + """RMSNorm implemention.""" + + def __init__(self, hidden_size, eps=1e-6): + """ + InternLMRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->InternLM +class InternLMRotaryEmbedding(torch.nn.Module): + """Implement InternLM's rotary embedding. + + Args: + dim (int): Characteristic dimension of each self-attentional head. + max_position_embeddings (int, optional): Model's training length. Defaults to 2048. + base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. + device (Any, optional): Running device. Defaults to None. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + return ( + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->InternLM +class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): + """Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K. + + Args: + dim (int): Characteristic dimension of each self-attentional head. + max_position_embeddings (int, optional): Model's training length. Defaults to 2048. + base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. + device (Any, optional): Running device. Defaults to None. + scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0. + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.dim = dim + self.base = base + self.scaling_factor = scaling_factor + + # Build here to make `torch.jit.trace` work. + self.max_position_embeddings = max_position_embeddings + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def _update_cached(self, x, seq_len=None): + self.max_seq_len_cached = max(seq_len, self.max_position_embeddings) + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + else: + inv_freq = self.inv_freq + t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len <= self.max_position_embeddings: + # Reset the tables if the sequence length has changed, + if self.max_seq_len_cached > self.max_position_embeddings: + self._update_cached(x, seq_len) + else: + self._update_cached(x, seq_len) + + return ( + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), + ) + + +# Copied from transformers.model.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.model.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + if position_ids.size(1) == 1: + q_cos = cos[position_ids].unsqueeze(1).expand(q.shape) + q_sin = sin[position_ids].unsqueeze(1).expand(q.shape) + q_embed = (q * q_cos) + (rotate_half(q) * q_sin) + + position_ids = position_ids.flatten() + 1 + max_length = max(position_ids) + position_ids = torch.stack( + [torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids] + ) + k_cos = cos[position_ids].unsqueeze(1).expand(k.shape) + k_sin = sin[position_ids].unsqueeze(1).expand(k.shape) + k_embed = (k * k_cos) + (rotate_half(k) * k_sin) + else: + cos = cos[position_ids].unsqueeze(1) + sin = sin[position_ids].unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->InternLM +class InternLMMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +# A mixed expert module containing shared experts. +class InternLMMoELayer(nn.Module): + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.config = config + self.num_shared_experts = config.num_shared_experts + self.num_shared_experts = config.num_shared_experts + self.num_experts_per_tok = config.num_experts_per_tok + self.gate = nn.Linear(config.hidden_size, config.num_routed_experts, bias=False) + self.experts = torch.nn.ModuleList( + [ + InternLMMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + for _ in range(config.num_routed_experts) + ] + ) + if config.num_shared_experts > 0: + intermediate_size = config.intermediate_size * config.num_shared_experts + self.shared_experts = InternLMMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + ) + + def forward(self, x): + orig_shape = x.shape + orig_inputs = x + x = x.view(-1, x.shape[-1]) + # if self.gate.weight.dtype != torch.float32: + # self.gate = self.gate.float() + # x = x.float() + logits = self.gate(x) + gates = F.softmax(logits, dim=1) + weights, indices = torch.topk(gates, self.num_experts_per_tok, dim=1) + weights /= weights.sum(dim=-1, keepdim=True) + flat_indices = indices.view(-1) + + x = x.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(x) + for i, expert in enumerate(self.experts): + y[flat_indices == i] = expert(x[flat_indices == i]) + y = (y.view(*weights.shape, -1) * weights.unsqueeze(-1)).sum(dim=1) + + if self.config.num_shared_experts > 0: + y = y + self.shared_experts(orig_inputs) + + # moe_loss = self.load_balancing_loss(weights, indices) if self.training else None + return y.view(*orig_shape), logits + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->InternLM +class InternLMAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) + self.rotary_emb = self._init_rope() + self.is_causal = True + + def _init_rope(self): + if self.config.rotary["type"] == "origin": + self.rotary_emb = InternLMRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rotary["base"], + ) + elif self.config.rotary["type"] == "dynamic": + self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rotary["base"], + scaling_factor=self.config.rotary.get("scaling_factor", 1.0), + ) + else: + raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').") + return self.rotary_emb + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->InternLM +class InternLMFlashAttention2(InternLMAttention): + """ + InternLM flash attention module. This module inherits from `InternLMAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, # pylint: disable=W0613 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # InternLMFlashAttention2 attention does not support output_attentions + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + kv_seq_len = key_states.shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward(query_states, key_states, value_states, attention_mask, q_len) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + causal = self.is_causal and query_length != 1 + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q.to(torch.int64), + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +INTERNLM_ATTENTION_CLASSES = { + "eager": InternLMAttention, + "flash_attention_2": InternLMFlashAttention2, +} + + +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->InternLM +class InternLMMoEDecoderLayer(nn.Module): + def __init__(self, config: InternLMMoEConfig): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = INTERNLM_ATTENTION_CLASSES[config.attn_implementation](config=config) + + self.mlp = ( + InternLMMoELayer(config) + if config.num_routed_experts > 1 + else InternLMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + ) + self.input_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + router_logits = None + if len(hidden_states) == 2: + hidden_states, router_logits = hidden_states + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if output_router_logits: + outputs += (router_logits,) + + return outputs + + +INTERNLM_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`InternLMMoEConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaPretrainedModel with Llama->InternLM +@add_start_docstrings( + "The bare InternLM Model outputting raw hidden-states without any specific head on top.", + INTERNLM_START_DOCSTRING, +) +class InternLMMoEPreTrainedModel(PreTrainedModel): + config_class = InternLMMoEConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["InternLMMoEDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): # pylint: disable=W0237 + if isinstance(module, InternLMMoEModel): + module.gradient_checkpointing = value + + +INTERNLM_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or + when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.llama.modeling_llama.LlamaModel with Llama->InternLM +@add_start_docstrings( + "The bare InternLM Model outputting raw hidden-states without any specific head on top.", + INTERNLM_START_DOCSTRING, +) +class InternLMMoEModel(InternLMMoEPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`] + Args: + config: InternLMMoEConfig + """ + + _auto_class = "AutoModel" + + def __init__(self, config: InternLMMoEConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.config = config + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.ModuleList([InternLMMoEDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = InternLMRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = output_router_logits if output_router_logits is not None else False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.attn_implementation == "flash_attention_2": + _import_flash_attn() + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + if self.config.attn_implementation == "flash_attention_2": + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + all_router_logits = () if output_router_logits else None + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->InternLM +class InternLMMoEForCausalLM(InternLMMoEPreTrainedModel): + _auto_class = "AutoModelForCausalLM" + + def __init__(self, config): + super().__init__(config) + self.model = InternLMMoEModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Returns: + + Example: + ```python + >>> from transformers import AutoTokenizer, InternLMForCausalLM + >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ``` + + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = output_router_logits if output_router_logits is not None else False + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + moe_loss = None + if output_router_logits: + aux_loss = _compute_load_balancing_loss( + outputs.router_logits if return_dict else outputs[-1], self.num_experts, self.num_experts_per_tok + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss + + if not return_dict: + output = (logits,) + outputs[1:] + if output_router_logits: + output = (moe_loss,) + output + return (loss,) + output if loss is not None else output + + return MoeCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + def build_inputs( + self, tokenizer, query: str, history: List[Tuple[str, str]] = [], meta_instruction="" + ): # pylint: disable=W0102 + if tokenizer.add_bos_token: + prompt = "" + else: + prompt = tokenizer.bos_token + if meta_instruction: + prompt += f"""<|System|>:{meta_instruction}\n""" + for record in history: + prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}\n""" + prompt += f"""<|User|>:{query}\n<|Bot|>:""" + return tokenizer([prompt], return_tensors="pt") + + @torch.no_grad() + def chat( # pylint: disable=W0102 + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + meta_instruction: str = "You are an AI assistant whose name is InternLM (书生·浦语).\n" + "- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory " + "(上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n" + "- InternLM (书生·浦语) can understand and communicate fluently in the language chosen by the user " + "such as English and 中文.", + **kwargs, + ): + inputs = self.build_inputs(tokenizer, query, history, meta_instruction) + inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} + outputs = self.generate( + **inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + response = tokenizer.decode(outputs, skip_special_tokens=True) + response = response.split("")[0] + history = history + [(query, response)] + return response, history + + @torch.no_grad() + def stream_chat( # pylint: disable=W0102 + self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs, + ): + """ + Return a generator in format: (response, history) + Eg. + ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) + ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) + """ + if BaseStreamer is None: + raise ModuleNotFoundError( + "The version of `transformers` is too low. Please make sure " + "that you have installed `transformers>=4.28.0`." + ) + + response_queue = queue.Queue(maxsize=20) + + class ChatStreamer(BaseStreamer): + def __init__(self, tokenizer) -> None: + super().__init__() + self.tokenizer = tokenizer + self.queue = response_queue + self.query = query + self.history = history + self.response = "" + self.cache = [] + self.received_inputs = False + self.queue.put((self.response, history + [(self.query, self.response)])) + + def put(self, value): + if len(value.shape) > 1 and value.shape[0] > 1: + raise ValueError("ChatStreamer only supports batch size 1") + elif len(value.shape) > 1: + value = value[0] + + if not self.received_inputs: + # The first received value is input_ids, ignore here + self.received_inputs = True + return + + self.cache.extend(value.tolist()) + token = self.tokenizer.decode(self.cache, skip_special_tokens=True) + if "�" in token and len(token) <= 5: + return + if token.strip() != "": + self.response = self.response + token + history = self.history + [(self.query, self.response)] + self.queue.put((self.response, history)) + self.cache = [] + else: + self.end() + + def end(self): + self.queue.put(None) + + def stream_producer(): + return self.chat( + tokenizer=tokenizer, + query=query, + streamer=ChatStreamer(tokenizer=tokenizer), + history=history, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs, + ) + + def consumer(): + producer = threading.Thread(target=stream_producer) + producer.start() + while True: + res = response_queue.get() + if res is None: + return + yield res + + return consumer() + + +@add_start_docstrings( + """ + The InternLM Model transformer with a sequence classification head on top (linear layer). + [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + INTERNLM_START_DOCSTRING, +) +class InternLMMoEForSequenceClassification(InternLMMoEPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = InternLMMoEModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(INTERNLM_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int # pylint: disable=R1714 + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/transformers/internlm_moe_model/tokenization_internlm.py b/transformers/internlm_moe_model/tokenization_internlm.py new file mode 100644 index 00000000..de474559 --- /dev/null +++ b/transformers/internlm_moe_model/tokenization_internlm.py @@ -0,0 +1,237 @@ +# coding=utf-8 +# Copyright (c) The InternLM team and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on transformers/src/transformers/models/llama/tokenization_llama.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for InternLM.""" +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.tokenization_utils import PreTrainedTokenizer +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = {} + + +# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer -> InternLM2Tokenizer +class InternLMTokenizer(PreTrainedTokenizer): + """ + Construct a InternLM tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + model_input_names = ["input_ids", "attention_mask"] + _auto_class = "AutoTokenizer" + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + decode_with_prefix_space=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.decode_with_prefix_space = decode_with_prefix_space + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + self._no_prefix_space_tokens = None + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def no_prefix_space_tokens(self): + if self._no_prefix_space_tokens is None: + vocab = self.convert_ids_to_tokens(list(range(self.vocab_size))) + self._no_prefix_space_tokens = {i for i, tok in enumerate(vocab) if not tok.startswith("▁")} + return self._no_prefix_space_tokens + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + @property + def bos_token_id(self) -> Optional[int]: + return self.sp_model.bos_id() + + @property + def eos_token_id(self) -> Optional[int]: + return self.sp_model.eos_id() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def _maybe_add_prefix_space(self, tokens, decoded): + if tokens and tokens[0] not in self.no_prefix_space_tokens: + return " " + decoded + else: + return decoded + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for token in tokens: + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + out_string = self.clean_up_tokenization(out_string) + out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string) + return out_string[1:] + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + if self.add_bos_token: + bos_token_ids = [self.bos_token_id] + else: + bos_token_ids = [] + + output = bos_token_ids + token_ids_0 + + if token_ids_1 is not None: + output = output + token_ids_1 + + if self.add_eos_token: + output = output + [self.eos_token_id] + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make + use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + """ + eos = [self.eos_token_id] + + if token_ids_1 is None: + return len(token_ids_0 + eos) * [0] + return len(token_ids_0 + eos + token_ids_1 + eos) * [0] diff --git a/version.txt b/version.txt index 267577d4..be14282b 100644 --- a/version.txt +++ b/version.txt @@ -1 +1 @@ -0.4.1 +0.5.3 diff --git a/web_demo_internlm.py b/web_demo_internlm.py index 8730c0c2..abe0568e 100644 --- a/web_demo_internlm.py +++ b/web_demo_internlm.py @@ -8,7 +8,7 @@ from internlm.accelerator import get_accelerator from tools.interface import GenerationConfig -from tools.load_internlm_model import ( +from tools.load_internlm2_model import ( initialize_internlm_model, internlm_interactive_generation, )