Skip to content

Commit

Permalink
[NVFuser] Upstream push 0714
Browse files Browse the repository at this point in the history
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Code changes includes:

- codegen improvements:
  1. Indexing refactor -> Remove reference tensor in predicate indexing logic
  2. MMA Rfactor support for cross-warp and cross-CTA split on K dimension
  3. Grouping grid allreduces across iterations
  4. Swizzle op formulation for non-affine swizzles
  5. Use scheduler_utils to cache inputs and outputs in schedulePointwise
- scheduler refactor
  1. New compute at interface
- transformation propagation refactor on MaxInfoSpanningTree
  1. Added sibling path that is required to generate consistent replay for some cases where `MaxInfoSpanningTree` is used with a selector.
  2. Optimization to skip Transform propagator
  3. SpanningTreePrinter for debugging
- parser update
  1. Fixes `div`
  2. Added `_to_copy`
  3. Broadcast in dim with expand to support expanding to concrete size
  4. Dropout prob extremal patch
- executor patch on caching strides for output allocation

Squashed commits to WAR github API
Commits that's actually in this PR from the devel branch:

```
3b87896 Fix allocation of work buffers and `fused_reduction::ParallelReduce` with unswitch (csarofeen#1818)
4cae122 schedulePointwise cleanup: - computeAt + InlinePropagator (csarofeen#1815)
3df9742 Use scheduler_utils to cache inputs and outputs in schedulePointwise (csarofeen#1811)
03180aa improve broadcast resolution (csarofeen#1792)
bee6c69 bug fix (csarofeen#1819)
4413c8f Support PYTORCH_NVFUSER_DUMP=transform_propagator (csarofeen#1812)
de6b7ca Fix negative position in InlinePropagator (csarofeen#1813)
10a996c Remove redundant check in schedulePointwise (csarofeen#1810)
acd5ed4 Swizzle op formulation for non-affine swizzles (csarofeen#1441)
3ed8330 Kernel args patch to show zero_init buffer (csarofeen#1809)
037a75a Dropout prob extremal patch (csarofeen#1804)
282c429 spam nvrtc options (csarofeen#1783)
3ba6a5f Broadcast in dim with expand (csarofeen#1794)
fd4be12 remove dead indexing code (csarofeen#1806)
fa4e6a4 Check siblings in getMaxPosAll (csarofeen#1805)
025c840 Grouping grid allreduces across iterations (csarofeen#1755)
37c579e Temporarily disable test requring large shared memory. (csarofeen#1802)
5f375d0 More cleanup on InlinePropagator (csarofeen#1800)
8d384da Indexing refactor stage 2 : Remove reference tensor in predicate indexing logic (csarofeen#1784)
f008140 MMA Rfactor support for cross-warp and cross-CTA split on K dimension (csarofeen#1554)
76b3cca Add parsing support for `_to_copy` to handle AMP casts. (csarofeen#1756)
ef04f6c Coding style cleanups (csarofeen#1798)
38c7f3c InlinePropagator please don't replay (csarofeen#1797)
3f2c263 validateDomain in TransformPropagator (csarofeen#1796)
c077085 Use TransformPropagatorWithCheck in many tests (csarofeen#1795)
d0d0908 Some further cleanup for the new computeAt interface (csarofeen#1793)
45f5203 Fix TransformReplay::getMatchedLeafPosWithoutReplay* (csarofeen#1791)
28cbaf9 New compute at interface (csarofeen#1743)
635ebfc Add SpanningTreePrinter (csarofeen#1786)
59f3c32 Output allocate patch (csarofeen#1790)
fe93bf5 Transform propagator skip replay when possible (csarofeen#1782)
ebf23a5 Fix isIntegralType error msg (csarofeen#1789)
0c82ecf Disable register reuse across serial broadcast ops (csarofeen#1787)
33a824d Adding sibling path for MaxInfoSpanningTree (csarofeen#1776)
86f46aa Fix div(Val, TensorView) (csarofeen#1778)
d3de227 Fix FusionMaxRootDomainInfoSpanningTreePrintTwice_CUDA (csarofeen#1781)
ecc7a87 Extend mma dimension and layout checking to support strided batched matmul and tensor contractions (csarofeen#1761)
```

