Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove old torchao imports, require 0.7.0+ #2513

Merged
merged 1 commit into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/recipes/test_qat_lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
safe_torch_load,
SHARD_FNAME,
)
from torchtune.training.quantization import _torchao_0_7_supported


class TestQATLoRAFinetuneDistributedRecipe:
Expand Down Expand Up @@ -63,7 +62,6 @@ def _fetch_expected_loss_values(self, model_type):
"micro_batch_size, gradient_accumulation_steps, should_compile",
[(4, 1, True), (1, 4, False)],
)
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
def test_loss(
self,
micro_batch_size,
Expand Down Expand Up @@ -116,7 +114,6 @@ def test_loss(
("llama3/8B_qat_lora", "llama3", "tune", False),
],
)
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
def test_training_state_on_resume(
self,
config,
Expand Down Expand Up @@ -217,7 +214,6 @@ def test_training_state_on_resume(
],
)
@gpu_test(gpu_count=2)
@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
def test_save_and_load_merged_weights(
self, recipe_config, model_type, ckpt_type, tmpdir, monkeypatch
):
Expand Down
2 changes: 0 additions & 2 deletions tests/torchtune/modules/peft/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torchtune import training
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook
from torchtune.modules.peft import LoRALinear, QATLoRALinear
from torchtune.training.quantization import _torchao_0_7_supported
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -237,7 +236,6 @@ def test_quantized_state_dict(self, dtype):
lora_linear.weight.quantized_data, lora_linear_reload.weight.quantized_data
)

@pytest.mark.skipif(not _torchao_0_7_supported, reason="needs torchao 0.7+")
def test_qat_lora_forward(self, inputs, lora_linear, out_dim) -> None:
lora_linear = lora_linear(use_bias=True, dtype=torch.float32)
qat_lora_linear = QATLoRALinear.from_lora_linear(lora_linear)
Expand Down
47 changes: 14 additions & 33 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,44 +7,26 @@
from typing import Callable, Optional

from torch import nn
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear


try:
# torchao 0.7+
from torchao.dtypes import TensorCoreTiledLayout
except ImportError:
# torchao 0.6 and before
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
from torchao.dtypes import TensorCoreTiledLayout

from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int4_weight,
quantize_,
)

try:
# torchao 0.7+
from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)
except ImportError:
# torchao 0.6 and before
from torchao.quantization.prototype.qat import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat import (
Int4WeightOnlyQATQuantizer,
Int8DynActInt4WeightQATQuantizer,
)
from torchao.quantization.qat.linear import (
disable_4w_fake_quant,
disable_8da4w_fake_quant,
enable_4w_fake_quant,
enable_8da4w_fake_quant,
)
from torchtune.modules.peft.lora import LoRALinear, QATLoRALinear


__all__ = [
Expand All @@ -58,11 +40,10 @@
]


_torchao_0_7_supported = True
try:
from torchao.quantization import qat # noqa: F401
except ImportError:
_torchao_0_7_supported = False
except ImportError as e:
raise ValueError("Need torchao version 0.7.0+") from e

_quantizer_to_mode = {}
_quantizer_mode_to_disable_fake_quant = {}
Expand Down