Skip to content

Commit

Permalink
Update on "Wean off of PYBIND in favor of torch.ops.load_library"
Browse files Browse the repository at this point in the history
Pave the path for python agnostic ao by removing depending on PYBIND (which is not python agnostic).

Concretely, what happened this PR?
- no more PYBIND, so no more init.cpp
- for all non-CUDA platforms, ao no longer has custom cpp extensions, and thus ao is a pure python lib
- so we skip auditwheel (which is used for when the platform is linux) for all non-cuda wheel builds
- no more PYBIND also means no more torchao._C, which we're replacing with a load_library call

This PR should have no failures. The next PR will be targeting the wheel process to only output one wheel for every python version.




[ghstack-poisoned]
  • Loading branch information
janeyx99 committed Nov 12, 2024
2 parents 903408a + 793a953 commit 2ea7b4c
Show file tree
Hide file tree
Showing 75 changed files with 2,450 additions and 1,269 deletions.
18 changes: 9 additions & 9 deletions ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
# To add a new path: Simply add it to the 'include' list.
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
include = [
"torchao/float8/inference.py",
"torchao/float8/float8_utils.py",
"torchao/dtypes/nf4tensor.py",
"test/dtypes/test_nf4.py",
"torchao/float8/float8_tensor.py",
"torchao/quantization/linear_activation_weight_observer.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"torchao/quantization/weight_tensor_linear_activation_quantization.py",
"torchao/float8/**/*.py",
"torchao/quantization/**/*.py",
"torchao/dtypes/**/*.py",
"torchao/sparsity/**/*.py",
"torchao/prototype/low_bit_optim/**.py",
"test/quantization/test_observer.py",
"test/dtypes/test_affine_quantized_float.py",
"test/dtypes/test_nf4.py",
"test/prototype/low_bit_optim/**.py",
"torchao/utils.py",

]

lint.ignore = ["E731"]
6 changes: 2 additions & 4 deletions test/dtypes/test_nf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,10 @@ def test_load_from_state_dicts(self, dtype: torch.dtype):
assert base_mod.param.block_size == 32
assert base_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(64, device="cuda", dtype=dtype)
input_tensor = torch.rand(64, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand All @@ -184,11 +183,10 @@ def test_load_from_nf4_same_meta(self, dtype: torch.dtype):
assert other_mod.param.block_size == 32
assert other_mod.param.scaler_block_size == 2

@unittest.skipIf(not torch.cuda.is_available(), "Need cuda for test")
@parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32])
def test_load_from_nf4_diff_meta(self, dtype: torch.dtype):
"""Tests loading to and from different module state dicts"""
input_tensor = torch.rand(128, device="cuda", dtype=dtype)
input_tensor = torch.rand(128, dtype=dtype)
base_mod = self.TestMod(input_tensor, 32, 2)
state_dict = base_mod.state_dict()
saved_state_dict = self.save_state_dict_to_buffer(state_dict)
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
with pytest.raises(
RuntimeError,
match=re.escape(
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41."
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)."
),
):
a_fp8 @ b_fp8
Expand Down
50 changes: 30 additions & 20 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
int4_weight_only,
int8_weight_only,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int4_weight,
quantize_,
_replace_with_custom_fn_if_matches_filter,
)
Expand Down Expand Up @@ -137,6 +138,12 @@ def _int4wo_api(mod):
else:
change_linear_weights_to_int4_woqtensors(mod)

def _int8da_int4w_api(mod):
quantize_(mod, int8_dynamic_activation_int4_weight(), set_inductor_config=False)
if not TORCH_VERSION_AT_LEAST_2_5:
unwrap_tensor_subclass(mod)


