Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP6 quantization end-to-end. #5234

Merged
merged 33 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a4562ab
FP6 quantization end-to-end.
JamesTheZ Nov 17, 2023
91bb4d7
Update CUDA kernels and clean codes.
JamesTheZ Jan 26, 2024
1c2131d
Make the quantizer on GPU.
JamesTheZ Jan 26, 2024
1ba45fd
[WIP] Fix the bug of FP16-to-FP6 data packing.
JamesTheZ Jan 26, 2024
ff6c3c3
Add FP6 end-to-end unit tests
arashb Jan 27, 2024
368a763
Refine the FP16-to-FP6 cast logic.
JamesTheZ Jan 29, 2024
6c45a84
Add unit tests for FP6 quantizer
arashb Jan 30, 2024
90b710d
Fix FP16-FP6 cast problems.
JamesTheZ Jan 30, 2024
f8e3acf
Update FP6 kernels.
JamesTheZ Feb 1, 2024
b025c5a
Fix the bug of subnormal FP6 casting and the 2bit/4bit tensor allocat…
JamesTheZ Feb 1, 2024
6ed67f7
Clean code.
JamesTheZ Feb 1, 2024
20b543c
pre-commit
JamesTheZ Feb 1, 2024
c43947a
Deal with the subnormal FP6 and FP16 values and refine the UT.
JamesTheZ Feb 2, 2024
a6d2f2f
Update according to review comments.
JamesTheZ Feb 5, 2024
62a2d49
Fix the CI workflow problem for FP6 end-to-end.
JamesTheZ Feb 6, 2024
118af37
Fix at::nullopt and at::optional conflicts.
JamesTheZ Feb 6, 2024
56eb8b9
Refine split-k setting.
JamesTheZ Feb 23, 2024
0ddbfd1
Remove debug files.
JamesTheZ Mar 4, 2024
35c82f2
Only compiler the kernel body for SM >= 8.0.
JamesTheZ Mar 4, 2024
63489d1
Fix the GPU architecture requirement of FP6 kernel.
JamesTheZ Mar 5, 2024
ed00ac9
Update deepspeed/inference/v2/config_v2.py
mrwyattii Mar 5, 2024
b15a1a1
Update deepspeed/inference/v2/config_v2.py
mrwyattii Mar 5, 2024
c2e6ebb
refactor fp6 tests, fix import error
mrwyattii Mar 5, 2024
fb8887c
Update deepspeed/inference/v2/modules/implementations/linear/quantize…
mrwyattii Mar 5, 2024
77f3883
Update requirements.txt
mrwyattii Mar 5, 2024
f6bcdee
revert testing to fix A6000 test
mrwyattii Mar 5, 2024
e1a4ce0
Update pydantic version
loadams Mar 6, 2024
e86611f
fix pydantic import
mrwyattii Mar 6, 2024
7e28144
Fix some review comments.
JamesTheZ Mar 6, 2024
f8454a0
Pin pydantic to latest version
loadams Mar 6, 2024
bed775e
Add the missed torch import.
JamesTheZ Mar 6, 2024
f34312a
Merge branch 'master' into features/rebase-quant-fp6
loadams Mar 7, 2024
4a91788
Merge branch 'master' into features/rebase-quant-fp6
loadams Mar 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/nv-a6000.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ jobs:
- name: Install deepspeed
run: |
python -m pip install docutils==0.18.1 jinja2==3.0 urllib3==1.26.11 ninja
python -m pip install .[dev,1bit,autotuning]
python -m pip install pydantic==1.10.11
python -m pip install .[dev,1bit,autotuning,inf]
ds_report
- name: Python environment
run: |
Expand Down
14 changes: 13 additions & 1 deletion deepspeed/inference/v2/config_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

# DeepSpeed Team

from typing import Optional
from deepspeed.pydantic_v1 import Field

from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from .ragged import DSStateManagerConfig

Expand All @@ -16,6 +16,16 @@ class DeepSpeedTPConfig(DeepSpeedConfigModel):
""" Number of devices to split the model across using tensor parallelism. """


class QuantizationConfig(DeepSpeedConfigModel):
""" Configure tensor parallelism settings """

quantization_mode: Optional[str] = None
""" The quantization mode in string format. The supported modes are as follows:
- 'wf6af16', weight-only quantization with FP6 weight and FP16 activation.
"""
# TODO: may reuse the constants in deepspeed/compression/constants.py


class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
""" Sets parameters for DeepSpeed Inference Engine. """

Expand All @@ -29,3 +39,5 @@ class RaggedInferenceEngineConfig(DeepSpeedConfigModel):
"""
Configuration for managing persistent state
"""

quantization: QuantizationConfig = {}
1 change: 1 addition & 0 deletions deepspeed/inference/v2/kernels/core_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .cuda_layer_norm import *
from .cuda_rms_norm import *
from .gated_activations import *
from .cuda_linear import *
6 changes: 6 additions & 0 deletions deepspeed/inference/v2/kernels/core_ops/core_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "bias_activation.h"
#include "blas.h"
#include "cuda_linear_kernels.h"
#include "gated_activation_kernels.h"
#include "layer_norm.h"
#include "rms_norm.h"
Expand All @@ -33,4 +34,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
// rms_norm.h
m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");

