diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 43d8d52a..8dd8bec7 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -1,26 +1,65 @@ -name: "Publish to PyPI" +name: Build and Publish to PyPI + on: - release: - types: - - published + push: + tags: + - "v*.*.*" + jobs: - build-n-publish: - name: Build and publish + build: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['36', '37', '38', '39'] + steps: - - uses: actions/checkout@master - - uses: actions/setup-python@v2 - with: - python-version: '3.7' - architecture: 'x64' - - name: Run build script - run: | - pip install twine --user - pip install wheel - pip install torch==1.11.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - python setup.py sdist --format=gztar - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@master - with: - user: __token__ - password: ${{ secrets.pypi_password }} \ No newline at end of file + - name: Checkout code + uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Pull Docker image + run: docker pull maydomine/bmtrain-manylinux:cu110 + + - name: Run Docker image and execute script + run: | + version=${{ matrix.python-version }} + docker run -e CUDACXX=/usr/local/cuda/bin/nvcc -e PATH="/workspace/cmake-3.26.4-linux-x86_64/bin:/opt/rh/devtoolset-7/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-7/root/usr/lib64:/opt/rh/devtoolset-7/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i maydomine/bmtrain-manylinux:cu110 /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/python setup.py bdist_wheel -d ./wheel/;/opt/python/cp${version}*/bin/python setup.py sdist -d ./sdist/;for file in wheel/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" + + - name: Archive distribution files + uses: actions/upload-artifact@v2 + with: + name: dist + path: | + sdist/*.tar.gz + wheel/*.whl + + publish: + needs: build + runs-on: ubuntu-latest + steps: + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Install twine + run: python -m pip install twine + + - name: Download distribution files + uses: actions/download-artifact@v2 + with: + name: dist + path: dist + + - name: Publish to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + cd dist + python -m twine upload sdist/*.tar.gz wheel/*.whl diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..dc3c593e --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,71 @@ +name: Publish release in Github + +on: + push: + tags: + - "v*.*.*" + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['36', '37', '38', '39'] + + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Pull Docker image + run: docker pull maydomine/bmtrain-manylinux:cu110 + + - name: Run Docker image and execute script + run: | + version=${{ matrix.python-version }} + docker run -e CUDACXX=/usr/local/cuda/bin/nvcc -e PATH="/workspace/cmake-3.26.4-linux-x86_64/bin:/opt/rh/devtoolset-7/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-7/root/usr/lib64:/opt/rh/devtoolset-7/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i maydomine/bmtrain-manylinux:cu110 /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/python setup.py bdist_wheel -d ./wheel/;/opt/python/cp${version}*/bin/python setup.py sdist -d ./sdist/;for file in wheel/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" + + - name: Archive distribution files + uses: actions/upload-artifact@v2 + with: + name: dist + path: | + sdist/*.tar.gz + wheel/*.whl + + publish: + needs: build + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Download distribution files + uses: actions/download-artifact@v2 + with: + name: dist + path: dist + + - name: Upload Distribution Files + uses: softprops/action-gh-release@v1 + with: + body_path: "Release.txt" + files: | + dist/sdist/*.tar.gz + dist/wheel/*.whl + prerelease: false + token: ${{ secrets.RELEASE_TOKEN }} + release_tag: ${{ steps.create_release.outputs.tag }} + github_token: ${{ secrets.GITHUB_TOKEN }} + env: + GITHUB_REPOSITORY: MayDomine/BMTrain diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 00000000..455f7dbd --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,62 @@ +cmake_minimum_required(VERSION 3.18) +project(bmtrain) +enable_language(C) +enable_language(CXX) +set(CMAKE_CUDA_ARCHITECTURES "61;62;70;72;75;80") +enable_language(CUDA) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_CUDA_STANDARD 14) +set(CMAKE_CUDA_STANDARD_REQUIRED True) + +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80") + +if(NOT DEFINED ENV{BUILD_DOCKER_ENV} OR "$ENV{BUILD_DOCKER_ENV}" STREQUAL "0") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_86,code=sm_86") +endif() + +set(CMAKE_BUILD_RPATH $ORIGIN) +set(CMAKE_INSTALL_RPATH $ORIGIN) +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/) + +find_package(NCCL REQUIRED) +find_package(Python ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED) +message (STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}") +execute_process(COMMAND ${Python_EXECUTABLE} "-c" + "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE PYBIND11_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE) +message (STATUS "PYBIND11_CMAKE_DIR: ${PYB +IND11_CMAKE_DIR}") +list(APPEND CMAKE_PREFIX_PATH ${PYBIND11_CMAKE_DIR}) +find_package(pybind11 REQUIRED) + +message (STATUS "CMAKE_INSTALL_RPATH: ${CMAKE_INSTALL_RPATH}") + +file(GLOB_RECURSE SOURCES "csrc/*.cpp") +file(GLOB_RECURSE CUDA_SOURCES "csrc/cuda/*.cu") + +set(AVX_FLAGS "${AVX_FLAGS} -march=native") + +pybind11_add_module(C ${SOURCES} ${CUDA_SOURCES}) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${AVX_FLAGS}") + +target_link_libraries(C PRIVATE + "-Wl,-Bsymbolic" + "-Wl,-Bsymbolic-functions" + ${NCCL_LIBRARIES} +) +target_include_directories(C PRIVATE ${NCCL_INCLUDE_DIRS}) +target_compile_definitions(C + PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) + +set_target_properties(C PROPERTIES CUDA_ARCHITECTURES "61;62;70;72;75;80") + +target_include_directories(C + PRIVATE "csrc/include" + PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} +) + + + diff --git a/Release.txt b/Release.txt new file mode 100644 index 00000000..8597044f --- /dev/null +++ b/Release.txt @@ -0,0 +1,3 @@ +# BMTrain New Version Release v0.2.3 +- easier to install (without torch dependency while compiling) +- compatible with torch 2.0 \ No newline at end of file diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 5269673b..8647b7ff 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -9,6 +9,7 @@ from .checkpointing import ScopedTensorInspectorContext from . import debug import copy +import inspect # the flag is used to control the zero level , 0 means normal zero3 , 1 means forward without release parameter ,2 means backward without gather parameter @@ -491,6 +492,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if key in state_dict: # load here input_param = state_dict[key] + if input_param.__class__.__name__ == "DistributedTensorWrapper": + input_param = input_param.broadcast() if input_param.shape != it["shape"]: error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' 'the shape in current model is {}.' @@ -617,10 +620,14 @@ def init_parameters(self): torch.tensor([], dtype=d_dtype, device=d_device).set_(tmp_tensor.storage(), offset_st, (offset_end - offset_st,))[:] del tmp_tensor - def _named_members(self, get_members_fn, prefix='', recurse=True): + def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): r"""Helper method for yielding various names + members of modules.""" - return self._module._named_members(get_members_fn, prefix, recurse) + #compitibity with torch 2.0 + if "remove_duplicate" in inspect.signature(torch.nn.Module._named_members).parameters and "remove_duplicate" not in kwargs: + kwargs['remove_duplicate'] = True + return self._module._named_members(get_members_fn, prefix, recurse, **kwargs) + def named_modules(self, memo = None, prefix: str = '', remove_duplicate: bool = True): r"""Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself. diff --git a/bmtrain/distributed/ops.py b/bmtrain/distributed/ops.py index 6124622d..ef69659a 100644 --- a/bmtrain/distributed/ops.py +++ b/bmtrain/distributed/ops.py @@ -20,22 +20,29 @@ def send_activations(hidden_state, next_rank, comm): send_meta(hidden_state, next_rank, comm) ncclSend(hidden_state.storage(), next_rank, comm) + def recv_activations(prev_rank, comm): dtype, shape = recv_meta(prev_rank, comm) hidden_state = torch.empty(shape, dtype=dtype, device="cuda") ncclRecv(hidden_state.storage(), prev_rank, comm) return hidden_state + def send_meta(x, next_rank, comm): - meta = [len(x.size()), DTYPE_LIST.index(x.dtype)] + list(x.size()) - meta_data = torch.tensor(data=meta, device=x.device, dtype=torch.long) + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + meta_data[0] = len(x.size()) + meta_data[1] = DTYPE_LIST.index(x.dtype) + meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) + meta_data = meta_data.contiguous() ncclSend(meta_data.storage(), next_rank, comm) + def recv_meta(prev_rank, comm): - meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.long) + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) ncclRecv(meta_data.storage(), prev_rank, comm) n_dims = meta_data[0].item() dtype = DTYPE_LIST[meta_data[1].item()] shape = meta_data[2:n_dims+2].tolist() return dtype,shape + class OpBroadcast(torch.autograd.Function): @staticmethod def forward(ctx, src, root, comm = None): diff --git a/bmtrain/init.py b/bmtrain/init.py index b98064b7..5c3006d2 100644 --- a/bmtrain/init.py +++ b/bmtrain/init.py @@ -149,11 +149,13 @@ def __init__(self,config): self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1 self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist() self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist() + def get_group_id(self,group_name): if group_name == "pipe": return self.pipe_idx elif group_name == "zero": return self.zero_idx + def get_group_rank(self,group_name): if group_name == "pipe": return self.stage_id diff --git a/bmtrain/layer.py b/bmtrain/layer.py index c32018db..ebbef815 100644 --- a/bmtrain/layer.py +++ b/bmtrain/layer.py @@ -82,6 +82,8 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, key = prefix + name if key in state_dict: input_param = state_dict[key] + if input_param.__class__.__name__ == "DistributedTensorWrapper": + input_param = input_param.broadcast() # This is used to avoid copying uninitialized parameters into # non-lazy modules, since they dont have the hook to do the checks # in such case, it will error when accessing the .shape attribute. diff --git a/bmtrain/loss/_function.py b/bmtrain/loss/_function.py new file mode 100644 index 00000000..658ef242 --- /dev/null +++ b/bmtrain/loss/_function.py @@ -0,0 +1,86 @@ + +from .. import C +import torch +CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda +def has_inf_nan(g_fp16: torch.Tensor, out: torch.Tensor) -> None: + assert g_fp16.dtype == torch.float16, "g_fp16 must be a half tensor" + assert out.dtype == torch.uint8, "out must be a uint8 tensor" + assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(out), "out must be contiguous and on cuda" + mid = torch.zeros(1024, device=out.device, dtype=out.dtype) + stream = torch.cuda.current_stream().cuda_stream + C.has_nan_inf_launcher(g_fp16.numel(), g_fp16.data_ptr(), mid.data_ptr(), out.data_ptr(), stream) + + + +def cross_entropy_forward(m: int, n: int, input: torch.Tensor, target: torch.Tensor, + softmax: torch.Tensor, output: torch.Tensor, ignore_index: int) -> None: + CHECK_INPUT(input) + CHECK_INPUT(target) + CHECK_INPUT(softmax) + CHECK_INPUT(output) + assert input.dtype == torch.float16, "input must be a half tensor" + assert target.dtype == torch.int32, "target must be an int tensor" + assert softmax.dtype == torch.float16, "softmax must be a half tensor" + assert output.dtype == torch.float32, "output must be a float tensor" + assert input.numel() == softmax.numel(), "input and softmax must have the same number of elements" + assert target.numel() == output.numel(), "target and output must have the same number of elements" + input_ptr = input.data_ptr() + target_ptr = target.data_ptr() + softmax_ptr = softmax.data_ptr() + output_ptr = output.data_ptr() + cuda_stream = torch.cuda.current_stream().cuda_stream + C.cross_entropy_forward_launcher(m, n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index, cuda_stream) + +def cross_entropy_backward(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, + softmax: torch.Tensor, grad_input: torch.Tensor, ignore_index: int) -> None: + CHECK_INPUT(grad_output) + CHECK_INPUT(target) + CHECK_INPUT(softmax) + CHECK_INPUT(grad_input) + assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" + assert target.dtype == torch.int32, "target must be an int tensor" + assert softmax.dtype == torch.float16, "softmax must be a half tensor" + assert grad_input.dtype == torch.float16, "grad_input must be a half tensor" + assert grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements" + assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" + grad_output_ptr = grad_output.data_ptr() + target_ptr = target.data_ptr() + softmax_ptr = softmax.data_ptr() + grad_input_ptr = grad_input.data_ptr() + cuda_stream = torch.cuda.current_stream().cuda_stream + C.cross_entropy_backward_launcher(m, n, grad_output_ptr, target_ptr, softmax_ptr, grad_input_ptr, ignore_index, cuda_stream) + +def cross_entropy_forward_inplace(m: int, n: int, x: torch.Tensor, target: torch.Tensor, + output: torch.Tensor, ignore_index: int) -> None: + CHECK_INPUT(x) + CHECK_INPUT(target) + CHECK_INPUT(output) + assert x.dtype == torch.float16, "x must be a half tensor" + assert target.dtype == torch.int32, "target must be an int tensor" + assert output.dtype == torch.float32, "output must be a float tensor" + assert target.numel() == output.numel(), "target and output must have the same number of elements" + cuda_stream = torch.cuda.current_stream().cuda_stream + x_ptr = x.data_ptr() + output_ptr = output.data_ptr() + target_ptr = target.data_ptr() + output_ptr = output.data_ptr() + + C.cross_entropy_forward_inplace_launcher(m, n, x_ptr, target_ptr, output_ptr, ignore_index, cuda_stream) + +def cross_entropy_backward_inplace(m: int, n: int, grad_output: torch.Tensor, target: torch.Tensor, + x: torch.Tensor, ignore_index: int) -> None: + CHECK_INPUT(grad_output) + CHECK_INPUT(target) + CHECK_INPUT(x) + assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" + assert target.dtype == torch.int32, "target must be an int tensor" + assert x.dtype == torch.float16, "x must be a half tensor" + assert target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements" + cuda_stream = torch.cuda.current_stream().cuda_stream + grad_output_ptr = grad_output.data_ptr() + target_ptr = target.data_ptr() + x_ptr = x.data_ptr() + + C.cross_entropy_backward_inplace_launcher(m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream) + diff --git a/bmtrain/loss/cross_entropy.py b/bmtrain/loss/cross_entropy.py index 1fb37181..160ef421 100644 --- a/bmtrain/loss/cross_entropy.py +++ b/bmtrain/loss/cross_entropy.py @@ -1,7 +1,6 @@ from typing import Optional import torch -from . import _cuda as C - +from . import _function as F class OpFusedCrossEntropy(torch.autograd.Function): """ CrossEntropy dim = 1 @@ -11,7 +10,7 @@ def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): assert x.ndim == 2 softmax = torch.empty(x.size(), device=x.device, dtype=x.dtype) out = torch.empty(x.size(0), device=x.device, dtype=torch.float) - C.f_cross_entropy_forward( + F.cross_entropy_forward( x.size(0), x.size(1), x, target, softmax, out, @@ -25,7 +24,7 @@ def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): def backward(ctx, grad_output : torch.Tensor): grad_output = grad_output.contiguous() softmax, target = ctx.saved_tensors - C.f_cross_entropy_backward_inplace( + F.cross_entropy_backward_inplace( softmax.size(0), softmax.size(1), grad_output, target, softmax, @@ -41,7 +40,7 @@ class OpFusedCrossEntropyInplace(torch.autograd.Function): def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): assert x.ndim == 2 out = torch.empty(x.size(0), device=x.device, dtype=torch.float) - C.f_cross_entropy_forward_inplace( + F.cross_entropy_forward_inplace( x.size(0), x.size(1), x, target, out, @@ -55,7 +54,7 @@ def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): def backward(ctx, grad_output : torch.Tensor): grad_output = grad_output.contiguous() softmax, target = ctx.saved_tensors - C.f_cross_entropy_backward_inplace( + F.cross_entropy_backward_inplace( softmax.size(0), softmax.size(1), grad_output, target, softmax, diff --git a/bmtrain/nccl/__init__.py b/bmtrain/nccl/__init__.py index b6b2de91..0f4129d5 100644 --- a/bmtrain/nccl/__init__.py +++ b/bmtrain/nccl/__init__.py @@ -1,7 +1,7 @@ from typing_extensions import Literal import torch -from . import _C as C +from .. import C from .enums import * class NCCLCommunicator: diff --git a/bmtrain/optim/_function.py b/bmtrain/optim/_function.py new file mode 100644 index 00000000..ee4b04a7 --- /dev/null +++ b/bmtrain/optim/_function.py @@ -0,0 +1,78 @@ + +from .. import C +import torch +CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda +def adam_cpu(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp32: torch.Tensor, + v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, + weight_decay: float, step: int) -> None: + assert param_fp32.is_contiguous(), "param_fp32 must be contiguous" + assert param_fp16.is_contiguous(), "param_fp16 must be contiguous" + assert g_fp16.is_contiguous(), "g_fp16 must be contiguous" + assert m_fp32.is_contiguous(), "m_fp32 must be contiguous" + assert v_fp32.is_contiguous(), "v_fp32 must be contiguous" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor" + assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor" + assert param_fp16.device == torch.device("cpu"), "param_fp16 must be a cpu tensor" + assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" + assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" + assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" + assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" + assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" + assert param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements" + assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + C.adam_cpu_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + g_fp16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + ) + +def adam(param_fp32: torch.Tensor, param_fp16: torch.Tensor, g_fp16: torch.Tensor, m_fp16: torch.Tensor, + v_fp32: torch.Tensor, beta1: float, beta2: float, eps: float, lr: float, scale: float, + weight_decay: float, step: int) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(param_fp16), "param_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp16), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor" + assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" + assert m_fp16.dtype == torch.float16, "m_fp16 must be float16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements" + assert param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements" + assert param_fp32.numel() == m_fp16.numel(), "param_fp32 and m_fp32 must have the same number of elements" + assert param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1 ** step + bias_correction2 = 1 - beta2 ** step + stream = torch.cuda.current_stream().cuda_stream + C.adam_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + g_fp16.data_ptr(), + m_fp16.data_ptr(), + v_fp32.data_ptr(), + beta1, beta2, + eps, lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream + ) diff --git a/bmtrain/optim/adam.py b/bmtrain/optim/adam.py index 689b151e..b63a4f51 100644 --- a/bmtrain/optim/adam.py +++ b/bmtrain/optim/adam.py @@ -1,10 +1,11 @@ import torch from ..global_var import config -import torch.optim._functional as F -from . import _cuda as C +from . import _function as F +import torch.optim._functional +from .. import C from .. import nccl import inspect - +from ..utils import check_torch_version from copy import deepcopy from itertools import chain from collections import defaultdict @@ -87,7 +88,7 @@ def step(self, closure=None, scale=1): grad = p.grad if p.dtype == torch.half: - C.f_adam( + F.adam( state["_param_fp32"], # fp32 p, # fp16 grad, # fp16 @@ -102,15 +103,15 @@ def step(self, closure=None, scale=1): ) else: other_kwargs = {} - if 'maximize' in inspect.signature(F.adam).parameters: + if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: other_kwargs['maximize'] = False - F.adam( + torch.optim._functional.adam( [p], [grad / scale], [state['exp_avg']], [state["exp_avg_sq"]], [], - [state["step"]] if int(torch.__version__.split('.')[1]) < 12 + [state["step"]] if check_torch_version("1.12.0") < 0 else [torch.tensor(state["step"])], amsgrad=False, beta1=group['betas'][0], @@ -176,4 +177,8 @@ def update_group(group, new_group): return new_group param_groups = [ update_group(g, ng) for g, ng in zip(groups, saved_groups)] - self.__setstate__({'state': state, 'param_groups': param_groups}) \ No newline at end of file + self.__setstate__({'state': state, 'param_groups': param_groups}) + + #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) \ No newline at end of file diff --git a/bmtrain/optim/adam_offload.py b/bmtrain/optim/adam_offload.py index 1b337268..e33219bf 100644 --- a/bmtrain/optim/adam_offload.py +++ b/bmtrain/optim/adam_offload.py @@ -1,11 +1,9 @@ import torch from ..global_var import config -from . import _cpu as C -from . import _cuda as G +from . import _function as F from .. import nccl -import torch.optim._functional as F import inspect - +from ..utils import check_torch_version from copy import deepcopy from itertools import chain from collections import defaultdict @@ -107,7 +105,7 @@ def step(self, closure=None, scale=1): grad = -state["_grad_fp16"] else: grad = state["_grad_fp16"] - C.f_adam_cpu( + F.adam_cpu( state["_param_fp32"].view(-1), state["_param_fp16"].view(-1), grad.view(-1), @@ -128,15 +126,15 @@ def step(self, closure=None, scale=1): else: grad = state["_grad_fp32"] other_kwargs = {} - if 'maximize' in inspect.signature(F.adam).parameters: + if 'maximize' in inspect.signature(torch.optim._functional.adam).parameters: other_kwargs['maximize'] = False - F.adam( + torch.optim._functional.adam( [state["_param_fp32"]], [grad], [state["exp_avg"]], [state["exp_avg_sq"]], [], - [state["step"]] if int(torch.__version__.split('.')[1]) < 12 + [state["step"]] if check_torch_version("1.12.0") < 0 else [torch.tensor(state["step"])], amsgrad=False, beta1=beta1, @@ -252,4 +250,9 @@ def cut_states(state): return { 'state': packed_state, 'param_groups': param_groups, - } \ No newline at end of file + } + + #TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) + diff --git a/bmtrain/optim/optim_manager.py b/bmtrain/optim/optim_manager.py index bb145748..23eaa868 100644 --- a/bmtrain/optim/optim_manager.py +++ b/bmtrain/optim/optim_manager.py @@ -1,7 +1,6 @@ from typing import Optional, Union, List, Dict, Tuple import torch - -from . import _cuda as G +from ..loss._function import has_inf_nan from ..utils import print_rank from ..lr_scheduler.warmup import WarmupLRScheduler from .. import nccl @@ -13,7 +12,7 @@ def check_overflow(param_groups): for group in param_groups: for p in group['params']: if p.grad is not None and p.dtype == torch.half: # TODO support other types - G.f_has_inf_nan(p.grad, has_inf_or_nan) + has_inf_nan(p.grad, has_inf_or_nan) if "comm" in config: nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"]) @@ -81,6 +80,7 @@ def add_optimizer( self.lr_schedulers.append(lr_scheduler) def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: + return loss * (self.loss_scale / config['world_size']) # loss scale def backward(self, loss : torch.Tensor): @@ -101,7 +101,7 @@ def zero_grad(self): This is a helper function to call optimizer.zero_grad() """ for optimizer in self.optimizers: - optimizer.zero_grad() + optimizer.zero_grad(set_to_none=False) def step(self): """ @@ -122,10 +122,10 @@ def step(self): break if has_overflow: print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor)) - self._justify_scale(self.loss_scale / self.loss_scale_factor) - self.zero_grad() + with torch.no_grad(): + self._justify_scale(self.loss_scale / self.loss_scale_factor) + self.zero_grad() return - for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers): if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer: optimizer.step(scale=self.loss_scale) diff --git a/bmtrain/store.py b/bmtrain/store.py index c0914f9a..0e0fd7ca 100644 --- a/bmtrain/store.py +++ b/bmtrain/store.py @@ -136,6 +136,51 @@ def broadcast_object(obj, comm, src = 0): return obj # Must be a Mapping after pytorch 1.12.0 +class DistributedTensorWrapper: + def __init__(self, tensor, shape=None): + self._dtype = tensor.dtype + self._device = tensor.device + self.shape = shape + self.tensor = tensor + + def broadcast(self): + output_param = torch.empty(self.shape, dtype=self._dtype, device="cuda") + if config['rank'] == 0: + input_param = self.tensor + if input_param.is_cuda: + input_param = input_param.clone().contiguous() + else: + input_param = input_param.cuda().contiguous() + + nccl.broadcast( + input_param.storage(), + output_param.storage(), + 0, + config['comm'] + ) + else: + nccl.broadcast( + output_param.storage(), + output_param.storage(), + 0, + config['comm'] + ) + return output_param + + def copy(self): + return self.tensor + + def __getattribute__(self, name): + if name == "tensor" or name == "shape": + return object.__getattribute__(self, name) + else: + try: + return object.__getattribute__(self, name) + except AttributeError: + pass + + return getattr(self.tensor, name) + class DistributedStateDictWrapper(Mapping): def __init__(self, state_dict : Dict) -> None: self._state_dict = state_dict @@ -165,30 +210,12 @@ def __getitem__(self, key : str): dtype_idx = tmp_shape[1].item() shape_list = torch.Size(tmp_shape[2: 2 + shape_list_size].tolist()) - output_param = torch.empty(shape_list, dtype=DTYPE_LIST[dtype_idx], device="cuda") - - if config['rank'] == 0: - input_param : torch.Tensor = self._state_dict[key] - if input_param.is_cuda: - input_param = input_param.clone().contiguous() - else: - input_param = input_param.cuda().contiguous() - - nccl.broadcast( - input_param.storage(), - output_param.storage(), - 0, - config['comm'] - ) + if config['rank'] != 0: + return DistributedTensorWrapper(torch.tensor([], dtype=DTYPE_LIST[dtype_idx], device="cuda"), shape=shape_list) else: - nccl.broadcast( - output_param.storage(), - output_param.storage(), - 0, - config['comm'] - ) + return DistributedTensorWrapper(self._state_dict[key], shape=shape_list) + - return output_param def copy(self): return self diff --git a/bmtrain/utils.py b/bmtrain/utils.py index 2df9a3a5..391b5a8f 100644 --- a/bmtrain/utils.py +++ b/bmtrain/utils.py @@ -5,6 +5,18 @@ ALIGN = 4 ROW_WIDTH = 60 +def check_torch_version(version_str): + """ + Checks if the current torch version is greater than or equal to the given version. + version_str (str): The version to compare with, in the format of "x.y.z" ,and the func will convert it into a int value of x*100+y*10+z. + """ + version_int_arr = [int(v) for v in version_str.split(".")] + + version_int = version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2] + torch_version = torch.__version__.split("+")[0] + current_version_int_arr = [int(v) for v in torch_version.split(".")] + current_version_int = current_version_int_arr[0] * 10000 + current_version_int_arr[1] * 100 + current_version_int_arr[2] + return current_version_int - version_int def round_up(x, d): return (x + d - 1) // d * d diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake new file mode 100644 index 00000000..2af8e3b9 --- /dev/null +++ b/cmake/FindNCCL.cmake @@ -0,0 +1,100 @@ +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +if(DEFINED ENV{NCCL_ROOT_DIR}) + set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR}) + set(NCCL_INCLUDE_DIR "${NCCL_ROOT_DIR}/include" CACHE PATH "Folder contains NVIDIA NCCL headers") + set(NCCL_LIB_DIR "${NCCL_ROOT_DIR}/lib" CACHE PATH "Folder contains NVIDIA NCCL libraries") +else() + set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") + set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +endif() + +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +if(NOT NCCL_INCLUDE_DIR OR NOT NCCL_LIB_DIR) + execute_process( + COMMAND python -c "import nvidia.nccl;import os; print(os.path.dirname(nvidia.nccl.__file__))" + OUTPUT_VARIABLE NCCL_PIP_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + list(APPEND NCCL_ROOT $ENV{NCCL_PIP_DIR}) + if(NOT NCCL_INCLUDE_DIR) + set(NCCL_INCLUDE_DIR "${NCCL_PIP_DIR}/include") + endif() + if(NOT NCCL_LIB_DIR) + set(NCCL_LIB_DIR "${NCCL_PIP_DIR}/lib") + endif() + find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) +endif() + +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + + + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + message(STATUS "NCCL version: ${NCCL_VERSION}") + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + else() + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.2" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() + +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include + #include + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/csrc/adam_cpu.cpp b/csrc/adam_cpu.cpp deleted file mode 100644 index b8685b60..00000000 --- a/csrc/adam_cpu.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include -#include -#include -#include - -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") - -#if defined(__AVX512F__) - -#pragma message "Using AVX512" -#define __AVX512__ 1 - -#elif defined(__AVX__) and defined(__FMA__) and defined(__F16C__) - -#pragma message "Using AVX256" -#define __AVX256__ 1 - -#endif - -void adam_cpu_launcher( - int n, - float* param_fp32, - at::Half* param_fp16, - at::Half* g_fp16, - float* m_fp32, - float* v_fp32, - float beta1, float beta2, - float eps, float lr, - float scale, - float weight_decay, - float bias_correction1, - float bias_correction2 -) { -#if defined(__AVX512__) - auto avx_beta1 = _mm512_set1_ps(beta1); - auto avx_beta2 = _mm512_set1_ps(beta2); - auto avx_beta1_1 = _mm512_set1_ps(1 - beta1); - auto avx_beta2_1 = _mm512_set1_ps(1 - beta2); - auto avx_eps = _mm512_set1_ps(eps); - auto avx_neg_lr = _mm512_set1_ps(-lr); - auto avx_scale = _mm512_set1_ps(scale); - auto avx_weight_decay = _mm512_set1_ps(weight_decay); - auto avx_bias_correction1 = _mm512_set1_ps(bias_correction1); - auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2); - int64_t span = 16; -#elif defined(__AVX256__) - auto avx_beta1 = _mm256_set1_ps(beta1); - auto avx_beta2 = _mm256_set1_ps(beta2); - auto avx_beta1_1 = _mm256_set1_ps(1 - beta1); - auto avx_beta2_1 = _mm256_set1_ps(1 - beta2); - auto avx_eps = _mm256_set1_ps(eps); - auto avx_neg_lr = _mm256_set1_ps(-lr); - auto avx_scale = _mm256_set1_ps(scale); - auto avx_weight_decay = _mm256_set1_ps(weight_decay); - auto avx_bias_correction1 = _mm256_set1_ps(bias_correction1); - auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2); - int64_t span = 8; -#else - int64_t span = 1; -#endif - - at::parallel_for(0, n, 0, [&](int64_t start, int64_t end) { - for (int64_t j = start; j < end; j += span) { -#if defined(__AVX256__) or defined(__AVX512__) - if (j + span > end) { -#else - if (true) { -#endif - // No AVX or n is not alinged - for (int64_t i = j; i < end; i++) { - float g = c10::detail::fp16_ieee_to_fp32_value(g_fp16[i].x) / scale; - float m = m_fp32[i]; - float v = v_fp32[i]; - float p = param_fp32[i]; - m = beta1 * m + (1 - beta1) * g; - v = beta2 * v + (1 - beta2) * g * g; - p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; - param_fp32[i] = p; - param_fp16[i] = at::Half(p); - m_fp32[i] = m; - v_fp32[i] = v; - } - break; // must break here - } else { - // use AVX here -#if defined(__AVX512__) - auto g = _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)&g_fp16[j])), avx_scale); - auto m = _mm512_loadu_ps(&m_fp32[j]); - auto v = _mm512_loadu_ps(&v_fp32[j]); - auto p = _mm512_loadu_ps(¶m_fp32[j]); - m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g)); - v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g))); - p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p - p = _mm512_fmadd_ps( - avx_neg_lr, - _mm512_div_ps( - _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 - _mm512_add_ps( - _mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), - avx_eps - ) // sqrt(v / bias_correction2) + eps - ), - p - ); // p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - _mm512_storeu_ps(¶m_fp32[j], p); - _mm256_storeu_si256((__m256i*)¶m_fp16[j], _mm512_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - _mm512_storeu_ps(&m_fp32[j], m); - _mm512_storeu_ps(&v_fp32[j], v); -#elif defined(__AVX256__) - auto g = _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)&g_fp16[j])), avx_scale); - auto m = _mm256_loadu_ps(&m_fp32[j]); - auto v = _mm256_loadu_ps(&v_fp32[j]); - auto p = _mm256_loadu_ps(¶m_fp32[j]); - m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g)); - v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g))); - p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p - p = _mm256_fmadd_ps( - avx_neg_lr, - _mm256_div_ps( - _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 - _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps - ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) - p - ); // p = p - lr * m / bias_correction1 / (sqrt(v / bias_correction2) + eps) - _mm256_storeu_ps(¶m_fp32[j], p); - _mm_storeu_si128((__m128i*)¶m_fp16[j], _mm256_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); - _mm256_storeu_ps(&m_fp32[j], m); - _mm256_storeu_ps(&v_fp32[j], v); -#endif - } - } - - }); -} - -void F_adam_cpu( - const torch::Tensor ¶m_fp32, - const torch::Tensor ¶m_fp16, - const torch::Tensor &g_fp16, - const torch::Tensor &m_fp32, - const torch::Tensor &v_fp32, - float beta1, float beta2, - float eps, float lr, - float scale, - float weight_decay, - int64_t step -) { - CHECK_CONTIGUOUS(param_fp32); - CHECK_CONTIGUOUS(param_fp16); - CHECK_CONTIGUOUS(g_fp16); - CHECK_CONTIGUOUS(m_fp32); - CHECK_CONTIGUOUS(v_fp32); - AT_ASSERTM(param_fp32.dtype() == torch::kFloat, "param_fp32 must be a float tensor"); - AT_ASSERTM(param_fp16.dtype() == torch::kHalf, "param_fp16 must be a half tensor"); - AT_ASSERTM(g_fp16.dtype() == torch::kHalf, "g_fp16 must be a half tensor"); - AT_ASSERTM(m_fp32.dtype() == torch::kFloat, "m_fp32 must be a float tensor"); - AT_ASSERTM(v_fp32.dtype() == torch::kFloat, "v_fp32 must be a float tensor"); - AT_ASSERTM(param_fp32.is_cpu(), "param_fp32 must be a cpu tensor"); - AT_ASSERTM(param_fp16.is_cpu(), "param_fp16 must be a cpu tensor"); - AT_ASSERTM(g_fp16.is_cpu(), "g_fp16 must be a cpu tensor"); - AT_ASSERTM(m_fp32.is_cpu(), "m_fp32 must be a cpu tensor"); - AT_ASSERTM(v_fp32.is_cpu(), "v_fp32 must be a cpu tensor"); - AT_ASSERTM(param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == m_fp32.numel(), "param_fp32 and m_fp32 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"); - - float bias_correction1 = 1 - powf(beta1, step); - float bias_correction2 = 1 - powf(beta2, step); - - adam_cpu_launcher( - param_fp32.numel(), - param_fp32.data_ptr(), - param_fp16.data_ptr(), - g_fp16.data_ptr(), - m_fp32.data_ptr(), - v_fp32.data_ptr(), - beta1, beta2, - eps, lr, - scale, - weight_decay, - bias_correction1, - bias_correction2 - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("f_adam_cpu", &F_adam_cpu, "adam function cpu"); -} diff --git a/csrc/adam_cuda.cpp b/csrc/adam_cuda.cpp deleted file mode 100644 index 2c102d94..00000000 --- a/csrc/adam_cuda.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include -#include - -void adam_launcher(const torch::Tensor ¶m_fp32, const torch::Tensor ¶m_fp16, const torch::Tensor &g_fp16, const torch::Tensor &m_fp16, const torch::Tensor &v_fp32, float beta1, float beta2, float eps, float lr, float scale, float weight_decay, float bias_correction1, float bias_correction2); -void has_nan_inf_launcher(const torch::Tensor &g_fp16, torch::Tensor mid, torch::Tensor out); - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) - -void F_adam( - const torch::Tensor ¶m_fp32, - const torch::Tensor ¶m_fp16, - const torch::Tensor &g_fp16, - const torch::Tensor &m_fp16, - const torch::Tensor &v_fp32, - float beta1, float beta2, - float eps, float lr, - float scale, - float weight_decay, - int64_t step -) { - CHECK_INPUT(param_fp32); - CHECK_INPUT(param_fp16); - CHECK_INPUT(g_fp16); - CHECK_INPUT(m_fp16); - CHECK_INPUT(v_fp32); - AT_ASSERTM(param_fp32.dtype() == torch::kFloat, "param_fp32 must be a float tensor"); - AT_ASSERTM(param_fp16.dtype() == torch::kHalf, "param_fp16 must be a half tensor"); - AT_ASSERTM(g_fp16.dtype() == torch::kHalf, "g_fp16 must be a half tensor"); - AT_ASSERTM(m_fp16.dtype() == torch::kHalf, "m_fp16 must be a half tensor"); - AT_ASSERTM(v_fp32.dtype() == torch::kFloat, "v_fp32 must be a float tensor"); - AT_ASSERTM(param_fp32.numel() == param_fp16.numel(), "param_fp32 and param_fp16 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == g_fp16.numel(), "param_fp32 and g_fp16 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == m_fp16.numel(), "param_fp32 and m_fp16 must have the same number of elements"); - AT_ASSERTM(param_fp32.numel() == v_fp32.numel(), "param_fp32 and v_fp32 must have the same number of elements"); - - float bias_correction1 = 1 - powf(beta1, step); - float bias_correction2 = 1 - powf(beta2, step); - - adam_launcher(param_fp32, param_fp16, g_fp16, m_fp16, v_fp32, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); -} - -void F_has_inf_nan(const torch::Tensor &g_fp16, torch::Tensor &out) { - CHECK_INPUT(g_fp16); - CHECK_INPUT(out); - AT_ASSERTM(g_fp16.dtype() == torch::kHalf, "g_fp16 must be a half tensor"); - AT_ASSERTM(out.dtype() == torch::kUInt8, "out must be a uint8 tensor"); - - torch::Tensor mid = out.new_zeros({1024}); - - has_nan_inf_launcher(g_fp16, mid, out); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("f_adam", &F_adam, "adam function"); - m.def("f_has_inf_nan", &F_has_inf_nan, "has inf or nan"); -} diff --git a/csrc/bind.cpp b/csrc/bind.cpp new file mode 100644 index 00000000..8324ba52 --- /dev/null +++ b/csrc/bind.cpp @@ -0,0 +1,25 @@ +#include "include/bind.hpp" + +PYBIND11_MODULE(C, m) { + m.def("has_nan_inf_launcher",&has_nan_inf_launcher,"has nan inf"); + m.def("adam_launcher", &adam_launcher, "adam function cpu"); + m.def("adam_cpu_launcher", &adam_cpu_launcher, "adam function cpu"); + m.def("cross_entropy_forward_launcher", &cross_entropy_forward_launcher, "cross entropy forward"); + m.def("cross_entropy_backward_launcher", &cross_entropy_backward_launcher, "cross entropy backward"); + m.def("cross_entropy_forward_inplace_launcher", &cross_entropy_forward_inplace_launcher, "cross entropy forward inplace"); + m.def("cross_entropy_backward_inplace_launcher", &cross_entropy_backward_inplace_launcher, "cross entropy backward inplace"); + m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); + m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); + m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); + m.def("ncclAllGather", &pyNCCLAllGather, "nccl all gather"); + m.def("ncclAllReduce", &pyNCCLAllReduce, "nccl all reduce"); + m.def("ncclBroadcast", &pyNCCLBroadcast, "nccl broadcast"); + m.def("ncclReduce", &pyNCCLReduce, "nccl reduce"); + m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter"); + m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start"); + m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end"); + m.def("ncclSend",&pyNCCLSend,"nccl send"); + m.def("ncclRecv",&pyNCCLRecv,"nccl recv"); + m.def("ncclCommCount",&pyNCCLCommCount,"nccl comm count"); + m.def("ncclCommUserRank",&pyNCCLCommUserRank,"nccl comm user rank"); +} diff --git a/csrc/cross_entropy_loss.cpp b/csrc/cross_entropy_loss.cpp deleted file mode 100644 index 40adef36..00000000 --- a/csrc/cross_entropy_loss.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include -#include - -void cross_entropy_forward_launcher(int32_t m, int32_t n, const torch::Tensor &input, const torch::Tensor &target, torch::Tensor &softmax, torch::Tensor &output, int32_t ignore_index); -void cross_entropy_backward_launcher(int32_t m, int32_t n, const torch::Tensor &grad_output, const torch::Tensor &target, const torch::Tensor &softmax, torch::Tensor &grad_input, int32_t ignore_index); -void cross_entropy_forward_inplace_launcher(int32_t m, int32_t n, torch::Tensor &x, const torch::Tensor &target, torch::Tensor &output, int32_t ignore_index); -void cross_entropy_backward_inplace_launcher(int32_t m, int32_t n, const torch::Tensor &grad_output, const torch::Tensor &target, torch::Tensor &x, int32_t ignore_index); - -#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") -#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); - -void F_cross_entropy_forward( - int32_t m, int32_t n, - const torch::Tensor &input, // (m, n) - const torch::Tensor &target, // (m) - torch::Tensor &softmax, // (m, n) - torch::Tensor &output, // (m) - int32_t ignore_index -) { - CHECK_INPUT(input); - CHECK_INPUT(target); - CHECK_INPUT(softmax); - CHECK_INPUT(output); - AT_ASSERTM(input.dtype() == torch::kHalf, "input must be a half tensor"); - AT_ASSERTM(target.dtype() == torch::kInt, "target must be a int tensor"); - AT_ASSERTM(softmax.dtype() == torch::kHalf, "softmax must be a half tensor"); - AT_ASSERTM(output.dtype() == torch::kFloat, "output must be a float tensor"); - AT_ASSERTM(input.numel() == softmax.numel(), "input and softmax must have the same number of elements"); - AT_ASSERTM(target.numel() == output.numel(), "target and output must have the same number of elements"); - - cross_entropy_forward_launcher(m, n, input, target, softmax, output, ignore_index); -} - -void F_cross_entropy_backward( - int32_t m, int32_t n, - const torch::Tensor &grad_output, // (m) - const torch::Tensor &target, // (m) - const torch::Tensor &softmax, // (m, n) - torch::Tensor &grad_input, // (m, n) - int32_t ignore_index -) { - CHECK_INPUT(grad_output); - CHECK_INPUT(target); - CHECK_INPUT(softmax); - CHECK_INPUT(grad_input); - AT_ASSERTM(grad_output.dtype() == torch::kFloat, "grad_output must be a float tensor"); - AT_ASSERTM(target.dtype() == torch::kInt, "target must be a int tensor"); - AT_ASSERTM(softmax.dtype() == torch::kHalf, "softmax must be a half tensor"); - AT_ASSERTM(grad_input.dtype() == torch::kHalf, "grad_input must be a half tensor"); - AT_ASSERTM(grad_input.numel() == softmax.numel(), "grad_input and softmax must have the same number of elements"); - AT_ASSERTM(target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"); - - cross_entropy_backward_launcher(m, n, grad_output, target, softmax, grad_input, ignore_index); -} - -void F_cross_entropy_forward_inplace( - int32_t m, int32_t n, - torch::Tensor &x, // (m, n) - const torch::Tensor &target, // (m) - torch::Tensor &output, // (m) - int32_t ignore_index -) { - CHECK_INPUT(x); - CHECK_INPUT(target); - CHECK_INPUT(output); - AT_ASSERTM(x.dtype() == torch::kHalf, "x must be a half tensor"); - AT_ASSERTM(target.dtype() == torch::kInt, "target must be a int tensor"); - AT_ASSERTM(output.dtype() == torch::kFloat, "output must be a float tensor"); - AT_ASSERTM(target.numel() == output.numel(), "target and output must have the same number of elements"); - - cross_entropy_forward_inplace_launcher(m, n, x, target, output, ignore_index); -} - -void F_cross_entropy_backward_inplace( - int32_t m, int32_t n, - const torch::Tensor &grad_output, // (m) - const torch::Tensor &target, // (m) - torch::Tensor &x, // (m, n) - int32_t ignore_index -) { - CHECK_INPUT(grad_output); - CHECK_INPUT(target); - CHECK_INPUT(x); - AT_ASSERTM(grad_output.dtype() == torch::kFloat, "grad_output must be a float tensor"); - AT_ASSERTM(target.dtype() == torch::kInt, "target must be a int tensor"); - AT_ASSERTM(x.dtype() == torch::kHalf, "x must be a half tensor"); - AT_ASSERTM(target.numel() == grad_output.numel(), "target and grad_output must have the same number of elements"); - - cross_entropy_backward_inplace_launcher(m, n, grad_output, target, x, ignore_index); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("f_cross_entropy_forward", &F_cross_entropy_forward, "cross entropy forward"); - m.def("f_cross_entropy_backward", &F_cross_entropy_backward, "cross entropy backward"); - m.def("f_cross_entropy_forward_inplace", &F_cross_entropy_forward_inplace, "cross entropy forward inplace"); - m.def("f_cross_entropy_backward_inplace", &F_cross_entropy_backward_inplace, "cross entropy backward inplace"); -} diff --git a/csrc/cuda/adam.cu b/csrc/cuda/adam_cuda.cu similarity index 62% rename from csrc/cuda/adam.cu rename to csrc/cuda/adam_cuda.cu index c7bff5ca..0ab55934 100644 --- a/csrc/cuda/adam.cu +++ b/csrc/cuda/adam_cuda.cu @@ -1,6 +1,5 @@ #include -#include -#include +#include namespace { // blocks , threads @@ -37,28 +36,28 @@ __global__ void adam_fp32_accum( } void adam_launcher( - const torch::Tensor ¶m_fp32, - const torch::Tensor ¶m_fp16, - const torch::Tensor &g_fp16, - const torch::Tensor &m_fp16, - const torch::Tensor &v_fp32, + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t g_fp16, + std::uintptr_t m_fp16, + std::uintptr_t v_fp32, float beta1, float beta2, float eps, float lr, float scale, float weight_decay, float bias_correction1, - float bias_correction2 + float bias_correction2, + uintptr_t stream ) { - int32_t n = param_fp32.numel(); if (n <= 0) return; - auto g_ptr = reinterpret_cast(g_fp16.data_ptr()); - auto m_ptr = reinterpret_cast(m_fp16.data_ptr()); - auto v_ptr = v_fp32.data_ptr(); - auto param_ptr = param_fp32.data_ptr(); - auto param_h_ptr = reinterpret_cast(param_fp16.data_ptr()); + auto g_ptr = reinterpret_cast(g_fp16); + auto m_ptr = reinterpret_cast(m_fp16); + auto param_h_ptr = reinterpret_cast(param_fp16); + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); int32_t threads = 1024; dim3 block_size = dim3(threads, 1, 1); dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); - auto stream = at::cuda::getCurrentCUDAStream(); - adam_fp32_accum<<>>(n, g_ptr, m_ptr, v_ptr, param_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + adam_fp32_accum<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); } \ No newline at end of file diff --git a/csrc/cuda/cross_entropy.cu b/csrc/cuda/cross_entropy.cu index 574ebde0..c0b742ac 100644 --- a/csrc/cuda/cross_entropy.cu +++ b/csrc/cuda/cross_entropy.cu @@ -1,7 +1,7 @@ #include -#include -#include #include "reduce.cuh" +#include +#include namespace { // blocks , threads<1024> @@ -134,64 +134,74 @@ __global__ void cross_entropy_backward_inplace( void cross_entropy_forward_launcher( int32_t m, int32_t n, - const torch::Tensor &input, - const torch::Tensor &target, - torch::Tensor &softmax, - torch::Tensor &output, - int32_t ignore_index + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream ) { - auto input_ptr = reinterpret_cast(input.data_ptr()); - auto target_ptr = target.data_ptr(); - auto softmax_ptr = reinterpret_cast(softmax.data_ptr()); - auto output_ptr = output.data_ptr(); + auto input_ptr = reinterpret_cast(input); + auto target_ptr = reinterpret_cast(target); + auto softmax_ptr = reinterpret_cast(softmax); + auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_forward<<>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); + cross_entropy_forward<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); } void cross_entropy_backward_launcher( int32_t m, int32_t n, - const torch::Tensor &grad_output, - const torch::Tensor &target, - const torch::Tensor &softmax, - torch::Tensor &grad_input, - int32_t ignore_index + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t grad_input, + int32_t ignore_index, + std::uintptr_t stream ) { - auto output_ptr = grad_output.data_ptr(); - auto target_ptr = target.data_ptr(); - auto softmax_ptr = reinterpret_cast(softmax.data_ptr()); - auto input_ptr = reinterpret_cast(grad_input.data_ptr()); + // auto output_ptr = grad_output.data_ptr(); + auto output_ptr = reinterpret_cast(grad_output); + // auto target_ptr = target.data_ptr(); + auto target_ptr = reinterpret_cast(target); + auto softmax_ptr = reinterpret_cast(softmax); + auto input_ptr = reinterpret_cast(grad_input); int32_t threads = 1024; - auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_backward<<>>(n, output_ptr, target_ptr, softmax_ptr, input_ptr, ignore_index); + cross_entropy_backward<<(stream)>>>(n, output_ptr, target_ptr, softmax_ptr, input_ptr, ignore_index); } void cross_entropy_forward_inplace_launcher( int32_t m, int32_t n, - torch::Tensor &x, - const torch::Tensor &target, - torch::Tensor &output, - int32_t ignore_index + std::uintptr_t x, + std::uintptr_t target, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream ) { - auto x_ptr = reinterpret_cast(x.data_ptr()); - auto target_ptr = target.data_ptr(); - auto output_ptr = output.data_ptr(); + // auto x_ptr = reinterpret_cast(x.data_ptr()); + auto x_ptr = reinterpret_cast(x); + // auto target_ptr = target.data_ptr(); + auto target_ptr = reinterpret_cast(target); + // auto output_ptr = output.data_ptr(); + auto output_ptr = reinterpret_cast(output); int32_t threads = 1024; - auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_forward_inplace<<>>(n, x_ptr, target_ptr, output_ptr, ignore_index); + // auto stream = at::cuda::getCurrentCUDAStream(); + cross_entropy_forward_inplace<<(stream)>>>(n, x_ptr, target_ptr, output_ptr, ignore_index); } void cross_entropy_backward_inplace_launcher( int32_t m, int32_t n, - const torch::Tensor &grad_output, - const torch::Tensor &target, - torch::Tensor &x, - int32_t ignore_index + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream ) { - auto output_ptr = grad_output.data_ptr(); - auto target_ptr = target.data_ptr(); - auto x_ptr = reinterpret_cast(x.data_ptr()); + // auto output_ptr = grad_output.data_ptr(); + auto output_ptr = reinterpret_cast(grad_output); + // auto target_ptr = target.data_ptr(); + auto target_ptr = reinterpret_cast(target); + // auto x_ptr = reinterpret_cast(x.data_ptr()); + auto x_ptr = reinterpret_cast(x); int32_t threads = 1024; - auto stream = at::cuda::getCurrentCUDAStream(); - cross_entropy_backward_inplace<<>>(n, output_ptr, target_ptr, x_ptr, ignore_index); + // auto stream = at::cuda::getCurrentCUDAStream(); + cross_entropy_backward_inplace<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); } \ No newline at end of file diff --git a/csrc/cuda/has_inf_nan.cu b/csrc/cuda/has_inf_nan.cu index c8bcca1a..b0e906ff 100644 --- a/csrc/cuda/has_inf_nan.cu +++ b/csrc/cuda/has_inf_nan.cu @@ -1,6 +1,5 @@ #include -#include -#include +#include namespace{ __inline__ __device__ bool isnan_(half v) { @@ -68,21 +67,21 @@ __global__ void bmt_has_nan_inf_2( } void has_nan_inf_launcher( - const torch::Tensor &g_fp16, - torch::Tensor mid, - torch::Tensor out + int32_t n, + std::uintptr_t g_fp16, + std::uintptr_t mid, + std::uintptr_t out, + std::uintptr_t stream ) { - int n = g_fp16.numel(); if (n <= 0) return; - auto g_ptr = reinterpret_cast(g_fp16.data_ptr()); - auto mid_ptr = mid.data_ptr(); - auto stream = at::cuda::getCurrentCUDAStream(); - + auto g_ptr = reinterpret_cast(g_fp16); + auto mid_ptr = reinterpret_cast(mid); + auto out_ptr = reinterpret_cast(out); int32_t threads = 1024; dim3 block_size = dim3(threads, 1, 1); dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); - bmt_has_nan_inf_1<<>>(n, g_ptr, mid_ptr); - bmt_has_nan_inf_2<<<1, block_size, 0, stream.stream()>>>(mid_ptr, out.data_ptr()); + bmt_has_nan_inf_1<<(stream)>>>(n, g_ptr, mid_ptr); + bmt_has_nan_inf_2<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr); } \ No newline at end of file diff --git a/csrc/cuda/reduce.cuh b/csrc/cuda/reduce.cuh index 8179443f..095e8593 100644 --- a/csrc/cuda/reduce.cuh +++ b/csrc/cuda/reduce.cuh @@ -1,6 +1,4 @@ #include -#include -#include namespace { const int WARP_SZ = 32; diff --git a/csrc/include/adam_cpu.hpp b/csrc/include/adam_cpu.hpp new file mode 100644 index 00000000..b070c7a0 --- /dev/null +++ b/csrc/include/adam_cpu.hpp @@ -0,0 +1,332 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "cpu_info.h" +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") + + +inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +} + +inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +} + +template +inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F& f) { + // Number of iterations + int64_t numiter = end - begin; + + // Number of threads to use + int64_t num_threads = 1; // Default to serial execution + + if (grain_size > 0) { + num_threads = std::max(numiter / grain_size, static_cast(1)); + } + else{ + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); + sched_getaffinity(0, sizeof(cpu_set), &cpu_set); + num_threads = CPU_COUNT(&cpu_set); + grain_size = std::max(numiter / num_threads, static_cast(1)); + + } + + // Check if parallel execution is feasible + if (num_threads > 1) { + py::gil_scoped_release release; // Release the GIL + std::vector threads(num_threads); + for (int64_t t = 0; t < num_threads; ++t) { + threads[t] = std::thread([&, t]() { + int64_t left = begin + t * grain_size; + int64_t right = std::min(begin + (t + 1) * grain_size, end); + f(left, right); + }); + } + for (auto& thread : threads) { + thread.join(); + } + } else { + // If not feasible or grain_size is 0, perform the operation serially + f(begin, end); + } +} + + + +inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + uint32_t scale_to_inf_bits = (uint32_t) 239 << 23; + uint32_t scale_to_zero_bits = (uint32_t) 17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = (uint32_t)fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = (uint32_t)fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); + } + +inline float fp16_ieee_to_fp32_value(uint16_t h) { + + const uint32_t w = (uint32_t)h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + const float exp_scale = 0x1.0p-112f; + const float normalized_value = + fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = + fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +void adam_cpu_0( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + int64_t span = 1; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } + }); +} + +static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + auto avx_beta1 = _mm256_set1_ps(beta1); + auto avx_beta2 = _mm256_set1_ps(beta2); + auto avx_beta1_1 = _mm256_set1_ps(1 - beta1); + auto avx_beta2_1 = _mm256_set1_ps(1 - beta2); + auto avx_eps = _mm256_set1_ps(eps); + auto avx_neg_lr = _mm256_set1_ps(-lr); + auto avx_scale = _mm256_set1_ps(scale); + auto avx_weight_decay = _mm256_set1_ps(weight_decay); + auto avx_bias_correction1 = _mm256_set1_ps(bias_correction1); + auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2); + int64_t span = 8; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + if (j + span > end) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } else { + auto g = _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)&g_fp16_ptr[j])), avx_scale); + auto m = _mm256_loadu_ps(&m_fp32_ptr[j]); + auto v = _mm256_loadu_ps(&v_fp32_ptr[j]); + auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]); + m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g)); + v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g))); + p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p + p = _mm256_fmadd_ps( + avx_neg_lr, + _mm256_div_ps( + _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + p + ); // p = p - lr * m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm256_storeu_ps(¶m_fp32_ptr[j], p); + _mm_storeu_si128((__m128i*)¶m_fp16_ptr[j], _mm256_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ps(&m_fp32_ptr[j], m); + _mm256_storeu_ps(&v_fp32_ptr[j], v); + } + }}); +} + +static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + auto avx_beta1 = _mm512_set1_ps(beta1); + auto avx_beta2 = _mm512_set1_ps(beta2); + auto avx_beta1_1 = _mm512_set1_ps(1 - beta1); + auto avx_beta2_1 = _mm512_set1_ps(1 - beta2); + auto avx_eps = _mm512_set1_ps(eps); + auto avx_neg_lr = _mm512_set1_ps(-lr); + auto avx_scale = _mm512_set1_ps(scale); + auto avx_weight_decay = _mm512_set1_ps(weight_decay); + auto avx_bias_correction1 = _mm512_set1_ps(bias_correction1); + auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2); + int64_t span = 16; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + if (j + span > end) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + }else{ + auto g = _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)&g_fp16_ptr[j])), avx_scale); + auto m = _mm512_loadu_ps(&m_fp32_ptr[j]); + auto v = _mm512_loadu_ps(&v_fp32_ptr[j]); + auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]); + m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g)); + v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g))); + p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p + p = _mm512_fmadd_ps( + avx_neg_lr, + _mm512_div_ps( + _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm512_add_ps( + _mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), + avx_eps + ) // sqrt(v / bias_correction2) + eps + ), + p + ); // p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + _mm512_storeu_ps(¶m_fp32_ptr[j], p); + _mm256_storeu_si256((__m256i*)¶m_fp16_ptr[j], _mm512_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm512_storeu_ps(&m_fp32_ptr[j], m); + _mm512_storeu_ps(&v_fp32_ptr[j], v); + } + } + }); +} + + + + +void adam_cpu_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t g_fp16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + + auto param_fp32_ptr = reinterpret_cast(param_fp32); + auto m_fp32_ptr = reinterpret_cast(m_fp32); + auto v_fp32_ptr = reinterpret_cast(v_fp32); + auto param_fp16_ptr = reinterpret_cast(param_fp16); + auto g_fp16_ptr = reinterpret_cast(g_fp16); + int cpu_level = get_cpu_level(); + if (cpu_level == 0 ){ + adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + }else if(cpu_level == 1){ + adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + }else{ + adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + } +} + + diff --git a/csrc/include/bind.hpp b/csrc/include/bind.hpp new file mode 100644 index 00000000..0929de91 --- /dev/null +++ b/csrc/include/bind.hpp @@ -0,0 +1,55 @@ +#include +#include "nccl.hpp" +#include "adam_cpu.hpp" + +void has_nan_inf_launcher(int32_t n,std::uintptr_t g_fp16,std::uintptr_t mid,std::uintptr_t out,std::uintptr_t stream); + +void cross_entropy_backward_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t grad_input, + int32_t ignore_index, + std::uintptr_t stream +); +void cross_entropy_backward_inplace_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream +); + void cross_entropy_forward_inplace_launcher( + int32_t m, int32_t n, + std::uintptr_t x, + std::uintptr_t target, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +); +void cross_entropy_forward_launcher( + int32_t m, int32_t n, + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +); +void adam_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t g_fp16, + std::uintptr_t m_fp16, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +); \ No newline at end of file diff --git a/csrc/include/cpu_info.h b/csrc/include/cpu_info.h new file mode 100644 index 00000000..53ed48f8 --- /dev/null +++ b/csrc/include/cpu_info.h @@ -0,0 +1,38 @@ +#include + +static void cpuid(int info[4], int InfoType){ + __cpuid_count(InfoType, 0, info[0], info[1], info[2], info[3]); +} + +int get_cpu_level() { + // SIMD: 128-bit + bool HW_F16C; + + // SIMD: 256-bit + bool HW_AVX; + bool HW_FMA; + + // SIMD: 512-bit + bool HW_AVX512F; // AVX512 Foundation + + int info[4]; + cpuid(info, 0); + int nIds = info[0]; + + // Detect Features + if (nIds >= 0x00000001){ + cpuid(info,0x00000001); + HW_AVX = (info[2] & ((int)1 << 28)) != 0; + HW_FMA = (info[2] & ((int)1 << 12)) != 0; + HW_F16C = (info[2] & ((int)1 << 29)) != 0; + } + if (nIds >= 0x00000007){ + cpuid(info,0x00000007); + HW_AVX512F = (info[1] & ((int)1 << 16)) != 0; + } + + int ret = 0; + if (HW_AVX && HW_FMA && HW_F16C) ret = 1; + if (HW_AVX512F) ret = 2; + return ret; +} diff --git a/csrc/include/nccl.h b/csrc/include/nccl.h deleted file mode 100644 index 4bdda76a..00000000 --- a/csrc/include/nccl.h +++ /dev/null @@ -1,359 +0,0 @@ -/************************************************************************* - * Copyright (c) 2015-2021, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef NCCL_H_ -#define NCCL_H_ - -#include -#include -#if CUDART_VERSION >= 11000 -#include -#endif - -#define NCCL_MAJOR 2 -#define NCCL_MINOR 11 -#define NCCL_PATCH 4 -#define NCCL_SUFFIX "" - -#define NCCL_VERSION_CODE 21104 -#define NCCL_VERSION(X,Y,Z) (((X) <= 2 && (Y) <= 8) ? (X) * 1000 + (Y) * 100 + (Z) : (X) * 10000 + (Y) * 100 + (Z)) - -#ifdef __cplusplus -extern "C" { -#endif - -/* Opaque handle to communicator */ -typedef struct ncclComm* ncclComm_t; - -#define NCCL_UNIQUE_ID_BYTES 128 -typedef struct { char internal[NCCL_UNIQUE_ID_BYTES]; } ncclUniqueId; - -/* Error type */ -typedef enum { ncclSuccess = 0, - ncclUnhandledCudaError = 1, - ncclSystemError = 2, - ncclInternalError = 3, - ncclInvalidArgument = 4, - ncclInvalidUsage = 5, - ncclNumResults = 6 } ncclResult_t; - -/* Return the NCCL_VERSION_CODE of the NCCL library in the supplied integer. - * This integer is coded with the MAJOR, MINOR and PATCH level of the - * NCCL library - */ -ncclResult_t ncclGetVersion(int *version); -ncclResult_t pncclGetVersion(int *version); - -/* Generates an Id to be used in ncclCommInitRank. ncclGetUniqueId should be - * called once and the Id should be distributed to all ranks in the - * communicator before calling ncclCommInitRank. */ -ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId); -ncclResult_t pncclGetUniqueId(ncclUniqueId* uniqueId); - -/* Creates a new communicator (multi thread/process version). - * rank must be between 0 and nranks-1 and unique within a communicator clique. - * Each rank is associated to a CUDA device, which has to be set before calling - * ncclCommInitRank. - * ncclCommInitRank implicitly syncronizes with other ranks, so it must be - * called by different threads/processes or use ncclGroupStart/ncclGroupEnd. */ -ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); -ncclResult_t pncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank); - -/* Creates a clique of communicators (single process version). - * This is a convenience function to create a single-process communicator clique. - * Returns an array of ndev newly initialized communicators in comm. - * comm should be pre-allocated with size at least ndev*sizeof(ncclComm_t). - * If devlist is NULL, the first ndev CUDA devices are used. - * Order of devlist defines user-order of processors within the communicator. */ -ncclResult_t ncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); -ncclResult_t pncclCommInitAll(ncclComm_t* comm, int ndev, const int* devlist); - -/* Frees resources associated with communicator object, but waits for any operations - * that might still be running on the device. */ -ncclResult_t ncclCommDestroy(ncclComm_t comm); -ncclResult_t pncclCommDestroy(ncclComm_t comm); - -/* Frees resources associated with communicator object and aborts any operations - * that might still be running on the device. */ -ncclResult_t ncclCommAbort(ncclComm_t comm); -ncclResult_t pncclCommAbort(ncclComm_t comm); - -/* Returns a human-readable error message. */ -const char* ncclGetErrorString(ncclResult_t result); -const char* pncclGetErrorString(ncclResult_t result); - -/* Checks whether the comm has encountered any asynchronous errors */ -ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); -ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); - -/* Gets the number of ranks in the communicator clique. */ -ncclResult_t ncclCommCount(const ncclComm_t comm, int* count); -ncclResult_t pncclCommCount(const ncclComm_t comm, int* count); - -/* Returns the cuda device number associated with the communicator. */ -ncclResult_t ncclCommCuDevice(const ncclComm_t comm, int* device); -ncclResult_t pncclCommCuDevice(const ncclComm_t comm, int* device); - -/* Returns the user-ordered "rank" associated with the communicator. */ -ncclResult_t ncclCommUserRank(const ncclComm_t comm, int* rank); -ncclResult_t pncclCommUserRank(const ncclComm_t comm, int* rank); - -/* Reduction operation selector */ -typedef enum { ncclNumOps_dummy = 5 } ncclRedOp_dummy_t; -typedef enum { ncclSum = 0, - ncclProd = 1, - ncclMax = 2, - ncclMin = 3, - ncclAvg = 4, - /* ncclNumOps: The number of built-in ncclRedOp_t values. Also - * serves as the least possible value for dynamic ncclRedOp_t's - * as constructed by ncclRedOpCreate*** functions. */ - ncclNumOps = 5, - /* ncclMaxRedOp: The largest valid value for ncclRedOp_t. - * It is defined to be the largest signed value (since compilers - * are permitted to use signed enums) that won't grow - * sizeof(ncclRedOp_t) when compared to previous NCCL versions to - * maintain ABI compatibility. */ - ncclMaxRedOp = 0x7fffffff>>(32-8*sizeof(ncclRedOp_dummy_t)) - } ncclRedOp_t; - -/* Data types */ -typedef enum { ncclInt8 = 0, ncclChar = 0, - ncclUint8 = 1, - ncclInt32 = 2, ncclInt = 2, - ncclUint32 = 3, - ncclInt64 = 4, - ncclUint64 = 5, - ncclFloat16 = 6, ncclHalf = 6, - ncclFloat32 = 7, ncclFloat = 7, - ncclFloat64 = 8, ncclDouble = 8, -#if defined(__CUDA_BF16_TYPES_EXIST__) - ncclBfloat16 = 9, - ncclNumTypes = 10 -#else - ncclNumTypes = 9 -#endif -} ncclDataType_t; - -/* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */ -typedef enum { - /* ncclScalarDevice: The scalar is in device-visible memory and will be - * dereferenced while the collective is running. */ - ncclScalarDevice = 0, - - /* ncclScalarHostImmediate: The scalar is in host-visible memory and will be - * dereferenced before the ncclRedOpCreate***() function returns. */ - ncclScalarHostImmediate = 1 -} ncclScalarResidence_t; - -/* - * ncclRedOpCreatePreMulSum - * - * Creates a new reduction operator which pre-multiplies input values by a given - * scalar locally before reducing them with peer values via summation. For use - * only with collectives launched against *comm* and *datatype*. The - * *residence* argument indicates how/when the memory pointed to by *scalar* - * will be dereferenced. Upon return, the newly created operator's handle - * is stored in *op*. - */ -ncclResult_t ncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); -ncclResult_t pncclRedOpCreatePreMulSum(ncclRedOp_t *op, void *scalar, ncclDataType_t datatype, ncclScalarResidence_t residence, ncclComm_t comm); - -/* - * ncclRedOpDestroy - * - * Destroys the reduction operator *op*. The operator must have been created by - * ncclRedOpCreatePreMul with the matching communicator *comm*. An operator may be - * destroyed as soon as the last NCCL function which is given that operator returns. - */ -ncclResult_t ncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); -ncclResult_t pncclRedOpDestroy(ncclRedOp_t op, ncclComm_t comm); - -/* - * Collective communication operations - * - * Collective communication operations must be called separately for each - * communicator in a communicator clique. - * - * They return when operations have been enqueued on the CUDA stream. - * - * Since they may perform inter-CPU synchronization, each call has to be done - * from a different thread or process, or need to use Group Semantics (see - * below). - */ - -/* - * Reduce - * - * Reduces data arrays of length count in sendbuff into recvbuff using op - * operation. - * recvbuff may be NULL on all calls except for root device. - * root is the rank (not the CUDA device) where data will reside after the - * operation is complete. - * - * In-place operation will happen if sendbuff == recvbuff. - */ -ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, - ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream); - -/* - * (deprecated) Broadcast (in-place) - * - * Copies count values from root to all other devices. - * root is the rank (not the CUDA device) where data resides before the - * operation is started. - * - * This operation is implicitely in place. - */ -ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, - ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, - ncclComm_t comm, cudaStream_t stream); - -/* - * Broadcast - * - * Copies count values from root to all other devices. - * root is the rank (not the CUDA device) where data resides before the - * operation is started. - * - * In-place operation will happen if sendbuff == recvbuff. - */ -ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, - ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype, int root, - ncclComm_t comm, cudaStream_t stream); - -/* - * All-Reduce - * - * Reduces data arrays of length count in sendbuff using op operation, and - * leaves identical copies of result on each recvbuff. - * - * In-place operation will happen if sendbuff == recvbuff. - */ -ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, - ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, - ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream); - -/* - * Reduce-Scatter - * - * Reduces data in sendbuff using op operation and leaves reduced result - * scattered over the devices so that recvbuff on rank i will contain the i-th - * block of the result. - * Assumes sendcount is equal to nranks*recvcount, which means that sendbuff - * should have a size of at least nranks*recvcount elements. - * - * In-place operations will happen if recvbuff == sendbuff + rank * recvcount. - */ -ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff, - size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - cudaStream_t stream); -ncclResult_t pncclReduceScatter(const void* sendbuff, void* recvbuff, - size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, - cudaStream_t stream); - -/* - * All-Gather - * - * Each device gathers sendcount values from other GPUs into recvbuff, - * receiving data from rank i at offset i*sendcount. - * Assumes recvcount is equal to nranks*sendcount, which means that recvbuff - * should have a size of at least nranks*sendcount elements. - * - * In-place operations will happen if sendbuff == recvbuff + rank * sendcount. - */ -ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, - ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount, - ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream); - -/* - * Send - * - * Send data from sendbuff to rank peer. - * - * Rank peer needs to call ncclRecv with the same datatype and the same count from this - * rank. - * - * This operation is blocking for the GPU. If multiple ncclSend and ncclRecv operations - * need to progress concurrently to complete, they must be fused within a ncclGroupStart/ - * ncclGroupEnd section. - */ -ncclResult_t ncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, - ncclComm_t comm, cudaStream_t stream); -ncclResult_t pncclSend(const void* sendbuff, size_t count, ncclDataType_t datatype, int peer, - ncclComm_t comm, cudaStream_t stream); - -/* - * Receive - * - * Receive data from rank peer into recvbuff. - * - * Rank peer needs to call ncclSend with the same datatype and the same count to this - * rank. - * - * This operation is blocking for the GPU. If multiple ncclSend and ncclRecv operations - * need to progress concurrently to complete, they must be fused within a ncclGroupStart/ - * ncclGroupEnd section. - */ -ncclResult_t pncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, - ncclComm_t comm, cudaStream_t stream); -ncclResult_t ncclRecv(void* recvbuff, size_t count, ncclDataType_t datatype, int peer, - ncclComm_t comm, cudaStream_t stream); - -/* - * Group semantics - * - * When managing multiple GPUs from a single thread, and since NCCL collective - * calls may perform inter-CPU synchronization, we need to "group" calls for - * different ranks/devices into a single call. - * - * Grouping NCCL calls as being part of the same collective operation is done - * using ncclGroupStart and ncclGroupEnd. ncclGroupStart will enqueue all - * collective calls until the ncclGroupEnd call, which will wait for all calls - * to be complete. Note that for collective communication, ncclGroupEnd only - * guarantees that the operations are enqueued on the streams, not that - * the operation is effectively done. - * - * Both collective communication and ncclCommInitRank can be used in conjunction - * of ncclGroupStart/ncclGroupEnd, but not together. - * - * Group semantics also allow to fuse multiple operations on the same device - * to improve performance (for aggregated collective calls), or to permit - * concurrent progress of multiple send/receive operations. - */ - -/* - * Group Start - * - * Start a group call. All calls to NCCL until ncclGroupEnd will be fused into - * a single NCCL operation. Nothing will be started on the CUDA stream until - * ncclGroupEnd. - */ -ncclResult_t ncclGroupStart(); -ncclResult_t pncclGroupStart(); - -/* - * Group End - * - * End a group call. Start a fused NCCL operation consisting of all calls since - * ncclGroupStart. Operations on the CUDA stream depending on the NCCL operations - * need to be called after ncclGroupEnd. - */ -ncclResult_t ncclGroupEnd(); -ncclResult_t pncclGroupEnd(); - - -#ifdef __cplusplus -} // end extern "C" -#endif - -#endif // end include guard diff --git a/csrc/nccl.cpp b/csrc/include/nccl.hpp similarity index 81% rename from csrc/nccl.cpp rename to csrc/include/nccl.hpp index 2d0ac8de..bba0278b 100644 --- a/csrc/nccl.cpp +++ b/csrc/include/nccl.hpp @@ -1,7 +1,9 @@ +#include +#include +#include -#include -#include "include/nccl.h" -#include +namespace py = pybind11; +#include void checkNCCLStatus(ncclResult_t result) { if (result == ncclSuccess) return; @@ -184,19 +186,3 @@ int pyNCCLCommUserRank( checkNCCLStatus(ncclCommUserRank(reinterpret_cast(comm),&rank)); return rank; } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); - m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); - m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); - m.def("ncclAllGather", &pyNCCLAllGather, "nccl all gather"); - m.def("ncclAllReduce", &pyNCCLAllReduce, "nccl all reduce"); - m.def("ncclBroadcast", &pyNCCLBroadcast, "nccl broadcast"); - m.def("ncclReduce", &pyNCCLReduce, "nccl reduce"); - m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter"); - m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start"); - m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end"); - m.def("ncclSend",&pyNCCLSend,"nccl send"); - m.def("ncclRecv",&pyNCCLRecv,"nccl recv"); - m.def("ncclCommCount",&pyNCCLCommCount,"nccl comm count"); - m.def("ncclCommUserRank",&pyNCCLCommUserRank,"nccl comm user rank"); -} diff --git a/csrc/include/nccl_net.h b/csrc/include/nccl_net.h deleted file mode 100644 index 389c1eaa..00000000 --- a/csrc/include/nccl_net.h +++ /dev/null @@ -1,124 +0,0 @@ -/************************************************************************* - * Copyright (c) 2017-2021, NVIDIA CORPORATION. All rights reserved. - * - * See LICENSE.txt for license information - ************************************************************************/ - -#ifndef NCCL_NET_H_ -#define NCCL_NET_H_ - -#include "nccl.h" -#include - -#define NCCL_NET_HANDLE_MAXSIZE 64 - -#define NCCL_PTR_HOST 0x1 -#define NCCL_PTR_CUDA 0x2 - -// Maximum number of requests per comm object -#define NCCL_NET_MAX_REQUESTS 8 - -typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel; -typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_ALL=~0} ncclDebugLogSubSys; - -typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); - -typedef struct { - char* name; // Used mostly for logging. - char* pciPath; // Path to the PCI device in /sys. - uint64_t guid; // Unique identifier for the NIC chip. Important for - // cards with multiple PCI functions (Physical or virtual). - int ptrSupport; // NCCL_PTR_HOST or NCCL_PTR_HOST|NCCL_PTR_CUDA - int speed; // Port speed in Mbps. - int port; // Port number. - int maxComms; // Maximum number of comms we can create -}ncclNetProperties_v4_t; - -typedef ncclNetProperties_v4_t ncclNetProperties_t; - -typedef struct { - // Name of the network (mainly for logs) - const char* name; - // Initialize the network. - ncclResult_t (*init)(ncclDebugLogger_t logFunction); - // Return the number of adapters. - ncclResult_t (*devices)(int* ndev); - // Get various device properties. - ncclResult_t (*getProperties)(int dev, ncclNetProperties_v4_t* props); - // Create a receiving object and provide a handle to connect to it. The - // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged - // between ranks to create a connection. - ncclResult_t (*listen)(int dev, void* handle, void** listenComm); - // Connect to a handle and return a sending comm object for that peer. - ncclResult_t (*connect)(int dev, void* handle, void** sendComm); - // Finalize connection establishment after remote peer has called connectHandle - ncclResult_t (*accept)(void* listenComm, void** recvComm); - // Register/Deregister memory. Comm can be either a sendComm or a recvComm. - // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. - ncclResult_t (*regMr)(void* comm, void* data, int size, int type, void** mhandle); - ncclResult_t (*deregMr)(void* comm, void* mhandle); - // Asynchronous send to a peer. - // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*isend)(void* sendComm, void* data, int size, void* mhandle, void** request); - // Asynchronous recv from a peer. - // May return request == NULL if the call cannot be performed (or would block) - ncclResult_t (*irecv)(void* recvComm, void* data, int size, void* mhandle, void** request); - // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is - // visible to the GPU - ncclResult_t (*iflush)(void* recvComm, void* data, int size, void* mhandle, void** request); - // Test whether a request is complete. If size is not NULL, it returns the - // number of bytes sent/received. - ncclResult_t (*test)(void* request, int* done, int* size); - // Close and free send/recv comm objects - ncclResult_t (*closeSend)(void* sendComm); - ncclResult_t (*closeRecv)(void* recvComm); - ncclResult_t (*closeListen)(void* listenComm); -} ncclNet_v4_t; - -typedef ncclNet_v4_t ncclNet_t; - -#define NCCL_PLUGIN_SYMBOL ncclNetPlugin_v4 - -typedef struct { - // Name of the collective network (mainly for logs) - const char* name; - // Initialize the collective network. - ncclResult_t (*init)(ncclDebugLogger_t logFunction); - // Return the number of adapters capable of doing collective operations. - // If ndev returns 0, all other functions might be set to NULL. - ncclResult_t (*devices)(int* ndev); - // Get various device properties. - ncclResult_t (*getProperties)(int dev, ncclNetProperties_v4_t* props); - // Create a receiving object and provide a handle to connect to it. The - // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged - // between ranks to create connections. - ncclResult_t (*listen)(int dev, void* handle, void** listenComm); - // Create a group for collective operations. handles have been created - // using listen() above. rank indicates caller's rank in the collective network. - ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); - // Returns whether a reduction operation on a data type is supported. - // 1 for supported, 0 otherwise. - ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); - // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. - ncclResult_t (*regMr)(void* collComm, void* data, int size, int type, void** mhandle); - ncclResult_t (*deregMr)(void* collComm, void* mhandle); - // Performs an asynchronous allreduce operation on the collective group. - // May return request == NULL if the call cannot be performed (or would block). - ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, int count, - ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); - // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is - // visible to the GPU - ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); - // Test whether a request is complete. If size is not NULL, it returns the - // number of bytes sent/received. - ncclResult_t (*test)(void* request, int* done, int* size); - // Close and free collective comm objects - ncclResult_t (*closeColl)(void* collComm); - ncclResult_t (*closeListen)(void* listenComm); -} ncclCollNet_v4_t; - -typedef ncclCollNet_v4_t ncclCollNet_t; - -#define NCCL_COLLNET_PLUGIN_SYMBOL ncclCollNetPlugin_v4 - -#endif // end include guard diff --git a/setup.py b/setup.py index 7c29a10f..16eb535d 100644 --- a/setup.py +++ b/setup.py @@ -1,88 +1,116 @@ -from setuptools import setup, find_packages -import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension import os +from setuptools.command.build_ext import build_ext +from setuptools import setup, find_packages, Extension +import setuptools +import warnings +import sys +import subprocess +COMMON_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + '--expt-relaxed-constexpr' +] +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) -def get_avx_flags(): - if os.environ.get("BMT_AVX256", "").lower() in ["1", "true", "on"]: - return ["-mavx", "-mfma", "-mf16c"] - elif os.environ.get("BMT_AVX512", "").lower() in ["1", "true", "on"]: - return ["-mavx512f"] - else: - return ["-march=native"] - -def get_device_cc(): - try: - CC_SET = set() - for i in range(torch.cuda.device_count()): - CC_SET.add(torch.cuda.get_device_capability(i)) - - if len(CC_SET) == 0: - return None - - ret = "" - for it in CC_SET: - if len(ret) > 0: - ret = ret + " " - ret = ret + ("%d.%d" % it) - return ret - except RuntimeError: - return None - -avx_flag = get_avx_flags() -device_cc = get_device_cc() -if device_cc is None: - if not torch.cuda.is_available(): - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX") - else: - if torch.version.cuda.startswith("10"): - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5+PTX") + +def is_ninja_available(): + r''' + Returns ``True`` if the `ninja `_ build system is + available on the system, ``False`` otherwise. + ''' + with open(os.devnull, 'wb') as devnull: + try: + subprocess.check_call('ninja --version'.split(), stdout=devnull) + except OSError: + return False else: - if not torch.version.cuda.startswith("11.0"): - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0 8.6+PTX") - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", "6.0 6.1 7.0 7.5 8.0+PTX") -else: - os.environ["TORCH_CUDA_ARCH_LIST"] = os.environ.get("TORCH_CUDA_ARCH_LIST", device_cc) - -ext_modules = [] - -if os.environ.get("GITHUB_ACTIONS", "false") == "false": - ext_modules = [ - CUDAExtension('bmtrain.nccl._C', [ - 'csrc/nccl.cpp', - ], include_dirs=["csrc/nccl/build/include"], extra_compile_args={}), - CUDAExtension('bmtrain.optim._cuda', [ - 'csrc/adam_cuda.cpp', - 'csrc/cuda/adam.cu', - 'csrc/cuda/has_inf_nan.cu' - ], extra_compile_args={}), - CppExtension("bmtrain.optim._cpu", [ - "csrc/adam_cpu.cpp", - ], extra_compile_args=[ - '-fopenmp', - *avx_flag - ], extra_link_args=['-lgomp']), - CUDAExtension('bmtrain.loss._cuda', [ - 'csrc/cross_entropy_loss.cpp', - 'csrc/cuda/cross_entropy.cu', - ], extra_compile_args={}), - ] -else: - ext_modules = [] + return True + + +class CMakeBuild(build_ext): + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection & inclusion of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_CXX_STANDARD=14", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DPYTHON_VERSION={sys.version_info.major}.{sys.version_info.minor}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + build_args = [] + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] + + + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja # noqa: F401 + + ninja_executable_path = os.path.join(ninja.BIN_DIR, "ninja") + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = os.path.join(self.build_temp, ext.name) + if not os.path.exists(build_temp): + os.makedirs(build_temp) + + cmake_args += ["-DPython_ROOT_DIR=" + os.path.dirname(os.path.dirname(sys.executable))] + subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) + +ext_modules = [ + CMakeExtension("bmtrain.C"), +] setup( name='bmtrain', - version='0.2.2', + version='0.2.3', author="Guoyang Zeng", author_email="qbjooo@qq.com", description="A toolkit for training big models", packages=find_packages(), install_requires=[ "numpy", + "nvidia-nccl-cu11>=2.14.3" + ], + setup_requires=[ + "nvidia-nccl-cu11>=2.14.3" ], ext_modules=ext_modules, cmdclass={ - 'build_ext': BuildExtension + 'build_ext': CMakeBuild })