[ghstack-poisoned]
  • Loading branch information
jjsjann123 committed Jul 21, 2022
1 parent 2fb2740 commit f686d82
Show file tree
Hide file tree
Showing 76 changed files with 6,375 additions and 2,049 deletions.
3 changes: 2 additions & 1 deletion build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/helpers.cu",
"torch/csrc/jit/codegen/cuda/runtime/index_utils.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensorcore.cu",
"torch/csrc/jit/codegen/cuda/runtime/swizzle.cu",
"torch/csrc/jit/codegen/cuda/runtime/memory.cu",
"torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu",
"torch/csrc/jit/codegen/cuda/runtime/tensor.cu",
Expand Down Expand Up @@ -643,6 +644,7 @@ libtorch_cuda_core_sources = [
"torch/csrc/autograd/functions/comm.cpp",
"torch/csrc/jit/codegen/cuda/arith.cpp",
"torch/csrc/jit/codegen/cuda/compute_at.cpp",
"torch/csrc/jit/codegen/cuda/inline_propagator.cpp",
"torch/csrc/jit/codegen/cuda/compute_at_map.cpp",
"torch/csrc/jit/codegen/cuda/codegen.cpp",
"torch/csrc/jit/codegen/cuda/contiguity.cpp",
Expand All @@ -658,7 +660,6 @@ libtorch_cuda_core_sources = [
"torch/csrc/jit/codegen/cuda/grouped_reduction.cpp",
"torch/csrc/jit/codegen/cuda/index_compute.cpp",
"torch/csrc/jit/codegen/cuda/lower_index_compute.cpp",
"torch/csrc/jit/codegen/cuda/index_reference_replay.cpp",
"torch/csrc/jit/codegen/cuda/instrumentation.cpp",
"torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp",
"torch/csrc/jit/codegen/cuda/ir_builder.cpp",
Expand Down
99 changes: 83 additions & 16 deletions test/test_jit_cuda_fuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def is_pre_volta():

TEST_BF16 = RUN_NVFUSER and torch.cuda.is_bf16_supported()

TEST_LARGE_TENSOR = RUN_NVFUSER
if RUN_NVFUSER:
torch.ones(1).cuda() # initialize cuda context
TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 12e9

class CudaFuserTestOptions():
def __init__(self):
self.old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
Expand Down Expand Up @@ -184,23 +189,27 @@ def tearDown(self):
self.cuda_fuser_options.restore()
super(TestCudaFuser, self).tearDown()

def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1):
torch.cuda.manual_seed_all(123)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
def _run_helper(self, jit_op, op, *args, check_stride=False, num_fusion=1, check_runs=1):
seed = 123
torch.cuda.manual_seed_all(seed)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(123)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]
for i in range(check_runs):
torch.cuda.manual_seed_all(seed + i)
jit_o = jit_op(*args)
torch.cuda.manual_seed_all(seed + i)
o = op(*args)

if type(jit_o) is torch.Tensor:
jit_o = [jit_o, ]
o = [o, ]

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())

for oo, jit_oo in zip(o, jit_o):
self.assertEqual(oo.dtype, jit_oo.dtype)
self.assertEqual(oo, jit_oo)
if check_stride:
self.assertEqual(oo.stride(), jit_oo.stride())
self.assertGraphContainsExactly(jit_op.graph_for(*args), FUSION_GUARD, num_fusion, consider_subgraphs=True)

def _run_training_helper(self, jit_op, op, grads, *args):
Expand Down Expand Up @@ -2563,13 +2572,14 @@ def t(x: torch.Tensor, p: float, train: bool):

self._run_helper(t_jit, t, x, 0.15, False)

@unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_dropout_train_nograd_fusion(self):
dtype = torch.float
device = "cuda"
x = torch.randn([10, 4, 8], dtype=dtype, device=device)
x = torch.randn([64, 128, 1024], dtype=dtype, device=device)

def t(x: torch.Tensor, p: float, train: bool):
o = torch.nn.functional.dropout(x, p, training=train)
Expand All @@ -2578,7 +2588,8 @@ def t(x: torch.Tensor, p: float, train: bool):

t_jit = torch.jit.script(t)