# TODO: use this to reduce the number of tests
TENSOR_SUBCLASS_APIS = [
_int8wo_api,
Expand Down Expand Up @@ -781,7 +788,7 @@ def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
Expand Down Expand Up @@ -973,11 +980,11 @@ def test_weight_only_groupwise_embedding_quant(self):
group_size = 64
m = nn.Embedding(4096, 128)
input = torch.randint(0, 4096, (1, 6))

quantize_(m, int8_weight_only(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding))
y_q = m(input)
y_ref = m.weight.dequantize()[input]

sqnr = compute_error(y_ref, y_q)

self.assertGreater(sqnr, 45.0)
Expand Down Expand Up @@ -1486,22 +1493,22 @@ def forward(self, x):



@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
@unittest.skip("AOTI tests are failing right now")
class TestAOTI(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_aoti(self, api, test_device, test_dtype):
if not TORCH_VERSION_AT_LEAST_2_4:
self.skipTest("aoti compatibility requires 2.4+.")

print(f"TestAOTI: {api}, {test_device}, {test_dtype}")
logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")
if api is change_linear_weights_to_int8_dqtensors and test_device == "cuda":
self.skipTest(f"{api} in {test_device} is not support for aoti compilation yet")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("Need CUDA and SM80+ available.")


logger.info(f"TestAOTI: {api}, {test_device}, {test_dtype}")

m, k, n = 32, 64, 32

Expand All @@ -1525,29 +1532,30 @@ def forward(self, x):
ref_f = model(x)

api(model)
unwrap_tensor_subclass(model)

# running model
model(x)

# make sure it compiles
torch._inductor.config.mixed_mm_choice = "triton"

example_inputs = (x,)
torch._export.aot_compile(model, example_inputs)
torch._inductor.aoti_compile_and_package(torch.export.export(model, example_inputs), example_inputs)


@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "requires 2.5+.")
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
class TestExport(unittest.TestCase):
@parameterized.expand(
list(itertools.product(TENSOR_SUBCLASS_APIS, COMMON_DEVICES, COMMON_DTYPES)),
list(itertools.product(TENSOR_SUBCLASS_APIS + [_int8da_int4w_api], COMMON_DEVICES, COMMON_DTYPES)),
)
@run_supported_device_dtype
def test_export(self, api, test_device, test_dtype):
if not TORCH_VERSION_AT_LEAST_2_4:
self.skipTest("aoti compatibility requires 2.4+.")
if test_device == "cuda" and torch.cuda.is_available() and test_dtype == torch.bfloat16 and torch.cuda.get_device_capability() < (8, 0):
self.skipTest("Need CUDA and SM80+ available.")

logger.info(f"TestExport: {api}, {test_device}, {test_dtype}")

if test_dtype != torch.bfloat16:
self.skipTest(f"{api} in {test_dtype} is not support for aoti compilation yet")

m, k, n = 32, 64, 32

class test_model(nn.Module):
Expand All @@ -1570,6 +1578,7 @@ def forward(self, x):
ref_f = model(x)

api(model)
unwrap_tensor_subclass(model)

# running model
ref = model(x)
Expand All @@ -1585,10 +1594,11 @@ def forward(self, x):
model = torch._export.capture_pre_autograd_graph(model, example_inputs)
after_export = model(x)
self.assertTrue(torch.equal(after_export, ref))
if api is _int8da_int8w_api:
if api is _int8da_int4w_api:
targets = [n.target for n in model.graph.nodes]
self.assertTrue(torch.ops.quant.choose_qparams_affine.default in targets)
self.assertTrue(torch.ops.quant.quantize_affine.default in targets)
self.assertFalse(torch.ops.aten.narrow.default in targets)



Expand Down
13 changes: 8 additions & 5 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,13 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
self.zero_point_domain,
output_dtype=output_dtype,
)
# need to return to original shape if tensor was padded
# in preprocessing
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
if isinstance(self._layout, TensorCoreTiledLayout):
# need to return to original shape if tensor was padded
# in preprocessing
# TODO: we could add an API for this if there are more use cases
# (e.g. dequant_post_process) in TensorImpl or Layout
for dim, dim_size in enumerate(self.shape):
dq = dq.narrow(dim, 0, dim_size)
return dq

@staticmethod
Expand Down Expand Up @@ -1698,7 +1701,7 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias):
output_dtype = input_tensor.dtype
y = y.to(output_dtype)
if bias is not None:
y += bias
y = y + bias
return y


Expand Down
6 changes: 6 additions & 0 deletions torchao/dtypes/nf4tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from torch._prims_common import make_contiguous_strides_for
from torch.distributed.device_mesh import DeviceMesh

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
Expand Down Expand Up @@ -1043,3 +1045,7 @@ def nf4_constructor(
quantized_data,
nf4,
)


if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([NF4Tensor])
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
* This file is generated by gen_metal_shader_lib.py
*/
#ifdef ATEN
#ifdef USE_ATEN
using namespace at::native::mps;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/kernels/mps/src/lowbit.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <fstream>
#include <sstream>

