Skip to content

Commit

Permalink
Add support for bfloat16 atomic adds in fbcode (pytorch#143629)
Browse files Browse the repository at this point in the history
Reland pytorch#141857 and fallback on A100 which doesn't have bfloat16 atomic add instrs.

Pull Request resolved: pytorch#143629
Approved by: https://github.com/eellison
  • Loading branch information
mlazos authored and pytorchmergebot committed Dec 20, 2024
1 parent a3b04d4 commit 8960cb5
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 6 deletions.
51 changes: 51 additions & 0 deletions test/inductor/test_cuda_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
SM80OrLater,
SM90OrLater,
TEST_MULTIGPU,
)
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -1522,6 +1523,56 @@ def foo(inp):
foo_c = torch.compile(foo)
torch.testing.assert_allclose(foo(inp), foo_c(inp))

@unittest.skipIf(
not config.is_fbcode(),
"bfloat16 atomic add is only supported in fbcode today #97016",
)
@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
def test_atomic_add_bfloat16(self):
def f(x, y):
return torch.index_select(x, 0, y)

x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()

out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check("tl.atomic_add")
fc.run(bw_code)

self.assertEqual(f(x_ref, y_ref), out)

@skipCUDAIf(
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
)
@unittest.skipIf(
config.is_fbcode(),
"bfloat16 atomic add is supported in fbcode, so we won't fallback",
)
def test_index_add_fallback(self):
def f(x, y):
return torch.index_select(x, 0, y)

x = torch.randn(
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
)
y = torch.ones(713268, dtype=torch.int64, device="cuda")
x_ref = x.clone().detach().requires_grad_(True)
y_ref = y.clone().detach()

out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
fc = FileCheck()
fc.check("aten.index_add")
fc.run(bw_code)

self.assertEqual(f(x_ref, y_ref), out)

@requires_multigpu()
def test_not_initializing_wrong_device(self):
device_stats = torch.cuda.memory_stats("cuda:0")
Expand Down
23 changes: 22 additions & 1 deletion torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
)
from torch._decomp.decompositions import (
_grid_sampler_2d as decomp_grid_sampler_2d,
_index_add,
pw_cast_for_opmath,
)
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch._dynamo.utils import counters
from torch._environment import is_fbcode
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.utils import pad_listlike
from torch._prims_common import (
Expand Down Expand Up @@ -48,6 +50,7 @@
inductor_decompositions = get_decompositions(
[
aten._adaptive_avg_pool2d_backward,
aten.index_select,
aten.addmv,
aten.arange,
aten.bitwise_and_,
Expand All @@ -58,7 +61,6 @@
aten.flip,
aten.gelu,
aten.hardtanh,
aten.index_select,
aten.lcm,
aten.leaky_relu,
aten.linalg_vector_norm,
Expand Down Expand Up @@ -101,6 +103,7 @@
aten._softmax_backward_data,
aten.clamp_max,
aten.clamp_min,
aten.index_add, # we conditionally call this decomp
aten.glu, # inductor lowers this directly
aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass
Expand Down Expand Up @@ -173,6 +176,24 @@ def full(
return NotImplemented


@register_decomposition([aten.index_add])
def index_add(
x: torch.Tensor,
dim: int,
index: torch.Tensor,
tensor: torch.Tensor,
*,
alpha: torch.types.Number = 1,
) -> torch.Tensor:
# If we are not in fbcode and dtype is bfloat16
# fallback to index_add kernel
# see https://github.com/pytorch/pytorch/issues/137425 for details
if not is_fbcode() and x.dtype == torch.bfloat16:
return NotImplemented
else:
return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)


# Not really sure how to put this into the main library. PrimTorch wants
# empty_permuted to go to the prim, and typically users don't really want
# to decompose to empty_strided (but inductor is OK with it, because we are
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,7 @@ def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) ->
), f"{target} is not an OpOverload"
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target)
make_fallback(target, warn=False, override_decomp=True)
elif config.implicit_fallbacks:
error = (
MissingOperatorWithDecomp
Expand Down
8 changes: 6 additions & 2 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
FALLBACK_ALLOW_LIST = OrderedSet(
[
"torchvision::roi_align",
"aten::index_add",
]
)

Expand Down Expand Up @@ -1940,8 +1941,10 @@ def check_skip_condition(node, parent, is_output):
return check_skip_condition(node, node, is_output=True)


def make_fallback(op, layout_constraint=None, warn=True):
assert op not in decompositions, f"both a fallback and a decomp for same op: {op}"
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
assert (
op not in decompositions or override_decomp
), f"both a fallback and a decomp for same op: {op}"
if (
warn
and bool(os.getenv("CI"))
Expand All @@ -1951,6 +1954,7 @@ def make_fallback(op, layout_constraint=None, warn=True):
config.fallback_random
and op in torch._decomp.decompositions_for_rng.extra_random_decomps
)
and not override_decomp
):
# Note: 'warn' is holdover from when this was a warning, but for ops that previously
# set warn=False we do not want a CI error.
Expand Down
14 changes: 12 additions & 2 deletions torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2017,8 +2017,18 @@ def device_need_guard(device: str):


def needs_fallback_due_to_atomic_add_limitations(dtype):
# tl.atomic_add does NOT support the following types
return dtype in (torch.int64, torch.bool, torch.bfloat16)
# tl.atomic add has bfloat16 support in fbcode
# but not in OSS https://github.com/pytorch/pytorch/issues/97016
# we will fallback until the code is upstreamed to OSS
if (
config.is_fbcode()
and dtype == torch.bfloat16
and torch.cuda.is_available()
and torch.cuda.get_device_capability() >= (9, 0)
):
return False
else:
return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16])


def use_scatter_fallback(
Expand Down

0 comments on commit 8960cb5

Please sign in to comment.