diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index ab327ee28f3636..53115b6e304383 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -28,6 +28,7 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, SM80OrLater, + SM90OrLater, TEST_MULTIGPU, ) from torch.testing._internal.common_utils import ( @@ -1522,6 +1523,56 @@ def foo(inp): foo_c = torch.compile(foo) torch.testing.assert_allclose(foo(inp), foo_c(inp)) + @unittest.skipIf( + not config.is_fbcode(), + "bfloat16 atomic add is only supported in fbcode today #97016", + ) + @skipCUDAIf( + not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90" + ) + def test_atomic_add_bfloat16(self): + def f(x, y): + return torch.index_select(x, 0, y) + + x = torch.randn( + 2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + y = torch.ones(713268, dtype=torch.int64, device="cuda") + x_ref = x.clone().detach().requires_grad_(True) + y_ref = y.clone().detach() + + out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y)) + fc = FileCheck() + fc.check("tl.atomic_add") + fc.run(bw_code) + + self.assertEqual(f(x_ref, y_ref), out) + + @skipCUDAIf( + not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90" + ) + @unittest.skipIf( + config.is_fbcode(), + "bfloat16 atomic add is supported in fbcode, so we won't fallback", + ) + def test_index_add_fallback(self): + def f(x, y): + return torch.index_select(x, 0, y) + + x = torch.randn( + 2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) + y = torch.ones(713268, dtype=torch.int64, device="cuda") + x_ref = x.clone().detach().requires_grad_(True) + y_ref = y.clone().detach() + + out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y)) + fc = FileCheck() + fc.check("aten.index_add") + fc.run(bw_code) + + self.assertEqual(f(x_ref, y_ref), out) + @requires_multigpu() def test_not_initializing_wrong_device(self): device_stats = torch.cuda.memory_stats("cuda:0") diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 4701f1bdafb471..4a819e5f84ee75 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -17,10 +17,12 @@ ) from torch._decomp.decompositions import ( _grid_sampler_2d as decomp_grid_sampler_2d, + _index_add, pw_cast_for_opmath, ) from torch._decomp.decompositions_for_rng import extra_random_decomps from torch._dynamo.utils import counters +from torch._environment import is_fbcode from torch._higher_order_ops.out_dtype import out_dtype from torch._inductor.utils import pad_listlike from torch._prims_common import ( @@ -48,6 +50,7 @@ inductor_decompositions = get_decompositions( [ aten._adaptive_avg_pool2d_backward, + aten.index_select, aten.addmv, aten.arange, aten.bitwise_and_, @@ -58,7 +61,6 @@ aten.flip, aten.gelu, aten.hardtanh, - aten.index_select, aten.lcm, aten.leaky_relu, aten.linalg_vector_norm, @@ -101,6 +103,7 @@ aten._softmax_backward_data, aten.clamp_max, aten.clamp_min, + aten.index_add, # we conditionally call this decomp aten.glu, # inductor lowers this directly aten.select_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass aten.slice_scatter, # need to be in the ATen graph in order for it to work with the re-inplacing pass @@ -173,6 +176,24 @@ def full( return NotImplemented +@register_decomposition([aten.index_add]) +def index_add( + x: torch.Tensor, + dim: int, + index: torch.Tensor, + tensor: torch.Tensor, + *, + alpha: torch.types.Number = 1, +) -> torch.Tensor: + # If we are not in fbcode and dtype is bfloat16 + # fallback to index_add kernel + # see https://github.com/pytorch/pytorch/issues/137425 for details + if not is_fbcode() and x.dtype == torch.bfloat16: + return NotImplemented + else: + return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha) + + # Not really sure how to put this into the main library. PrimTorch wants # empty_permuted to go to the prim, and typically users don't really want # to decompose to empty_strided (but inductor is OK with it, because we are diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2b46b1bffd51a1..7f8f68261911c8 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1090,7 +1090,7 @@ def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> ), f"{target} is not an OpOverload" base_name = target.name().split(".")[0] if base_name in FALLBACK_ALLOW_LIST: - make_fallback(target) + make_fallback(target, warn=False, override_decomp=True) elif config.implicit_fallbacks: error = ( MissingOperatorWithDecomp diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 30dbb456e36769..114b4da21dbf8f 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -82,6 +82,7 @@ FALLBACK_ALLOW_LIST = OrderedSet( [ "torchvision::roi_align", + "aten::index_add", ] ) @@ -1940,8 +1941,10 @@ def check_skip_condition(node, parent, is_output): return check_skip_condition(node, node, is_output=True) -def make_fallback(op, layout_constraint=None, warn=True): - assert op not in decompositions, f"both a fallback and a decomp for same op: {op}" +def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): + assert ( + op not in decompositions or override_decomp + ), f"both a fallback and a decomp for same op: {op}" if ( warn and bool(os.getenv("CI")) @@ -1951,6 +1954,7 @@ def make_fallback(op, layout_constraint=None, warn=True): config.fallback_random and op in torch._decomp.decompositions_for_rng.extra_random_decomps ) + and not override_decomp ): # Note: 'warn' is holdover from when this was a warning, but for ops that previously # set warn=False we do not want a CI error. diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 6d6e13ec4f2117..d0dd58d536dca3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2017,8 +2017,18 @@ def device_need_guard(device: str): def needs_fallback_due_to_atomic_add_limitations(dtype): - # tl.atomic_add does NOT support the following types - return dtype in (torch.int64, torch.bool, torch.bfloat16) + # tl.atomic add has bfloat16 support in fbcode + # but not in OSS https://github.com/pytorch/pytorch/issues/97016 + # we will fallback until the code is upstreamed to OSS + if ( + config.is_fbcode() + and dtype == torch.bfloat16 + and torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + ): + return False + else: + return dtype in OrderedSet([torch.int64, torch.bool, torch.bfloat16]) def use_scatter_fallback(