#ifdef ATEN
#ifdef USE_ATEN
#include <ATen/native/mps/OperationUtils.h>
using namespace at::native::mps;
inline void finalize_block(MPSStream* mpsStream) {}
Expand Down
2 changes: 1 addition & 1 deletion torchao/experimental/ops/mps/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
name="torchao_mps_ops",
sources=["register.mm"],
include_dirs=[os.getenv("TORCHAO_ROOT")],
extra_compile_args=["-DATEN=1"],
extra_compile_args=["-DUSE_ATEN=1"],
),
],
cmdclass={"build_ext": BuildExtension},
Expand Down
6 changes: 2 additions & 4 deletions torchao/float8/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Float8LinearConfig,
ScalingType,
)
from torchao.float8.float8_linear import Float8Linear, WeightWithDelayedFloat8CastTensor
from torchao.float8.float8_linear import WeightWithDelayedFloat8CastTensor
from torchao.float8.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand All @@ -23,12 +23,10 @@
LinearMMConfig,
ScaledMMConfig,
)
from torchao.float8.inference import Float8MMConfig
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp

from torchao.float8.inference import Float8MMConfig
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


if TORCH_VERSION_AT_LEAST_2_5:
# Needed to load Float8Tensor with weights_only = True
from torch.serialization import add_safe_globals
Expand Down
37 changes: 23 additions & 14 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

logger: logging.Logger = logging.getLogger()


class ScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"
Expand Down Expand Up @@ -71,8 +72,10 @@ def __post_init__(self):
self.static_scale is not None
), "static_scale must be specified for static scaling"
if self.scaling_granularity is ScalingGranularity.AXISWISE:
assert self.scaling_type is ScalingType.DYNAMIC, \
"only dynamic scaling type is supported for axiswise scaling granularity"
assert (
self.scaling_type is ScalingType.DYNAMIC
), "only dynamic scaling type is supported for axiswise scaling granularity"


@dataclass(frozen=True)
class DelayedScalingConfig:
Expand Down Expand Up @@ -226,7 +229,7 @@ class Float8LinearConfig:

# If True, we only use fp8-all-gather to reduce the communication cost.
# The gemm computation is still done in the original precision.
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# `cast_config_weight` is used to decide how to cast the weight to fp8,
# other casting configs will be ignored.
use_fp8_all_gather_only: bool = False

Expand All @@ -238,16 +241,23 @@ def __post_init__(self):
# to work.
# Source of hack: https://stackoverflow.com/a/65959419/
if self.cast_config_input_for_grad_weight is None:
object.__setattr__(self, "cast_config_input_for_grad_weight", self.cast_config_input)
object.__setattr__(
self, "cast_config_input_for_grad_weight", self.cast_config_input
)
if self.cast_config_weight_for_grad_input is None:
object.__setattr__(self, "cast_config_weight_for_grad_input", self.cast_config_weight)
object.__setattr__(
self, "cast_config_weight_for_grad_input", self.cast_config_weight
)
if self.cast_config_grad_output_for_grad_weight is None:
object.__setattr__(self, "cast_config_grad_output_for_grad_weight", self.cast_config_grad_output)
object.__setattr__(
self,
"cast_config_grad_output_for_grad_weight",
self.cast_config_grad_output,
)

# float8 all-gather only supports tensorwise, in the future may support blockwise
if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE:
assert not self.enable_fsdp_float8_all_gather, \
f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
assert not self.enable_fsdp_float8_all_gather, f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"

# save some characters in the compatibility checks below
cc_i = self.cast_config_input
Expand All @@ -266,12 +276,13 @@ def __post_init__(self):
):
is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED
is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED
assert is_disabled_1 == is_disabled_2, \
f"incompatible operand precision for {gemm_name}"

assert (
is_disabled_1 == is_disabled_2
), f"incompatible operand precision for {gemm_name}"

if self.use_fp8_all_gather_only:
assert self.enable_fsdp_float8_all_gather, "use_fp8_all_gather_only requires enable_fsdp_float8_all_gather to be True"

# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
if (
self.enable_fsdp_float8_all_gather
Expand All @@ -280,7 +291,6 @@ def __post_init__(self):
logger.warning(
"When using FSDP, it's recommended to enable config.force_recompute_fp8_weight_in_bwd."
)



# Pre-made recipes for common configurations
Expand Down Expand Up @@ -328,7 +338,6 @@ def recipe_name_to_linear_config(
)

elif recipe_name is Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP:

# lw's recipe for a modification on all-axiswise:
#
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
Expand Down
1 change: 0 additions & 1 deletion torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from typing import Any

import torch

from fairscale.nn.model_parallel.initialize import get_model_parallel_group

# from float8_tensor import Float8Tensor
Expand Down
Loading

0 comments on commit 2ea7b4c

Please sign in to comment.