Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
ebsmothers committed Oct 11, 2024
1 parent e6b19b2 commit 024b812
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 28 deletions.
4 changes: 2 additions & 2 deletions torchtune/training/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from typing import Callable, Optional

from torchtune.utils._import_guard import _NEW_TENSOR_CORE_TILED_LAYOUT_API
from torchtune.utils._import_guard import _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API

if _NEW_TENSOR_CORE_TILED_LAYOUT_API:
if _USE_NEW_TENSOR_CORE_TILED_LAYOUT_API:
from torchao.dtypes import TensorCoreTiledLayout
else:
from torchao.dtypes import TensorCoreTiledLayoutType as TensorCoreTiledLayout
Expand Down
12 changes: 4 additions & 8 deletions torchtune/utils/_import_guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@
# LICENSE file in the root directory of this source tree.

import torch
from torchtune.utils._version import (
_get_torchao_version,
_is_fbcode,
_nightly_version_ge,
torch_version_ge,
)
import torchao
from torchtune.utils._version import _is_fbcode, _nightly_version_ge, torch_version_ge

# We can only use flex attention / BlockMask if torch version >= 2.5.0 and GPU is Turing / SM75 and above
_SUPPORTS_FLEX_ATTENTION = (
Expand All @@ -19,9 +15,9 @@
and torch.cuda.get_device_capability() >= (7, 5)
)

torchao_version = _get_torchao_version()
torchao_version = torchao.__version__

_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and (
_USE_NEW_TENSOR_CORE_TILED_LAYOUT_API = not _is_fbcode() and (
("dev" not in torchao_version and torchao_version >= "0.6.0")
or ("dev" in torchao_version and _nightly_version_ge(torchao_version, "2024-10-10"))
)
18 changes: 0 additions & 18 deletions torchtune/utils/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,9 @@
# LICENSE file in the root directory of this source tree.

from datetime import datetime
from typing import Optional

import torch

import torchao


def torch_version_ge(version: str) -> bool:
"""
Expand Down Expand Up @@ -47,18 +44,3 @@ def _nightly_version_ge(ao_version_str: str, date: str) -> bool:
ao_version_str.split("+")[0].split("dev")[1], "%Y%m%d"
)
return ao_datetime >= datetime.strptime(date, "%Y-%m-%d")


def _get_torchao_version() -> Optional[str]:
"""
Get torchao version.
Checks:
1) is_fbcode, then
2) torchao.__version__ (only defined for torchao >= 0.3.0), then
"""
if _is_fbcode():
return None
else:
return torchao.__version__

0 comments on commit 024b812

Please sign in to comment.