self._run_helper(t_jit, t, x, 0.0, True)
self._run_helper(t_jit, t, x, 0.0, True, check_runs=20)
self._run_helper(t_jit, t, x, 1.0, True, check_runs=20)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down Expand Up @@ -4391,6 +4402,33 @@ def t(x):
t_jit = torch.jit.script(t)
self._run_helper(t_jit, t, x)


@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_to_copy(self):
x = torch.randn(4, 2, device="cuda")

with nvfuser_singleton_fusion(True):
def t(x, dtype : torch.dtype):
o = torch.ops.aten._to_copy(x, dtype=dtype)
return o

t.__disable_jit_function_caching__ = True

t_jit = torch.jit.script(t)
for dtype in [torch.float16, torch.bool, torch.float64]:
self._run_helper(t_jit, t, x, dtype)

def t_none(x):
with torch.jit.strict_fusion():
o = torch.ops.aten._to_copy(x, dtype=None)
return o

t_jit_none = torch.jit.script(t_none)
self._run_helper(t_jit_none, t_none, x)


@unittest.skipIf(ALIAS_TEST_DISABLED, "skipping this test since reshape is disabled now")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
Expand Down Expand Up @@ -4752,6 +4790,35 @@ def t(x):
jit_t = torch.jit.script(t)
self._run_helper(jit_t, t, x)

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
def test_issue_1785(self):
class Fusion(torch.nn.Module):
def __init__(self):
super(Fusion, self).__init__()

def forward(self, x, a, b):
out = torch.mul(x.unsqueeze(-1), a)
out = out + b
return out

x = torch.randn(1024, 192, 3, device='cuda')
a = torch.randn(3, 128, device='cuda')
b = torch.randn(3, 128, device='cuda')

model = Fusion()
jit_model = torch.jit.script(model)

with torch.jit.fuser('fuser2'):
for _ in range(4):
out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)

out_ref = model(x, a, b)
out_jit = jit_model(x, a, b)
self.assertTrue(self._compare("comparing output failed", out_ref, out_jit, 1e-5))

@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING,
"Requires fusion optimization pass to be effective")
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/codegen/cuda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ There're a few debug dump that could be turned on via environment variables. Loo
1. `dump_eff_bandwidth`: print out effective bandwidth of each generated kernel. This naively measure the kernel time divided by I/O buffer size and is a good/simple metric of performance for bandwidth bound kernels
2. `cuda_kernel`: print out generated cuda kernels
3. `launch_param`: print out launch config of generated kernels
4. `print_args`: print out input output tensors of executed codegen kernels
4. `kernel_args`: print out input/output/buffer tensors of all executed codegen kernels, note that for buffers, we indicate whether they are zero-initialized, which hints on an extra kernel to fill the tensor before codegen kernels.

### FAQs

Expand Down
25 changes: 24 additions & 1 deletion torch/csrc/jit/codegen/cuda/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ TensorView* unaryOp(
}

NVFUSER_DEFINE_UNARY_OP(set, Set)
NVFUSER_DEFINE_UNARY_OP(randlike, RandLike)
NVFUSER_DEFINE_UNARY_OP(ceil, Ceil)
NVFUSER_DEFINE_UNARY_OP(floor, Floor)
NVFUSER_DEFINE_UNARY_OP(frac, Frac)
Expand All @@ -469,6 +468,30 @@ NVFUSER_DEFINE_UNARY_OP(silu, Silu)
NVFUSER_DEFINE_UNARY_OP(trunc, Trunc)
#undef NVFUSER_DEFINE_UNARY_OP

Val* randlike(Val* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

TensorView* randlike(TensorView* v) {
TORCH_CHECK(
isFloatingPointType(v->dtype()),
"input must have floating point type, but got ",
v->dtype());
auto rand_vals = unaryOp(UnaryOpType::RandLike, v);
return where(
eq(rand_vals, IrBuilder::create<Double>(1.0)),
IrBuilder::create<Double>(0.0),
rand_vals);
}

Val* bitwise_not(Val* v) {
TORCH_CHECK(
isIntegralType(v->dtype()) || v->dtype() == DataType::Bool,
Expand Down
Loading

0 comments on commit f686d82

Please sign in to comment.