Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into aqt_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 12, 2024
2 parents 84b4d38 + 2ba1a61 commit babdd34
Show file tree
Hide file tree
Showing 30 changed files with 347 additions and 223 deletions.
7 changes: 4 additions & 3 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
include = [
"torchao/float8/**/*.py",
"test/dtypes/test_nf4.py",
"torchao/quantization/**/*.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"torchao/dtypes/**/*.py",
"torchao/sparsity/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/prototype/low_bit_optim/**.py",
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
* This file is generated by gen_metal_shader_lib.py
*/
#ifdef ATEN
#ifdef USE_ATEN
using namespace at::native::mps;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <fstream>
#include <sstream>

#ifdef ATEN
#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/ops/mps/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
name="torchao_mps_ops",
sources=["register.mm"],
include_dirs=[os.getenv("TORCHAO_ROOT")],
extra_compile_args=["-DATEN=1"],
extra_compile_args=["-DUSE_ATEN=1"],
),
],
cmdclass={"build_ext": BuildExtension},
Expand Down
15 changes: 9 additions & 6 deletions torchao/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .wanda import WandaSparsifier # noqa: F403
from .utils import PerChannelNormObserver # noqa: F403
from torchao.quantization.quant_api import (
int8_dynamic_activation_int8_semi_sparse_weight,
)

from .sparse_api import (
apply_fake_sparsity,
sparsify_,
semi_sparse_weight,
int8_dynamic_activation_int8_semi_sparse_weight
sparsify_,
)
from .utils import PerChannelNormObserver # noqa: F403
from .wanda import WandaSparsifier # noqa: F403

__all__ = [
"WandaSparsifier",
"PerChannelNormObserver",
"apply_fake_sparsity",
"sparsify_"
"sparsify_",
"semi_sparse_weight",
"int8_dynamic_activation_int8_semi_sparse_weight"
"int8_dynamic_activation_int8_semi_sparse_weight",
]
152 changes: 76 additions & 76 deletions torchao/sparsity/marlin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Tuple

import torch
from typing import Tuple, Dict, List

import torchao.sparsity.marlin.utils as utils
from torchao.sparsity.marlin.utils import const
from torchao.sparsity.utils import mask_creator


__all__ = [
"inject_24",
"marlin_24_workspace",
Expand All @@ -14,11 +14,13 @@
]