// cuda_linear_kernels.h
m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA");
m.def(
"preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from .cuda_linear import *
207 changes: 207 additions & 0 deletions deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import torch

from ....inference_utils import DtypeEnum
from ....logging import inference_logger
from deepspeed.ops.op_builder import InferenceCoreBuilder
from ... import DSKernelBase


class CUDAWf6Af16Linear(DSKernelBase):
"""
Wrapper around the CUDA kernel of Wf6Af16 quantized linear.

Performs z = x @ y
"""
supported_dtypes = [DtypeEnum.fp16]

def __init__(self):
self.inf_module = InferenceCoreBuilder().load()
self.inf_module.create_handle()
self.kernel = self.inf_module.cuda_wf6af16_linear
# The split_k_map is profiled on A100-80G GPU for some common shapes.
# It is an array of dictionaries, where the array index is the tokens chunk id.
# The dictionary is the mapping from the output channel to the split-K size.
self.split_k_map = [
{ # tokens: [1, 64]
3072: 18,
4096: 13,
5120: 10,
6144: 9,
8192: 6,
10240: 5,
14336: 7,
28672: 7,
57344: 7
},
{ # tokens: [65:128]
3072: 9,
4096: 6,
5120: 5,
6144: 9,
8192: 3,
10240: 5,
14336: 7,
28672: 7,
57344: 6
},
{ # tokens: [129:192]
3072: 6,
4096: 4,
5120: 7,
6144: 3,
8192: 2,
10240: 5,
14336: 5,
28672: 5,
57344: 4
},
{ # tokens: [193:256]
3072: 9,
4096: 3,
5120: 5,
6144: 2,
8192: 5,
10240: 4,
14336: 8,
28672: 6,
57344: 4
},
{ # tokens: [257:320]
3072: 7,
4096: 5,
5120: 2,
6144: 5,
8192: 4,
10240: 1,
14336: 3,
28672: 3,
57344: 4
},
{ # tokens: [321:384]
3072: 3,
4096: 2,
5120: 5,
6144: 3,
8192: 1,
10240: 8,
14336: 3,
28672: 4,
57344: 3
},
{ # tokens: [385:448]
3072: 5,
4096: 7,
5120: 3,
6144: 5,
8192: 7,
10240: 3,
14336: 1,
28672: 1,
57344: 3
},
{ # tokens: [449:512]
3072: 2,
4096: 5,
5120: 4,
6144: 1,
8192: 5,
10240: 2,
14336: 6,
28672: 4,
57344: 1
},
{ # tokens: [513:576]
3072: 2,
4096: 3,
5120: 1,
6144: 1,
8192: 3,
10240: 3,
14336: 3,
28672: 1,
57344: 1
},
{ # tokens: [577:640]
3072: 5,
4096: 4,
5120: 1,
6144: 4,
8192: 2,
10240: 1,
14336: 1,
28672: 1,
57344: 1
},
{ # tokens: [641:704]
3072: 3,
4096: 1,
5120: 2,
6144: 2,
8192: 1,
10240: 2,
14336: 1,
28672: 1,
57344: 1
},
{ # tokens: [705:768]
3072: 3,
4096: 1,
5120: 3,
6144: 2,
8192: 1,
10240: 1,
14336: 1,
28672: 1,
57344: 1
}
]

def __call__(self, output: torch.Tensor, hidden_states: torch.Tensor, weights_2bit: torch.Tensor,
weights_4bit: torch.Tensor, scale: torch.Tensor, out_channels, tokens, in_channels) -> torch.Tensor:
"""
Matmul kernel of FP6 weight-only quantized linear. All inputs should be contiguous.
It does not support batched-matmul.

Parameters:
output (torch.Tensor): Output tensor. Shape is of [token_number, out_features]
hidden_states (torch.Tensor): Input tensor. Shape is of [token_number, in_features]
weights_2bit (torch.Tensor): Input tensor of the 2-bit slice. Shape is of [out_features*2/8, in_features]
weights_4bit (torch.Tensor): Input tensor of the 4-bit slice. Shape is of [out_features*4/8, in_features]
scale (torch.Tensor): Input tensor. Shape is of [out_features], since the scale is per output channel
out_channels (int): The number of output channels
tokens (int): The number of tokens
in_channels (int): The number of input channels
"""

if out_channels % 256 != 0 or in_channels % 64 != 0:
raise ValueError("The out and in channel should be multiple of 256 and 64 respectively.")

# TODO: add a more general heuristic to determine the split-K.
split_k = -1 # not initialized
if tokens <= 768:
# Try to find the split-K from the pre-profiled map.
tokens_chunk_id = (tokens - 1) // 64
split_k = self.split_k_map[tokens_chunk_id].get(out_channels, -1)
if split_k == -1:
split_k = 1
inference_logger().warning(
f"The split-K setting may be suboptimal for shape {tokens}x{in_channels}x{out_channels}...")

workspace = self.get_workspace(out_channels, tokens, in_channels, split_k, torch.float, hidden_states.device)
self.kernel(output, hidden_states, weights_2bit, weights_4bit, scale, workspace, out_channels, tokens,
in_channels, split_k)

def get_workspace(self, out_channels: int, tokens: int, in_channels: int, split_k: int, dtype,
device) -> torch.Tensor:
"""
Allocate workspace for the kernel. The workspace is used to store the intermediate results of the matmul before
split-K. The split-K size is determined by the size of the matmul.
"""
workspace = torch.empty((split_k, out_channels, tokens), dtype=dtype, device=device)

return workspace
Loading
Loading