Skip to content

Commit

Permalink
Merge branch 'main' into pinbump1111
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack-Khuu authored Dec 13, 2024
2 parents dbb090f + 570aebc commit 9579f18
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 21 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/pull.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1104,3 +1104,41 @@ jobs:
echo "Generate AOTI"
python torchchat.py generate stories110M --aoti-package-path ./model.pt2 --prompt "${PRMT}"
echo "Tests complete."
test-torchao-experimental-mps:
strategy:
matrix:
runner: [macos-m1-stable]
runs-on: ${{matrix.runner}}
steps:
- name: Checkout repo
uses: actions/checkout@v3
with:
submodules: true
- name: Setup Python
uses: actions/setup-python@v2
with:
python-version: 3.10.11
- name: Print machine info
run: |
uname -a
if [ $(uname -s) == Darwin ]; then
sysctl machdep.cpu.brand_string
sysctl machdep.cpu.core_count
fi
- name: Install torchchat
run: |
echo "Intalling pip3 packages"
./install/install_requirements.sh
pip3 list
python3 -c 'import torch;print(f"torch: {torch.__version__, torch.version.git_version}")'
- name: Install torchao-ops-mps
id: install-torchao-ops-mps
run: |
bash torchchat/utils/scripts/build_torchao_ops.sh mps
- name: Run inference
run: |
python torchchat.py download stories110M
export PRMT="Once upon a time in a land far away"
echo "Generate eager"
python torchchat.py generate stories110M --temperature 0 --prompt "${PRMT}" --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 3, "groupsize": 32}}'
26 changes: 26 additions & 0 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,32 @@ Note: only the ExecuTorch C++ runner in torchchat when built using the instructi
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
```

## Experimental TorchAO MPS lowbit kernels

WARNING: These kernels only work on devices with Apple Silicon.

### Use

#### linear:afpwx
The quantization scheme linear:afpwx quantizes only the weights in a groupwise manner with a specified bitwidth and groupsize.
It takes arguments bitwidth (1, 2, 3, 4, 5, 6, 7) and groupsize (32, 64, 128, 256).

### Setup
To use linear:afpwx, you must set up the torchao mps experimental kernels. These will only work on device with Apple Silicon.
Currently, torchchat can only run them on Eager mode.

From the torchchat root directory, run
```
sh torchchat/utils/scripts/build_torchao_ops.sh mps
```

### Examples

#### Eager mode
```
python3 torchchat.py generate stories110M --device mps --dtype float32 --quantize '{"linear:afpwx": {"bitwidth": 4, "groupsize": 256}}' --prompt "Once upon a time," --num-samples 5
```

## Quantization Profiles

Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json`
Expand Down
2 changes: 1 addition & 1 deletion install/.pins/torchao-pin.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2f97b0955953fa1a46594a27f0df2bc48d93e79d
7d7c14e898eca3fe66138d2a9445755a9270b800
38 changes: 21 additions & 17 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
Expand All @@ -80,8 +80,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

Expand Down Expand Up @@ -116,15 +116,18 @@ def quantize_model(
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
set_precision(torch.float32)

# We set global precision from quantize options if it is specified at cli.py:485

if quantizer == "linear:afpwx" and device != "mps":
raise RuntimeError("linear:afpwx quantization can only run on mps device!")

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

Expand Down Expand Up @@ -915,10 +918,12 @@ def quantized_model(self) -> nn.Module:
from torchao_experimental_quant_api import (
Int8DynActIntxWeightLinearQuantizer,
IntxWeightEmbeddingQuantizer,
UIntxWeightOnlyLinearQuantizer,
)

quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightLinearQuantizer
quantizer_class_dict["embedding:wx"] = IntxWeightEmbeddingQuantizer
quantizer_class_dict["linear:afpwx"] = UIntxWeightOnlyLinearQuantizer

# Try loading custom op
try:
Expand All @@ -928,15 +933,14 @@ def quantized_model(self) -> nn.Module:
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])
except Exception as e:
print("Failed to torchao ops library with error: ", e)
print("Slow fallback kernels will be used.")
print("Unabled to load torchao cpu ops library. Slow fallback kernels will be used.")

try:
libname = "libtorchao_ops_mps_aten.dylib"
libpath = f"{torchao_build_path}/cmake-out/lib/{libname}"
torch.ops.load_library(libpath)
except Exception as e:
print("Unabled to load torchao mps ops library.")

except Exception as e:
class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler
print("Unabled to import torchao experimental quant_api with error: ", e)
7 changes: 6 additions & 1 deletion torchchat/utils/scripts/build_torchao_ops.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

device=${1:-cpu}

if [[ "$device" != "cpu" && "$device" != "mps" ]]; then
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
exit 1
fi

source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"

pushd ${TORCHCHAT_ROOT}
find_cmake_prefix_path
clone_torchao
install_torchao_aten_ops
install_torchao_aten_ops "$device"
popd
14 changes: 12 additions & 2 deletions torchchat/utils/scripts/install_utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,18 @@ clone_torchao() {
}

install_torchao_aten_ops() {
echo "Building torchao custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
local device=${1:-cpu}

if [[ "$device" == "cpu" ]]; then
echo "Building torchao custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental
elif [[ "$device" == "mps" ]]; then
echo "Building torchao mps custom ops for ATen"
pushd ${TORCHCHAT_ROOT}/torchao-build/src/ao/torchao/experimental/ops/mps
else
echo "Invalid argument: $device. Valid values are 'cpu' or 'mps'." >&2
return 1
fi

CMAKE_OUT_DIR=${TORCHCHAT_ROOT}/torchao-build/cmake-out
cmake -DCMAKE_PREFIX_PATH=${MY_CMAKE_PREFIX_PATH} \
Expand Down

0 comments on commit 9579f18

Please sign in to comment.