def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor, torch.Tensor]:
def inject_24(
w: torch.Tensor, size_k: int, size_n: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Injects 2:4 sparsity into a weight tensor. The sparsity is applied in a 2:4 ratio, where for every
group of 4 weights, 2 will be pruned based on their value. The mask will be created based on the
ranked weight values.
Args:
w (torch.Tensor): The weight tensor to inject sparsity into.
size_k (int): The number of input features.
Expand All @@ -32,33 +34,35 @@ def inject_24(w: torch.Tensor, size_k: int, size_n: int) -> Tuple[torch.Tensor,


def marlin_24_workspace(
out_features: int,
min_thread_n: int = const.MIN_THREAD_N,
max_parallel: int = const.MAX_PARALLEL
) -> torch.Tensor:
"""Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
out_features: int,
min_thread_n: int = const.MIN_THREAD_N,
max_parallel: int = const.MAX_PARALLEL,
) -> torch.Tensor:
"""Creates a workspace for marlin 2:4 quantization. The workspace is used to coordinate the locks
during the execution of the kernel.
Args:
out_features (int): The number of output features.
min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_24_MIN_THREAD_N`.
max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_24_MAX_PARALLEL`.
Returns:
torch.Tensor: The workspace tensor fully initialized with zeros.
"""
assert (out_features % min_thread_n == 0), f"out_features = {out_features}, min_thread_n = {min_thread_n}"
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
assert (
out_features % min_thread_n == 0
), f"out_features = {out_features}, min_thread_n = {min_thread_n}"
max_workspace_size = (out_features // min_thread_n) * max_parallel
return torch.zeros(max_workspace_size, dtype=torch.int, device="cuda")


def pack_to_marlin_24(
q_w_24: torch.Tensor,
scales: torch.Tensor,
num_bits: int,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_w_24: torch.Tensor,
scales: torch.Tensor,
num_bits: int,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Packs the quantized weights and scales into the marlin 2:4 format.
Args:
q_w_24 (torch.Tensor): The quantized weight tensor with 2:4 sparsity applied.
scales (torch.Tensor): The scale tensor.
Expand Down Expand Up @@ -89,13 +93,13 @@ def pack_to_marlin_24(


def unpack_from_marlin_24(
q_w_24_comp: torch.Tensor,
scales: torch.Tensor,
meta: torch.Tensor,
original_shape: torch.Size,
group_size: int,
num_bits: int
) -> Tuple[torch.Tensor, torch.Tensor]:
q_w_24_comp: torch.Tensor,
scales: torch.Tensor,
meta: torch.Tensor,
original_shape: torch.Size,
group_size: int,
num_bits: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Unpacks the quantized weights and scales from the marlin 2:4 format.
Args:
q_w_24_comp (torch.Tensor): The packed quantized weights.
Expand All @@ -109,10 +113,8 @@ def unpack_from_marlin_24(
"""
in_features, out_features = original_shape

# Unpacks the scales
unpacked_scales = _from_marlin_scale(
scales, *original_shape, group_size, num_bits
)
# Unpacks the scales
unpacked_scales = _from_marlin_scale(scales, *original_shape, group_size, num_bits)

in_features_comp = in_features // 2

Expand All @@ -130,14 +132,11 @@ def unpack_from_marlin_24(


def _compress_quantized_24_weight(
q_24: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
q_24: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compresses the quantized weights to a 2:4 sparse format. Normalizes the weights over 0
before compressing them.
Args:
q_24 (torch.Tensor): The quantized weight tensor.
size_k (int): The number of input features.
Expand Down Expand Up @@ -168,14 +167,10 @@ def _compress_quantized_24_weight(


def _decompress_quantized_24_weight(
q_24_comp: torch.Tensor,
meta: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int
) -> torch.Tensor:
q_24_comp: torch.Tensor, meta: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
"""Decompresses the quantized weights from a 2:4 sparse format and restores the original shape.
Args:
q_24_comp (torch.Tensor): The compressed quantized weight tensor in 2:4 sparse format.
meta (torch.Tensor): The meta tensor.
Expand Down Expand Up @@ -210,13 +205,13 @@ def _decompress_quantized_24_weight(


def _to_marlin_weights(
q_w: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
q_w: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int,
) -> torch.Tensor:
"""Converts a quantized and 2:4 sparse format weight tensor to the marlin 2:4 format.
Args:
q_w (torch.Tensor): The quantized weight tensor in 2:4 sparse format.
size_k (int): The number of input features.
Expand All @@ -236,7 +231,11 @@ def _to_marlin_weights(
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
# does not support rshift_cpu.
q_w = q_w.cpu().to(torch.int64)
q_packed = torch.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=torch.int64, device=q_w.device)
q_packed = torch.zeros(
(q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=torch.int64,
device=q_w.device,
)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << (num_bits * i)

Expand All @@ -245,13 +244,10 @@ def _to_marlin_weights(


def _from_marlin_weights(
q_packed: torch.Tensor,
size_k: int,
size_n: int,
num_bits: int
) -> torch.Tensor:
q_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
) -> torch.Tensor:
"""Converts a weight tensor in the marlin 2:4 format to a regular quantized 2:4 sparse format.
Args:
q_packed (torch.Tensor): The weight tensor in the marlin 2:4 format.
size_k (int): The number of input features.
Expand All @@ -269,52 +265,54 @@ def _from_marlin_weights(
# Original implementation uses numpy + uint32 but we need to use int64 because torch.uint32
# does not support rshift_cpu.
q_packed = q_packed.cpu().to(torch.int64)
q_w_unpacked = torch.zeros((q_packed.shape[0], q_packed.shape[1] * pack_factor), dtype=torch.int64, device=q_packed.device)
q_w_unpacked = torch.zeros(
(q_packed.shape[0], q_packed.shape[1] * pack_factor),
dtype=torch.int64,
device=q_packed.device,
)
for i in range(pack_factor):
q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & ((1 << num_bits) - 1)
q_w_unpacked[:, i::pack_factor] = (q_packed >> (num_bits * i)) & (
(1 << num_bits) - 1
)

q_w_unpacked = q_w_unpacked.to(orig_device, dtype=torch.int32)

q_w_comp = utils.reverse_marlin_permute_weights(q_w_unpacked, size_k, size_n, perm_24)
q_w_comp = utils.reverse_marlin_permute_weights(
q_w_unpacked, size_k, size_n, perm_24
)
return q_w_comp


def _to_marlin_scales(
scales: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
num_bits: int
) -> torch.Tensor:
scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
) -> torch.Tensor:
"""Converts a scale tensor to the format necessary for marlin.
Args:
scales (torch.Tensor): The scale tensor.
size_k (int): The number of input features.
size_n (int): The number of output features.
group_size (int): The group size that was applied during quantization.
num_bits (int): The number of bits used for quantization.
Returns:
torch.Tensor: The scale tensor in the marlin format.
"""
_, scale_perm_24, scale_perm_single_24 = utils.get_perms_24(num_bits)
if group_size < size_k and group_size != -1:
scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24]
else:
scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24]
scales = scales.reshape((-1, len(scale_perm_single_24)))[
:, scale_perm_single_24
]
scales = scales.reshape((-1, size_n)).contiguous()
return scales


def _from_marlin_scale(
scales: torch.Tensor,
size_k: int,
size_n: int,
group_size: int,
num_bits: int
) -> torch.Tensor:
scales: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
) -> torch.Tensor:
"""Converts a scale tensor from the marlin format to their original format.
Args:
scales (torch.Tensor): The scale tensor in the marlin format.
size_k (int): The number of input features.
Expand All @@ -329,5 +327,7 @@ def _from_marlin_scale(
scales = scales.reshape((-1, len(scale_perm_24)))[:, scale_perm_24]
return scales.reshape((size_k // group_size, size_n))
else:
scales = scales.reshape((-1, len(scale_perm_single_24)))[:, scale_perm_single_24]
return scales.reshape((1, -1))
scales = scales.reshape((-1, len(scale_perm_single_24)))[
:, scale_perm_single_24
]
return scales.reshape((1, -1))
Loading

0 comments on commit babdd34

Please sign in to comment.