Skip to content

Commit

Permalink
Refactored files
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 7, 2024
1 parent 6fd77d5 commit 79a7f52
Show file tree
Hide file tree
Showing 18 changed files with 1,795 additions and 1,511 deletions.
2 changes: 1 addition & 1 deletion test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_to_device(self, apply_quant):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_register_new_dispatch(self):
from torchao.dtypes.affine_quantized_tensor import (
from torchao.dtypes.affine_quantized_tensor_ops import (
register_aqt_quantized_linear_dispatch,
deregister_aqt_quantized_linear_dispatch,
)
Expand Down
15 changes: 11 additions & 4 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .nf4tensor import NF4Tensor, to_nf4
# from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor
from .uint4 import UInt4Tensor
from .uintx import UInt4Tensor
from .affine_quantized_tensor import (
AffineQuantizedTensor,
to_affine_quantized_intx,
Expand All @@ -9,15 +9,22 @@
to_affine_quantized_fpx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
PlainAQTTensorImpl,
)
from .affine_quantized_tensor_ops import *
from .utils import (
Layout,
PlainLayout,
SemiSparseLayout,
TensorCoreTiledLayout,
)
from .floatx import (
Float8Layout,
Float8AQTTensorImpl,
)
from .uintx import (
SemiSparseLayout,
TensorCoreTiledLayout,
MarlinSparseLayout,
)

__all__ = [
"NF4Tensor",
"to_nf4",
Expand Down
Loading

0 comments on commit 79a7f52

Please sign in to comment.