Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update algorithm #105

Merged
merged 4 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion csrc/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ T __device__ __forceinline__ ADD2(T a, T b) {
template <typename T>
T __device__ __forceinline__ ZERO_VALUE(T a) {
if constexpr (std::is_same<T, __bfloat16>::value) {
return __ushort_as_bfloat16((unsigned short)0x0000U);
#if defined(USE_ROCM)
return __float2bfloat16(0.0f);
#else
return __float2bfloat16_rn(0.0f);
#endif
} else if constexpr (std::is_same<T, float>::value) {
return 0.0f;
} else {
Expand Down
26 changes: 26 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
[project]
name = "VPTQ"
version = "0.0.3"
authors = [
{ name="Yang Wang", email="wyatuestc@gmai.com" },
{ name="Jicheng Wen", email="wejoincy@gmail.com"},
]

description = "VPTQ (Vector Post-Training Quantization) is a novel Post-Training Quantization method."
readme = "README.md"
requires-python = ">=3.8"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Topic :: Software Development :: Libraries",
]

[project.urls]
Homepage = "https://github.com/microsoft/VPTQ"
Issues = "https://github.com/microsoft/VPTQ/issues"

[build-system]
# Should be mirrored in requirements.txt
requires = [
Expand Down Expand Up @@ -65,3 +87,7 @@ include_trailing_comma = true
force_grid_wrap = 0
combine_as_imports = true
ensure_newline_before_comments = true

# ignore pdf files
[tool.setuptools]
packages.find.exclude = ["**/*.pdf"]
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def build_cuda_extensions():
arch_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
print(" build for compute capabilities: ==============", compute_capabilities)

# set nvcc threads
nvcc_threads = os.getenv("NVCC_THREADS") or "4"

extra_compile_args = {
"nvcc": [
"-O3",
Expand All @@ -58,6 +61,7 @@ def build_cuda_extensions():
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
f"--threads={nvcc_threads}",
] + arch_flags,
"cxx": ["-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
}
Expand Down
2 changes: 1 addition & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

__version__ = "0.0.2.post1"
__version__ = "0.0.3"
from vptq.layers import AutoModelForCausalLM as AutoModelForCausalLM
4 changes: 2 additions & 2 deletions vptq/layers/vqlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,9 +532,9 @@ def dequant(self):
indices, res_indices = self.unpack_index_tensor(
pack_tensor=self.indices,
index_bits=index_bits,
num_elements=self.in_features,
num_elements=self.group_size,
res_bits=index_res_bits,
num_res_elements=self.in_features,
num_res_elements=self.group_size,
index_dtype=torch.uint16,
)

Expand Down
Loading